Илья Глазунов bd0b381195 feat(caching): Implement cache hit detection and response header management
- Added functionality to mark responses as cache hits to prevent incorrect X-Cache headers.
- Introduced setCacheHitFlag function to traverse response writer wrappers and set cache hit flag.
- Updated cachingResponseWriter to manage cache hit state and adjust X-Cache header accordingly.
- Enhanced ProcessRequest and ProcessResponse methods to utilize new caching logic.

feat(extension): Introduce ResponseWriterWrapper and ResponseFinalizer interfaces

- Added ResponseWriterWrapper interface for extensions to wrap response writers.
- Introduced ResponseFinalizer interface for finalizing responses after processing.

refactor(manager): Improve response writer wrapping and finalization

- Updated Manager.Handler to wrap response writers through all enabled extensions.
- Implemented finalization of response writers after processing requests.

test(caching): Add comprehensive integration tests for caching behavior

- Created caching_test.go with tests for cache hit/miss, TTL expiration, pattern-based caching, and more.
- Ensured that caching logic works correctly for various scenarios including query strings and error responses.

test(routing): Add integration tests for routing behavior

- Created routing_test.go with tests for route priority, case sensitivity, default routes, and return directives.
- Verified that routing behaves as expected with multiple regex routes and named groups.
2025-12-12 01:03:32 +03:00

467 lines
10 KiB
Go

