forked from aegis/pyserveX
748 lines
18 KiB
Go
748 lines
18 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// ============== Test Backend Server ==============
|
|
|
|
type testBackend struct {
|
|
server *httptest.Server
|
|
requestLog []requestLogEntry
|
|
mu sync.Mutex
|
|
requestCount int64
|
|
}
|
|
|
|
type requestLogEntry struct {
|
|
Method string
|
|
Path string
|
|
Query string
|
|
Headers http.Header
|
|
Body string
|
|
}
|
|
|
|
func newTestBackend(handler http.HandlerFunc) *testBackend {
|
|
tb := &testBackend{
|
|
requestLog: make([]requestLogEntry, 0),
|
|
}
|
|
|
|
if handler == nil {
|
|
handler = tb.defaultHandler
|
|
}
|
|
|
|
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tb.logRequest(r)
|
|
handler(w, r)
|
|
}))
|
|
|
|
return tb
|
|
}
|
|
|
|
func (tb *testBackend) logRequest(r *http.Request) {
|
|
tb.mu.Lock()
|
|
defer tb.mu.Unlock()
|
|
|
|
body, _ := io.ReadAll(r.Body)
|
|
// Restore the body for the handler
|
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
|
|
|
tb.requestLog = append(tb.requestLog, requestLogEntry{
|
|
Method: r.Method,
|
|
Path: r.URL.Path,
|
|
Query: r.URL.RawQuery,
|
|
Headers: r.Header.Clone(),
|
|
Body: string(body),
|
|
})
|
|
atomic.AddInt64(&tb.requestCount, 1)
|
|
}
|
|
|
|
func (tb *testBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"message": "Backend response",
|
|
"path": r.URL.Path,
|
|
"method": r.Method,
|
|
})
|
|
}
|
|
|
|
func (tb *testBackend) close() {
|
|
tb.server.Close()
|
|
}
|
|
|
|
func (tb *testBackend) URL() string {
|
|
return tb.server.URL
|
|
}
|
|
|
|
func (tb *testBackend) getRequestCount() int64 {
|
|
return atomic.LoadInt64(&tb.requestCount)
|
|
}
|
|
|
|
func (tb *testBackend) getLastRequest() *requestLogEntry {
|
|
tb.mu.Lock()
|
|
defer tb.mu.Unlock()
|
|
if len(tb.requestLog) == 0 {
|
|
return nil
|
|
}
|
|
return &tb.requestLog[len(tb.requestLog)-1]
|
|
}
|
|
|
|
// ============== Proxy Creation Tests ==============
|
|
|
|
func TestNew_ValidConfig(t *testing.T) {
|
|
cfg := &Config{
|
|
Target: "http://localhost:8080",
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
|
|
proxy, err := New(cfg, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create proxy: %v", err)
|
|
}
|
|
|
|
if proxy == nil {
|
|
t.Fatal("Expected proxy instance")
|
|
}
|
|
}
|
|
|
|
func TestNew_EmptyTarget(t *testing.T) {
|
|
cfg := &Config{
|
|
Target: "",
|
|
}
|
|
|
|
_, err := New(cfg, nil)
|
|
if err == nil {
|
|
t.Error("Expected error for empty target")
|
|
}
|
|
}
|
|
|
|
func TestNew_InvalidTargetURL(t *testing.T) {
|
|
cfg := &Config{
|
|
Target: "://invalid-url",
|
|
}
|
|
|
|
_, err := New(cfg, nil)
|
|
if err == nil {
|
|
t.Error("Expected error for invalid URL")
|
|
}
|
|
}
|
|
|
|
func TestNew_DefaultTimeout(t *testing.T) {
|
|
cfg := &Config{
|
|
Target: "http://localhost:8080",
|
|
}
|
|
|
|
proxy, err := New(cfg, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create proxy: %v", err)
|
|
}
|
|
|
|
if proxy.httpClient.Timeout != 30*time.Second {
|
|
t.Errorf("Expected default timeout 30s, got %v", proxy.httpClient.Timeout)
|
|
}
|
|
}
|
|
|
|
// ============== Basic Proxy Tests ==============
|
|
|
|
func TestProxy_BasicGET(t *testing.T) {
|
|
backend := newTestBackend(nil)
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
|
|
var response map[string]interface{}
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["path"] != "/test" {
|
|
t.Errorf("Expected path /test, got %v", response["path"])
|
|
}
|
|
}
|
|
|
|
func TestProxy_BasicPOST(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
body, _ := io.ReadAll(r.Body)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"received": string(body),
|
|
"method": r.Method,
|
|
})
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
req := httptest.NewRequest("POST", "/api/data", strings.NewReader(`{"key":"value"}`))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
|
|
var response map[string]interface{}
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["method"] != "POST" {
|
|
t.Errorf("Expected method POST, got %v", response["method"])
|
|
}
|
|
}
|
|
|
|
func TestProxy_PUT(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
req := httptest.NewRequest("PUT", "/resource/123", strings.NewReader(`{"name":"updated"}`))
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
var response map[string]string
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["method"] != "PUT" {
|
|
t.Errorf("Expected method PUT, got %v", response["method"])
|
|
}
|
|
}
|
|
|
|
func TestProxy_DELETE(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
req := httptest.NewRequest("DELETE", "/resource/123", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
var response map[string]string
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["method"] != "DELETE" {
|
|
t.Errorf("Expected method DELETE, got %v", response["method"])
|
|
}
|
|
}
|
|
|
|
// ============== Header Tests ==============
|
|
|
|
func TestProxy_HeadersForwarding(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"custom_header": r.Header.Get("X-Custom-Header"),
|
|
"forwarded_for": r.Header.Get("X-Forwarded-For"),
|
|
"forwarded_host": r.Header.Get("X-Forwarded-Host"),
|
|
})
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/headers", nil)
|
|
req.Header.Set("X-Custom-Header", "test-value")
|
|
req.RemoteAddr = "192.168.1.100:12345"
|
|
req.Host = "example.com"
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
var response map[string]interface{}
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["custom_header"] != "test-value" {
|
|
t.Errorf("Expected custom header, got %v", response["custom_header"])
|
|
}
|
|
|
|
if response["forwarded_for"] != "192.168.1.100" {
|
|
t.Errorf("Expected X-Forwarded-For, got %v", response["forwarded_for"])
|
|
}
|
|
|
|
if response["forwarded_host"] != "example.com" {
|
|
t.Errorf("Expected X-Forwarded-Host, got %v", response["forwarded_host"])
|
|
}
|
|
}
|
|
|
|
func TestProxy_CustomHeaders(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"api_version": r.Header.Get("X-API-Version"),
|
|
})
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{
|
|
Target: backend.URL(),
|
|
Headers: map[string]string{
|
|
"X-API-Version": "{version}",
|
|
},
|
|
}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/api", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Simulate parameter substitution
|
|
proxy.ProxyRequest(rr, req, map[string]string{"version": "2"})
|
|
|
|
var response map[string]string
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["api_version"] != "2" {
|
|
t.Errorf("Expected API version 2, got %v", response["api_version"])
|
|
}
|
|
}
|
|
|
|
func TestProxy_RemoteAddrSubstitution(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"client_ip": r.Header.Get("X-Client-IP"),
|
|
})
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{
|
|
Target: backend.URL(),
|
|
Headers: map[string]string{
|
|
"X-Client-IP": "$remote_addr",
|
|
},
|
|
}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/api", nil)
|
|
req.RemoteAddr = "10.0.0.1:54321"
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
var response map[string]string
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["client_ip"] != "10.0.0.1" {
|
|
t.Errorf("Expected client IP 10.0.0.1, got %v", response["client_ip"])
|
|
}
|
|
}
|
|
|
|
// ============== Query String Tests ==============
|
|
|
|
func TestProxy_QueryStringPreservation(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"query": r.URL.RawQuery,
|
|
"param": r.URL.Query().Get("key"),
|
|
})
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/search?key=value&page=2", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
var response map[string]string
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["param"] != "value" {
|
|
t.Errorf("Expected query param 'value', got %v", response["param"])
|
|
}
|
|
}
|
|
|
|
// ============== Status Code Tests ==============
|
|
|
|
func TestProxy_StatusCodePreservation(t *testing.T) {
|
|
statusCodes := []int{200, 201, 400, 404, 500}
|
|
|
|
for _, code := range statusCodes {
|
|
code := code // capture range variable
|
|
t.Run(fmt.Sprintf("Status_%d", code), func(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(code)
|
|
w.Write([]byte("OK"))
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/status", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != code {
|
|
t.Errorf("Expected status %d, got %d", code, rr.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ============== Error Handling Tests ==============
|
|
|
|
func TestProxy_BackendUnavailable(t *testing.T) {
|
|
// Use a port that's definitely not listening
|
|
proxy, _ := New(&Config{
|
|
Target: "http://127.0.0.1:59999",
|
|
Timeout: 1 * time.Second,
|
|
}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusBadGateway {
|
|
t.Errorf("Expected status 502 Bad Gateway, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestProxy_Timeout(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
time.Sleep(2 * time.Second)
|
|
w.Write([]byte("OK"))
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{
|
|
Target: backend.URL(),
|
|
Timeout: 100 * time.Millisecond,
|
|
}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/slow", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusGatewayTimeout {
|
|
t.Errorf("Expected status 504 Gateway Timeout, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
// ============== Path Handling Tests ==============
|
|
|
|
func TestProxy_StripPrefix(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"path": r.URL.Path,
|
|
})
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{
|
|
Target: backend.URL(),
|
|
StripPrefix: "/api/v1",
|
|
}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/api/v1/users", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
var response map[string]string
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["path"] != "/users" {
|
|
t.Errorf("Expected stripped path /users, got %v", response["path"])
|
|
}
|
|
}
|
|
|
|
func TestProxy_TargetWithPath(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"path": r.URL.Path,
|
|
})
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{
|
|
Target: backend.URL() + "/backend",
|
|
}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/resource", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
var response map[string]string
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["path"] != "/backend/resource" {
|
|
t.Errorf("Expected path /backend/resource, got %v", response["path"])
|
|
}
|
|
}
|
|
|
|
// ============== Large Body Tests ==============
|
|
|
|
func TestProxy_LargeRequestBody(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
body, _ := io.ReadAll(r.Body)
|
|
json.NewEncoder(w).Encode(map[string]int{
|
|
"received_bytes": len(body),
|
|
})
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
// 100KB body
|
|
largeBody := strings.Repeat("x", 100000)
|
|
req := httptest.NewRequest("POST", "/upload", strings.NewReader(largeBody))
|
|
req.ContentLength = int64(len(largeBody))
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
var response map[string]int
|
|
json.NewDecoder(rr.Body).Decode(&response)
|
|
|
|
if response["received_bytes"] != 100000 {
|
|
t.Errorf("Expected 100000 bytes, got %d", response["received_bytes"])
|
|
}
|
|
}
|
|
|
|
func TestProxy_LargeResponseBody(t *testing.T) {
|
|
largeResponse := strings.Repeat("y", 100000)
|
|
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte(largeResponse))
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/large", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
if rr.Body.Len() != 100000 {
|
|
t.Errorf("Expected 100000 bytes in response, got %d", rr.Body.Len())
|
|
}
|
|
}
|
|
|
|
// ============== Concurrent Requests Tests ==============
|
|
|
|
func TestProxy_ConcurrentRequests(t *testing.T) {
|
|
backend := newTestBackend(nil)
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
const numRequests = 50
|
|
var wg sync.WaitGroup
|
|
errors := make(chan error, numRequests)
|
|
|
|
for i := 0; i < numRequests; i++ {
|
|
wg.Add(1)
|
|
go func(n int) {
|
|
defer wg.Done()
|
|
|
|
req := httptest.NewRequest("GET", "/concurrent", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
errors <- &net.OpError{Op: "test", Err: context.DeadlineExceeded}
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errors)
|
|
|
|
errorCount := 0
|
|
for range errors {
|
|
errorCount++
|
|
}
|
|
|
|
if errorCount > 0 {
|
|
t.Errorf("Got %d errors in concurrent requests", errorCount)
|
|
}
|
|
|
|
if backend.getRequestCount() != numRequests {
|
|
t.Errorf("Expected %d requests at backend, got %d", numRequests, backend.getRequestCount())
|
|
}
|
|
}
|
|
|
|
// ============== Echo Tests ==============
|
|
|
|
func TestProxy_Echo(t *testing.T) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
body, _ := io.ReadAll(r.Body)
|
|
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
|
|
w.Write(body)
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
testData := "Hello, Proxy!"
|
|
req := httptest.NewRequest("POST", "/echo", strings.NewReader(testData))
|
|
req.Header.Set("Content-Type", "text/plain")
|
|
rr := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rr, req)
|
|
|
|
if rr.Body.String() != testData {
|
|
t.Errorf("Expected echo of '%s', got '%s'", testData, rr.Body.String())
|
|
}
|
|
|
|
if rr.Header().Get("Content-Type") != "text/plain" {
|
|
t.Errorf("Expected Content-Type text/plain, got %s", rr.Header().Get("Content-Type"))
|
|
}
|
|
}
|
|
|
|
// ============== Helper Function Tests ==============
|
|
|
|
func TestSingleJoiningSlash(t *testing.T) {
|
|
tests := []struct {
|
|
a, b, expected string
|
|
}{
|
|
{"/api", "/users", "/api/users"},
|
|
{"/api/", "/users", "/api/users"},
|
|
{"/api", "users", "/api/users"},
|
|
{"/api/", "users", "/api/users"},
|
|
{"", "/users", "/users"},
|
|
{"/api", "", "/api/"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := singleJoiningSlash(tt.a, tt.b)
|
|
if result != tt.expected {
|
|
t.Errorf("singleJoiningSlash(%q, %q) = %q, want %q", tt.a, tt.b, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetClientIP(t *testing.T) {
|
|
tests := []struct {
|
|
remoteAddr string
|
|
xRealIP string
|
|
expected string
|
|
}{
|
|
{"192.168.1.1:1234", "", "192.168.1.1"},
|
|
{"192.168.1.1:1234", "10.0.0.1", "10.0.0.1"},
|
|
{"invalid", "", "invalid"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = tt.remoteAddr
|
|
if tt.xRealIP != "" {
|
|
req.Header.Set("X-Real-IP", tt.xRealIP)
|
|
}
|
|
|
|
result := getClientIP(req)
|
|
if result != tt.expected {
|
|
t.Errorf("getClientIP() = %q, want %q", result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetScheme(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tls bool
|
|
header string
|
|
expected string
|
|
}{
|
|
{"HTTP", false, "", "http"},
|
|
{"HTTPS from TLS", true, "", "https"},
|
|
{"HTTPS from header", false, "https", "https"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
if tt.header != "" {
|
|
req.Header.Set("X-Forwarded-Proto", tt.header)
|
|
}
|
|
// Note: httptest doesn't set TLS, so we can only test non-TLS cases fully
|
|
|
|
result := getScheme(req)
|
|
if !tt.tls && result != tt.expected {
|
|
t.Errorf("getScheme() = %q, want %q", result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsConnectionError(t *testing.T) {
|
|
tests := []struct {
|
|
err error
|
|
expected bool
|
|
}{
|
|
{nil, false},
|
|
{&net.OpError{Op: "dial", Err: &net.DNSError{Err: "no such host"}}, true},
|
|
{context.DeadlineExceeded, false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := isConnectionError(tt.err)
|
|
if result != tt.expected {
|
|
t.Errorf("isConnectionError(%v) = %v, want %v", tt.err, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============== Benchmarks ==============
|
|
|
|
func BenchmarkProxy_SimpleGET(b *testing.B) {
|
|
backend := newTestBackend(nil)
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
req := httptest.NewRequest("GET", "/bench", nil)
|
|
rr := httptest.NewRecorder()
|
|
proxy.ServeHTTP(rr, req)
|
|
}
|
|
}
|
|
|
|
func BenchmarkProxy_POSTWithBody(b *testing.B) {
|
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
io.Copy(io.Discard, r.Body)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
defer backend.close()
|
|
|
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
body := strings.Repeat("x", 1024) // 1KB body
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
req := httptest.NewRequest("POST", "/bench", strings.NewReader(body))
|
|
rr := httptest.NewRecorder()
|
|
proxy.ServeHTTP(rr, req)
|
|
}
|
|
}
|