konduktor/go/tests/integration/reverse_proxy_test.go
Илья Глазунов 881028c1e6 feat: Add reverse proxy functionality with enhanced routing capabilities
- Introduced IgnoreRequestPath option in proxy configuration to allow exact match routing.
- Implemented proxy_pass directive in routing extension to handle backend requests.
- Enhanced error handling for backend unavailability and timeouts.
- Added integration tests for reverse proxy, including basic requests, exact match routes, regex routes, header forwarding, and query string preservation.
- Created helper functions for setting up test servers and backends, along with assertion utilities for response validation.
- Updated server initialization to support extension management and middleware chaining.
- Improved logging for debugging purposes during request handling.
2025-12-12 00:38:30 +03:00

563 lines
14 KiB
Go

package integration
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/konduktor/konduktor/internal/extension"
"github.com/konduktor/konduktor/internal/logging"
)
// createTestLogger creates a logger for tests
func createTestLogger(t *testing.T) *logging.Logger {
t.Helper()
logger, err := logging.New(logging.Config{Level: "DEBUG"})
if err != nil {
t.Fatalf("Failed to create logger: %v", err)
}
return logger
}
// ============== Basic Reverse Proxy Tests ==============
func TestReverseProxy_BasicGET(t *testing.T) {
// Start backend server
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"message": "Hello from backend",
"path": r.URL.Path,
"method": r.Method,
})
})
defer backend.Close()
// Create routing extension with proxy to backend
logger := createTestLogger(t)
routingExt, err := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
if err != nil {
t.Fatalf("Failed to create routing extension: %v", err)
}
// Start Konduktor server
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
// Make request through Konduktor
client := NewHTTPClient(server.URL)
resp, err := client.Get("/api/test", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
// Verify response
AssertStatus(t, resp, http.StatusOK)
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if result["message"] != "Hello from backend" {
t.Errorf("Unexpected message: %v", result["message"])
}
if result["path"] != "/api/test" {
t.Errorf("Expected path /api/test, got %v", result["path"])
}
// Verify backend received request
if backend.RequestCount() != 1 {
t.Errorf("Expected 1 backend request, got %d", backend.RequestCount())
}
}
func TestReverseProxy_POST(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
var body map[string]interface{}
json.NewDecoder(r.Body).Decode(&body)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"received": body,
"method": r.Method,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
body := []byte(`{"name":"test","value":123}`)
resp, err := client.Post("/api/data", body, map[string]string{
"Content-Type": "application/json",
})
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
if result["method"] != "POST" {
t.Errorf("Expected method POST, got %v", result["method"])
}
received := result["received"].(map[string]interface{})
if received["name"] != "test" {
t.Errorf("Expected name 'test', got %v", received["name"])
}
}
// ============== Exact Match Routes ==============
func TestReverseProxy_ExactMatchRoute(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"endpoint": "version",
"path": r.URL.Path,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
// Exact match - should use backend URL as-is
"=/api/version": map[string]interface{}{
"proxy_pass": backend.URL() + "/releases/latest",
},
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Test exact match route
resp, err := client.Get("/api/version", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
lastReq := backend.LastRequest()
if lastReq == nil {
t.Fatal("No request received by backend")
}
// For exact match, the target path should be used as-is (IgnoreRequestPath=true)
if lastReq.Path != "/releases/latest" {
t.Errorf("Expected backend path /releases/latest, got %s", lastReq.Path)
}
}
// ============== Regex Routes with Parameters ==============
func TestReverseProxy_RegexRouteWithParams(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"path": r.URL.Path,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
// Regex with named group
"~^/api/users/(?P<id>\\d+)$": map[string]interface{}{
"proxy_pass": backend.URL() + "/v2/users/{id}",
},
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Test regex route with parameter
resp, err := client.Get("/api/users/42", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
lastReq := backend.LastRequest()
if lastReq == nil {
t.Fatal("No request received by backend")
}
// Parameter {id} should be substituted
if lastReq.Path != "/v2/users/42" {
t.Errorf("Expected backend path /v2/users/42, got %s", lastReq.Path)
}
}
// ============== Header Forwarding ==============
func TestReverseProxy_HeaderForwarding(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"x-forwarded-for": r.Header.Get("X-Forwarded-For"),
"x-real-ip": r.Header.Get("X-Real-IP"),
"x-custom": r.Header.Get("X-Custom"),
"x-forwarded-host": r.Header.Get("X-Forwarded-Host"),
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
"headers": []interface{}{
"X-Forwarded-For: $remote_addr",
"X-Real-IP: $remote_addr",
},
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
resp, err := client.Get("/test", map[string]string{
"X-Custom": "custom-value",
})
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
// X-Custom should be forwarded
if result["x-custom"] != "custom-value" {
t.Errorf("Expected X-Custom header to be forwarded, got %v", result["x-custom"])
}
// X-Forwarded-For should be set (will contain 127.0.0.1)
if result["x-forwarded-for"] == "" {
t.Error("Expected X-Forwarded-For header to be set")
}
}
// ============== Query String ==============
func TestReverseProxy_QueryStringPreservation(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"query": r.URL.RawQuery,
"foo": r.URL.Query().Get("foo"),
"bar": r.URL.Query().Get("bar"),
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
resp, err := client.Get("/search?foo=hello&bar=world", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
if result["foo"] != "hello" {
t.Errorf("Expected foo=hello, got %v", result["foo"])
}
if result["bar"] != "world" {
t.Errorf("Expected bar=world, got %v", result["bar"])
}
}
// ============== Error Handling ==============
func TestReverseProxy_BackendUnavailable(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
// Non-existent backend
"proxy_pass": "http://127.0.0.1:59999",
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
resp, err := client.Get("/test", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
// Should return 502 Bad Gateway
AssertStatus(t, resp, http.StatusBadGateway)
}
func TestReverseProxy_BackendTimeout(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
// Simulate slow backend
time.Sleep(3 * time.Second)
w.Write([]byte("OK"))
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
"timeout": 0.5, // 500ms timeout
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
resp, err := client.Get("/slow", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
// Should return 504 Gateway Timeout
AssertStatus(t, resp, http.StatusGatewayTimeout)
}
// ============== HTTP Methods ==============
func TestReverseProxy_AllMethods(t *testing.T) {
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"method": r.Method,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
for _, method := range methods {
t.Run(method, func(t *testing.T) {
resp, err := client.Do(method, "/resource", nil, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
if method != "HEAD" {
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
if result["method"] != method {
t.Errorf("Expected method %s, got %v", method, result["method"])
}
}
})
}
}
// ============== Large Bodies ==============
func TestReverseProxy_LargeRequestBody(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.NewEncoder(w).Encode(map[string]int{
"received": len(body),
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// 1MB body
largeBody := []byte(strings.Repeat("x", 1024*1024))
resp, err := client.Post("/upload", largeBody, map[string]string{
"Content-Type": "application/octet-stream",
})
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
}
// ============== Concurrent Requests ==============
func TestReverseProxy_ConcurrentRequests(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
// Small delay to simulate work
time.Sleep(10 * time.Millisecond)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
const numRequests = 50
results := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func(n int) {
client := NewHTTPClient(server.URL)
resp, err := client.Get(fmt.Sprintf("/concurrent/%d", n), nil)
if err != nil {
results <- err
return
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
results <- fmt.Errorf("unexpected status: %d", resp.StatusCode)
return
}
results <- nil
}(i)
}
// Collect results
var errors []error
for i := 0; i < numRequests; i++ {
if err := <-results; err != nil {
errors = append(errors, err)
}
}
if len(errors) > 0 {
t.Errorf("Got %d errors in concurrent requests: %v", len(errors), errors[:min(5, len(errors))])
}
// Verify all requests reached backend
if backend.RequestCount() != numRequests {
t.Errorf("Expected %d backend requests, got %d", numRequests, backend.RequestCount())
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}