forked from aegis/pyserveX
- Introduced IgnoreRequestPath option in proxy configuration to allow exact match routing. - Implemented proxy_pass directive in routing extension to handle backend requests. - Enhanced error handling for backend unavailability and timeouts. - Added integration tests for reverse proxy, including basic requests, exact match routes, regex routes, header forwarding, and query string preservation. - Created helper functions for setting up test servers and backends, along with assertion utilities for response validation. - Updated server initialization to support extension management and middleware chaining. - Improved logging for debugging purposes during request handling.
409 lines
9.9 KiB
Go
409 lines
9.9 KiB
Go
package integration
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/konduktor/konduktor/internal/extension"
|
|
"github.com/konduktor/konduktor/internal/logging"
|
|
"github.com/konduktor/konduktor/internal/middleware"
|
|
)
|
|
|
|
// TestServer represents a running Konduktor server for testing
|
|
type TestServer struct {
|
|
Server *http.Server
|
|
URL string
|
|
Port int
|
|
listener net.Listener
|
|
handler http.Handler
|
|
t *testing.T
|
|
}
|
|
|
|
// TestBackend represents a mock backend server
|
|
type TestBackend struct {
|
|
server *httptest.Server
|
|
requestLog []RequestLogEntry
|
|
mu sync.Mutex
|
|
requestCount int64
|
|
handler http.HandlerFunc
|
|
}
|
|
|
|
// RequestLogEntry stores information about a received request
|
|
type RequestLogEntry struct {
|
|
Method string
|
|
Path string
|
|
Query string
|
|
Headers http.Header
|
|
Body string
|
|
Timestamp time.Time
|
|
}
|
|
|
|
// ServerConfig holds configuration for starting a test server
|
|
type ServerConfig struct {
|
|
Extensions []extension.Extension
|
|
StaticDir string
|
|
Middleware []func(http.Handler) http.Handler
|
|
}
|
|
|
|
// ============== Test Server ==============
|
|
|
|
// StartTestServer creates and starts a Konduktor server for testing
|
|
func StartTestServer(t *testing.T, cfg *ServerConfig) *TestServer {
|
|
t.Helper()
|
|
|
|
logger, err := logging.New(logging.Config{
|
|
Level: "DEBUG",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create logger: %v", err)
|
|
}
|
|
|
|
// Create extension manager
|
|
extManager := extension.NewManager(logger)
|
|
|
|
// Add extensions if provided
|
|
if cfg != nil && len(cfg.Extensions) > 0 {
|
|
for _, ext := range cfg.Extensions {
|
|
if err := extManager.AddExtension(ext); err != nil {
|
|
t.Fatalf("Failed to add extension: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create a fallback handler for when no extension handles the request
|
|
fallback := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.NotFound(w, r)
|
|
})
|
|
|
|
// Create handler chain
|
|
var handler http.Handler = extManager.Handler(fallback)
|
|
|
|
// Add middleware
|
|
handler = middleware.AccessLog(handler, logger)
|
|
handler = middleware.Recovery(handler, logger)
|
|
|
|
// Add custom middleware if provided
|
|
if cfg != nil {
|
|
for _, mw := range cfg.Middleware {
|
|
handler = mw(handler)
|
|
}
|
|
}
|
|
|
|
// Find available port
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("Failed to find available port: %v", err)
|
|
}
|
|
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
|
|
server := &http.Server{
|
|
Handler: handler,
|
|
ReadTimeout: 10 * time.Second,
|
|
WriteTimeout: 10 * time.Second,
|
|
}
|
|
|
|
ts := &TestServer{
|
|
Server: server,
|
|
URL: fmt.Sprintf("http://127.0.0.1:%d", port),
|
|
Port: port,
|
|
listener: listener,
|
|
handler: handler,
|
|
t: t,
|
|
}
|
|
|
|
// Start server in goroutine
|
|
go func() {
|
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
|
// Don't fail test here as server might be intentionally closed
|
|
}
|
|
}()
|
|
|
|
// Wait for server to be ready
|
|
ts.waitReady()
|
|
|
|
return ts
|
|
}
|
|
|
|
// waitReady waits for the server to be ready to accept connections
|
|
func (ts *TestServer) waitReady() {
|
|
deadline := time.Now().Add(5 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", ts.Port), 100*time.Millisecond)
|
|
if err == nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
ts.t.Fatal("Server failed to start within timeout")
|
|
}
|
|
|
|
// Close shuts down the test server
|
|
func (ts *TestServer) Close() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
ts.Server.Shutdown(ctx)
|
|
}
|
|
|
|
// ============== Test Backend ==============
|
|
|
|
// StartBackend creates and starts a mock backend server
|
|
func StartBackend(handler http.HandlerFunc) *TestBackend {
|
|
tb := &TestBackend{
|
|
requestLog: make([]RequestLogEntry, 0),
|
|
handler: handler,
|
|
}
|
|
|
|
if handler == nil {
|
|
handler = tb.defaultHandler
|
|
}
|
|
|
|
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tb.logRequest(r)
|
|
handler(w, r)
|
|
}))
|
|
|
|
return tb
|
|
}
|
|
|
|
func (tb *TestBackend) logRequest(r *http.Request) {
|
|
tb.mu.Lock()
|
|
defer tb.mu.Unlock()
|
|
|
|
body, _ := io.ReadAll(r.Body)
|
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
|
|
|
tb.requestLog = append(tb.requestLog, RequestLogEntry{
|
|
Method: r.Method,
|
|
Path: r.URL.Path,
|
|
Query: r.URL.RawQuery,
|
|
Headers: r.Header.Clone(),
|
|
Body: string(body),
|
|
Timestamp: time.Now(),
|
|
})
|
|
atomic.AddInt64(&tb.requestCount, 1)
|
|
}
|
|
|
|
func (tb *TestBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"backend": "default",
|
|
"path": r.URL.Path,
|
|
"method": r.Method,
|
|
"query": r.URL.RawQuery,
|
|
"received": time.Now().Unix(),
|
|
})
|
|
}
|
|
|
|
// URL returns the backend server URL
|
|
func (tb *TestBackend) URL() string {
|
|
return tb.server.URL
|
|
}
|
|
|
|
// Close shuts down the backend server
|
|
func (tb *TestBackend) Close() {
|
|
tb.server.Close()
|
|
}
|
|
|
|
// RequestCount returns the number of requests received
|
|
func (tb *TestBackend) RequestCount() int64 {
|
|
return atomic.LoadInt64(&tb.requestCount)
|
|
}
|
|
|
|
// LastRequest returns the most recent request
|
|
func (tb *TestBackend) LastRequest() *RequestLogEntry {
|
|
tb.mu.Lock()
|
|
defer tb.mu.Unlock()
|
|
if len(tb.requestLog) == 0 {
|
|
return nil
|
|
}
|
|
return &tb.requestLog[len(tb.requestLog)-1]
|
|
}
|
|
|
|
// AllRequests returns all logged requests
|
|
func (tb *TestBackend) AllRequests() []RequestLogEntry {
|
|
tb.mu.Lock()
|
|
defer tb.mu.Unlock()
|
|
result := make([]RequestLogEntry, len(tb.requestLog))
|
|
copy(result, tb.requestLog)
|
|
return result
|
|
}
|
|
|
|
// ============== HTTP Client Helpers ==============
|
|
|
|
// HTTPClient is a configured HTTP client for testing
|
|
type HTTPClient struct {
|
|
client *http.Client
|
|
baseURL string
|
|
}
|
|
|
|
// NewHTTPClient creates a new test HTTP client
|
|
func NewHTTPClient(baseURL string) *HTTPClient {
|
|
return &HTTPClient{
|
|
client: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse // Don't follow redirects
|
|
},
|
|
},
|
|
baseURL: baseURL,
|
|
}
|
|
}
|
|
|
|
// Get performs a GET request
|
|
func (c *HTTPClient) Get(path string, headers map[string]string) (*http.Response, error) {
|
|
return c.Do("GET", path, nil, headers)
|
|
}
|
|
|
|
// Post performs a POST request
|
|
func (c *HTTPClient) Post(path string, body []byte, headers map[string]string) (*http.Response, error) {
|
|
return c.Do("POST", path, body, headers)
|
|
}
|
|
|
|
// Do performs an HTTP request
|
|
func (c *HTTPClient) Do(method, path string, body []byte, headers map[string]string) (*http.Response, error) {
|
|
var bodyReader io.Reader
|
|
if body != nil {
|
|
bodyReader = bytes.NewReader(body)
|
|
}
|
|
|
|
req, err := http.NewRequest(method, c.baseURL+path, bodyReader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for k, v := range headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
return c.client.Do(req)
|
|
}
|
|
|
|
// GetJSON performs GET and decodes JSON response
|
|
func (c *HTTPClient) GetJSON(path string, result interface{}) (*http.Response, error) {
|
|
resp, err := c.Get(path, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
|
|
return resp, err
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// ============== File System Helpers ==============
|
|
|
|
// CreateTempDir creates a temporary directory for static files
|
|
func CreateTempDir(t *testing.T) string {
|
|
t.Helper()
|
|
dir, err := os.MkdirTemp("", "konduktor-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
t.Cleanup(func() { os.RemoveAll(dir) })
|
|
return dir
|
|
}
|
|
|
|
// CreateTempFile creates a temporary file with given content
|
|
func CreateTempFile(t *testing.T, dir, name, content string) string {
|
|
t.Helper()
|
|
path := filepath.Join(dir, name)
|
|
|
|
// Create parent directories if needed
|
|
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
|
t.Fatalf("Failed to create directories: %v", err)
|
|
}
|
|
|
|
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
|
t.Fatalf("Failed to write file: %v", err)
|
|
}
|
|
return path
|
|
}
|
|
|
|
// ============== Assertion Helpers ==============
|
|
|
|
// AssertStatus checks if response has expected status code
|
|
func AssertStatus(t *testing.T, resp *http.Response, expected int) {
|
|
t.Helper()
|
|
if resp.StatusCode != expected {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Errorf("Expected status %d, got %d. Body: %s", expected, resp.StatusCode, string(body))
|
|
}
|
|
}
|
|
|
|
// AssertHeader checks if response has expected header value
|
|
func AssertHeader(t *testing.T, resp *http.Response, header, expected string) {
|
|
t.Helper()
|
|
actual := resp.Header.Get(header)
|
|
if actual != expected {
|
|
t.Errorf("Expected header %s=%q, got %q", header, expected, actual)
|
|
}
|
|
}
|
|
|
|
// AssertHeaderContains checks if header contains substring
|
|
func AssertHeaderContains(t *testing.T, resp *http.Response, header, substring string) {
|
|
t.Helper()
|
|
actual := resp.Header.Get(header)
|
|
if actual == "" || !contains(actual, substring) {
|
|
t.Errorf("Expected header %s to contain %q, got %q", header, substring, actual)
|
|
}
|
|
}
|
|
|
|
// AssertJSONField checks if JSON response has expected field value
|
|
func AssertJSONField(t *testing.T, body []byte, field string, expected interface{}) {
|
|
t.Helper()
|
|
var data map[string]interface{}
|
|
if err := json.Unmarshal(body, &data); err != nil {
|
|
t.Fatalf("Failed to parse JSON: %v", err)
|
|
}
|
|
|
|
actual, ok := data[field]
|
|
if !ok {
|
|
t.Errorf("Field %q not found in JSON", field)
|
|
return
|
|
}
|
|
|
|
if actual != expected {
|
|
t.Errorf("Expected %s=%v, got %v", field, expected, actual)
|
|
}
|
|
}
|
|
|
|
func contains(s, substr string) bool {
|
|
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0))
|
|
}
|
|
|
|
func containsAt(s, substr string, start int) bool {
|
|
for i := start; i <= len(s)-len(substr); i++ {
|
|
if s[i:i+len(substr)] == substr {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// ReadBody reads and returns response body
|
|
func ReadBody(t *testing.T, resp *http.Response) []byte {
|
|
t.Helper()
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read body: %v", err)
|
|
}
|
|
return body
|
|
}
|