package middleware import ( "net/http" "net/http/httptest" "strings" "testing" "github.com/konduktor/konduktor/internal/logging" ) // ============== ServerHeader Tests ============== func TestServerHeader(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) wrapped := ServerHeader(handler, "1.0.0") req := httptest.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() wrapped.ServeHTTP(rr, req) serverHeader := rr.Header().Get("Server") if serverHeader != "konduktor/1.0.0" { t.Errorf("Expected Server header 'konduktor/1.0.0', got '%s'", serverHeader) } } // ============== AccessLog Tests ============== func TestAccessLog(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("Hello")) }) logger, _ := logging.New(logging.Config{Level: "INFO"}) wrapped := AccessLog(handler, logger) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() wrapped.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", rr.Code) } } func TestAccessLog_CapturesStatusCode(t *testing.T) { tests := []struct { name string statusCode int }{ {"OK", http.StatusOK}, {"NotFound", http.StatusNotFound}, {"InternalError", http.StatusInternalServerError}, {"Redirect", http.StatusMovedPermanently}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(tt.statusCode) }) logger, _ := logging.New(logging.Config{Level: "INFO"}) wrapped := AccessLog(handler, logger) req := httptest.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() wrapped.ServeHTTP(rr, req) if rr.Code != tt.statusCode { t.Errorf("Expected status %d, got %d", tt.statusCode, rr.Code) } }) } } // ============== Recovery Tests ============== func TestRecovery_NoPanic(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) }) logger, _ := logging.New(logging.Config{Level: "INFO"}) wrapped := Recovery(handler, logger) req := httptest.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() wrapped.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", rr.Code) } if rr.Body.String() != "OK" { t.Errorf("Expected body 'OK', got '%s'", rr.Body.String()) } } func TestRecovery_WithPanic(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic("test panic") }) logger, _ := logging.New(logging.Config{Level: "ERROR"}) wrapped := Recovery(handler, logger) req := httptest.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() // Should not panic wrapped.ServeHTTP(rr, req) if rr.Code != http.StatusInternalServerError { t.Errorf("Expected status 500, got %d", rr.Code) } if !strings.Contains(rr.Body.String(), "Internal Server Error") { t.Errorf("Expected 'Internal Server Error' in body, got '%s'", rr.Body.String()) } } // ============== responseWriter Tests ============== func TestResponseWriter_WriteHeader(t *testing.T) { rr := httptest.NewRecorder() rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK} rw.WriteHeader(http.StatusNotFound) if rw.status != http.StatusNotFound { t.Errorf("Expected status 404, got %d", rw.status) } } func TestResponseWriter_Write(t *testing.T) { rr := httptest.NewRecorder() rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK} n, err := rw.Write([]byte("Hello World")) if err != nil { t.Errorf("Unexpected error: %v", err) } if n != 11 { t.Errorf("Expected 11 bytes written, got %d", n) } if rw.size != 11 { t.Errorf("Expected size 11, got %d", rw.size) } } // ============== Middleware Chain Tests ============== func TestMiddlewareChain(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) }) logger, _ := logging.New(logging.Config{Level: "INFO"}) // Apply middleware chain wrapped := Recovery(AccessLog(ServerHeader(handler, "1.0.0"), logger), logger) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() wrapped.ServeHTTP(rr, req) // Check all middleware worked if rr.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", rr.Code) } if rr.Header().Get("Server") != "konduktor/1.0.0" { t.Errorf("Expected Server header") } if rr.Body.String() != "OK" { t.Errorf("Expected body 'OK', got '%s'", rr.Body.String()) } } // ============== Benchmarks ============== func BenchmarkServerHeader(b *testing.B) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) wrapped := ServerHeader(handler, "1.0.0") req := httptest.NewRequest("GET", "/", nil) b.ResetTimer() for i := 0; i < b.N; i++ { rr := httptest.NewRecorder() wrapped.ServeHTTP(rr, req) } } func BenchmarkAccessLog(b *testing.B) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) logger, _ := logging.New(logging.Config{Level: "ERROR"}) // Minimize logging overhead wrapped := AccessLog(handler, logger) req := httptest.NewRequest("GET", "/", nil) b.ResetTimer() for i := 0; i < b.N; i++ { rr := httptest.NewRecorder() wrapped.ServeHTTP(rr, req) } } func BenchmarkRecovery(b *testing.B) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) logger, _ := logging.New(logging.Config{Level: "ERROR"}) wrapped := Recovery(handler, logger) req := httptest.NewRequest("GET", "/", nil) b.ResetTimer() for i := 0; i < b.N; i++ { rr := httptest.NewRecorder() wrapped.ServeHTTP(rr, req) } }