package proxy import ( "bytes" "context" "encoding/json" "fmt" "io" "net" "net/http" "net/http/httptest" "strings" "sync" "sync/atomic" "testing" "time" ) // ============== Test Backend Server ============== type testBackend struct { server *httptest.Server requestLog []requestLogEntry mu sync.Mutex requestCount int64 } type requestLogEntry struct { Method string Path string Query string Headers http.Header Body string } func newTestBackend(handler http.HandlerFunc) *testBackend { tb := &testBackend{ requestLog: make([]requestLogEntry, 0), } if handler == nil { handler = tb.defaultHandler } tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tb.logRequest(r) handler(w, r) })) return tb } func (tb *testBackend) logRequest(r *http.Request) { tb.mu.Lock() defer tb.mu.Unlock() body, _ := io.ReadAll(r.Body) // Restore the body for the handler r.Body = io.NopCloser(bytes.NewReader(body)) tb.requestLog = append(tb.requestLog, requestLogEntry{ Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Headers: r.Header.Clone(), Body: string(body), }) atomic.AddInt64(&tb.requestCount, 1) } func (tb *testBackend) defaultHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "message": "Backend response", "path": r.URL.Path, "method": r.Method, }) } func (tb *testBackend) close() { tb.server.Close() } func (tb *testBackend) URL() string { return tb.server.URL } func (tb *testBackend) getRequestCount() int64 { return atomic.LoadInt64(&tb.requestCount) } func (tb *testBackend) getLastRequest() *requestLogEntry { tb.mu.Lock() defer tb.mu.Unlock() if len(tb.requestLog) == 0 { return nil } return &tb.requestLog[len(tb.requestLog)-1] } // ============== Proxy Creation Tests ============== func TestNew_ValidConfig(t *testing.T) { cfg := &Config{ Target: "http://localhost:8080", Timeout: 10 * time.Second, } proxy, err := New(cfg, nil) if err != nil { t.Fatalf("Failed to create proxy: %v", err) } if proxy == nil { t.Fatal("Expected proxy instance") } } func TestNew_EmptyTarget(t *testing.T) { cfg := &Config{ Target: "", } _, err := New(cfg, nil) if err == nil { t.Error("Expected error for empty target") } } func TestNew_InvalidTargetURL(t *testing.T) { cfg := &Config{ Target: "://invalid-url", } _, err := New(cfg, nil) if err == nil { t.Error("Expected error for invalid URL") } } func TestNew_DefaultTimeout(t *testing.T) { cfg := &Config{ Target: "http://localhost:8080", } proxy, err := New(cfg, nil) if err != nil { t.Fatalf("Failed to create proxy: %v", err) } if proxy.httpClient.Timeout != 30*time.Second { t.Errorf("Expected default timeout 30s, got %v", proxy.httpClient.Timeout) } } // ============== Basic Proxy Tests ============== func TestProxy_BasicGET(t *testing.T) { backend := newTestBackend(nil) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", rr.Code) } var response map[string]interface{} json.NewDecoder(rr.Body).Decode(&response) if response["path"] != "/test" { t.Errorf("Expected path /test, got %v", response["path"]) } } func TestProxy_BasicPOST(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "received": string(body), "method": r.Method, }) }) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) req := httptest.NewRequest("POST", "/api/data", strings.NewReader(`{"key":"value"}`)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", rr.Code) } var response map[string]interface{} json.NewDecoder(rr.Body).Decode(&response) if response["method"] != "POST" { t.Errorf("Expected method POST, got %v", response["method"]) } } func TestProxy_PUT(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"method": r.Method}) }) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) req := httptest.NewRequest("PUT", "/resource/123", strings.NewReader(`{"name":"updated"}`)) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) var response map[string]string json.NewDecoder(rr.Body).Decode(&response) if response["method"] != "PUT" { t.Errorf("Expected method PUT, got %v", response["method"]) } } func TestProxy_DELETE(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"method": r.Method}) }) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) req := httptest.NewRequest("DELETE", "/resource/123", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) var response map[string]string json.NewDecoder(rr.Body).Decode(&response) if response["method"] != "DELETE" { t.Errorf("Expected method DELETE, got %v", response["method"]) } } // ============== Header Tests ============== func TestProxy_HeadersForwarding(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]interface{}{ "custom_header": r.Header.Get("X-Custom-Header"), "forwarded_for": r.Header.Get("X-Forwarded-For"), "forwarded_host": r.Header.Get("X-Forwarded-Host"), }) }) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) req := httptest.NewRequest("GET", "/headers", nil) req.Header.Set("X-Custom-Header", "test-value") req.RemoteAddr = "192.168.1.100:12345" req.Host = "example.com" rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) var response map[string]interface{} json.NewDecoder(rr.Body).Decode(&response) if response["custom_header"] != "test-value" { t.Errorf("Expected custom header, got %v", response["custom_header"]) } if response["forwarded_for"] != "192.168.1.100" { t.Errorf("Expected X-Forwarded-For, got %v", response["forwarded_for"]) } if response["forwarded_host"] != "example.com" { t.Errorf("Expected X-Forwarded-Host, got %v", response["forwarded_host"]) } } func TestProxy_CustomHeaders(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{ "api_version": r.Header.Get("X-API-Version"), }) }) defer backend.close() proxy, _ := New(&Config{ Target: backend.URL(), Headers: map[string]string{ "X-API-Version": "{version}", }, }, nil) req := httptest.NewRequest("GET", "/api", nil) rr := httptest.NewRecorder() // Simulate parameter substitution proxy.ProxyRequest(rr, req, map[string]string{"version": "2"}) var response map[string]string json.NewDecoder(rr.Body).Decode(&response) if response["api_version"] != "2" { t.Errorf("Expected API version 2, got %v", response["api_version"]) } } func TestProxy_RemoteAddrSubstitution(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{ "client_ip": r.Header.Get("X-Client-IP"), }) }) defer backend.close() proxy, _ := New(&Config{ Target: backend.URL(), Headers: map[string]string{ "X-Client-IP": "$remote_addr", }, }, nil) req := httptest.NewRequest("GET", "/api", nil) req.RemoteAddr = "10.0.0.1:54321" rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) var response map[string]string json.NewDecoder(rr.Body).Decode(&response) if response["client_ip"] != "10.0.0.1" { t.Errorf("Expected client IP 10.0.0.1, got %v", response["client_ip"]) } } // ============== Query String Tests ============== func TestProxy_QueryStringPreservation(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{ "query": r.URL.RawQuery, "param": r.URL.Query().Get("key"), }) }) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) req := httptest.NewRequest("GET", "/search?key=value&page=2", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) var response map[string]string json.NewDecoder(rr.Body).Decode(&response) if response["param"] != "value" { t.Errorf("Expected query param 'value', got %v", response["param"]) } } // ============== Status Code Tests ============== func TestProxy_StatusCodePreservation(t *testing.T) { statusCodes := []int{200, 201, 400, 404, 500} for _, code := range statusCodes { code := code // capture range variable t.Run(fmt.Sprintf("Status_%d", code), func(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(code) w.Write([]byte("OK")) }) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) req := httptest.NewRequest("GET", "/status", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) if rr.Code != code { t.Errorf("Expected status %d, got %d", code, rr.Code) } }) } } // ============== Error Handling Tests ============== func TestProxy_BackendUnavailable(t *testing.T) { // Use a port that's definitely not listening proxy, _ := New(&Config{ Target: "http://127.0.0.1:59999", Timeout: 1 * time.Second, }, nil) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) if rr.Code != http.StatusBadGateway { t.Errorf("Expected status 502 Bad Gateway, got %d", rr.Code) } } func TestProxy_Timeout(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { time.Sleep(2 * time.Second) w.Write([]byte("OK")) }) defer backend.close() proxy, _ := New(&Config{ Target: backend.URL(), Timeout: 100 * time.Millisecond, }, nil) req := httptest.NewRequest("GET", "/slow", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) if rr.Code != http.StatusGatewayTimeout { t.Errorf("Expected status 504 Gateway Timeout, got %d", rr.Code) } } // ============== Path Handling Tests ============== func TestProxy_StripPrefix(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{ "path": r.URL.Path, }) }) defer backend.close() proxy, _ := New(&Config{ Target: backend.URL(), StripPrefix: "/api/v1", }, nil) req := httptest.NewRequest("GET", "/api/v1/users", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) var response map[string]string json.NewDecoder(rr.Body).Decode(&response) if response["path"] != "/users" { t.Errorf("Expected stripped path /users, got %v", response["path"]) } } func TestProxy_TargetWithPath(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{ "path": r.URL.Path, }) }) defer backend.close() proxy, _ := New(&Config{ Target: backend.URL() + "/backend", }, nil) req := httptest.NewRequest("GET", "/resource", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) var response map[string]string json.NewDecoder(rr.Body).Decode(&response) if response["path"] != "/backend/resource" { t.Errorf("Expected path /backend/resource, got %v", response["path"]) } } // ============== Large Body Tests ============== func TestProxy_LargeRequestBody(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) json.NewEncoder(w).Encode(map[string]int{ "received_bytes": len(body), }) }) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) // 100KB body largeBody := strings.Repeat("x", 100000) req := httptest.NewRequest("POST", "/upload", strings.NewReader(largeBody)) req.ContentLength = int64(len(largeBody)) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) var response map[string]int json.NewDecoder(rr.Body).Decode(&response) if response["received_bytes"] != 100000 { t.Errorf("Expected 100000 bytes, got %d", response["received_bytes"]) } } func TestProxy_LargeResponseBody(t *testing.T) { largeResponse := strings.Repeat("y", 100000) backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(largeResponse)) }) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) req := httptest.NewRequest("GET", "/large", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) if rr.Body.Len() != 100000 { t.Errorf("Expected 100000 bytes in response, got %d", rr.Body.Len()) } } // ============== Concurrent Requests Tests ============== func TestProxy_ConcurrentRequests(t *testing.T) { backend := newTestBackend(nil) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) const numRequests = 50 var wg sync.WaitGroup errors := make(chan error, numRequests) for i := 0; i < numRequests; i++ { wg.Add(1) go func(n int) { defer wg.Done() req := httptest.NewRequest("GET", "/concurrent", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) if rr.Code != http.StatusOK { errors <- &net.OpError{Op: "test", Err: context.DeadlineExceeded} } }(i) } wg.Wait() close(errors) errorCount := 0 for range errors { errorCount++ } if errorCount > 0 { t.Errorf("Got %d errors in concurrent requests", errorCount) } if backend.getRequestCount() != numRequests { t.Errorf("Expected %d requests at backend, got %d", numRequests, backend.getRequestCount()) } } // ============== Echo Tests ============== func TestProxy_Echo(t *testing.T) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) w.Header().Set("Content-Type", r.Header.Get("Content-Type")) w.Write(body) }) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) testData := "Hello, Proxy!" req := httptest.NewRequest("POST", "/echo", strings.NewReader(testData)) req.Header.Set("Content-Type", "text/plain") rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) if rr.Body.String() != testData { t.Errorf("Expected echo of '%s', got '%s'", testData, rr.Body.String()) } if rr.Header().Get("Content-Type") != "text/plain" { t.Errorf("Expected Content-Type text/plain, got %s", rr.Header().Get("Content-Type")) } } // ============== Helper Function Tests ============== func TestSingleJoiningSlash(t *testing.T) { tests := []struct { a, b, expected string }{ {"/api", "/users", "/api/users"}, {"/api/", "/users", "/api/users"}, {"/api", "users", "/api/users"}, {"/api/", "users", "/api/users"}, {"", "/users", "/users"}, {"/api", "", "/api/"}, } for _, tt := range tests { result := singleJoiningSlash(tt.a, tt.b) if result != tt.expected { t.Errorf("singleJoiningSlash(%q, %q) = %q, want %q", tt.a, tt.b, result, tt.expected) } } } func TestGetClientIP(t *testing.T) { tests := []struct { remoteAddr string xRealIP string expected string }{ {"192.168.1.1:1234", "", "192.168.1.1"}, {"192.168.1.1:1234", "10.0.0.1", "10.0.0.1"}, {"invalid", "", "invalid"}, } for _, tt := range tests { req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = tt.remoteAddr if tt.xRealIP != "" { req.Header.Set("X-Real-IP", tt.xRealIP) } result := getClientIP(req) if result != tt.expected { t.Errorf("getClientIP() = %q, want %q", result, tt.expected) } } } func TestGetScheme(t *testing.T) { tests := []struct { name string tls bool header string expected string }{ {"HTTP", false, "", "http"}, {"HTTPS from TLS", true, "", "https"}, {"HTTPS from header", false, "https", "https"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) if tt.header != "" { req.Header.Set("X-Forwarded-Proto", tt.header) } // Note: httptest doesn't set TLS, so we can only test non-TLS cases fully result := getScheme(req) if !tt.tls && result != tt.expected { t.Errorf("getScheme() = %q, want %q", result, tt.expected) } }) } } func TestIsConnectionError(t *testing.T) { tests := []struct { err error expected bool }{ {nil, false}, {&net.OpError{Op: "dial", Err: &net.DNSError{Err: "no such host"}}, true}, {context.DeadlineExceeded, false}, } for _, tt := range tests { result := isConnectionError(tt.err) if result != tt.expected { t.Errorf("isConnectionError(%v) = %v, want %v", tt.err, result, tt.expected) } } } // ============== Benchmarks ============== func BenchmarkProxy_SimpleGET(b *testing.B) { backend := newTestBackend(nil) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) b.ResetTimer() for i := 0; i < b.N; i++ { req := httptest.NewRequest("GET", "/bench", nil) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) } } func BenchmarkProxy_POSTWithBody(b *testing.B) { backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { io.Copy(io.Discard, r.Body) w.WriteHeader(http.StatusOK) }) defer backend.close() proxy, _ := New(&Config{Target: backend.URL()}, nil) body := strings.Repeat("x", 1024) // 1KB body b.ResetTimer() for i := 0; i < b.N; i++ { req := httptest.NewRequest("POST", "/bench", strings.NewReader(body)) rr := httptest.NewRecorder() proxy.ServeHTTP(rr, req) } }