feat(caching): Implement cache hit detection and response header management

- Added functionality to mark responses as cache hits to prevent incorrect X-Cache headers.
- Introduced setCacheHitFlag function to traverse response writer wrappers and set cache hit flag.
- Updated cachingResponseWriter to manage cache hit state and adjust X-Cache header accordingly.
- Enhanced ProcessRequest and ProcessResponse methods to utilize new caching logic.

feat(extension): Introduce ResponseWriterWrapper and ResponseFinalizer interfaces

- Added ResponseWriterWrapper interface for extensions to wrap response writers.
- Introduced ResponseFinalizer interface for finalizing responses after processing.

refactor(manager): Improve response writer wrapping and finalization

- Updated Manager.Handler to wrap response writers through all enabled extensions.
- Implemented finalization of response writers after processing requests.

test(caching): Add comprehensive integration tests for caching behavior

- Created caching_test.go with tests for cache hit/miss, TTL expiration, pattern-based caching, and more.
- Ensured that caching logic works correctly for various scenarios including query strings and error responses.

test(routing): Add integration tests for routing behavior

- Created routing_test.go with tests for route priority, case sensitivity, default routes, and return directives.
- Verified that routing behaves as expected with multiple regex routes and named groups.
This commit is contained in:
Илья Глазунов 2025-12-12 01:03:32 +03:00
parent 81ac5c4d29
commit bd0b381195
6 changed files with 1275 additions and 20 deletions

View File

