3 files changed,
103 insertions(+),
0 deletions(-)
Author:
Oleksandr Smirnov
olexsmir@gmail.com
Committed at:
2026-02-18 22:34:24 +0200
Change ID:
lmvrxksvunnoosrzvqlnwoulnqlqwuzz
Parent:
6c0a72e
jump to
| M | readme |
| A | reqid/reqid.go |
| A | reqid/reqid_test.go |
M
readme
@@ -5,5 +5,6 @@ Set of packages that I often copy-paste into my new projects.
- is - simple assertion library for tests - envy - type-safe os.Getenv with default values +- reqid - http middleware that add request id header and sets it in context This repo is MIT Licensed
A
reqid/reqid.go
@@ -0,0 +1,63 @@
+package reqid + +import ( + "context" + "crypto/rand" + "encoding/hex" + "net/http" +) + +type requestIDKey string + +const ( + RequestID requestIDKey = "request-id" + + Header = "X-Request-ID" +) + +// Middleware http middleware that sets random generated request id to each +// request it get fed. +func Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rid := r.Header.Get(Header) + if rid == "" { + rid = generateRequestID() + r.Header.Set(Header, rid) + } + + ctx := SetContext(r.Context(), rid) + r = r.WithContext(ctx) + + w.Header().Add(Header, rid) + next.ServeHTTP(w, r) + }) +} + +// Get returns the request ID of http request (looks up from headers) +func Get(r *http.Request) string { + return r.Header.Get(Header) +} + +// GetContext returns the request ID from context. +func GetContext(ctx context.Context) string { + rid, ok := ctx.Value(RequestID).(string) + if !ok { + return "" + } + return rid +} + +// SetContext gets a parent context and returns a child context with the set +// provided request ID +func SetContext(ctx context.Context, reqID string) context.Context { + return context.WithValue(ctx, RequestID, reqID) +} + +func generateRequestID() string { + b := make([]byte, 13) + _, err := rand.Read(b) + if err != nil { + return "unknown" + } + return hex.EncodeToString(b) +}
A
reqid/reqid_test.go
@@ -0,0 +1,39 @@
+package reqid + +import ( + "net/http" + "net/http/httptest" + "testing" + + "olexsmir.xyz/x/is" +) + +func TestMiddleware(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("GET /", testHandler) + hand := Middleware(mux) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/", nil) + hand.ServeHTTP(w, req) + + is.Equal(t, http.StatusOK, w.Code) + is.NotEqual(t, w.Header().Get(Header), "") +} + +func BenchmarkMiddleware(b *testing.B) { + mux := http.NewServeMux() + mux.HandleFunc("GET /", testHandler) + hand := Middleware(mux) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/", nil) + + for b.Loop() { + hand.ServeHTTP(w, req) + } +} + +func testHandler(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) +}