scratch/dns-server/record.go (view raw)
| 1 | package main |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "encoding/binary" |
| 6 | "fmt" |
| 7 | "net" |
| 8 | "strings" |
| 9 | ) |
| 10 | |
| 11 | type Record struct { |
| 12 | Name string |
| 13 | Type uint16 |
| 14 | Class uint16 |
| 15 | TTL uint32 |
| 16 | Data string |
| 17 | } |
| 18 | |
| 19 | func ReadRecord(r *bytes.Reader, packet []byte) (Record, error) { |
| 20 | name, err := readName(r, packet) |
| 21 | if err != nil { |
| 22 | return Record{}, err |
| 23 | } |
| 24 | |
| 25 | var rtype, class, rdlen uint16 |
| 26 | var ttl uint32 |
| 27 | _ = binary.Read(r, binary.BigEndian, &rtype) |
| 28 | _ = binary.Read(r, binary.BigEndian, &class) |
| 29 | _ = binary.Read(r, binary.BigEndian, &ttl) |
| 30 | _ = binary.Read(r, binary.BigEndian, &rdlen) |
| 31 | |
| 32 | var data string |
| 33 | switch rtype { |
| 34 | case 1: // A |
| 35 | var ip [4]byte |
| 36 | _, _ = r.Read(ip[:]) |
| 37 | data = fmt.Sprintf("%d.%d.%d.%d", |
| 38 | ip[0], ip[1], ip[2], ip[3]) |
| 39 | |
| 40 | default: |
| 41 | buf := make([]byte, rdlen) |
| 42 | _, _ = r.Read(buf) |
| 43 | data = fmt.Sprintf("%x", buf) |
| 44 | } |
| 45 | |
| 46 | return Record{ |
| 47 | Name: name, |
| 48 | Type: rtype, |
| 49 | Class: class, |
| 50 | TTL: ttl, |
| 51 | Data: data, |
| 52 | }, nil |
| 53 | } |
| 54 | |
| 55 | func (r Record) Write(b *bytes.Buffer) (int, error) { |
| 56 | start := b.Len() |
| 57 | switch r.Type { |
| 58 | case 1: // A |
| 59 | _ = writeName(b, r.Name) |
| 60 | _ = binary.Write(b, binary.BigEndian, r.Type) |
| 61 | _ = binary.Write(b, binary.BigEndian, r.Class) |
| 62 | _ = binary.Write(b, binary.BigEndian, r.TTL) |
| 63 | _ = binary.Write(b, binary.BigEndian, uint16(4)) |
| 64 | |
| 65 | ip := net.ParseIP(r.Data).To4() |
| 66 | if ip == nil { |
| 67 | return 0, fmt.Errorf("invalid IPv4 address: %s", r.Data) |
| 68 | } |
| 69 | |
| 70 | _, _ = b.Write(ip) |
| 71 | |
| 72 | default: |
| 73 | fmt.Printf("Skipping record: %+v\n", r) |
| 74 | } |
| 75 | |
| 76 | return b.Len() - start, nil |
| 77 | } |
| 78 | |
| 79 | func readName(r *bytes.Reader, packet []byte) (string, error) { |
| 80 | var labels []string |
| 81 | for { |
| 82 | length, err := r.ReadByte() |
| 83 | if err != nil { |
| 84 | return "", err |
| 85 | } |
| 86 | if length == 0 { |
| 87 | break |
| 88 | } |
| 89 | // pointer: top two bits set (0xC0) |
| 90 | if length&0xC0 == 0xC0 { |
| 91 | low, err := r.ReadByte() |
| 92 | if err != nil { |
| 93 | return "", err |
| 94 | } |
| 95 | offset := int(uint16(length&0x3F)<<8 | uint16(low)) |
| 96 | sub := bytes.NewReader(packet[offset:]) |
| 97 | name, err := readName(sub, packet) |
| 98 | if err != nil { |
| 99 | return "", err |
| 100 | } |
| 101 | labels = append(labels, name) |
| 102 | break // pointer always ends the name |
| 103 | } |
| 104 | buf := make([]byte, length) |
| 105 | if _, err := r.Read(buf); err != nil { |
| 106 | return "", err |
| 107 | } |
| 108 | labels = append(labels, string(buf)) |
| 109 | } |
| 110 | return strings.Join(labels, "."), nil |
| 111 | } |
| 112 | |
| 113 | // TODO: wrap the Buffer, to have the len == 512 guard |
| 114 | func writeName(w *bytes.Buffer, qname string) error { |
| 115 | for label := range strings.SplitSeq(qname, ".") { |
| 116 | llen := len(label) |
| 117 | if llen > 0x3f { |
| 118 | return fmt.Errorf("single label exceeds 63 characters of length") |
| 119 | } |
| 120 | _ = w.WriteByte(byte(llen)) |
| 121 | _, _ = w.Write([]byte(label)) |
| 122 | } |
| 123 | _ = w.WriteByte(0) |
| 124 | return nil |
| 125 | } |