forked from aegis/pyserveX
- Added functionality to mark responses as cache hits to prevent incorrect X-Cache headers. - Introduced setCacheHitFlag function to traverse response writer wrappers and set cache hit flag. - Updated cachingResponseWriter to manage cache hit state and adjust X-Cache header accordingly. - Enhanced ProcessRequest and ProcessResponse methods to utilize new caching logic. feat(extension): Introduce ResponseWriterWrapper and ResponseFinalizer interfaces - Added ResponseWriterWrapper interface for extensions to wrap response writers. - Introduced ResponseFinalizer interface for finalizing responses after processing. refactor(manager): Improve response writer wrapping and finalization - Updated Manager.Handler to wrap response writers through all enabled extensions. - Implemented finalization of response writers after processing requests. test(caching): Add comprehensive integration tests for caching behavior - Created caching_test.go with tests for cache hit/miss, TTL expiration, pattern-based caching, and more. - Ensured that caching logic works correctly for various scenarios including query strings and error responses. test(routing): Add integration tests for routing behavior - Created routing_test.go with tests for route priority, case sensitivity, default routes, and return directives. - Verified that routing behaves as expected with multiple regex routes and named groups.
467 lines
10 KiB
Go
467 lines
10 KiB
Go
package extension
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"io"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/konduktor/konduktor/internal/logging"
|
|
)
|
|
|
|
type CachingExtension struct {
|
|
BaseExtension
|
|
cache map[string]*cacheEntry
|
|
cachePatterns []*cachePattern
|
|
defaultTTL time.Duration
|
|
maxSize int
|
|
currentSize int
|
|
mu sync.RWMutex
|
|
|
|
hits int64
|
|
misses int64
|
|
}
|
|
|
|
type cacheEntry struct {
|
|
key string
|
|
body []byte
|
|
headers http.Header
|
|
statusCode int
|
|
contentType string
|
|
createdAt time.Time
|
|
expiresAt time.Time
|
|
size int
|
|
}
|
|
|
|
type cachePattern struct {
|
|
pattern *regexp.Regexp
|
|
ttl time.Duration
|
|
methods []string
|
|
}
|
|
|
|
type CachingConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
DefaultTTL string `yaml:"default_ttl"` // e.g., "5m", "1h"
|
|
MaxSizeMB int `yaml:"max_size_mb"` // Max cache size in MB
|
|
CachePatterns []PatternConfig `yaml:"cache_patterns"` // Patterns to cache
|
|
}
|
|
|
|
type PatternConfig struct {
|
|
Pattern string `yaml:"pattern"` // Regex pattern
|
|
TTL string `yaml:"ttl"` // TTL for this pattern
|
|
Methods []string `yaml:"methods"` // HTTP methods to cache (default: GET)
|
|
}
|
|
|
|
func NewCachingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
|
|
ext := &CachingExtension{
|
|
BaseExtension: NewBaseExtension("caching", 20, logger),
|
|
cache: make(map[string]*cacheEntry),
|
|
cachePatterns: make([]*cachePattern, 0),
|
|
defaultTTL: 5 * time.Minute,
|
|
maxSize: 100 * 1024 * 1024, // 100MB default
|
|
}
|
|
|
|
if ttl, ok := config["default_ttl"].(string); ok {
|
|
if duration, err := time.ParseDuration(ttl); err == nil {
|
|
ext.defaultTTL = duration
|
|
}
|
|
}
|
|
|
|
if maxSize, ok := config["max_size_mb"].(int); ok {
|
|
ext.maxSize = maxSize * 1024 * 1024
|
|
} else if maxSizeFloat, ok := config["max_size_mb"].(float64); ok {
|
|
ext.maxSize = int(maxSizeFloat) * 1024 * 1024
|
|
}
|
|
|
|
if patterns, ok := config["cache_patterns"].([]interface{}); ok {
|
|
for _, p := range patterns {
|
|
if patternCfg, ok := p.(map[string]interface{}); ok {
|
|
pattern := &cachePattern{
|
|
ttl: ext.defaultTTL,
|
|
methods: []string{"GET"},
|
|
}
|
|
|
|
if patternStr, ok := patternCfg["pattern"].(string); ok {
|
|
re, err := regexp.Compile(patternStr)
|
|
if err != nil {
|
|
logger.Error("Invalid cache pattern", "pattern", patternStr, "error", err)
|
|
continue
|
|
}
|
|
pattern.pattern = re
|
|
}
|
|
|
|
if ttl, ok := patternCfg["ttl"].(string); ok {
|
|
if duration, err := time.ParseDuration(ttl); err == nil {
|
|
pattern.ttl = duration
|
|
}
|
|
}
|
|
|
|
if methods, ok := patternCfg["methods"].([]interface{}); ok {
|
|
pattern.methods = make([]string, 0)
|
|
for _, m := range methods {
|
|
if method, ok := m.(string); ok {
|
|
pattern.methods = append(pattern.methods, strings.ToUpper(method))
|
|
}
|
|
}
|
|
}
|
|
|
|
ext.cachePatterns = append(ext.cachePatterns, pattern)
|
|
}
|
|
}
|
|
}
|
|
|
|
go ext.cleanupLoop()
|
|
|
|
return ext, nil
|
|
}
|
|
|
|
func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
|
if !e.shouldCache(r) {
|
|
return false, nil
|
|
}
|
|
|
|
key := e.cacheKey(r)
|
|
|
|
e.mu.RLock()
|
|
entry, exists := e.cache[key]
|
|
e.mu.RUnlock()
|
|
|
|
if exists && time.Now().Before(entry.expiresAt) {
|
|
e.mu.Lock()
|
|
e.hits++
|
|
e.mu.Unlock()
|
|
|
|
// Mark as cache hit to prevent setting X-Cache: MISS
|
|
// Try to find cachingResponseWriter in the wrapper chain
|
|
setCacheHitFlag(w)
|
|
|
|
for k, values := range entry.headers {
|
|
for _, v := range values {
|
|
w.Header().Add(k, v)
|
|
}
|
|
}
|
|
w.Header().Set("X-Cache", "HIT")
|
|
w.Header().Set("Content-Type", entry.contentType)
|
|
w.WriteHeader(entry.statusCode)
|
|
w.Write(entry.body)
|
|
|
|
e.logger.Debug("Cache hit", "key", key, "path", r.URL.Path)
|
|
return true, nil
|
|
}
|
|
|
|
e.mu.Lock()
|
|
e.misses++
|
|
e.mu.Unlock()
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// setCacheHitFlag tries to find cachingResponseWriter and set cache hit flag
|
|
func setCacheHitFlag(w http.ResponseWriter) {
|
|
// Direct match
|
|
if cw, ok := w.(*cachingResponseWriter); ok {
|
|
cw.SetCacheHit()
|
|
return
|
|
}
|
|
|
|
// Try unwrapping
|
|
type unwrapper interface {
|
|
Unwrap() http.ResponseWriter
|
|
}
|
|
|
|
for {
|
|
if u, ok := w.(unwrapper); ok {
|
|
w = u.Unwrap()
|
|
if cw, ok := w.(*cachingResponseWriter); ok {
|
|
cw.SetCacheHit()
|
|
return
|
|
}
|
|
} else {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// ProcessResponse caches the response if applicable
|
|
func (e *CachingExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
// Response caching is handled by the CachingResponseWriter
|
|
// X-Cache header is set in the cachingResponseWriter.WriteHeader
|
|
}
|
|
|
|
// WrapResponseWriter wraps the response writer to capture the response for caching
|
|
func (e *CachingExtension) WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter {
|
|
if !e.shouldCache(r) {
|
|
return w
|
|
}
|
|
|
|
return &cachingResponseWriter{
|
|
ResponseWriter: w,
|
|
ext: e,
|
|
request: r,
|
|
buffer: &bytes.Buffer{},
|
|
}
|
|
}
|
|
|
|
type cachingResponseWriter struct {
|
|
http.ResponseWriter
|
|
ext *CachingExtension
|
|
request *http.Request
|
|
buffer *bytes.Buffer
|
|
statusCode int
|
|
wroteHeader bool
|
|
cacheHit bool // Flag to indicate if this was a cache hit
|
|
}
|
|
|
|
func (cw *cachingResponseWriter) WriteHeader(code int) {
|
|
if !cw.wroteHeader {
|
|
cw.statusCode = code
|
|
cw.wroteHeader = true
|
|
// Set X-Cache: MISS header before writing headers (only if not a cache hit)
|
|
if !cw.cacheHit {
|
|
cw.ResponseWriter.Header().Set("X-Cache", "MISS")
|
|
}
|
|
cw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
}
|
|
|
|
// SetCacheHit marks this response as a cache hit (to avoid setting X-Cache: MISS)
|
|
func (cw *cachingResponseWriter) SetCacheHit() {
|
|
cw.cacheHit = true
|
|
}
|
|
|
|
func (cw *cachingResponseWriter) Write(b []byte) (int, error) {
|
|
if !cw.wroteHeader {
|
|
cw.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
cw.buffer.Write(b)
|
|
|
|
return cw.ResponseWriter.Write(b)
|
|
}
|
|
|
|
func (cw *cachingResponseWriter) Finalize() {
|
|
if cw.statusCode < 200 || cw.statusCode >= 400 {
|
|
return
|
|
}
|
|
|
|
body := cw.buffer.Bytes()
|
|
if len(body) == 0 {
|
|
return
|
|
}
|
|
|
|
key := cw.ext.cacheKey(cw.request)
|
|
ttl := cw.ext.getTTL(cw.request)
|
|
|
|
entry := &cacheEntry{
|
|
key: key,
|
|
body: body,
|
|
headers: cw.Header().Clone(),
|
|
statusCode: cw.statusCode,
|
|
contentType: cw.Header().Get("Content-Type"),
|
|
createdAt: time.Now(),
|
|
expiresAt: time.Now().Add(ttl),
|
|
size: len(body),
|
|
}
|
|
|
|
cw.ext.store(entry)
|
|
}
|
|
|
|
func (e *CachingExtension) shouldCache(r *http.Request) bool {
|
|
path := r.URL.Path
|
|
method := r.Method
|
|
|
|
for _, pattern := range e.cachePatterns {
|
|
if pattern.pattern.MatchString(path) {
|
|
for _, m := range pattern.methods {
|
|
if m == method {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Default: only cache GET requests
|
|
return method == "GET" && len(e.cachePatterns) == 0
|
|
}
|
|
|
|
func (e *CachingExtension) cacheKey(r *http.Request) string {
|
|
// Create cache key from method + URL + relevant headers
|
|
h := sha256.New()
|
|
h.Write([]byte(r.Method))
|
|
h.Write([]byte(r.URL.String()))
|
|
|
|
// Include Accept-Encoding for vary
|
|
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
|
|
h.Write([]byte(ae))
|
|
}
|
|
|
|
return hex.EncodeToString(h.Sum(nil))
|
|
}
|
|
|
|
func (e *CachingExtension) getTTL(r *http.Request) time.Duration {
|
|
path := r.URL.Path
|
|
|
|
for _, pattern := range e.cachePatterns {
|
|
if pattern.pattern.MatchString(path) {
|
|
return pattern.ttl
|
|
}
|
|
}
|
|
|
|
return e.defaultTTL
|
|
}
|
|
|
|
func (e *CachingExtension) store(entry *cacheEntry) {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
|
|
// Evict old entries if needed
|
|
for e.currentSize+entry.size > e.maxSize && len(e.cache) > 0 {
|
|
e.evictOldest()
|
|
}
|
|
|
|
// Store new entry
|
|
if existing, ok := e.cache[entry.key]; ok {
|
|
e.currentSize -= existing.size
|
|
}
|
|
|
|
e.cache[entry.key] = entry
|
|
e.currentSize += entry.size
|
|
|
|
e.logger.Debug("Cached response",
|
|
"key", entry.key[:16],
|
|
"size", entry.size,
|
|
"ttl", entry.expiresAt.Sub(entry.createdAt).String())
|
|
}
|
|
|
|
func (e *CachingExtension) evictOldest() {
|
|
var oldestKey string
|
|
var oldestTime time.Time
|
|
|
|
for key, entry := range e.cache {
|
|
if oldestKey == "" || entry.createdAt.Before(oldestTime) {
|
|
oldestKey = key
|
|
oldestTime = entry.createdAt
|
|
}
|
|
}
|
|
|
|
if oldestKey != "" {
|
|
entry := e.cache[oldestKey]
|
|
e.currentSize -= entry.size
|
|
delete(e.cache, oldestKey)
|
|
}
|
|
}
|
|
|
|
func (e *CachingExtension) cleanupLoop() {
|
|
ticker := time.NewTicker(time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
e.cleanupExpired()
|
|
}
|
|
}
|
|
|
|
func (e *CachingExtension) cleanupExpired() {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
for key, entry := range e.cache {
|
|
if now.After(entry.expiresAt) {
|
|
e.currentSize -= entry.size
|
|
delete(e.cache, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (e *CachingExtension) Invalidate(key string) {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
|
|
if entry, ok := e.cache[key]; ok {
|
|
e.currentSize -= entry.size
|
|
delete(e.cache, key)
|
|
}
|
|
}
|
|
|
|
// InvalidatePattern removes all entries matching a pattern, unlike Invalidate
|
|
func (e *CachingExtension) InvalidatePattern(pattern string) error {
|
|
re, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
|
|
for key, entry := range e.cache {
|
|
if re.MatchString(key) {
|
|
e.currentSize -= entry.size
|
|
delete(e.cache, key)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Clear removes all entries from cache
|
|
func (e *CachingExtension) Clear() {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
|
|
e.cache = make(map[string]*cacheEntry)
|
|
e.currentSize = 0
|
|
}
|
|
|
|
// GetMetrics returns caching metrics
|
|
func (e *CachingExtension) GetMetrics() map[string]interface{} {
|
|
e.mu.RLock()
|
|
defer e.mu.RUnlock()
|
|
|
|
hitRate := float64(0)
|
|
total := e.hits + e.misses
|
|
if total > 0 {
|
|
hitRate = float64(e.hits) / float64(total) * 100
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"entries": len(e.cache),
|
|
"size_bytes": e.currentSize,
|
|
"max_size": e.maxSize,
|
|
"hits": e.hits,
|
|
"misses": e.misses,
|
|
"hit_rate": hitRate,
|
|
"patterns": len(e.cachePatterns),
|
|
"default_ttl": e.defaultTTL.String(),
|
|
}
|
|
}
|
|
|
|
// Cleanup stops the cleanup goroutine
|
|
func (e *CachingExtension) Cleanup() error {
|
|
e.Clear()
|
|
return nil
|
|
}
|
|
|
|
// CacheReader wraps an io.ReadCloser to cache the body
|
|
type CacheReader struct {
|
|
io.ReadCloser
|
|
buffer *bytes.Buffer
|
|
}
|
|
|
|
func (cr *CacheReader) Read(p []byte) (int, error) {
|
|
n, err := cr.ReadCloser.Read(p)
|
|
if n > 0 {
|
|
cr.buffer.Write(p[:n])
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (cr *CacheReader) GetBody() []byte {
|
|
return cr.buffer.Bytes()
|
|
}
|