scratch/dns-server/record.go (view raw)
| 1 | package main |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "encoding/binary" |
| 6 | "fmt" |
| 7 | "net" |
| 8 | "strings" |
| 9 | ) |
| 10 | |
| 11 | type QueryType uint16 |
| 12 | |
| 13 | const ( |
| 14 | AType QueryType = 1 |
| 15 | NSType QueryType = 2 |
| 16 | CNAMEType QueryType = 5 |
| 17 | MXType QueryType = 15 |
| 18 | AAAAType QueryType = 28 |
| 19 | ) |
| 20 | |
| 21 | type Record struct { |
| 22 | Name string |
| 23 | Type QueryType |
| 24 | Class uint16 |
| 25 | TTL uint32 |
| 26 | Data string |
| 27 | Priority uint16 |
| 28 | } |
| 29 | |
| 30 | func ReadRecord(r *bytes.Reader, packet []byte) (Record, error) { |
| 31 | name, err := readName(r, packet) |
| 32 | if err != nil { |
| 33 | return Record{}, err |
| 34 | } |
| 35 | |
| 36 | var rtype QueryType |
| 37 | var class, rdlen uint16 |
| 38 | var ttl uint32 |
| 39 | _ = binary.Read(r, binary.BigEndian, &rtype) |
| 40 | _ = binary.Read(r, binary.BigEndian, &class) |
| 41 | _ = binary.Read(r, binary.BigEndian, &ttl) |
| 42 | _ = binary.Read(r, binary.BigEndian, &rdlen) |
| 43 | |
| 44 | var data string |
| 45 | var priority uint16 |
| 46 | switch rtype { |
| 47 | case AType: |
| 48 | var ip [4]byte |
| 49 | _, _ = r.Read(ip[:]) |
| 50 | data = net.IP(ip[:]).String() |
| 51 | |
| 52 | case AAAAType: |
| 53 | var ip [16]byte |
| 54 | _, _ = r.Read(ip[:]) |
| 55 | data = net.IP(ip[:]).String() |
| 56 | |
| 57 | case NSType, CNAMEType: |
| 58 | data, _ = readName(r, packet) |
| 59 | |
| 60 | case MXType: |
| 61 | var pri uint16 |
| 62 | _ = binary.Read(r, binary.BigEndian, &pri) |
| 63 | host, _ := readName(r, packet) |
| 64 | priority = pri |
| 65 | data = host |
| 66 | |
| 67 | default: |
| 68 | buf := make([]byte, rdlen) |
| 69 | _, _ = r.Read(buf) |
| 70 | data = fmt.Sprintf("%x", buf) |
| 71 | } |
| 72 | |
| 73 | return Record{ |
| 74 | Name: name, |
| 75 | Type: rtype, |
| 76 | Class: class, |
| 77 | TTL: ttl, |
| 78 | Data: data, |
| 79 | Priority: priority, |
| 80 | }, nil |
| 81 | } |
| 82 | |
| 83 | func (r Record) Write(b *bytes.Buffer) (int, error) { |
| 84 | start := b.Len() |
| 85 | switch r.Type { |
| 86 | case AType: |
| 87 | _ = writeName(b, r.Name) |
| 88 | _ = binary.Write(b, binary.BigEndian, r.Type) |
| 89 | _ = binary.Write(b, binary.BigEndian, r.Class) |
| 90 | _ = binary.Write(b, binary.BigEndian, r.TTL) |
| 91 | _ = binary.Write(b, binary.BigEndian, uint16(4)) |
| 92 | |
| 93 | ip := net.ParseIP(r.Data).To4() |
| 94 | if ip == nil { |
| 95 | return 0, fmt.Errorf("invalid IPv4 address: %s", r.Data) |
| 96 | } |
| 97 | |
| 98 | _, _ = b.Write(ip) |
| 99 | |
| 100 | case AAAAType: |
| 101 | _ = writeName(b, r.Name) |
| 102 | _ = binary.Write(b, binary.BigEndian, r.Type) |
| 103 | _ = binary.Write(b, binary.BigEndian, r.Class) |
| 104 | _ = binary.Write(b, binary.BigEndian, r.TTL) |
| 105 | _ = binary.Write(b, binary.BigEndian, uint16(16)) |
| 106 | |
| 107 | ip := net.ParseIP(r.Data).To16() |
| 108 | if ip == nil { |
| 109 | return 0, fmt.Errorf("invalid IPv6 address: %s", r.Data) |
| 110 | } |
| 111 | |
| 112 | _, _ = b.Write(ip) |
| 113 | |
| 114 | case NSType, CNAMEType: |
| 115 | _ = writeName(b, r.Name) |
| 116 | _ = binary.Write(b, binary.BigEndian, r.Type) |
| 117 | _ = binary.Write(b, binary.BigEndian, r.Class) |
| 118 | _ = binary.Write(b, binary.BigEndian, r.TTL) |
| 119 | |
| 120 | encoded := encodeName(r.Data) |
| 121 | _ = binary.Write(b, binary.BigEndian, uint16(len(encoded))) |
| 122 | _, _ = b.Write(encoded) |
| 123 | |
| 124 | case MXType: |
| 125 | _ = writeName(b, r.Name) |
| 126 | _ = binary.Write(b, binary.BigEndian, r.Type) |
| 127 | _ = binary.Write(b, binary.BigEndian, r.Class) |
| 128 | _ = binary.Write(b, binary.BigEndian, r.TTL) |
| 129 | |
| 130 | encoded := encodeName(r.Data) |
| 131 | _ = binary.Write(b, binary.BigEndian, uint16(2+len(encoded))) |
| 132 | _ = binary.Write(b, binary.BigEndian, uint16(r.Priority)) |
| 133 | _, _ = b.Write(encoded) |
| 134 | |
| 135 | default: |
| 136 | fmt.Printf("Skipping record: %+v\n", r) |
| 137 | } |
| 138 | |
| 139 | return b.Len() - start, nil |
| 140 | } |
| 141 | |
| 142 | func readName(r *bytes.Reader, packet []byte) (string, error) { |
| 143 | var labels []string |
| 144 | for { |
| 145 | length, err := r.ReadByte() |
| 146 | if err != nil { |
| 147 | return "", err |
| 148 | } |
| 149 | if length == 0 { |
| 150 | break |
| 151 | } |
| 152 | // pointer: top two bits set (0xC0) |
| 153 | if length&0xC0 == 0xC0 { |
| 154 | low, err := r.ReadByte() |
| 155 | if err != nil { |
| 156 | return "", err |
| 157 | } |
| 158 | offset := int(uint16(length&0x3F)<<8 | uint16(low)) |
| 159 | sub := bytes.NewReader(packet[offset:]) |
| 160 | name, err := readName(sub, packet) |
| 161 | if err != nil { |
| 162 | return "", err |
| 163 | } |
| 164 | labels = append(labels, name) |
| 165 | break // pointer always ends the name |
| 166 | } |
| 167 | buf := make([]byte, length) |
| 168 | if _, err := r.Read(buf); err != nil { |
| 169 | return "", err |
| 170 | } |
| 171 | labels = append(labels, string(buf)) |
| 172 | } |
| 173 | return strings.Join(labels, "."), nil |
| 174 | } |
| 175 | |
| 176 | func encodeName(name string) []byte { |
| 177 | var b bytes.Buffer |
| 178 | for label := range strings.SplitSeq(name, ".") { |
| 179 | _ = b.WriteByte(byte(len(label))) |
| 180 | _, _ = b.WriteString(label) |
| 181 | } |
| 182 | _ = b.WriteByte(0) |
| 183 | return b.Bytes() |
| 184 | } |
| 185 | |
| 186 | // TODO: wrap the Buffer, to have the len == 512 guard |
| 187 | func writeName(w *bytes.Buffer, qname string) error { |
| 188 | for label := range strings.SplitSeq(qname, ".") { |
| 189 | llen := len(label) |
| 190 | if llen > 0x3f { |
| 191 | return fmt.Errorf("single label exceeds 63 characters of length") |
| 192 | } |
| 193 | _ = w.WriteByte(byte(llen)) |
| 194 | _, _ = w.Write([]byte(label)) |
| 195 | } |
| 196 | _ = w.WriteByte(0) |
| 197 | return nil |
| 198 | } |