all repos

scratch @ 7075724574a144fa80ec2e2929b0c267a7fe1204

⭐ me doing recreational ~~drugs~~ programming

scratch/dns-server/record.go (view raw)

1
package main
2
3
import (
4
	"bytes"
5
	"encoding/binary"
6
	"fmt"
7
	"net"
8
	"strings"
9
)
10
11
type QueryType uint16
12
13
const (
14
	AType     QueryType = 1
15
	NSType    QueryType = 2
16
	CNAMEType QueryType = 5
17
	MXType    QueryType = 15
18
	AAAAType  QueryType = 28
19
)
20
21
type Record struct {
22
	Name  string
23
	Type  QueryType
24
	Class uint16
25
	TTL   uint32
26
	Data  string
27
}
28
29
func ReadRecord(r *bytes.Reader, packet []byte) (Record, error) {
30
	name, err := readName(r, packet)
31
	if err != nil {
32
		return Record{}, err
33
	}
34
35
	var rtype QueryType
36
	var class, rdlen uint16
37
	var ttl uint32
38
	_ = binary.Read(r, binary.BigEndian, &rtype)
39
	_ = binary.Read(r, binary.BigEndian, &class)
40
	_ = binary.Read(r, binary.BigEndian, &ttl)
41
	_ = binary.Read(r, binary.BigEndian, &rdlen)
42
43
	var data string
44
	switch rtype {
45
	case AType:
46
		var ip [4]byte
47
		_, _ = r.Read(ip[:])
48
		data = fmt.Sprintf("%d.%d.%d.%d",
49
			ip[0], ip[1], ip[2], ip[3])
50
51
	default:
52
		buf := make([]byte, rdlen)
53
		_, _ = r.Read(buf)
54
		data = fmt.Sprintf("%x", buf)
55
	}
56
57
	return Record{
58
		Name:  name,
59
		Type:  rtype,
60
		Class: class,
61
		TTL:   ttl,
62
		Data:  data,
63
	}, nil
64
}
65
66
func (r Record) Write(b *bytes.Buffer) (int, error) {
67
	start := b.Len()
68
	switch r.Type {
69
	case AType:
70
		_ = writeName(b, r.Name)
71
		_ = binary.Write(b, binary.BigEndian, r.Type)
72
		_ = binary.Write(b, binary.BigEndian, r.Class)
73
		_ = binary.Write(b, binary.BigEndian, r.TTL)
74
		_ = binary.Write(b, binary.BigEndian, uint16(4))
75
76
		ip := net.ParseIP(r.Data).To4()
77
		if ip == nil {
78
			return 0, fmt.Errorf("invalid IPv4 address: %s", r.Data)
79
		}
80
81
		_, _ = b.Write(ip)
82
83
	default:
84
		fmt.Printf("Skipping record: %+v\n", r)
85
	}
86
87
	return b.Len() - start, nil
88
}
89
90
func readName(r *bytes.Reader, packet []byte) (string, error) {
91
	var labels []string
92
	for {
93
		length, err := r.ReadByte()
94
		if err != nil {
95
			return "", err
96
		}
97
		if length == 0 {
98
			break
99
		}
100
		// pointer: top two bits set (0xC0)
101
		if length&0xC0 == 0xC0 {
102
			low, err := r.ReadByte()
103
			if err != nil {
104
				return "", err
105
			}
106
			offset := int(uint16(length&0x3F)<<8 | uint16(low))
107
			sub := bytes.NewReader(packet[offset:])
108
			name, err := readName(sub, packet)
109
			if err != nil {
110
				return "", err
111
			}
112
			labels = append(labels, name)
113
			break // pointer always ends the name
114
		}
115
		buf := make([]byte, length)
116
		if _, err := r.Read(buf); err != nil {
117
			return "", err
118
		}
119
		labels = append(labels, string(buf))
120
	}
121
	return strings.Join(labels, "."), nil
122
}
123
124
// TODO: wrap the Buffer, to have the len == 512 guard
125
func writeName(w *bytes.Buffer, qname string) error {
126
	for label := range strings.SplitSeq(qname, ".") {
127
		llen := len(label)
128
		if llen > 0x3f {
129
			return fmt.Errorf("single label exceeds 63 characters of length")
130
		}
131
		_ = w.WriteByte(byte(llen))
132
		_, _ = w.Write([]byte(label))
133
	}
134
	_ = w.WriteByte(0)
135
	return nil
136
}