konduktor/go/tests/integration/helpers_test.go
Илья Глазунов 881028c1e6 feat: Add reverse proxy functionality with enhanced routing capabilities
- Introduced IgnoreRequestPath option in proxy configuration to allow exact match routing.
- Implemented proxy_pass directive in routing extension to handle backend requests.
- Enhanced error handling for backend unavailability and timeouts.
- Added integration tests for reverse proxy, including basic requests, exact match routes, regex routes, header forwarding, and query string preservation.
- Created helper functions for setting up test servers and backends, along with assertion utilities for response validation.
- Updated server initialization to support extension management and middleware chaining.
- Improved logging for debugging purposes during request handling.
2025-12-12 00:38:30 +03:00

409 lines
9.9 KiB
Go

package integration
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/konduktor/konduktor/internal/extension"
"github.com/konduktor/konduktor/internal/logging"
"github.com/konduktor/konduktor/internal/middleware"
)
// TestServer represents a running Konduktor server for testing
type TestServer struct {
Server *http.Server
URL string
Port int
listener net.Listener
handler http.Handler
t *testing.T
}
// TestBackend represents a mock backend server
type TestBackend struct {
server *httptest.Server
requestLog []RequestLogEntry
mu sync.Mutex
requestCount int64
handler http.HandlerFunc
}
// RequestLogEntry stores information about a received request
type RequestLogEntry struct {
Method string
Path string
Query string
Headers http.Header
Body string
Timestamp time.Time
}
// ServerConfig holds configuration for starting a test server
type ServerConfig struct {
Extensions []extension.Extension
StaticDir string
Middleware []func(http.Handler) http.Handler
}
// ============== Test Server ==============
// StartTestServer creates and starts a Konduktor server for testing
func StartTestServer(t *testing.T, cfg *ServerConfig) *TestServer {
t.Helper()
logger, err := logging.New(logging.Config{
Level: "DEBUG",
})
if err != nil {
t.Fatalf("Failed to create logger: %v", err)
}
// Create extension manager
extManager := extension.NewManager(logger)
// Add extensions if provided
if cfg != nil && len(cfg.Extensions) > 0 {
for _, ext := range cfg.Extensions {
if err := extManager.AddExtension(ext); err != nil {
t.Fatalf("Failed to add extension: %v", err)
}
}
}
// Create a fallback handler for when no extension handles the request
fallback := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
})
// Create handler chain
var handler http.Handler = extManager.Handler(fallback)
// Add middleware
handler = middleware.AccessLog(handler, logger)
handler = middleware.Recovery(handler, logger)
// Add custom middleware if provided
if cfg != nil {
for _, mw := range cfg.Middleware {
handler = mw(handler)
}
}
// Find available port
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to find available port: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
server := &http.Server{
Handler: handler,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
ts := &TestServer{
Server: server,
URL: fmt.Sprintf("http://127.0.0.1:%d", port),
Port: port,
listener: listener,
handler: handler,
t: t,
}
// Start server in goroutine
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
// Don't fail test here as server might be intentionally closed
}
}()
// Wait for server to be ready
ts.waitReady()
return ts
}
// waitReady waits for the server to be ready to accept connections
func (ts *TestServer) waitReady() {
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", ts.Port), 100*time.Millisecond)
if err == nil {
conn.Close()
return
}
time.Sleep(10 * time.Millisecond)
}
ts.t.Fatal("Server failed to start within timeout")
}
// Close shuts down the test server
func (ts *TestServer) Close() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ts.Server.Shutdown(ctx)
}
// ============== Test Backend ==============
// StartBackend creates and starts a mock backend server
func StartBackend(handler http.HandlerFunc) *TestBackend {
tb := &TestBackend{
requestLog: make([]RequestLogEntry, 0),
handler: handler,
}
if handler == nil {
handler = tb.defaultHandler
}
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tb.logRequest(r)
handler(w, r)
}))
return tb
}
func (tb *TestBackend) logRequest(r *http.Request) {
tb.mu.Lock()
defer tb.mu.Unlock()
body, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewReader(body))
tb.requestLog = append(tb.requestLog, RequestLogEntry{
Method: r.Method,
Path: r.URL.Path,
Query: r.URL.RawQuery,
Headers: r.Header.Clone(),
Body: string(body),
Timestamp: time.Now(),
})
atomic.AddInt64(&tb.requestCount, 1)
}
func (tb *TestBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"backend": "default",
"path": r.URL.Path,
"method": r.Method,
"query": r.URL.RawQuery,
"received": time.Now().Unix(),
})
}
// URL returns the backend server URL
func (tb *TestBackend) URL() string {
return tb.server.URL
}
// Close shuts down the backend server
func (tb *TestBackend) Close() {
tb.server.Close()
}
// RequestCount returns the number of requests received
func (tb *TestBackend) RequestCount() int64 {
return atomic.LoadInt64(&tb.requestCount)
}
// LastRequest returns the most recent request
func (tb *TestBackend) LastRequest() *RequestLogEntry {
tb.mu.Lock()
defer tb.mu.Unlock()
if len(tb.requestLog) == 0 {
return nil
}
return &tb.requestLog[len(tb.requestLog)-1]
}
// AllRequests returns all logged requests
func (tb *TestBackend) AllRequests() []RequestLogEntry {
tb.mu.Lock()
defer tb.mu.Unlock()
result := make([]RequestLogEntry, len(tb.requestLog))
copy(result, tb.requestLog)
return result
}
// ============== HTTP Client Helpers ==============
// HTTPClient is a configured HTTP client for testing
type HTTPClient struct {
client *http.Client
baseURL string
}
// NewHTTPClient creates a new test HTTP client
func NewHTTPClient(baseURL string) *HTTPClient {
return &HTTPClient{
client: &http.Client{
Timeout: 10 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Don't follow redirects
},
},
baseURL: baseURL,
}
}
// Get performs a GET request
func (c *HTTPClient) Get(path string, headers map[string]string) (*http.Response, error) {
return c.Do("GET", path, nil, headers)
}
// Post performs a POST request
func (c *HTTPClient) Post(path string, body []byte, headers map[string]string) (*http.Response, error) {
return c.Do("POST", path, body, headers)
}
// Do performs an HTTP request
func (c *HTTPClient) Do(method, path string, body []byte, headers map[string]string) (*http.Response, error) {
var bodyReader io.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequest(method, c.baseURL+path, bodyReader)
if err != nil {
return nil, err
}
for k, v := range headers {
req.Header.Set(k, v)
}
return c.client.Do(req)
}
// GetJSON performs GET and decodes JSON response
func (c *HTTPClient) GetJSON(path string, result interface{}) (*http.Response, error) {
resp, err := c.Get(path, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
return resp, err
}
return resp, nil
}
// ============== File System Helpers ==============
// CreateTempDir creates a temporary directory for static files
func CreateTempDir(t *testing.T) string {
t.Helper()
dir, err := os.MkdirTemp("", "konduktor-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
t.Cleanup(func() { os.RemoveAll(dir) })
return dir
}
// CreateTempFile creates a temporary file with given content
func CreateTempFile(t *testing.T, dir, name, content string) string {
t.Helper()
path := filepath.Join(dir, name)
// Create parent directories if needed
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
t.Fatalf("Failed to create directories: %v", err)
}
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
t.Fatalf("Failed to write file: %v", err)
}
return path
}
// ============== Assertion Helpers ==============
// AssertStatus checks if response has expected status code
func AssertStatus(t *testing.T, resp *http.Response, expected int) {
t.Helper()
if resp.StatusCode != expected {
body, _ := io.ReadAll(resp.Body)
t.Errorf("Expected status %d, got %d. Body: %s", expected, resp.StatusCode, string(body))
}
}
// AssertHeader checks if response has expected header value
func AssertHeader(t *testing.T, resp *http.Response, header, expected string) {
t.Helper()
actual := resp.Header.Get(header)
if actual != expected {
t.Errorf("Expected header %s=%q, got %q", header, expected, actual)
}
}
// AssertHeaderContains checks if header contains substring
func AssertHeaderContains(t *testing.T, resp *http.Response, header, substring string) {
t.Helper()
actual := resp.Header.Get(header)
if actual == "" || !contains(actual, substring) {
t.Errorf("Expected header %s to contain %q, got %q", header, substring, actual)
}
}
// AssertJSONField checks if JSON response has expected field value
func AssertJSONField(t *testing.T, body []byte, field string, expected interface{}) {
t.Helper()
var data map[string]interface{}
if err := json.Unmarshal(body, &data); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
actual, ok := data[field]
if !ok {
t.Errorf("Field %q not found in JSON", field)
return
}
if actual != expected {
t.Errorf("Expected %s=%v, got %v", field, expected, actual)
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0))
}
func containsAt(s, substr string, start int) bool {
for i := start; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// ReadBody reads and returns response body
func ReadBody(t *testing.T, resp *http.Response) []byte {
t.Helper()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read body: %v", err)
}
return body
}