all repos

scratch @ 17b595b13a6b7913b0e606ccb699e368f104968c

⭐ me doing recreational ~~drugs~~ programming

scratch/dns-server/record.go (view raw)

1
package main
2
3
import (
4
	"bytes"
5
	"encoding/binary"
6
	"fmt"
7
	"net"
8
	"strings"
9
)
10
11
type QueryType uint16
12
13
const (
14
	AType     QueryType = 1
15
	NSType    QueryType = 2
16
	CNAMEType QueryType = 5
17
	MXType    QueryType = 15
18
	AAAAType  QueryType = 28
19
)
20
21
type Record struct {
22
	Name     string
23
	Type     QueryType
24
	Class    uint16
25
	TTL      uint32
26
	Data     string
27
	Priority uint16
28
}
29
30
func ReadRecord(r *bytes.Reader, packet []byte) (Record, error) {
31
	name, err := readName(r, packet)
32
	if err != nil {
33
		return Record{}, err
34
	}
35
36
	var rtype QueryType
37
	var class, rdlen uint16
38
	var ttl uint32
39
	_ = binary.Read(r, binary.BigEndian, &rtype)
40
	_ = binary.Read(r, binary.BigEndian, &class)
41
	_ = binary.Read(r, binary.BigEndian, &ttl)
42
	_ = binary.Read(r, binary.BigEndian, &rdlen)
43
44
	var data string
45
	var priority uint16
46
	switch rtype {
47
	case AType:
48
		var ip [4]byte
49
		_, _ = r.Read(ip[:])
50
		data = net.IP(ip[:]).String()
51
52
	case AAAAType:
53
		var ip [16]byte
54
		_, _ = r.Read(ip[:])
55
		data = net.IP(ip[:]).String()
56
57
	case NSType, CNAMEType:
58
		data, _ = readName(r, packet)
59
60
	case MXType:
61
		var pri uint16
62
		_ = binary.Read(r, binary.BigEndian, &pri)
63
		host, _ := readName(r, packet)
64
		priority = pri
65
		data = host
66
67
	default:
68
		buf := make([]byte, rdlen)
69
		_, _ = r.Read(buf)
70
		data = fmt.Sprintf("%x", buf)
71
	}
72
73
	return Record{
74
		Name:     name,
75
		Type:     rtype,
76
		Class:    class,
77
		TTL:      ttl,
78
		Data:     data,
79
		Priority: priority,
80
	}, nil
81
}
82
83
func (r Record) Write(b *bytes.Buffer) (int, error) {
84
	start := b.Len()
85
	switch r.Type {
86
	case AType:
87
		_ = r.writePremable(b)
88
		_ = binary.Write(b, binary.BigEndian, uint16(4))
89
90
		ip := net.ParseIP(r.Data).To4()
91
		if ip == nil {
92
			return 0, fmt.Errorf("invalid IPv4 address: %s", r.Data)
93
		}
94
95
		_, _ = b.Write(ip)
96
97
	case AAAAType:
98
		_ = r.writePremable(b)
99
		_ = binary.Write(b, binary.BigEndian, uint16(16))
100
101
		ip := net.ParseIP(r.Data).To16()
102
		if ip == nil {
103
			return 0, fmt.Errorf("invalid IPv6 address: %s", r.Data)
104
		}
105
106
		_, _ = b.Write(ip)
107
108
	case NSType, CNAMEType:
109
		_ = r.writePremable(b)
110
111
		encoded := encodeName(r.Data)
112
		_ = binary.Write(b, binary.BigEndian, uint16(len(encoded)))
113
		_, _ = b.Write(encoded)
114
115
	case MXType:
116
		_ = r.writePremable(b)
117
118
		encoded := encodeName(r.Data)
119
		_ = binary.Write(b, binary.BigEndian, uint16(2+len(encoded)))
120
		_ = binary.Write(b, binary.BigEndian, uint16(r.Priority))
121
		_, _ = b.Write(encoded)
122
123
	default:
124
		fmt.Printf("Skipping record: %+v\n", r)
125
	}
126
127
	return b.Len() - start, nil
128
}
129
130
func (r Record) writePremable(b *bytes.Buffer) error {
131
	// TODO: errors
132
	_ = writeName(b, r.Name)
133
	_ = binary.Write(b, binary.BigEndian, r.Type)
134
	_ = binary.Write(b, binary.BigEndian, r.Class)
135
	_ = binary.Write(b, binary.BigEndian, r.TTL)
136
	return nil
137
}
138
139
func readName(r *bytes.Reader, packet []byte) (string, error) {
140
	var labels []string
141
	for {
142
		length, err := r.ReadByte()
143
		if err != nil {
144
			return "", err
145
		}
146
		if length == 0 {
147
			break
148
		}
149
		// pointer: top two bits set (0xC0)
150
		if length&0xC0 == 0xC0 {
151
			low, err := r.ReadByte()
152
			if err != nil {
153
				return "", err
154
			}
155
			offset := int(uint16(length&0x3F)<<8 | uint16(low))
156
			sub := bytes.NewReader(packet[offset:])
157
			name, err := readName(sub, packet)
158
			if err != nil {
159
				return "", err
160
			}
161
			labels = append(labels, name)
162
			break // pointer always ends the name
163
		}
164
		buf := make([]byte, length)
165
		if _, err := r.Read(buf); err != nil {
166
			return "", err
167
		}
168
		labels = append(labels, string(buf))
169
	}
170
	return strings.Join(labels, "."), nil
171
}
172
173
func encodeName(name string) []byte {
174
	var b bytes.Buffer
175
	for label := range strings.SplitSeq(name, ".") {
176
		_ = b.WriteByte(byte(len(label)))
177
		_, _ = b.WriteString(label)
178
	}
179
	_ = b.WriteByte(0)
180
	return b.Bytes()
181
}
182
183
// TODO: wrap the Buffer, to have the len == 512 guard
184
func writeName(w *bytes.Buffer, qname string) error {
185
	for label := range strings.SplitSeq(qname, ".") {
186
		llen := len(label)
187
		if llen > 0x3f {
188
			return fmt.Errorf("single label exceeds 63 characters of length")
189
		}
190
		_ = w.WriteByte(byte(llen))
191
		_, _ = w.Write([]byte(label))
192
	}
193
	_ = w.WriteByte(0)
194
	return nil
195
}