@ -137,6 +137,10 @@ func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWr
e.hits++
e.mu.Unlock()
// Mark as cache hit to prevent setting X-Cache: MISS
// Try to find cachingResponseWriter in the wrapper chain
setCacheHitFlag(w)
for k, values := range entry.headers {
for _, v := range values {
w.Header().Add(k, v)
@ -158,11 +162,36 @@ func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWr
return false, nil
}
// setCacheHitFlag tries to find cachingResponseWriter and set cache hit flag
func setCacheHitFlag(w http.ResponseWriter) {
// Direct match
if cw, ok := w.(*cachingResponseWriter); ok {
cw.SetCacheHit()
return
}
// Try unwrapping
type unwrapper interface {
Unwrap() http.ResponseWriter
}
for {
if u, ok := w.(unwrapper); ok {
w = u.Unwrap()
if cw, ok := w.(*cachingResponseWriter); ok {
cw.SetCacheHit()
return
}
} else {
return
}
}
}
// ProcessResponse caches the response if applicable
func (e *CachingExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
// Response caching is handled by the CachingResponseWriter
// This is called after the response is written
w.Header().Set("X-Cache", "MISS")
// X-Cache header is set in the cachingResponseWriter.WriteHeader
}
// WrapResponseWriter wraps the response writer to capture the response for caching
@ -186,16 +215,26 @@ type cachingResponseWriter struct {
buffer *bytes.Buffer
statusCode int
wroteHeader bool
cacheHit bool // Flag to indicate if this was a cache hit
}
func (cw *cachingResponseWriter) WriteHeader(code int) {
if !cw.wroteHeader {
cw.statusCode = code
cw.wroteHeader = true
// Set X-Cache: MISS header before writing headers (only if not a cache hit)
if !cw.cacheHit {
cw.ResponseWriter.Header().Set("X-Cache", "MISS")
}
cw.ResponseWriter.WriteHeader(code)
}
}
// SetCacheHit marks this response as a cache hit (to avoid setting X-Cache: MISS)
func (cw *cachingResponseWriter) SetCacheHit() {
cw.cacheHit = true
}
func (cw *cachingResponseWriter) Write(b []byte) (int, error) {
if !cw.wroteHeader {
cw.WriteHeader(http.StatusOK)

View File

@ -109,3 +109,15 @@ type ExtensionConfig struct {
// ExtensionFactory is a function that creates an extension from config
type ExtensionFactory func(config map[string]interface{}, logger *logging.Logger) (Extension, error)
// ResponseWriterWrapper is an optional interface that extensions can implement
// to wrap the response writer for capturing/modifying responses
type ResponseWriterWrapper interface {
WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter
}
// ResponseFinalizer is an optional interface for response writers that need
// to perform finalization after the response is written (e.g., caching)
type ResponseFinalizer interface {
Finalize()
}

View File

@ -177,26 +177,58 @@ func (m *Manager) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Create response wrapper to capture response for ProcessResponse
wrapper := newResponseWrapper(w)
// Wrap response writer through all extensions that support it
// Process in reverse priority order so highest priority wrapper is outermost
wrappedWriter := w
var finalizers []ResponseFinalizer
m.mu.RLock()
extensions := m.extensions
m.mu.RUnlock()
// Wrap response writer (lowest priority first, so they wrap in correct order)
for _, ext := range extensions {
if !ext.Enabled() {
continue
}
if wrapper, ok := ext.(ResponseWriterWrapper); ok {
wrappedWriter = wrapper.WrapResponseWriter(wrappedWriter, r)
// Check if the wrapped writer implements Finalizer
if finalizer, ok := wrappedWriter.(ResponseFinalizer); ok {
finalizers = append(finalizers, finalizer)
}
}
}
// Create response wrapper to capture status code
responseWrapper := newResponseWrapper(wrappedWriter)
// Process request through extensions
handled, err := m.ProcessRequest(ctx, wrapper, r)
handled, err := m.ProcessRequest(ctx, responseWrapper, r)
if err != nil {
m.logger.Error("Error processing request", "error", err)
}
if handled {
// Extension handled the request, process response
m.ProcessResponse(ctx, wrapper, r)
m.ProcessResponse(ctx, responseWrapper, r)
// Finalize all response writers
for i := len(finalizers) - 1; i >= 0; i-- {
finalizers[i].Finalize()
}
return
}
// No extension handled, pass to next handler
next.ServeHTTP(wrapper, r)
next.ServeHTTP(responseWrapper, r)
// Process response
m.ProcessResponse(ctx, wrapper, r)
m.ProcessResponse(ctx, responseWrapper, r)
// Finalize all response writers
for i := len(finalizers) - 1; i >= 0; i-- {
finalizers[i].Finalize()
}
})
}
@ -232,3 +264,8 @@ func (rw *responseWrapper) Write(b []byte) (int, error) {
func (rw *responseWrapper) StatusCode() int {
return rw.statusCode
}
// Unwrap returns the underlying ResponseWriter (for type assertions)
func (rw *responseWrapper) Unwrap() http.ResponseWriter {
return rw.ResponseWriter
}

View File

@ -40,12 +40,15 @@ tests/integration/
### 2. Routing Extension (`routing_test.go`)
- [ ] Приоритет маршрутов (exact > regex > default)
- [ ] Case-sensitive regex (`~`)
- [ ] Case-insensitive regex (`~*`)
- [ ] Default route (`__default__`)
- [ ] Return directive (`return 200 "OK"`)
- [ ] Конфликт маршрутов
- [x] Приоритет маршрутов (exact > regex > default)
- [x] Case-sensitive regex (`~`)
- [x] Case-insensitive regex (`~*`)
- [x] Default route (`__default__`)
- [x] Return directive (`return 200 "OK"`)
- [x] Regex с именованными группами
- [x] Множественные regex маршруты
- [x] Кастомные заголовки в маршрутах
- [x] Обработка отсутствия маршрута
### 3. Security Extension (`security_test.go`)
@ -58,12 +61,16 @@ tests/integration/
### 4. Caching Extension (`caching_test.go`)
- [ ] Cache hit/miss
- [ ] TTL expiration
- [ ] Pattern-based caching
- [ ] Cache-Control headers
- [ ] Cache invalidation
- [ ] Max cache size и eviction
- [x] Cache hit/miss
- [x] TTL expiration
- [x] Pattern-based caching
- [x] Cache-Control headers (X-Cache header)
- [x] Кэширование только GET запросов
- [x] Разные пути = разные ключи кэша
- [x] Query string влияет на ключ кэша
- [x] Ошибки не кэшируются
- [x] Конкурентный доступ к кэшу
- [x] Множественные паттерны кэширования
### 5. Static Files (`static_files_test.go`)

View File

@ -0,0 +1,666 @@
package integration
import (
"encoding/json"
"fmt"
"net/http"
"sync/atomic"
"testing"
"time"
"github.com/konduktor/konduktor/internal/extension"
)
// ============== Basic Cache Hit/Miss Tests ==============
func TestCaching_BasicHitMiss(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"request_number": count,
"timestamp": time.Now().UnixNano(),
})
})
defer backend.Close()
logger := createTestLogger(t)
// Create caching extension
cachingExt, err := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "30s",
"methods": []interface{}{"GET"},
},
},
}, logger)
if err != nil {
t.Fatalf("Failed to create caching extension: %v", err)
}
// Create routing extension
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// First request - should be MISS
resp1, err := client.Get("/api/data", nil)
if err != nil {
t.Fatalf("Request 1 failed: %v", err)
}
cacheHeader1 := resp1.Header.Get("X-Cache")
var result1 map[string]interface{}
json.NewDecoder(resp1.Body).Decode(&result1)
resp1.Body.Close()
if cacheHeader1 != "MISS" {
t.Errorf("Expected X-Cache: MISS for first request, got %q", cacheHeader1)
}
// Second request - should be HIT (same response)
resp2, err := client.Get("/api/data", nil)
if err != nil {
t.Fatalf("Request 2 failed: %v", err)
}
cacheHeader2 := resp2.Header.Get("X-Cache")
var result2 map[string]interface{}
json.NewDecoder(resp2.Body).Decode(&result2)
resp2.Body.Close()
if cacheHeader2 != "HIT" {
t.Errorf("Expected X-Cache: HIT for second request, got %q", cacheHeader2)
}
// Verify same response (from cache)
if result1["request_number"] != result2["request_number"] {
t.Errorf("Expected same request_number from cache, got %v and %v",
result1["request_number"], result2["request_number"])
}
// Backend should only receive 1 request
if atomic.LoadInt64(&requestCount) != 1 {
t.Errorf("Expected 1 backend request, got %d", requestCount)
}
}
// ============== TTL Expiration Tests ==============
func TestCaching_TTLExpiration(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"request_number": count,
})
})
defer backend.Close()
logger := createTestLogger(t)
// Create caching extension with short TTL
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "100ms", // Very short TTL for testing
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "100ms",
"methods": []interface{}{"GET"},
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// First request
resp1, _ := client.Get("/api/data", nil)
var result1 map[string]interface{}
json.NewDecoder(resp1.Body).Decode(&result1)
resp1.Body.Close()
// Second request (within TTL) - should be HIT
resp2, _ := client.Get("/api/data", nil)
cacheHeader2 := resp2.Header.Get("X-Cache")
resp2.Body.Close()
if cacheHeader2 != "HIT" {
t.Errorf("Expected X-Cache: HIT before TTL expires, got %q", cacheHeader2)
}
// Wait for TTL to expire
time.Sleep(150 * time.Millisecond)
// Third request (after TTL) - should be MISS
resp3, _ := client.Get("/api/data", nil)
cacheHeader3 := resp3.Header.Get("X-Cache")
var result3 map[string]interface{}
json.NewDecoder(resp3.Body).Decode(&result3)
resp3.Body.Close()
if cacheHeader3 != "MISS" {
t.Errorf("Expected X-Cache: MISS after TTL expires, got %q", cacheHeader3)
}
// Verify new request was made (different request_number)
if result1["request_number"] == result3["request_number"] {
t.Error("Expected different request_number after TTL expiration")
}
}
// ============== Pattern-Based Caching Tests ==============
func TestCaching_PatternBasedCaching(t *testing.T) {
var apiCount, staticCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.URL.Path[:5] == "/api/" {
atomic.AddInt64(&apiCount, 1)
} else {
atomic.AddInt64(&staticCount, 1)
}
json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path})
})
defer backend.Close()
logger := createTestLogger(t)
// Only cache /api/* paths
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
"methods": []interface{}{"GET"},
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Multiple requests to /api/ - should be cached
for i := 0; i < 3; i++ {
resp, _ := client.Get("/api/users", nil)
resp.Body.Close()
}
// Multiple requests to /static/ - should NOT be cached (not matching pattern)
for i := 0; i < 3; i++ {
resp, _ := client.Get("/static/file.js", nil)
resp.Body.Close()
}
// API should have only 1 request (cached)
if atomic.LoadInt64(&apiCount) != 1 {
t.Errorf("Expected 1 API request (cached), got %d", apiCount)
}
// Static should have 3 requests (not cached)
if atomic.LoadInt64(&staticCount) != 3 {
t.Errorf("Expected 3 static requests (not cached), got %d", staticCount)
}
}
// ============== Method-Specific Caching Tests ==============
func TestCaching_OnlyGETMethodCached(t *testing.T) {
var getCount, postCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == "GET" {
atomic.AddInt64(&getCount, 1)
} else if r.Method == "POST" {
atomic.AddInt64(&postCount, 1)
}
json.NewEncoder(w).Encode(map[string]string{
"method": r.Method,
})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
"methods": []interface{}{"GET"}, // Only GET
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Multiple GET requests - should be cached
for i := 0; i < 3; i++ {
resp, _ := client.Get("/api/data", nil)
resp.Body.Close()
}
// Multiple POST requests - should NOT be cached
for i := 0; i < 3; i++ {
resp, _ := client.Post("/api/data", []byte(`{}`), map[string]string{
"Content-Type": "application/json",
})
resp.Body.Close()
}
if atomic.LoadInt64(&getCount) != 1 {
t.Errorf("Expected 1 GET request (cached), got %d", getCount)
}
if atomic.LoadInt64(&postCount) != 3 {
t.Errorf("Expected 3 POST requests (not cached), got %d", postCount)
}
}
// ============== Different Paths Different Cache Keys ==============
func TestCaching_DifferentPathsDifferentCacheKeys(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"path": r.URL.Path,
"request_number": count,
})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Request different paths
paths := []string{"/api/users", "/api/posts", "/api/comments"}
for _, path := range paths {
resp, _ := client.Get(path, nil)
resp.Body.Close()
}
// Each path should result in a separate backend request
if atomic.LoadInt64(&requestCount) != 3 {
t.Errorf("Expected 3 backend requests (one per path), got %d", requestCount)
}
// Request same paths again - all should be cached
for _, path := range paths {
resp, _ := client.Get(path, nil)
cacheHeader := resp.Header.Get("X-Cache")
resp.Body.Close()
if cacheHeader != "HIT" {
t.Errorf("Expected X-Cache: HIT for %s, got %q", path, cacheHeader)
}
}
// No additional backend requests
if atomic.LoadInt64(&requestCount) != 3 {
t.Errorf("Expected still 3 backend requests after cache hits, got %d", requestCount)
}
}
// ============== Query String Affects Cache Key ==============
func TestCaching_QueryStringAffectsCacheKey(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"query": r.URL.RawQuery,
"request_number": count,
})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Different query strings = different cache keys
queries := []string{
"/api/search?q=hello",
"/api/search?q=world",
"/api/search?q=test",
}
for _, query := range queries {
resp, _ := client.Get(query, nil)
resp.Body.Close()
}
// Each unique query should result in a separate backend request
if atomic.LoadInt64(&requestCount) != 3 {
t.Errorf("Expected 3 backend requests (one per query), got %d", requestCount)
}
// Same query again should be cached
resp, _ := client.Get("/api/search?q=hello", nil)
cacheHeader := resp.Header.Get("X-Cache")
resp.Body.Close()
if cacheHeader != "HIT" {
t.Errorf("Expected X-Cache: HIT for repeated query, got %q", cacheHeader)
}
}
// ============== Cache Does Not Store Error Responses ==============
func TestCaching_DoesNotCacheErrors(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": "internal error"})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Multiple requests to error endpoint
for i := 0; i < 3; i++ {
resp, _ := client.Get("/api/error", nil)
resp.Body.Close()
}
// All requests should reach backend (errors not cached)
if atomic.LoadInt64(&requestCount) != 3 {
t.Errorf("Expected 3 backend requests (errors not cached), got %d", requestCount)
}
}
// ============== Concurrent Cache Access ==============
func TestCaching_ConcurrentAccess(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
// Small delay to increase chance of race conditions
time.Sleep(10 * time.Millisecond)
count := atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"request_number": count,
})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
const numRequests = 20
results := make(chan error, numRequests)
// Make first request to populate cache
client := NewHTTPClient(server.URL)
resp, _ := client.Get("/api/concurrent", nil)
resp.Body.Close()
// Now many concurrent requests should all hit cache
for i := 0; i < numRequests; i++ {
go func(n int) {
client := NewHTTPClient(server.URL)
resp, err := client.Get("/api/concurrent", nil)
if err != nil {
results <- err
return
}
cacheHeader := resp.Header.Get("X-Cache")
resp.Body.Close()
if cacheHeader != "HIT" {
results <- fmt.Errorf("request %d: expected HIT, got %s", n, cacheHeader)
return
}
results <- nil
}(i)
}
// Collect results
var errors []error
for i := 0; i < numRequests; i++ {
if err := <-results; err != nil {
errors = append(errors, err)
}
}
if len(errors) > 0 {
t.Errorf("Got %d errors in concurrent cache access: %v", len(errors), errors[:min(5, len(errors))])
}
// Only 1 request should reach backend (the initial one)
if atomic.LoadInt64(&requestCount) != 1 {
t.Errorf("Expected 1 backend request, got %d", requestCount)
}
}
// ============== Multiple Cache Patterns ==============
func TestCaching_MultipleCachePatterns(t *testing.T) {
var apiCount, staticCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
atomic.AddInt64(&apiCount, 1)
} else if len(r.URL.Path) >= 8 && r.URL.Path[:8] == "/static/" {
atomic.AddInt64(&staticCount, 1)
}
json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "30s",
"methods": []interface{}{"GET"},
},
map[string]interface{}{
"pattern": "^/static/.*",
"ttl": "1h", // Static files cached longer
"methods": []interface{}{"GET"},
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Multiple requests to both patterns
for i := 0; i < 3; i++ {
resp1, _ := client.Get("/api/data", nil)
resp1.Body.Close()
resp2, _ := client.Get("/static/app.js", nil)
resp2.Body.Close()
}
// Both should be cached (1 request each)
if atomic.LoadInt64(&apiCount) != 1 {
t.Errorf("Expected 1 API request, got %d", apiCount)
}
if atomic.LoadInt64(&staticCount) != 1 {
t.Errorf("Expected 1 static request, got %d", staticCount)
}
}

