package extension import ( "context" "net/http" "net/http/httptest" "testing" ) func TestNewSecurityExtension(t *testing.T) { logger := newTestLogger() ext, err := NewSecurityExtension(map[string]interface{}{}, logger) if err != nil { t.Fatalf("Failed to create security extension: %v", err) } if ext.Name() != "security" { t.Errorf("Expected name 'security', got %s", ext.Name()) } if ext.Priority() != 10 { t.Errorf("Expected priority 10, got %d", ext.Priority()) } } func TestSecurityExtension_BlockedIP(t *testing.T) { logger := newTestLogger() ext, _ := NewSecurityExtension(map[string]interface{}{ "blocked_ips": []interface{}{"192.168.1.100"}, }, logger) req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.100:12345" rr := httptest.NewRecorder() handled, err := ext.ProcessRequest(context.Background(), rr, req) if err != nil { t.Errorf("Unexpected error: %v", err) } if !handled { t.Error("Expected blocked request to be handled") } if rr.Code != http.StatusForbidden { t.Errorf("Expected status 403, got %d", rr.Code) } } func TestSecurityExtension_AllowedIP(t *testing.T) { logger := newTestLogger() ext, _ := NewSecurityExtension(map[string]interface{}{ "allowed_ips": []interface{}{"192.168.1.50"}, }, logger) // Allowed IP req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.50:12345" rr := httptest.NewRecorder() handled, _ := ext.ProcessRequest(context.Background(), rr, req) if handled { t.Error("Expected allowed IP to pass through") } // Not allowed IP req = httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.51:12345" rr = httptest.NewRecorder() handled, _ = ext.ProcessRequest(context.Background(), rr, req) if !handled { t.Error("Expected non-allowed IP to be blocked") } if rr.Code != http.StatusForbidden { t.Errorf("Expected status 403, got %d", rr.Code) } } func TestSecurityExtension_CIDR(t *testing.T) { logger := newTestLogger() ext, _ := NewSecurityExtension(map[string]interface{}{ "blocked_ips": []interface{}{"10.0.0.0/8"}, }, logger) // IP in blocked CIDR req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "10.1.2.3:12345" rr := httptest.NewRecorder() handled, _ := ext.ProcessRequest(context.Background(), rr, req) if !handled { t.Error("Expected IP in blocked CIDR to be blocked") } // IP not in blocked CIDR req = httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.1:12345" rr = httptest.NewRecorder() handled, _ = ext.ProcessRequest(context.Background(), rr, req) if handled { t.Error("Expected IP not in blocked CIDR to pass through") } } func TestSecurityExtension_SecurityHeaders(t *testing.T) { logger := newTestLogger() ext, _ := NewSecurityExtension(map[string]interface{}{ "security_headers": map[string]interface{}{ "X-Custom-Header": "custom-value", }, }, logger) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() ext.ProcessResponse(context.Background(), rr, req) // Check default headers if rr.Header().Get("X-Content-Type-Options") != "nosniff" { t.Error("Expected X-Content-Type-Options header") } // Check custom header if rr.Header().Get("X-Custom-Header") != "custom-value" { t.Error("Expected custom header") } } func TestSecurityExtension_RateLimit(t *testing.T) { logger := newTestLogger() ext, _ := NewSecurityExtension(map[string]interface{}{ "rate_limit": map[string]interface{}{ "enabled": true, "requests": 2, "window": "1m", }, }, logger) securityExt := ext.(*SecurityExtension) clientIP := "192.168.1.1" // First request - should pass if !securityExt.checkRateLimit(clientIP) { t.Error("First request should pass") } // Second request - should pass if !securityExt.checkRateLimit(clientIP) { t.Error("Second request should pass") } // Third request - should be rate limited if securityExt.checkRateLimit(clientIP) { t.Error("Third request should be rate limited") } } func TestSecurityExtension_GetMetrics(t *testing.T) { logger := newTestLogger() ext, _ := NewSecurityExtension(map[string]interface{}{ "blocked_ips": []interface{}{"192.168.1.1"}, "allowed_ips": []interface{}{"192.168.1.2"}, }, logger) securityExt := ext.(*SecurityExtension) metrics := securityExt.GetMetrics() if metrics["blocked_ips"].(int) != 1 { t.Errorf("Expected 1 blocked IP, got %v", metrics["blocked_ips"]) } if metrics["allowed_ips"].(int) != 1 { t.Errorf("Expected 1 allowed IP, got %v", metrics["allowed_ips"]) } } func TestSecurityExtension_AddRemoveIPs(t *testing.T) { logger := newTestLogger() ext, _ := NewSecurityExtension(map[string]interface{}{}, logger) securityExt := ext.(*SecurityExtension) // Add blocked IP securityExt.AddBlockedIP("192.168.1.100") req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.100:12345" rr := httptest.NewRecorder() handled, _ := ext.ProcessRequest(context.Background(), rr, req) if !handled { t.Error("Expected dynamically blocked IP to be blocked") } // Remove blocked IP securityExt.RemoveBlockedIP("192.168.1.100") rr = httptest.NewRecorder() handled, _ = ext.ProcessRequest(context.Background(), rr, req) if handled { t.Error("Expected removed blocked IP to pass through") } }