all repos

scratch @ 0d2ab83bc28c9eb22c22bfeb78aac53dd6af1ea7

⭐ me doing recreational ~~drugs~~ programming

scratch/dns-server/record.go (view raw)

1
package main
2
3
import (
4
	"bytes"
5
	"encoding/binary"
6
	"fmt"
7
	"net"
8
	"strings"
9
)
10
11
type QueryType uint16
12
13
const (
14
	AType     QueryType = 1
15
	NSType    QueryType = 2
16
	CNAMEType QueryType = 5
17
	MXType    QueryType = 15
18
	AAAAType  QueryType = 28
19
)
20
21
type Record struct {
22
	Name     string
23
	Type     QueryType
24
	Class    uint16
25
	TTL      uint32
26
	Data     string
27
	Priority uint16
28
}
29
30
func ReadRecord(r *bytes.Reader, packet []byte) (Record, error) {
31
	name, err := readName(r, packet)
32
	if err != nil {
33
		return Record{}, err
34
	}
35
36
	var rtype QueryType
37
	var class, rdlen uint16
38
	var ttl uint32
39
	_ = binary.Read(r, binary.BigEndian, &rtype)
40
	_ = binary.Read(r, binary.BigEndian, &class)
41
	_ = binary.Read(r, binary.BigEndian, &ttl)
42
	_ = binary.Read(r, binary.BigEndian, &rdlen)
43
44
	var data string
45
	var priority uint16
46
	switch rtype {
47
	case AType:
48
		var ip [4]byte
49
		_, _ = r.Read(ip[:])
50
		data = net.IP(ip[:]).String()
51
52
	case AAAAType:
53
		var ip [16]byte
54
		_, _ = r.Read(ip[:])
55
		data = net.IP(ip[:]).String()
56
57
	case NSType, CNAMEType:
58
		data, _ = readName(r, packet)
59
60
	case MXType:
61
		var pri uint16
62
		_ = binary.Read(r, binary.BigEndian, &pri)
63
		host, _ := readName(r, packet)
64
		priority = pri
65
		data = host
66
67
	default:
68
		buf := make([]byte, rdlen)
69
		_, _ = r.Read(buf)
70
		data = fmt.Sprintf("%x", buf)
71
	}
72
73
	return Record{
74
		Name:     name,
75
		Type:     rtype,
76
		Class:    class,
77
		TTL:      ttl,
78
		Data:     data,
79
		Priority: priority,
80
	}, nil
81
}
82
83
func (r Record) Write(b *bytes.Buffer) (int, error) {
84
	start := b.Len()
85
	switch r.Type {
86
	case AType:
87
		_ = writeName(b, r.Name)
88
		_ = binary.Write(b, binary.BigEndian, r.Type)
89
		_ = binary.Write(b, binary.BigEndian, r.Class)
90
		_ = binary.Write(b, binary.BigEndian, r.TTL)
91
		_ = binary.Write(b, binary.BigEndian, uint16(4))
92
93
		ip := net.ParseIP(r.Data).To4()
94
		if ip == nil {
95
			return 0, fmt.Errorf("invalid IPv4 address: %s", r.Data)
96
		}
97
98
		_, _ = b.Write(ip)
99
100
	case AAAAType:
101
		_ = writeName(b, r.Name)
102
		_ = binary.Write(b, binary.BigEndian, r.Type)
103
		_ = binary.Write(b, binary.BigEndian, r.Class)
104
		_ = binary.Write(b, binary.BigEndian, r.TTL)
105
		_ = binary.Write(b, binary.BigEndian, uint16(16))
106
107
		ip := net.ParseIP(r.Data).To16()
108
		if ip == nil {
109
			return 0, fmt.Errorf("invalid IPv6 address: %s", r.Data)
110
		}
111
112
		_, _ = b.Write(ip)
113
114
	case NSType, CNAMEType:
115
		_ = writeName(b, r.Name)
116
		_ = binary.Write(b, binary.BigEndian, r.Type)
117
		_ = binary.Write(b, binary.BigEndian, r.Class)
118
		_ = binary.Write(b, binary.BigEndian, r.TTL)
119
120
		encoded := encodeName(r.Data)
121
		_ = binary.Write(b, binary.BigEndian, uint16(len(encoded)))
122
		_, _ = b.Write(encoded)
123
124
	case MXType:
125
		_ = writeName(b, r.Name)
126
		_ = binary.Write(b, binary.BigEndian, r.Type)
127
		_ = binary.Write(b, binary.BigEndian, r.Class)
128
		_ = binary.Write(b, binary.BigEndian, r.TTL)
129
130
		encoded := encodeName(r.Data)
131
		_ = binary.Write(b, binary.BigEndian, uint16(2+len(encoded)))
132
		_ = binary.Write(b, binary.BigEndian, uint16(r.Priority))
133
		_, _ = b.Write(encoded)
134
135
	default:
136
		fmt.Printf("Skipping record: %+v\n", r)
137
	}
138
139
	return b.Len() - start, nil
140
}
141
142
func readName(r *bytes.Reader, packet []byte) (string, error) {
143
	var labels []string
144
	for {
145
		length, err := r.ReadByte()
146
		if err != nil {
147
			return "", err
148
		}
149
		if length == 0 {
150
			break
151
		}
152
		// pointer: top two bits set (0xC0)
153
		if length&0xC0 == 0xC0 {
154
			low, err := r.ReadByte()
155
			if err != nil {
156
				return "", err
157
			}
158
			offset := int(uint16(length&0x3F)<<8 | uint16(low))
159
			sub := bytes.NewReader(packet[offset:])
160
			name, err := readName(sub, packet)
161
			if err != nil {
162
				return "", err
163
			}
164
			labels = append(labels, name)
165
			break // pointer always ends the name
166
		}
167
		buf := make([]byte, length)
168
		if _, err := r.Read(buf); err != nil {
169
			return "", err
170
		}
171
		labels = append(labels, string(buf))
172
	}
173
	return strings.Join(labels, "."), nil
174
}
175
176
func encodeName(name string) []byte {
177
	var b bytes.Buffer
178
	for label := range strings.SplitSeq(name, ".") {
179
		_ = b.WriteByte(byte(len(label)))
180
		_, _ = b.WriteString(label)
181
	}
182
	_ = b.WriteByte(0)
183
	return b.Bytes()
184
}
185
186
// TODO: wrap the Buffer, to have the len == 512 guard
187
func writeName(w *bytes.Buffer, qname string) error {
188
	for label := range strings.SplitSeq(qname, ".") {
189
		llen := len(label)
190
		if llen > 0x3f {
191
			return fmt.Errorf("single label exceeds 63 characters of length")
192
		}
193
		_ = w.WriteByte(byte(llen))
194
		_, _ = w.Write([]byte(label))
195
	}
196
	_ = w.WriteByte(0)
197
	return nil
198
}