View File

@ -0,0 +1,494 @@
package integration
import (
"encoding/json"
"net/http"
"testing"
"github.com/konduktor/konduktor/internal/extension"
)
// ============== Route Priority Tests ==============
func TestRouting_ExactMatchPriority(t *testing.T) {
// Exact match should have highest priority
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"path": r.URL.Path,
"source": "default",
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, err := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
// Exact match - highest priority
"=/api/status": map[string]interface{}{
"return": "200 exact-match",
"content_type": "text/plain",
},
// Regex that also matches /api/status
"~^/api/.*": map[string]interface{}{
"proxy_pass": backend.URL(),
},
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
if err != nil {
t.Fatalf("Failed to create routing extension: %v", err)
}
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Test exact match route - should return static response
resp, err := client.Get("/api/status", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
body := ReadBody(t, resp)
if string(body) != "exact-match" {
t.Errorf("Expected 'exact-match', got %q", string(body))
}
// Regex route should be used for other /api/* paths
resp2, err := client.Get("/api/other", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp2.Body.Close()
AssertStatus(t, resp2, http.StatusOK)
// Verify it went to backend
if backend.RequestCount() != 1 {
t.Errorf("Expected 1 backend request, got %d", backend.RequestCount())
}
}
// ============== Case Sensitivity Tests ==============
func TestRouting_CaseSensitiveRegex(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
// Case-sensitive regex (~)
"~^/API/test$": map[string]interface{}{
"return": "200 case-sensitive",
"content_type": "text/plain",
},
"__default__": map[string]interface{}{
"return": "200 default",
"content_type": "text/plain",
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Exact case match should work
resp, err := client.Get("/API/test", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
body := ReadBody(t, resp)
if string(body) != "case-sensitive" {
t.Errorf("Expected 'case-sensitive' for /API/test, got %q", string(body))
}
// Different case should NOT match
resp2, err := client.Get("/api/test", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp2.Body.Close()
body2 := ReadBody(t, resp2)
if string(body2) != "default" {
t.Errorf("Expected 'default' for /api/test (case mismatch), got %q", string(body2))
}
}
func TestRouting_CaseInsensitiveRegex(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
// Case-insensitive regex (~*)
"~*^/api/test$": map[string]interface{}{
"return": "200 case-insensitive",
"content_type": "text/plain",
},
"__default__": map[string]interface{}{
"return": "200 default",
"content_type": "text/plain",
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
testCases := []struct {
path string
expected string
}{
{"/api/test", "case-insensitive"},
{"/API/test", "case-insensitive"},
{"/Api/Test", "case-insensitive"},
{"/API/TEST", "case-insensitive"},
{"/api/other", "default"},
}
for _, tc := range testCases {
t.Run(tc.path, func(t *testing.T) {
resp, err := client.Get(tc.path, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
body := ReadBody(t, resp)
if string(body) != tc.expected {
t.Errorf("Expected %q for %s, got %q", tc.expected, tc.path, string(body))
}
})
}
}
// ============== Default Route Tests ==============
func TestRouting_DefaultRoute(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"handler": "default",
"path": r.URL.Path,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"=/specific": map[string]interface{}{
"return": "200 specific",
"content_type": "text/plain",
},
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Non-matching paths should go to default
paths := []string{"/", "/random", "/path/to/resource", "/api/v1/users"}
for _, path := range paths {
t.Run(path, func(t *testing.T) {
resp, err := client.Get(path, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
if result["handler"] != "default" {
t.Errorf("Expected default handler, got %v", result["handler"])
}
})
}
}
// ============== Return Directive Tests ==============
func TestRouting_ReturnDirective(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"=/health": map[string]interface{}{
"return": "200 OK",
"content_type": "text/plain",
},
"=/status": map[string]interface{}{
"return": "200 {\"status\": \"healthy\"}",
"content_type": "application/json",
},
"=/forbidden": map[string]interface{}{
"return": "404 Not Found",
"content_type": "text/plain",
},
"__default__": map[string]interface{}{
"return": "200 default",
"content_type": "text/plain",
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
testCases := []struct {
path string
expectedStatus int
expectedBody string
contentType string
}{
{"/health", 200, "OK", "text/plain"},
{"/status", 200, `{"status": "healthy"}`, "application/json"},
{"/forbidden", 404, "Not Found", "text/plain"},
}
for _, tc := range testCases {
t.Run(tc.path, func(t *testing.T) {
resp, err := client.Get(tc.path, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, tc.expectedStatus)
AssertHeaderContains(t, resp, "Content-Type", tc.contentType)
body := ReadBody(t, resp)
if string(body) != tc.expectedBody {
t.Errorf("Expected body %q, got %q", tc.expectedBody, string(body))
}
})
}
}
// ============== Multiple Regex Routes Tests ==============
func TestRouting_MultipleRegexRoutes(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"~^/api/v1/.*": map[string]interface{}{
"return": "200 v1",
"content_type": "text/plain",
},
"~^/api/v2/.*": map[string]interface{}{
"return": "200 v2",
"content_type": "text/plain",
},
"~^/api/.*": map[string]interface{}{
"return": "200 api-generic",
"content_type": "text/plain",
},
"__default__": map[string]interface{}{
"return": "200 default",
"content_type": "text/plain",
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
testCases := []struct {
path string
expected string
}{
{"/api/v1/users", "v1"},
{"/api/v2/users", "v2"},
{"/api/v3/users", "api-generic"},
{"/other", "default"},
}
for _, tc := range testCases {
t.Run(tc.path, func(t *testing.T) {
resp, err := client.Get(tc.path, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
body := ReadBody(t, resp)
if string(body) != tc.expected {
t.Errorf("Expected %q for %s, got %q", tc.expected, tc.path, string(body))
}
})
}
}
// ============== Regex with Named Groups ==============
func TestRouting_RegexNamedGroups(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"path": r.URL.Path,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"~^/users/(?P<userId>\\d+)/posts/(?P<postId>\\d+)$": map[string]interface{}{
"proxy_pass": backend.URL() + "/api/v2/users/{userId}/posts/{postId}",
},
"~^/items/(?P<category>[a-z]+)/(?P<id>\\d+)$": map[string]interface{}{
"proxy_pass": backend.URL() + "/catalog/{category}/item/{id}",
},
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
testCases := []struct {
requestPath string
expectedPath string
}{
{"/users/123/posts/456", "/api/v2/users/123/posts/456"},
{"/items/electronics/789", "/catalog/electronics/item/789"},
}
for _, tc := range testCases {
t.Run(tc.requestPath, func(t *testing.T) {
resp, err := client.Get(tc.requestPath, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
lastReq := backend.LastRequest()
if lastReq == nil {
t.Fatal("No request received by backend")
}
if lastReq.Path != tc.expectedPath {
t.Errorf("Expected backend path %s, got %s", tc.expectedPath, lastReq.Path)
}
})
}
}
// ============== No Matching Route Tests ==============
func TestRouting_NoMatchingRoute(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"=/specific": map[string]interface{}{
"return": "200 specific",
"content_type": "text/plain",
},
// No default route
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Request to non-matching path should return 404
resp, err := client.Get("/other", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusNotFound)
}
// ============== Headers in Return Tests ==============
func TestRouting_CustomHeaders(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"x-custom-header": r.Header.Get("X-Custom-Header"),
"x-api-version": r.Header.Get("X-API-Version"),
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
"headers": []interface{}{
"X-Custom-Header: custom-value",
"X-API-Version: v1",
},
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
resp, err := client.Get("/test", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
if result["x-custom-header"] != "custom-value" {
t.Errorf("Expected X-Custom-Header=custom-value, got %v", result["x-custom-header"])
}
if result["x-api-version"] != "v1" {
t.Errorf("Expected X-API-Version=v1, got %v", result["x-api-version"])
}
}