package extension import ( "bytes" "context" "crypto/sha256" "encoding/hex" "io" "net/http" "regexp" "strings" "sync" "time" "github.com/konduktor/konduktor/internal/logging" ) type CachingExtension struct { BaseExtension cache map[string]*cacheEntry cachePatterns []*cachePattern defaultTTL time.Duration maxSize int currentSize int mu sync.RWMutex hits int64 misses int64 } type cacheEntry struct { key string body []byte headers http.Header statusCode int contentType string createdAt time.Time expiresAt time.Time size int } type cachePattern struct { pattern *regexp.Regexp ttl time.Duration methods []string } type CachingConfig struct { Enabled bool `yaml:"enabled"` DefaultTTL string `yaml:"default_ttl"` // e.g., "5m", "1h" MaxSizeMB int `yaml:"max_size_mb"` // Max cache size in MB CachePatterns []PatternConfig `yaml:"cache_patterns"` // Patterns to cache } type PatternConfig struct { Pattern string `yaml:"pattern"` // Regex pattern TTL string `yaml:"ttl"` // TTL for this pattern Methods []string `yaml:"methods"` // HTTP methods to cache (default: GET) } func NewCachingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) { ext := &CachingExtension{ BaseExtension: NewBaseExtension("caching", 20, logger), cache: make(map[string]*cacheEntry), cachePatterns: make([]*cachePattern, 0), defaultTTL: 5 * time.Minute, maxSize: 100 * 1024 * 1024, // 100MB default } if ttl, ok := config["default_ttl"].(string); ok { if duration, err := time.ParseDuration(ttl); err == nil { ext.defaultTTL = duration } } if maxSize, ok := config["max_size_mb"].(int); ok { ext.maxSize = maxSize * 1024 * 1024 } else if maxSizeFloat, ok := config["max_size_mb"].(float64); ok { ext.maxSize = int(maxSizeFloat) * 1024 * 1024 } if patterns, ok := config["cache_patterns"].([]interface{}); ok { for _, p := range patterns { if patternCfg, ok := p.(map[string]interface{}); ok { pattern := &cachePattern{ ttl: ext.defaultTTL, methods: []string{"GET"}, } if patternStr, ok := patternCfg["pattern"].(string); ok { re, err := regexp.Compile(patternStr) if err != nil { logger.Error("Invalid cache pattern", "pattern", patternStr, "error", err) continue } pattern.pattern = re } if ttl, ok := patternCfg["ttl"].(string); ok { if duration, err := time.ParseDuration(ttl); err == nil { pattern.ttl = duration } } if methods, ok := patternCfg["methods"].([]interface{}); ok { pattern.methods = make([]string, 0) for _, m := range methods { if method, ok := m.(string); ok { pattern.methods = append(pattern.methods, strings.ToUpper(method)) } } } ext.cachePatterns = append(ext.cachePatterns, pattern) } } } go ext.cleanupLoop() return ext, nil } func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) { if !e.shouldCache(r) { return false, nil } key := e.cacheKey(r) e.mu.RLock() entry, exists := e.cache[key] e.mu.RUnlock() if exists && time.Now().Before(entry.expiresAt) { e.mu.Lock() e.hits++ e.mu.Unlock() // Mark as cache hit to prevent setting X-Cache: MISS // Try to find cachingResponseWriter in the wrapper chain setCacheHitFlag(w) for k, values := range entry.headers { for _, v := range values { w.Header().Add(k, v) } } w.Header().Set("X-Cache", "HIT") w.Header().Set("Content-Type", entry.contentType) w.WriteHeader(entry.statusCode) w.Write(entry.body) e.logger.Debug("Cache hit", "key", key, "path", r.URL.Path) return true, nil } e.mu.Lock() e.misses++ e.mu.Unlock() return false, nil } // setCacheHitFlag tries to find cachingResponseWriter and set cache hit flag func setCacheHitFlag(w http.ResponseWriter) { // Direct match if cw, ok := w.(*cachingResponseWriter); ok { cw.SetCacheHit() return } // Try unwrapping type unwrapper interface { Unwrap() http.ResponseWriter } for { if u, ok := w.(unwrapper); ok { w = u.Unwrap() if cw, ok := w.(*cachingResponseWriter); ok { cw.SetCacheHit() return } } else { return } } } // ProcessResponse caches the response if applicable func (e *CachingExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) { // Response caching is handled by the CachingResponseWriter // X-Cache header is set in the cachingResponseWriter.WriteHeader } // WrapResponseWriter wraps the response writer to capture the response for caching func (e *CachingExtension) WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter { if !e.shouldCache(r) { return w } return &cachingResponseWriter{ ResponseWriter: w, ext: e, request: r, buffer: &bytes.Buffer{}, } } type cachingResponseWriter struct { http.ResponseWriter ext *CachingExtension request *http.Request buffer *bytes.Buffer statusCode int wroteHeader bool cacheHit bool // Flag to indicate if this was a cache hit } func (cw *cachingResponseWriter) WriteHeader(code int) { if !cw.wroteHeader { cw.statusCode = code cw.wroteHeader = true // Set X-Cache: MISS header before writing headers (only if not a cache hit) if !cw.cacheHit { cw.ResponseWriter.Header().Set("X-Cache", "MISS") } cw.ResponseWriter.WriteHeader(code) } } // SetCacheHit marks this response as a cache hit (to avoid setting X-Cache: MISS) func (cw *cachingResponseWriter) SetCacheHit() { cw.cacheHit = true } func (cw *cachingResponseWriter) Write(b []byte) (int, error) { if !cw.wroteHeader { cw.WriteHeader(http.StatusOK) } cw.buffer.Write(b) return cw.ResponseWriter.Write(b) } func (cw *cachingResponseWriter) Finalize() { if cw.statusCode < 200 || cw.statusCode >= 400 { return } body := cw.buffer.Bytes() if len(body) == 0 { return } key := cw.ext.cacheKey(cw.request) ttl := cw.ext.getTTL(cw.request) entry := &cacheEntry{ key: key, body: body, headers: cw.Header().Clone(), statusCode: cw.statusCode, contentType: cw.Header().Get("Content-Type"), createdAt: time.Now(), expiresAt: time.Now().Add(ttl), size: len(body), } cw.ext.store(entry) } func (e *CachingExtension) shouldCache(r *http.Request) bool { path := r.URL.Path method := r.Method for _, pattern := range e.cachePatterns { if pattern.pattern.MatchString(path) { for _, m := range pattern.methods { if m == method { return true } } } } // Default: only cache GET requests return method == "GET" && len(e.cachePatterns) == 0 } func (e *CachingExtension) cacheKey(r *http.Request) string { // Create cache key from method + URL + relevant headers h := sha256.New() h.Write([]byte(r.Method)) h.Write([]byte(r.URL.String())) // Include Accept-Encoding for vary if ae := r.Header.Get("Accept-Encoding"); ae != "" { h.Write([]byte(ae)) } return hex.EncodeToString(h.Sum(nil)) } func (e *CachingExtension) getTTL(r *http.Request) time.Duration { path := r.URL.Path for _, pattern := range e.cachePatterns { if pattern.pattern.MatchString(path) { return pattern.ttl } } return e.defaultTTL } func (e *CachingExtension) store(entry *cacheEntry) { e.mu.Lock() defer e.mu.Unlock() // Evict old entries if needed for e.currentSize+entry.size > e.maxSize && len(e.cache) > 0 { e.evictOldest() } // Store new entry if existing, ok := e.cache[entry.key]; ok { e.currentSize -= existing.size } e.cache[entry.key] = entry e.currentSize += entry.size e.logger.Debug("Cached response", "key", entry.key[:16], "size", entry.size, "ttl", entry.expiresAt.Sub(entry.createdAt).String()) } func (e *CachingExtension) evictOldest() { var oldestKey string var oldestTime time.Time for key, entry := range e.cache { if oldestKey == "" || entry.createdAt.Before(oldestTime) { oldestKey = key oldestTime = entry.createdAt } } if oldestKey != "" { entry := e.cache[oldestKey] e.currentSize -= entry.size delete(e.cache, oldestKey) } } func (e *CachingExtension) cleanupLoop() { ticker := time.NewTicker(time.Minute) defer ticker.Stop() for range ticker.C { e.cleanupExpired() } } func (e *CachingExtension) cleanupExpired() { e.mu.Lock() defer e.mu.Unlock() now := time.Now() for key, entry := range e.cache { if now.After(entry.expiresAt) { e.currentSize -= entry.size delete(e.cache, key) } } } func (e *CachingExtension) Invalidate(key string) { e.mu.Lock() defer e.mu.Unlock() if entry, ok := e.cache[key]; ok { e.currentSize -= entry.size delete(e.cache, key) } } // InvalidatePattern removes all entries matching a pattern, unlike Invalidate func (e *CachingExtension) InvalidatePattern(pattern string) error { re, err := regexp.Compile(pattern) if err != nil { return err } e.mu.Lock() defer e.mu.Unlock() for key, entry := range e.cache { if re.MatchString(key) { e.currentSize -= entry.size delete(e.cache, key) } } return nil } // Clear removes all entries from cache func (e *CachingExtension) Clear() { e.mu.Lock() defer e.mu.Unlock() e.cache = make(map[string]*cacheEntry) e.currentSize = 0 } // GetMetrics returns caching metrics func (e *CachingExtension) GetMetrics() map[string]interface{} { e.mu.RLock() defer e.mu.RUnlock() hitRate := float64(0) total := e.hits + e.misses if total > 0 { hitRate = float64(e.hits) / float64(total) * 100 } return map[string]interface{}{ "entries": len(e.cache), "size_bytes": e.currentSize, "max_size": e.maxSize, "hits": e.hits, "misses": e.misses, "hit_rate": hitRate, "patterns": len(e.cachePatterns), "default_ttl": e.defaultTTL.String(), } } // Cleanup stops the cleanup goroutine func (e *CachingExtension) Cleanup() error { e.Clear() return nil } // CacheReader wraps an io.ReadCloser to cache the body type CacheReader struct { io.ReadCloser buffer *bytes.Buffer } func (cr *CacheReader) Read(p []byte) (int, error) { n, err := cr.ReadCloser.Read(p) if n > 0 { cr.buffer.Write(p[:n]) } return n, err } func (cr *CacheReader) GetBody() []byte { return cr.buffer.Bytes() }