all repos

scratch @ bc17d480554cde516be3fa2adc55455aadfebc1a

⭐ me doing recreational ~~drugs~~ programming

scratch/dns-server/record.go (view raw)

1
package main
2
3
import (
4
	"bytes"
5
	"encoding/binary"
6
	"fmt"
7
	"net"
8
	"strings"
9
)
10
11
type Record struct {
12
	Name  string
13
	Type  uint16
14
	Class uint16
15
	TTL   uint32
16
	Data  string
17
}
18
19
func ReadRecord(r *bytes.Reader, packet []byte) (Record, error) {
20
	name, err := readName(r, packet)
21
	if err != nil {
22
		return Record{}, err
23
	}
24
25
	var rtype, class, rdlen uint16
26
	var ttl uint32
27
	_ = binary.Read(r, binary.BigEndian, &rtype)
28
	_ = binary.Read(r, binary.BigEndian, &class)
29
	_ = binary.Read(r, binary.BigEndian, &ttl)
30
	_ = binary.Read(r, binary.BigEndian, &rdlen)
31
32
	var data string
33
	switch rtype {
34
	case 1: // A
35
		var ip [4]byte
36
		_, _ = r.Read(ip[:])
37
		data = fmt.Sprintf("%d.%d.%d.%d",
38
			ip[0], ip[1], ip[2], ip[3])
39
40
	default:
41
		buf := make([]byte, rdlen)
42
		_, _ = r.Read(buf)
43
		data = fmt.Sprintf("%x", buf)
44
	}
45
46
	return Record{
47
		Name:  name,
48
		Type:  rtype,
49
		Class: class,
50
		TTL:   ttl,
51
		Data:  data,
52
	}, nil
53
}
54
55
func (r Record) Write(b *bytes.Buffer) (int, error) {
56
	start := b.Len()
57
	switch r.Type {
58
	case 1: // A
59
		_ = writeName(b, r.Name)
60
		_ = binary.Write(b, binary.BigEndian, r.Type)
61
		_ = binary.Write(b, binary.BigEndian, r.Class)
62
		_ = binary.Write(b, binary.BigEndian, r.TTL)
63
		_ = binary.Write(b, binary.BigEndian, uint16(4))
64
65
		ip := net.ParseIP(r.Data).To4()
66
		if ip == nil {
67
			return 0, fmt.Errorf("invalid IPv4 address: %s", r.Data)
68
		}
69
70
		_, _ = b.Write(ip)
71
72
	default:
73
		fmt.Printf("Skipping record: %+v\n", r)
74
	}
75
76
	return b.Len() - start, nil
77
}
78
79
func readName(r *bytes.Reader, packet []byte) (string, error) {
80
	var labels []string
81
	for {
82
		length, err := r.ReadByte()
83
		if err != nil {
84
			return "", err
85
		}
86
		if length == 0 {
87
			break
88
		}
89
		// pointer: top two bits set (0xC0)
90
		if length&0xC0 == 0xC0 {
91
			low, err := r.ReadByte()
92
			if err != nil {
93
				return "", err
94
			}
95
			offset := int(uint16(length&0x3F)<<8 | uint16(low))
96
			sub := bytes.NewReader(packet[offset:])
97
			name, err := readName(sub, packet)
98
			if err != nil {
99
				return "", err
100
			}
101
			labels = append(labels, name)
102
			break // pointer always ends the name
103
		}
104
		buf := make([]byte, length)
105
		if _, err := r.Read(buf); err != nil {
106
			return "", err
107
		}
108
		labels = append(labels, string(buf))
109
	}
110
	return strings.Join(labels, "."), nil
111
}
112
113
// TODO: wrap the Buffer, to have the len == 512 guard
114
func writeName(w *bytes.Buffer, qname string) error {
115
	for label := range strings.SplitSeq(qname, ".") {
116
		llen := len(label)
117
		if llen > 0x3f {
118
			return fmt.Errorf("single label exceeds 63 characters of length")
119
		}
120
		_ = w.WriteByte(byte(llen))
121
		_, _ = w.Write([]byte(label))
122
	}
123
	_ = w.WriteByte(0)
124
	return nil
125
}