all repos

mugit @ 3d4f6c6

🐮 git server that your cow will love

mugit/internal/ssh/ssh.go (view raw)

Oleksandr Smirnov Oleksandr Smirnov
olexsmir@gmail.com
run errcheck, 28 days ago
1
package ssh
2
3
import (
4
	"context"
5
	"errors"
6
	"fmt"
7
	"io"
8
	"strings"
9
10
	"olexsmir.xyz/mugit/internal/config"
11
	"olexsmir.xyz/mugit/internal/git"
12
13
	gossh "golang.org/x/crypto/ssh"
14
)
15
16
type Shell struct {
17
	cfg *config.Config
18
19
	keys []gossh.PublicKey
20
}
21
22
func NewShell(cfg *config.Config) (*Shell, error) {
23
	parsedKeys := make([]gossh.PublicKey, len(cfg.SSH.Keys))
24
	for i, key := range cfg.SSH.Keys {
25
		pkey, _, _, _, err := gossh.ParseAuthorizedKey([]byte(key))
26
		if err != nil {
27
			return nil, err
28
		}
29
		parsedKeys[i] = pkey
30
	}
31
32
	return &Shell{
33
		cfg:  cfg,
34
		keys: parsedKeys,
35
	}, nil
36
}
37
38
func (s *Shell) HandleCommand(ctx context.Context, cmd string, stdin io.Reader, stdout, stderr io.Writer) error {
39
	// ssh -T `mugit@host`
40
	if strings.TrimSpace(cmd) == "" {
41
		_, err := fmt.Fprintln(stderr, s.cfg.Meta.Modt)
42
		return err
43
	}
44
45
	gitCmd, repoName, err := s.parseCommand(cmd)
46
	if err != nil {
47
		return s.replyWithGitError(stderr, "access denied: invalid command", err)
48
	}
49
50
	repoPath, err := git.ResolvePath(s.cfg.Repo.Dir, git.ResolveName(repoName))
51
	if err != nil {
52
		return s.replyWithGitError(stderr, "access denied", err)
53
	}
54
55
	repo, err := git.Open(repoPath, "")
56
	if err != nil {
57
		if !errors.Is(err, git.ErrRepoNotFound) || gitCmd != "git-receive-pack" {
58
			return s.replyWithGitError(stderr, "repository not found", err)
59
		}
60
61
		// SSH Git clients display informational messages from stderr; stdout must remain protocol-only for git-receive-pack.
62
		if ierr := s.replyWithGitInfo(stderr, "auto-initializing "+repoName); ierr != nil {
63
			return ierr
64
		}
65
66
		if ierr := git.Init(repoPath); ierr != nil {
67
			return s.replyWithGitError(stderr, "failed to init repo", ierr)
68
		}
69
70
		repo, err = git.Open(repoPath, "")
71
		if err != nil {
72
			return s.replyWithGitError(stderr, "failed to open initialized repo", err)
73
		}
74
	}
75
76
	if s.cfg.Meta.Modt != "" {
77
		_, _ = fmt.Fprintln(stderr, s.cfg.Meta.Modt)
78
	}
79
80
	switch gitCmd {
81
	case "git-upload-pack":
82
		err = repo.UploadPack(ctx, false, "", stdin, stdout)
83
	case "git-upload-archive":
84
		err = repo.UploadArchive(ctx, stdin, stdout)
85
	case "git-receive-pack":
86
		err = repo.ReceivePack(ctx, stdin, stdout, stderr)
87
	default:
88
		msg := "access denied: invalid git command"
89
		return s.replyWithGitError(stderr, msg, errors.New(msg))
90
	}
91
92
	if err != nil {
93
		return err
94
	}
95
96
	return nil
97
}
98
99
func (s *Shell) AuthorizedKeys(executablePath string) string {
100
	var out strings.Builder
101
	for _, key := range s.cfg.SSH.Keys {
102
		fmt.Fprintf(&out, `command="%s shell",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty %s`+"\n",
103
			executablePath, key)
104
	}
105
	return out.String()
106
}
107
108
var validCommands = map[string]bool{
109
	"git-upload-pack":    true,
110
	"git-upload-archive": true,
111
	"git-receive-pack":   true,
112
}
113
114
func (s *Shell) parseCommand(cmd string) (gitCmd, repoName string, err error) {
115
	cmdParts := strings.Fields(cmd)
116
	if len(cmdParts) != 2 {
117
		return "", "", fmt.Errorf("invalid command: expected 'git-cmd repo', got %q", cmd)
118
	}
119
120
	gitCmd = cmdParts[0]
121
	if !validCommands[gitCmd] {
122
		return "", "", fmt.Errorf("invalid command: disallowed command")
123
	}
124
125
	repoName = strings.Trim(cmdParts[1], "'\"")
126
	if repoName == "" {
127
		return "", "", fmt.Errorf("invalid command: empty repository name")
128
	}
129
130
	return gitCmd, repoName, nil
131
}
132
133
func (s *Shell) replyWithGitError(stderr io.Writer, msg string, cause error) error {
134
	if _, err := fmt.Fprintf(stderr, "error: %s\n", msg); err != nil {
135
		return err
136
	}
137
138
	return cause
139
}
140
141
func (s *Shell) replyWithGitInfo(msgOut io.Writer, msg string) error {
142
	_, err := fmt.Fprintf(msgOut, "info: %s\n", msg)
143
	return err
144
}