konduktor/go/internal/middleware/middleware_test.go
Илья Глазунов 8f5b9a5cd1 go implementation
2025-12-11 16:52:13 +03:00

245 lines
5.9 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/konduktor/konduktor/internal/logging"
)
// ============== ServerHeader Tests ==============
func TestServerHeader(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := ServerHeader(handler, "1.0.0")
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
serverHeader := rr.Header().Get("Server")
if serverHeader != "konduktor/1.0.0" {
t.Errorf("Expected Server header 'konduktor/1.0.0', got '%s'", serverHeader)
}
}
// ============== AccessLog Tests ==============
func TestAccessLog(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello"))
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
wrapped := AccessLog(handler, logger)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
}
func TestAccessLog_CapturesStatusCode(t *testing.T) {
tests := []struct {
name string
statusCode int
}{
{"OK", http.StatusOK},
{"NotFound", http.StatusNotFound},
{"InternalError", http.StatusInternalServerError},
{"Redirect", http.StatusMovedPermanently},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
wrapped := AccessLog(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
if rr.Code != tt.statusCode {
t.Errorf("Expected status %d, got %d", tt.statusCode, rr.Code)
}
})
}
}
// ============== Recovery Tests ==============
func TestRecovery_NoPanic(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
wrapped := Recovery(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
if rr.Body.String() != "OK" {
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
}
}
func TestRecovery_WithPanic(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("test panic")
})
logger, _ := logging.New(logging.Config{Level: "ERROR"})
wrapped := Recovery(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
// Should not panic
wrapped.ServeHTTP(rr, req)
if rr.Code != http.StatusInternalServerError {
t.Errorf("Expected status 500, got %d", rr.Code)
}
if !strings.Contains(rr.Body.String(), "Internal Server Error") {
t.Errorf("Expected 'Internal Server Error' in body, got '%s'", rr.Body.String())
}
}
// ============== responseWriter Tests ==============
func TestResponseWriter_WriteHeader(t *testing.T) {
rr := httptest.NewRecorder()
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
rw.WriteHeader(http.StatusNotFound)
if rw.status != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", rw.status)
}
}
func TestResponseWriter_Write(t *testing.T) {
rr := httptest.NewRecorder()
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
n, err := rw.Write([]byte("Hello World"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if n != 11 {
t.Errorf("Expected 11 bytes written, got %d", n)
}
if rw.size != 11 {
t.Errorf("Expected size 11, got %d", rw.size)
}
}
// ============== Middleware Chain Tests ==============
func TestMiddlewareChain(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
// Apply middleware chain
wrapped := Recovery(AccessLog(ServerHeader(handler, "1.0.0"), logger), logger)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
// Check all middleware worked
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
if rr.Header().Get("Server") != "konduktor/1.0.0" {
t.Errorf("Expected Server header")
}
if rr.Body.String() != "OK" {
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
}
}
// ============== Benchmarks ==============
func BenchmarkServerHeader(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := ServerHeader(handler, "1.0.0")
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
}
}
func BenchmarkAccessLog(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
logger, _ := logging.New(logging.Config{Level: "ERROR"}) // Minimize logging overhead
wrapped := AccessLog(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
}
}
func BenchmarkRecovery(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
logger, _ := logging.New(logging.Config{Level: "ERROR"})
wrapped := Recovery(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
}
}