forked from aegis/pyserveX
245 lines
5.9 KiB
Go
245 lines
5.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/konduktor/konduktor/internal/logging"
|
|
)
|
|
|
|
// ============== ServerHeader Tests ==============
|
|
|
|
func TestServerHeader(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
wrapped := ServerHeader(handler, "1.0.0")
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rr, req)
|
|
|
|
serverHeader := rr.Header().Get("Server")
|
|
if serverHeader != "konduktor/1.0.0" {
|
|
t.Errorf("Expected Server header 'konduktor/1.0.0', got '%s'", serverHeader)
|
|
}
|
|
}
|
|
|
|
// ============== AccessLog Tests ==============
|
|
|
|
func TestAccessLog(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("Hello"))
|
|
})
|
|
|
|
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
|
wrapped := AccessLog(handler, logger)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestAccessLog_CapturesStatusCode(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
statusCode int
|
|
}{
|
|
{"OK", http.StatusOK},
|
|
{"NotFound", http.StatusNotFound},
|
|
{"InternalError", http.StatusInternalServerError},
|
|
{"Redirect", http.StatusMovedPermanently},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(tt.statusCode)
|
|
})
|
|
|
|
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
|
wrapped := AccessLog(handler, logger)
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != tt.statusCode {
|
|
t.Errorf("Expected status %d, got %d", tt.statusCode, rr.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ============== Recovery Tests ==============
|
|
|
|
func TestRecovery_NoPanic(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
})
|
|
|
|
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
|
wrapped := Recovery(handler, logger)
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
|
|
if rr.Body.String() != "OK" {
|
|
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestRecovery_WithPanic(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
panic("test panic")
|
|
})
|
|
|
|
logger, _ := logging.New(logging.Config{Level: "ERROR"})
|
|
wrapped := Recovery(handler, logger)
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Should not panic
|
|
wrapped.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusInternalServerError {
|
|
t.Errorf("Expected status 500, got %d", rr.Code)
|
|
}
|
|
|
|
if !strings.Contains(rr.Body.String(), "Internal Server Error") {
|
|
t.Errorf("Expected 'Internal Server Error' in body, got '%s'", rr.Body.String())
|
|
}
|
|
}
|
|
|
|
// ============== responseWriter Tests ==============
|
|
|
|
func TestResponseWriter_WriteHeader(t *testing.T) {
|
|
rr := httptest.NewRecorder()
|
|
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
|
|
|
|
rw.WriteHeader(http.StatusNotFound)
|
|
|
|
if rw.status != http.StatusNotFound {
|
|
t.Errorf("Expected status 404, got %d", rw.status)
|
|
}
|
|
}
|
|
|
|
func TestResponseWriter_Write(t *testing.T) {
|
|
rr := httptest.NewRecorder()
|
|
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
|
|
|
|
n, err := rw.Write([]byte("Hello World"))
|
|
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if n != 11 {
|
|
t.Errorf("Expected 11 bytes written, got %d", n)
|
|
}
|
|
|
|
if rw.size != 11 {
|
|
t.Errorf("Expected size 11, got %d", rw.size)
|
|
}
|
|
}
|
|
|
|
// ============== Middleware Chain Tests ==============
|
|
|
|
func TestMiddlewareChain(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
})
|
|
|
|
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
|
|
|
// Apply middleware chain
|
|
wrapped := Recovery(AccessLog(ServerHeader(handler, "1.0.0"), logger), logger)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rr, req)
|
|
|
|
// Check all middleware worked
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
|
|
if rr.Header().Get("Server") != "konduktor/1.0.0" {
|
|
t.Errorf("Expected Server header")
|
|
}
|
|
|
|
if rr.Body.String() != "OK" {
|
|
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
|
|
}
|
|
}
|
|
|
|
// ============== Benchmarks ==============
|
|
|
|
func BenchmarkServerHeader(b *testing.B) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
wrapped := ServerHeader(handler, "1.0.0")
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
rr := httptest.NewRecorder()
|
|
wrapped.ServeHTTP(rr, req)
|
|
}
|
|
}
|
|
|
|
func BenchmarkAccessLog(b *testing.B) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
logger, _ := logging.New(logging.Config{Level: "ERROR"}) // Minimize logging overhead
|
|
wrapped := AccessLog(handler, logger)
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
rr := httptest.NewRecorder()
|
|
wrapped.ServeHTTP(rr, req)
|
|
}
|
|
}
|
|
|
|
func BenchmarkRecovery(b *testing.B) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
logger, _ := logging.New(logging.Config{Level: "ERROR"})
|
|
wrapped := Recovery(handler, logger)
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
rr := httptest.NewRecorder()
|
|
wrapped.ServeHTTP(rr, req)
|
|
}
|
|
}
|