package extension
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"io"
"net/http"
"regexp"
"strings"
"sync"
"time"
"github.com/konduktor/konduktor/internal/logging"
)
type CachingExtension struct {
BaseExtension
cache map[string]*cacheEntry
cachePatterns []*cachePattern
defaultTTL time.Duration
maxSize int
currentSize int
mu sync.RWMutex
hits int64
misses int64
}
type cacheEntry struct {
key string
body []byte
headers http.Header
statusCode int
contentType string
createdAt time.Time
expiresAt time.Time
size int
}
type cachePattern struct {
pattern *regexp.Regexp
ttl time.Duration
methods []string
}
type CachingConfig struct {
Enabled bool `yaml:"enabled"`
DefaultTTL string `yaml:"default_ttl"` // e.g., "5m", "1h"
MaxSizeMB int `yaml:"max_size_mb"` // Max cache size in MB
CachePatterns []PatternConfig `yaml:"cache_patterns"` // Patterns to cache
}
type PatternConfig struct {
Pattern string `yaml:"pattern"` // Regex pattern
TTL string `yaml:"ttl"` // TTL for this pattern
Methods []string `yaml:"methods"` // HTTP methods to cache (default: GET)
}
func NewCachingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
ext := &CachingExtension{
BaseExtension: NewBaseExtension("caching", 20, logger),
cache: make(map[string]*cacheEntry),
cachePatterns: make([]*cachePattern, 0),
defaultTTL: 5 * time.Minute,
maxSize: 100 * 1024 * 1024, // 100MB default
}
if ttl, ok := config["default_ttl"].(string); ok {
if duration, err := time.ParseDuration(ttl); err == nil {
ext.defaultTTL = duration
}
}
if maxSize, ok := config["max_size_mb"].(int); ok {
ext.maxSize = maxSize * 1024 * 1024
} else if maxSizeFloat, ok := config["max_size_mb"].(float64); ok {
ext.maxSize = int(maxSizeFloat) * 1024 * 1024
}
if patterns, ok := config["cache_patterns"].([]interface{}); ok {
for _, p := range patterns {
if patternCfg, ok := p.(map[string]interface{}); ok {
pattern := &cachePattern{
ttl: ext.defaultTTL,
methods: []string{"GET"},
}
if patternStr, ok := patternCfg["pattern"].(string); ok {
re, err := regexp.Compile(patternStr)
if err != nil {
logger.Error("Invalid cache pattern", "pattern", patternStr, "error", err)
continue
}
pattern.pattern = re
}
if ttl, ok := patternCfg["ttl"].(string); ok {
if duration, err := time.ParseDuration(ttl); err == nil {
pattern.ttl = duration
}
}
if methods, ok := patternCfg["methods"].([]interface{}); ok {
pattern.methods = make([]string, 0)
for _, m := range methods {
if method, ok := m.(string); ok {
pattern.methods = append(pattern.methods, strings.ToUpper(method))
}
}
}
ext.cachePatterns = append(ext.cachePatterns, pattern)
}
}
}
go ext.cleanupLoop()
return ext, nil
}
func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
if !e.shouldCache(r) {
return false, nil
}
key := e.cacheKey(r)
e.mu.RLock()
entry, exists := e.cache[key]
e.mu.RUnlock()
if exists && time.Now().Before(entry.expiresAt) {
e.mu.Lock()
e.hits++
e.mu.Unlock()
// Mark as cache hit to prevent setting X-Cache: MISS
// Try to find cachingResponseWriter in the wrapper chain
setCacheHitFlag(w)
for k, values := range entry.headers {
for _, v := range values {
w.Header().Add(k, v)
}
}
w.Header().Set("X-Cache", "HIT")
w.Header().Set("Content-Type", entry.contentType)
w.WriteHeader(entry.statusCode)
w.Write(entry.body)
e.logger.Debug("Cache hit", "key", key, "path", r.URL.Path)
return true, nil
}
e.mu.Lock()
e.misses++
e.mu.Unlock()
return false, nil
}
// setCacheHitFlag tries to find cachingResponseWriter and set cache hit flag
func setCacheHitFlag(w http.ResponseWriter) {
// Direct match
if cw, ok := w.(*cachingResponseWriter); ok {
cw.SetCacheHit()
return
}
// Try unwrapping
type unwrapper interface {
Unwrap() http.ResponseWriter
}
for {
if u, ok := w.(unwrapper); ok {
w = u.Unwrap()
if cw, ok := w.(*cachingResponseWriter); ok {
cw.SetCacheHit()
return
}
} else {
return
}
}
}
// ProcessResponse caches the response if applicable
func (e *CachingExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
// Response caching is handled by the CachingResponseWriter
// X-Cache header is set in the cachingResponseWriter.WriteHeader
}
// WrapResponseWriter wraps the response writer to capture the response for caching
func (e *CachingExtension) WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter {
if !e.shouldCache(r) {
return w
}
return &cachingResponseWriter{
ResponseWriter: w,
ext: e,
request: r,
buffer: &bytes.Buffer{},
}
}
type cachingResponseWriter struct {
http.ResponseWriter
ext *CachingExtension
request *http.Request
buffer *bytes.Buffer
statusCode int
wroteHeader bool
cacheHit bool // Flag to indicate if this was a cache hit
}
func (cw *cachingResponseWriter) WriteHeader(code int) {
if !cw.wroteHeader {
cw.statusCode = code
cw.wroteHeader = true
// Set X-Cache: MISS header before writing headers (only if not a cache hit)
if !cw.cacheHit {
cw.ResponseWriter.Header().Set("X-Cache", "MISS")
}
cw.ResponseWriter.WriteHeader(code)
}
}
// SetCacheHit marks this response as a cache hit (to avoid setting X-Cache: MISS)
func (cw *cachingResponseWriter) SetCacheHit() {
cw.cacheHit = true
}
func (cw *cachingResponseWriter) Write(b []byte) (int, error) {
if !cw.wroteHeader {
cw.WriteHeader(http.StatusOK)
}
cw.buffer.Write(b)
return cw.ResponseWriter.Write(b)
}
func (cw *cachingResponseWriter) Finalize() {
if cw.statusCode < 200 || cw.statusCode >= 400 {
return
}
body := cw.buffer.Bytes()
if len(body) == 0 {
return
}
key := cw.ext.cacheKey(cw.request)
ttl := cw.ext.getTTL(cw.request)
entry := &cacheEntry{
key: key,
body: body,
headers: cw.Header().Clone(),
statusCode: cw.statusCode,
contentType: cw.Header().Get("Content-Type"),
createdAt: time.Now(),
expiresAt: time.Now().Add(ttl),
size: len(body),
}
cw.ext.store(entry)
}
func (e *CachingExtension) shouldCache(r *http.Request) bool {
path := r.URL.Path
method := r.Method
for _, pattern := range e.cachePatterns {
if pattern.pattern.MatchString(path) {
for _, m := range pattern.methods {
if m == method {
return true
}
}
}
}
// Default: only cache GET requests
return method == "GET" && len(e.cachePatterns) == 0
}
func (e *CachingExtension) cacheKey(r *http.Request) string {
// Create cache key from method + URL + relevant headers
h := sha256.New()
h.Write([]byte(r.Method))
h.Write([]byte(r.URL.String()))
// Include Accept-Encoding for vary
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
h.Write([]byte(ae))
}
return hex.EncodeToString(h.Sum(nil))
}
func (e *CachingExtension) getTTL(r *http.Request) time.Duration {
path := r.URL.Path
for _, pattern := range e.cachePatterns {
if pattern.pattern.MatchString(path) {
return pattern.ttl
}
}
return e.defaultTTL
}
func (e *CachingExtension) store(entry *cacheEntry) {
e.mu.Lock()
defer e.mu.Unlock()
// Evict old entries if needed
for e.currentSize+entry.size > e.maxSize && len(e.cache) > 0 {
e.evictOldest()
}
// Store new entry
if existing, ok := e.cache[entry.key]; ok {
e.currentSize -= existing.size
}
e.cache[entry.key] = entry
e.currentSize += entry.size
e.logger.Debug("Cached response",
"key", entry.key[:16],
"size", entry.size,
"ttl", entry.expiresAt.Sub(entry.createdAt).String())
}
func (e *CachingExtension) evictOldest() {
var oldestKey string
var oldestTime time.Time
for key, entry := range e.cache {
if oldestKey == "" || entry.createdAt.Before(oldestTime) {
oldestKey = key
oldestTime = entry.createdAt
}
}
if oldestKey != "" {
entry := e.cache[oldestKey]
e.currentSize -= entry.size
delete(e.cache, oldestKey)
}
}
func (e *CachingExtension) cleanupLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
e.cleanupExpired()
}
}
func (e *CachingExtension) cleanupExpired() {
e.mu.Lock()
defer e.mu.Unlock()
now := time.Now()
for key, entry := range e.cache {
if now.After(entry.expiresAt) {
e.currentSize -= entry.size
delete(e.cache, key)
}
}
}
func (e *CachingExtension) Invalidate(key string) {
e.mu.Lock()
defer e.mu.Unlock()
if entry, ok := e.cache[key]; ok {
e.currentSize -= entry.size
delete(e.cache, key)
}
}
// InvalidatePattern removes all entries matching a pattern, unlike Invalidate
func (e *CachingExtension) InvalidatePattern(pattern string) error {
re, err := regexp.Compile(pattern)
if err != nil {
return err
}
e.mu.Lock()
defer e.mu.Unlock()
for key, entry := range e.cache {
if re.MatchString(key) {
e.currentSize -= entry.size
delete(e.cache, key)
}
}
return nil
}
// Clear removes all entries from cache
func (e *CachingExtension) Clear() {
e.mu.Lock()
defer e.mu.Unlock()
e.cache = make(map[string]*cacheEntry)
e.currentSize = 0
}
// GetMetrics returns caching metrics
func (e *CachingExtension) GetMetrics() map[string]interface{} {
e.mu.RLock()
defer e.mu.RUnlock()
hitRate := float64(0)
total := e.hits + e.misses
if total > 0 {
hitRate = float64(e.hits) / float64(total) * 100
}
return map[string]interface{}{
"entries": len(e.cache),
"size_bytes": e.currentSize,
"max_size": e.maxSize,
"hits": e.hits,
"misses": e.misses,
"hit_rate": hitRate,
"patterns": len(e.cachePatterns),
"default_ttl": e.defaultTTL.String(),
}
}
// Cleanup stops the cleanup goroutine
func (e *CachingExtension) Cleanup() error {
e.Clear()
return nil
}
// CacheReader wraps an io.ReadCloser to cache the body
type CacheReader struct {
io.ReadCloser
buffer *bytes.Buffer
}
func (cr *CacheReader) Read(p []byte) (int, error) {
n, err := cr.ReadCloser.Read(p)
if n > 0 {
cr.buffer.Write(p[:n])
}
return n, err
}
func (cr *CacheReader) GetBody() []byte {
return cr.buffer.Bytes()
}