From 881028c1e658dfce05c6cf22cd9c67ca0f583566 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=98=D0=BB=D1=8C=D1=8F=20=D0=93=D0=BB=D0=B0=D0=B7=D1=83?= =?UTF-8?q?=D0=BD=D0=BE=D0=B2?= Date: Fri, 12 Dec 2025 00:38:30 +0300 Subject: [PATCH] feat: Add reverse proxy functionality with enhanced routing capabilities - Introduced IgnoreRequestPath option in proxy configuration to allow exact match routing. - Implemented proxy_pass directive in routing extension to handle backend requests. - Enhanced error handling for backend unavailability and timeouts. - Added integration tests for reverse proxy, including basic requests, exact match routes, regex routes, header forwarding, and query string preservation. - Created helper functions for setting up test servers and backends, along with assertion utilities for response validation. - Updated server initialization to support extension management and middleware chaining. - Improved logging for debugging purposes during request handling. --- go/go.mod | 3 + go/go.sum | 20 + go/internal/extension/caching.go | 427 ++++++++++++++++ go/internal/extension/extension.go | 111 ++++ go/internal/extension/manager.go | 234 +++++++++ go/internal/extension/manager_test.go | 176 +++++++ go/internal/extension/routing.go | 428 ++++++++++++++++ go/internal/extension/security.go | 312 ++++++++++++ go/internal/extension/security_test.go | 213 ++++++++ go/internal/logging/logger.go | 372 ++++++++++---- go/internal/logging/logger_test.go | 183 ++++--- go/internal/proxy/proxy.go | 17 +- go/internal/routing/router.go | 105 +++- go/internal/server/server.go | 52 +- go/tests/integration/README.md | 127 +++++ go/tests/integration/helpers_test.go | 408 +++++++++++++++ go/tests/integration/reverse_proxy_test.go | 562 +++++++++++++++++++++ 17 files changed, 3574 insertions(+), 176 deletions(-) create mode 100644 go/go.sum create mode 100644 go/internal/extension/caching.go create mode 100644 go/internal/extension/extension.go create mode 100644 go/internal/extension/manager.go create mode 100644 go/internal/extension/manager_test.go create mode 100644 go/internal/extension/routing.go create mode 100644 go/internal/extension/security.go create mode 100644 go/internal/extension/security_test.go create mode 100644 go/tests/integration/README.md create mode 100644 go/tests/integration/helpers_test.go create mode 100644 go/tests/integration/reverse_proxy_test.go diff --git a/go/go.mod b/go/go.mod index bc5b813..3a2d69a 100644 --- a/go/go.mod +++ b/go/go.mod @@ -12,4 +12,7 @@ require ( require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/spf13/pflag v1.0.10 // indirect + go.uber.org/multierr v1.10.0 // indirect + go.uber.org/zap v1.27.1 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go/go.sum b/go/go.sum new file mode 100644 index 0000000..6dc1cf4 --- /dev/null +++ b/go/go.sum @@ -0,0 +1,20 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go/internal/extension/caching.go b/go/internal/extension/caching.go new file mode 100644 index 0000000..2dce799 --- /dev/null +++ b/go/internal/extension/caching.go @@ -0,0 +1,427 @@ +package extension + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "io" + "net/http" + "regexp" + "strings" + "sync" + "time" + + "github.com/konduktor/konduktor/internal/logging" +) + +type CachingExtension struct { + BaseExtension + cache map[string]*cacheEntry + cachePatterns []*cachePattern + defaultTTL time.Duration + maxSize int + currentSize int + mu sync.RWMutex + + hits int64 + misses int64 +} + +type cacheEntry struct { + key string + body []byte + headers http.Header + statusCode int + contentType string + createdAt time.Time + expiresAt time.Time + size int +} + +type cachePattern struct { + pattern *regexp.Regexp + ttl time.Duration + methods []string +} + +type CachingConfig struct { + Enabled bool `yaml:"enabled"` + DefaultTTL string `yaml:"default_ttl"` // e.g., "5m", "1h" + MaxSizeMB int `yaml:"max_size_mb"` // Max cache size in MB + CachePatterns []PatternConfig `yaml:"cache_patterns"` // Patterns to cache +} + +type PatternConfig struct { + Pattern string `yaml:"pattern"` // Regex pattern + TTL string `yaml:"ttl"` // TTL for this pattern + Methods []string `yaml:"methods"` // HTTP methods to cache (default: GET) +} + +func NewCachingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) { + ext := &CachingExtension{ + BaseExtension: NewBaseExtension("caching", 20, logger), + cache: make(map[string]*cacheEntry), + cachePatterns: make([]*cachePattern, 0), + defaultTTL: 5 * time.Minute, + maxSize: 100 * 1024 * 1024, // 100MB default + } + + if ttl, ok := config["default_ttl"].(string); ok { + if duration, err := time.ParseDuration(ttl); err == nil { + ext.defaultTTL = duration + } + } + + if maxSize, ok := config["max_size_mb"].(int); ok { + ext.maxSize = maxSize * 1024 * 1024 + } else if maxSizeFloat, ok := config["max_size_mb"].(float64); ok { + ext.maxSize = int(maxSizeFloat) * 1024 * 1024 + } + + if patterns, ok := config["cache_patterns"].([]interface{}); ok { + for _, p := range patterns { + if patternCfg, ok := p.(map[string]interface{}); ok { + pattern := &cachePattern{ + ttl: ext.defaultTTL, + methods: []string{"GET"}, + } + + if patternStr, ok := patternCfg["pattern"].(string); ok { + re, err := regexp.Compile(patternStr) + if err != nil { + logger.Error("Invalid cache pattern", "pattern", patternStr, "error", err) + continue + } + pattern.pattern = re + } + + if ttl, ok := patternCfg["ttl"].(string); ok { + if duration, err := time.ParseDuration(ttl); err == nil { + pattern.ttl = duration + } + } + + if methods, ok := patternCfg["methods"].([]interface{}); ok { + pattern.methods = make([]string, 0) + for _, m := range methods { + if method, ok := m.(string); ok { + pattern.methods = append(pattern.methods, strings.ToUpper(method)) + } + } + } + + ext.cachePatterns = append(ext.cachePatterns, pattern) + } + } + } + + go ext.cleanupLoop() + + return ext, nil +} + +func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) { + if !e.shouldCache(r) { + return false, nil + } + + key := e.cacheKey(r) + + e.mu.RLock() + entry, exists := e.cache[key] + e.mu.RUnlock() + + if exists && time.Now().Before(entry.expiresAt) { + e.mu.Lock() + e.hits++ + e.mu.Unlock() + + for k, values := range entry.headers { + for _, v := range values { + w.Header().Add(k, v) + } + } + w.Header().Set("X-Cache", "HIT") + w.Header().Set("Content-Type", entry.contentType) + w.WriteHeader(entry.statusCode) + w.Write(entry.body) + + e.logger.Debug("Cache hit", "key", key, "path", r.URL.Path) + return true, nil + } + + e.mu.Lock() + e.misses++ + e.mu.Unlock() + + return false, nil +} + +// ProcessResponse caches the response if applicable +func (e *CachingExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) { + // Response caching is handled by the CachingResponseWriter + // This is called after the response is written + w.Header().Set("X-Cache", "MISS") +} + +// WrapResponseWriter wraps the response writer to capture the response for caching +func (e *CachingExtension) WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter { + if !e.shouldCache(r) { + return w + } + + return &cachingResponseWriter{ + ResponseWriter: w, + ext: e, + request: r, + buffer: &bytes.Buffer{}, + } +} + +type cachingResponseWriter struct { + http.ResponseWriter + ext *CachingExtension + request *http.Request + buffer *bytes.Buffer + statusCode int + wroteHeader bool +} + +func (cw *cachingResponseWriter) WriteHeader(code int) { + if !cw.wroteHeader { + cw.statusCode = code + cw.wroteHeader = true + cw.ResponseWriter.WriteHeader(code) + } +} + +func (cw *cachingResponseWriter) Write(b []byte) (int, error) { + if !cw.wroteHeader { + cw.WriteHeader(http.StatusOK) + } + + cw.buffer.Write(b) + + return cw.ResponseWriter.Write(b) +} + +func (cw *cachingResponseWriter) Finalize() { + if cw.statusCode < 200 || cw.statusCode >= 400 { + return + } + + body := cw.buffer.Bytes() + if len(body) == 0 { + return + } + + key := cw.ext.cacheKey(cw.request) + ttl := cw.ext.getTTL(cw.request) + + entry := &cacheEntry{ + key: key, + body: body, + headers: cw.Header().Clone(), + statusCode: cw.statusCode, + contentType: cw.Header().Get("Content-Type"), + createdAt: time.Now(), + expiresAt: time.Now().Add(ttl), + size: len(body), + } + + cw.ext.store(entry) +} + +func (e *CachingExtension) shouldCache(r *http.Request) bool { + path := r.URL.Path + method := r.Method + + for _, pattern := range e.cachePatterns { + if pattern.pattern.MatchString(path) { + for _, m := range pattern.methods { + if m == method { + return true + } + } + } + } + + // Default: only cache GET requests + return method == "GET" && len(e.cachePatterns) == 0 +} + +func (e *CachingExtension) cacheKey(r *http.Request) string { + // Create cache key from method + URL + relevant headers + h := sha256.New() + h.Write([]byte(r.Method)) + h.Write([]byte(r.URL.String())) + + // Include Accept-Encoding for vary + if ae := r.Header.Get("Accept-Encoding"); ae != "" { + h.Write([]byte(ae)) + } + + return hex.EncodeToString(h.Sum(nil)) +} + +func (e *CachingExtension) getTTL(r *http.Request) time.Duration { + path := r.URL.Path + + for _, pattern := range e.cachePatterns { + if pattern.pattern.MatchString(path) { + return pattern.ttl + } + } + + return e.defaultTTL +} + +func (e *CachingExtension) store(entry *cacheEntry) { + e.mu.Lock() + defer e.mu.Unlock() + + // Evict old entries if needed + for e.currentSize+entry.size > e.maxSize && len(e.cache) > 0 { + e.evictOldest() + } + + // Store new entry + if existing, ok := e.cache[entry.key]; ok { + e.currentSize -= existing.size + } + + e.cache[entry.key] = entry + e.currentSize += entry.size + + e.logger.Debug("Cached response", + "key", entry.key[:16], + "size", entry.size, + "ttl", entry.expiresAt.Sub(entry.createdAt).String()) +} + +func (e *CachingExtension) evictOldest() { + var oldestKey string + var oldestTime time.Time + + for key, entry := range e.cache { + if oldestKey == "" || entry.createdAt.Before(oldestTime) { + oldestKey = key + oldestTime = entry.createdAt + } + } + + if oldestKey != "" { + entry := e.cache[oldestKey] + e.currentSize -= entry.size + delete(e.cache, oldestKey) + } +} + +func (e *CachingExtension) cleanupLoop() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for range ticker.C { + e.cleanupExpired() + } +} + +func (e *CachingExtension) cleanupExpired() { + e.mu.Lock() + defer e.mu.Unlock() + + now := time.Now() + for key, entry := range e.cache { + if now.After(entry.expiresAt) { + e.currentSize -= entry.size + delete(e.cache, key) + } + } +} + +func (e *CachingExtension) Invalidate(key string) { + e.mu.Lock() + defer e.mu.Unlock() + + if entry, ok := e.cache[key]; ok { + e.currentSize -= entry.size + delete(e.cache, key) + } +} + +// InvalidatePattern removes all entries matching a pattern, unlike Invalidate +func (e *CachingExtension) InvalidatePattern(pattern string) error { + re, err := regexp.Compile(pattern) + if err != nil { + return err + } + + e.mu.Lock() + defer e.mu.Unlock() + + for key, entry := range e.cache { + if re.MatchString(key) { + e.currentSize -= entry.size + delete(e.cache, key) + } + } + + return nil +} + +// Clear removes all entries from cache +func (e *CachingExtension) Clear() { + e.mu.Lock() + defer e.mu.Unlock() + + e.cache = make(map[string]*cacheEntry) + e.currentSize = 0 +} + +// GetMetrics returns caching metrics +func (e *CachingExtension) GetMetrics() map[string]interface{} { + e.mu.RLock() + defer e.mu.RUnlock() + + hitRate := float64(0) + total := e.hits + e.misses + if total > 0 { + hitRate = float64(e.hits) / float64(total) * 100 + } + + return map[string]interface{}{ + "entries": len(e.cache), + "size_bytes": e.currentSize, + "max_size": e.maxSize, + "hits": e.hits, + "misses": e.misses, + "hit_rate": hitRate, + "patterns": len(e.cachePatterns), + "default_ttl": e.defaultTTL.String(), + } +} + +// Cleanup stops the cleanup goroutine +func (e *CachingExtension) Cleanup() error { + e.Clear() + return nil +} + +// CacheReader wraps an io.ReadCloser to cache the body +type CacheReader struct { + io.ReadCloser + buffer *bytes.Buffer +} + +func (cr *CacheReader) Read(p []byte) (int, error) { + n, err := cr.ReadCloser.Read(p) + if n > 0 { + cr.buffer.Write(p[:n]) + } + return n, err +} + +func (cr *CacheReader) GetBody() []byte { + return cr.buffer.Bytes() +} diff --git a/go/internal/extension/extension.go b/go/internal/extension/extension.go new file mode 100644 index 0000000..5767517 --- /dev/null +++ b/go/internal/extension/extension.go @@ -0,0 +1,111 @@ +package extension + +import ( + "context" + "net/http" + + "github.com/konduktor/konduktor/internal/logging" +) + +// Extension is the interface that all extensions must implement +type Extension interface { + // Name returns the unique name of the extension + Name() string + + // Initialize is called when the extension is loaded + Initialize() error + + // ProcessRequest processes an incoming request before routing. + // Returns: + // - response: if non-nil, the request is handled and no further processing occurs + // - handled: if true, the request was handled by this extension + // - err: any error that occurred + ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (handled bool, err error) + + // ProcessResponse is called after the response is generated but before it's sent. + // Extensions can modify the response here. + ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) + + // Cleanup is called when the extension is being unloaded + Cleanup() error + + // Enabled returns whether the extension is currently enabled + Enabled() bool + + // SetEnabled enables or disables the extension + SetEnabled(enabled bool) + + // Priority returns the extension's priority (lower = earlier execution) + Priority() int +} + +// BaseExtension provides a default implementation for common Extension methods +type BaseExtension struct { + name string + enabled bool + priority int + logger *logging.Logger +} + +// NewBaseExtension creates a new BaseExtension +func NewBaseExtension(name string, priority int, logger *logging.Logger) BaseExtension { + return BaseExtension{ + name: name, + enabled: true, + priority: priority, + logger: logger, + } +} + +// Name returns the extension name +func (b *BaseExtension) Name() string { + return b.name +} + +// Enabled returns whether the extension is enabled +func (b *BaseExtension) Enabled() bool { + return b.enabled +} + +// SetEnabled sets the enabled state +func (b *BaseExtension) SetEnabled(enabled bool) { + b.enabled = enabled +} + +// Priority returns the extension priority +func (b *BaseExtension) Priority() int { + return b.priority +} + +// Initialize default implementation (no-op) +func (b *BaseExtension) Initialize() error { + return nil +} + +// Cleanup default implementation (no-op) +func (b *BaseExtension) Cleanup() error { + return nil +} + +// ProcessRequest default implementation (pass-through) +func (b *BaseExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) { + return false, nil +} + +// ProcessResponse default implementation (no-op) +func (b *BaseExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) { +} + +// Logger returns the extension's logger +func (b *BaseExtension) Logger() *logging.Logger { + return b.logger +} + +// ExtensionConfig holds configuration for creating extensions +type ExtensionConfig struct { + Type string + Config map[string]interface{} +} + +// ExtensionFactory is a function that creates an extension from config +type ExtensionFactory func(config map[string]interface{}, logger *logging.Logger) (Extension, error) diff --git a/go/internal/extension/manager.go b/go/internal/extension/manager.go new file mode 100644 index 0000000..0f98fea --- /dev/null +++ b/go/internal/extension/manager.go @@ -0,0 +1,234 @@ +package extension + +import ( + "context" + "fmt" + "net/http" + "sort" + "sync" + + "github.com/konduktor/konduktor/internal/logging" +) + +// Manager manages all loaded extensions +type Manager struct { + extensions []Extension + registry map[string]ExtensionFactory + logger *logging.Logger + mu sync.RWMutex +} + +// NewManager creates a new extension manager +func NewManager(logger *logging.Logger) *Manager { + m := &Manager{ + extensions: make([]Extension, 0), + registry: make(map[string]ExtensionFactory), + logger: logger, + } + + // Register built-in extensions + m.RegisterFactory("routing", NewRoutingExtension) + m.RegisterFactory("security", NewSecurityExtension) + m.RegisterFactory("caching", NewCachingExtension) + + return m +} + +// RegisterFactory registers an extension factory +func (m *Manager) RegisterFactory(name string, factory ExtensionFactory) { + m.mu.Lock() + defer m.mu.Unlock() + m.registry[name] = factory +} + +// LoadExtension loads an extension by type and config +func (m *Manager) LoadExtension(extType string, config map[string]interface{}) error { + m.mu.Lock() + defer m.mu.Unlock() + + factory, ok := m.registry[extType] + if !ok { + return fmt.Errorf("unknown extension type: %s", extType) + } + + ext, err := factory(config, m.logger) + if err != nil { + return fmt.Errorf("failed to create extension %s: %w", extType, err) + } + + if err := ext.Initialize(); err != nil { + return fmt.Errorf("failed to initialize extension %s: %w", extType, err) + } + + m.extensions = append(m.extensions, ext) + + // Sort by priority (lower first) + sort.Slice(m.extensions, func(i, j int) bool { + return m.extensions[i].Priority() < m.extensions[j].Priority() + }) + + m.logger.Info("Loaded extension", "type", extType, "name", ext.Name(), "priority", ext.Priority()) + return nil +} + +// AddExtension adds a pre-created extension +func (m *Manager) AddExtension(ext Extension) error { + m.mu.Lock() + defer m.mu.Unlock() + + if err := ext.Initialize(); err != nil { + return fmt.Errorf("failed to initialize extension %s: %w", ext.Name(), err) + } + + m.extensions = append(m.extensions, ext) + + // Sort by priority + sort.Slice(m.extensions, func(i, j int) bool { + return m.extensions[i].Priority() < m.extensions[j].Priority() + }) + + m.logger.Info("Added extension", "name", ext.Name(), "priority", ext.Priority()) + return nil +} + +// ProcessRequest runs all extensions' ProcessRequest in order +// Returns true if any extension handled the request +func (m *Manager) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) { + m.mu.RLock() + extensions := m.extensions + m.mu.RUnlock() + + for _, ext := range extensions { + if !ext.Enabled() { + continue + } + + handled, err := ext.ProcessRequest(ctx, w, r) + if err != nil { + m.logger.Error("Extension error", "extension", ext.Name(), "error", err) + // Continue to next extension on error + continue + } + + if handled { + return true, nil + } + } + + return false, nil +} + +// ProcessResponse runs all extensions' ProcessResponse in reverse order +func (m *Manager) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) { + m.mu.RLock() + extensions := m.extensions + m.mu.RUnlock() + + // Process in reverse order for response + for i := len(extensions) - 1; i >= 0; i-- { + ext := extensions[i] + if !ext.Enabled() { + continue + } + + ext.ProcessResponse(ctx, w, r) + } +} + +// Cleanup cleans up all extensions +func (m *Manager) Cleanup() { + m.mu.Lock() + defer m.mu.Unlock() + + for _, ext := range m.extensions { + if err := ext.Cleanup(); err != nil { + m.logger.Error("Extension cleanup error", "extension", ext.Name(), "error", err) + } + } + + m.extensions = nil +} + +// GetExtension returns an extension by name +func (m *Manager) GetExtension(name string) Extension { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, ext := range m.extensions { + if ext.Name() == name { + return ext + } + } + return nil +} + +// Extensions returns all loaded extensions +func (m *Manager) Extensions() []Extension { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]Extension, len(m.extensions)) + copy(result, m.extensions) + return result +} + +// Handler returns an http.Handler that processes requests through all extensions +func (m *Manager) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Create response wrapper to capture response for ProcessResponse + wrapper := newResponseWrapper(w) + + // Process request through extensions + handled, err := m.ProcessRequest(ctx, wrapper, r) + if err != nil { + m.logger.Error("Error processing request", "error", err) + } + + if handled { + // Extension handled the request, process response + m.ProcessResponse(ctx, wrapper, r) + return + } + + // No extension handled, pass to next handler + next.ServeHTTP(wrapper, r) + + // Process response + m.ProcessResponse(ctx, wrapper, r) + }) +} + +// responseWrapper wraps http.ResponseWriter to allow response modification +type responseWrapper struct { + http.ResponseWriter + statusCode int + written bool +} + +func newResponseWrapper(w http.ResponseWriter) *responseWrapper { + return &responseWrapper{ + ResponseWriter: w, + statusCode: http.StatusOK, + } +} + +func (rw *responseWrapper) WriteHeader(code int) { + if !rw.written { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) + rw.written = true + } +} + +func (rw *responseWrapper) Write(b []byte) (int, error) { + if !rw.written { + rw.WriteHeader(http.StatusOK) + } + return rw.ResponseWriter.Write(b) +} + +func (rw *responseWrapper) StatusCode() int { + return rw.statusCode +} diff --git a/go/internal/extension/manager_test.go b/go/internal/extension/manager_test.go new file mode 100644 index 0000000..e77f76d --- /dev/null +++ b/go/internal/extension/manager_test.go @@ -0,0 +1,176 @@ +package extension + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/konduktor/konduktor/internal/logging" +) + +func newTestLogger() *logging.Logger { + logger, _ := logging.New(logging.Config{Level: "DEBUG"}) + return logger +} + +func TestNewManager(t *testing.T) { + logger := newTestLogger() + manager := NewManager(logger) + + if manager == nil { + t.Fatal("Expected manager, got nil") + } + + // Check built-in factories are registered + if _, ok := manager.registry["routing"]; !ok { + t.Error("Expected routing factory to be registered") + } + if _, ok := manager.registry["security"]; !ok { + t.Error("Expected security factory to be registered") + } + if _, ok := manager.registry["caching"]; !ok { + t.Error("Expected caching factory to be registered") + } +} + +func TestManager_LoadExtension(t *testing.T) { + logger := newTestLogger() + manager := NewManager(logger) + + err := manager.LoadExtension("security", map[string]interface{}{}) + if err != nil { + t.Errorf("Failed to load security extension: %v", err) + } + + exts := manager.Extensions() + if len(exts) != 1 { + t.Errorf("Expected 1 extension, got %d", len(exts)) + } +} + +func TestManager_LoadExtension_Unknown(t *testing.T) { + logger := newTestLogger() + manager := NewManager(logger) + + err := manager.LoadExtension("unknown", map[string]interface{}{}) + if err == nil { + t.Error("Expected error for unknown extension type") + } +} + +func TestManager_GetExtension(t *testing.T) { + logger := newTestLogger() + manager := NewManager(logger) + + manager.LoadExtension("security", map[string]interface{}{}) + + ext := manager.GetExtension("security") + if ext == nil { + t.Error("Expected to find security extension") + } + + ext = manager.GetExtension("nonexistent") + if ext != nil { + t.Error("Expected nil for nonexistent extension") + } +} + +func TestManager_ProcessRequest(t *testing.T) { + logger := newTestLogger() + manager := NewManager(logger) + + // Load security extension with blocked IP + manager.LoadExtension("security", map[string]interface{}{ + "blocked_ips": []interface{}{"192.168.1.1"}, + }) + + // Create test request + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + + handled, err := manager.ProcessRequest(context.Background(), rr, req) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if !handled { + t.Error("Expected request to be handled (blocked)") + } + + if rr.Code != http.StatusForbidden { + t.Errorf("Expected status 403, got %d", rr.Code) + } +} + +func TestManager_Handler(t *testing.T) { + logger := newTestLogger() + manager := NewManager(logger) + + // Load routing extension with a simple route + manager.LoadExtension("routing", map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "=/health": map[string]interface{}{ + "return": "200 OK", + }, + }, + }) + + baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }) + + handler := manager.Handler(baseHandler) + + // Test health route + req := httptest.NewRequest("GET", "/health", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } +} + +func TestManager_Priority(t *testing.T) { + logger := newTestLogger() + manager := NewManager(logger) + + // Load extensions in any order + manager.LoadExtension("routing", map[string]interface{}{}) // Priority 50 + manager.LoadExtension("security", map[string]interface{}{}) // Priority 10 + manager.LoadExtension("caching", map[string]interface{}{}) // Priority 20 + + exts := manager.Extensions() + if len(exts) != 3 { + t.Fatalf("Expected 3 extensions, got %d", len(exts)) + } + + // Check order by priority + if exts[0].Name() != "security" { + t.Errorf("Expected security first, got %s", exts[0].Name()) + } + if exts[1].Name() != "caching" { + t.Errorf("Expected caching second, got %s", exts[1].Name()) + } + if exts[2].Name() != "routing" { + t.Errorf("Expected routing third, got %s", exts[2].Name()) + } +} + +func TestManager_Cleanup(t *testing.T) { + logger := newTestLogger() + manager := NewManager(logger) + + manager.LoadExtension("security", map[string]interface{}{}) + manager.LoadExtension("routing", map[string]interface{}{}) + + manager.Cleanup() + + exts := manager.Extensions() + if len(exts) != 0 { + t.Errorf("Expected 0 extensions after cleanup, got %d", len(exts)) + } +} diff --git a/go/internal/extension/routing.go b/go/internal/extension/routing.go new file mode 100644 index 0000000..f040dfb --- /dev/null +++ b/go/internal/extension/routing.go @@ -0,0 +1,428 @@ +package extension + +import ( + "context" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + "time" + + "github.com/konduktor/konduktor/internal/logging" + "github.com/konduktor/konduktor/internal/proxy" +) + +// RoutingExtension handles request routing based on patterns +type RoutingExtension struct { + BaseExtension + exactRoutes map[string]RouteConfig + regexRoutes []*regexRoute + defaultRoute *RouteConfig + staticDir string +} + +// RouteConfig holds configuration for a route +type RouteConfig struct { + ProxyPass string + Root string + IndexFile string + Return string + ContentType string + Headers []string + CacheControl string + SPAFallback bool + ExcludePatterns []string + Timeout float64 +} + +type regexRoute struct { + pattern *regexp.Regexp + config RouteConfig + caseSensitive bool + originalExpr string +} + +// NewRoutingExtension creates a new routing extension +func NewRoutingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) { + ext := &RoutingExtension{ + BaseExtension: NewBaseExtension("routing", 50, logger), // Middle priority + exactRoutes: make(map[string]RouteConfig), + regexRoutes: make([]*regexRoute, 0), + staticDir: "./static", + } + + logger.Debug("Routing extension config", "config", config) + + // Parse regex_locations from config + if locations, ok := config["regex_locations"].(map[string]interface{}); ok { + logger.Debug("Found regex_locations", "count", len(locations)) + for pattern, routeCfg := range locations { + logger.Debug("Adding route", "pattern", pattern) + if rc, ok := routeCfg.(map[string]interface{}); ok { + ext.addRoute(pattern, parseRouteConfig(rc)) + } + } + } else { + logger.Warn("No regex_locations found in config", "config_keys", getKeys(config)) + } + + // Parse static_dir if provided + if staticDir, ok := config["static_dir"].(string); ok { + ext.staticDir = staticDir + } + + return ext, nil +} + +func parseRouteConfig(cfg map[string]interface{}) RouteConfig { + rc := RouteConfig{ + IndexFile: "index.html", + } + + if v, ok := cfg["proxy_pass"].(string); ok { + rc.ProxyPass = v + } + if v, ok := cfg["root"].(string); ok { + rc.Root = v + } + if v, ok := cfg["index_file"].(string); ok { + rc.IndexFile = v + } + if v, ok := cfg["return"].(string); ok { + rc.Return = v + } + if v, ok := cfg["content_type"].(string); ok { + rc.ContentType = v + } + if v, ok := cfg["cache_control"].(string); ok { + rc.CacheControl = v + } + if v, ok := cfg["spa_fallback"].(bool); ok { + rc.SPAFallback = v + } + if v, ok := cfg["timeout"].(float64); ok { + rc.Timeout = v + } + + // Parse headers + if headers, ok := cfg["headers"].([]interface{}); ok { + for _, h := range headers { + if header, ok := h.(string); ok { + rc.Headers = append(rc.Headers, header) + } + } + } + + // Parse exclude_patterns + if patterns, ok := cfg["exclude_patterns"].([]interface{}); ok { + for _, p := range patterns { + if pattern, ok := p.(string); ok { + rc.ExcludePatterns = append(rc.ExcludePatterns, pattern) + } + } + } + + return rc +} + +func (e *RoutingExtension) addRoute(pattern string, config RouteConfig) { + switch { + case pattern == "__default__": + e.defaultRoute = &config + + case strings.HasPrefix(pattern, "="): + // Exact match + path := strings.TrimPrefix(pattern, "=") + e.exactRoutes[path] = config + + case strings.HasPrefix(pattern, "~*"): + // Case-insensitive regex + expr := strings.TrimPrefix(pattern, "~*") + re, err := regexp.Compile("(?i)" + expr) + if err != nil { + e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err) + return + } + e.regexRoutes = append(e.regexRoutes, ®exRoute{ + pattern: re, + config: config, + caseSensitive: false, + originalExpr: expr, + }) + + case strings.HasPrefix(pattern, "~"): + // Case-sensitive regex + expr := strings.TrimPrefix(pattern, "~") + re, err := regexp.Compile(expr) + if err != nil { + e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err) + return + } + e.regexRoutes = append(e.regexRoutes, ®exRoute{ + pattern: re, + config: config, + caseSensitive: true, + originalExpr: expr, + }) + } +} + +// ProcessRequest handles the request routing +func (e *RoutingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) { + path := r.URL.Path + + // 1. Check exact routes (ignore request path for proxy) + if config, ok := e.exactRoutes[path]; ok { + return e.handleRoute(w, r, config, nil, true) + } + + // 2. Check regex routes + for _, route := range e.regexRoutes { + match := route.pattern.FindStringSubmatch(path) + if match != nil { + params := make(map[string]string) + names := route.pattern.SubexpNames() + for i, name := range names { + if i > 0 && name != "" && i < len(match) { + params[name] = match[i] + } + } + return e.handleRoute(w, r, route.config, params, false) + } + } + + // 3. Check default route + if e.defaultRoute != nil { + return e.handleRoute(w, r, *e.defaultRoute, nil, false) + } + + return false, nil +} + +func (e *RoutingExtension) handleRoute(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) { + // Handle "return" directive + if config.Return != "" { + return e.handleReturn(w, config) + } + + // Handle proxy_pass + if config.ProxyPass != "" { + return e.handleProxy(w, r, config, params, exactMatch) + } + + // Handle static files with root + if config.Root != "" { + return e.handleStatic(w, r, config) + } + + // Handle SPA fallback + if config.SPAFallback { + return e.handleSPAFallback(w, r, config) + } + + return false, nil +} + +func (e *RoutingExtension) handleReturn(w http.ResponseWriter, config RouteConfig) (bool, error) { + parts := strings.SplitN(config.Return, " ", 2) + statusCode := 200 + body := "OK" + + if len(parts) >= 1 { + switch parts[0] { + case "200": + statusCode = 200 + case "201": + statusCode = 201 + case "301": + statusCode = 301 + case "302": + statusCode = 302 + case "400": + statusCode = 400 + case "404": + statusCode = 404 + case "500": + statusCode = 500 + } + } + if len(parts) >= 2 { + body = parts[1] + } + + contentType := "text/plain" + if config.ContentType != "" { + contentType = config.ContentType + } + + w.Header().Set("Content-Type", contentType) + w.WriteHeader(statusCode) + w.Write([]byte(body)) + return true, nil +} + +func (e *RoutingExtension) handleProxy(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) { + target := config.ProxyPass + + // Check if target URL contains parameter placeholders + hasParams := strings.Contains(target, "{") && strings.Contains(target, "}") + + // Substitute params in target URL + for key, value := range params { + target = strings.ReplaceAll(target, "{"+key+"}", value) + } + + // Create proxy config + // IgnoreRequestPath=true when: + // - exact match route (=/path) + // - target URL had parameter substitutions (the target path is fully specified) + proxyConfig := &proxy.Config{ + Target: target, + Headers: make(map[string]string), + IgnoreRequestPath: exactMatch || hasParams, + } + + // Set timeout if specified + if config.Timeout > 0 { + proxyConfig.Timeout = time.Duration(config.Timeout * float64(time.Second)) + } + + // Parse headers + clientIP := getClientIP(r) + for _, header := range config.Headers { + parts := strings.SplitN(header, ": ", 2) + if len(parts) == 2 { + value := parts[1] + // Substitute params + for key, pValue := range params { + value = strings.ReplaceAll(value, "{"+key+"}", pValue) + } + // Substitute special variables + value = strings.ReplaceAll(value, "$remote_addr", clientIP) + proxyConfig.Headers[parts[0]] = value + } + } + + p, err := proxy.New(proxyConfig, e.logger) + if err != nil { + e.logger.Error("Failed to create proxy", "target", target, "error", err) + http.Error(w, "Bad Gateway", http.StatusBadGateway) + return true, nil + } + + p.ProxyRequest(w, r, params) + return true, nil +} + +func (e *RoutingExtension) handleStatic(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) { + path := r.URL.Path + + // Handle index file for root or directory paths + if path == "/" || strings.HasSuffix(path, "/") { + path = "/" + config.IndexFile + } + + // Get absolute path for root dir + absRoot, err := filepath.Abs(config.Root) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return true, nil + } + + filePath := filepath.Join(absRoot, filepath.Clean("/"+path)) + cleanPath := filepath.Clean(filePath) + + // Prevent directory traversal + if !strings.HasPrefix(cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) { + http.Error(w, "Forbidden", http.StatusForbidden) + return true, nil + } + + // Check if file exists + if _, err := os.Stat(filePath); os.IsNotExist(err) { + return false, nil // Let other handlers try + } + + // Set cache control header + if config.CacheControl != "" { + w.Header().Set("Cache-Control", config.CacheControl) + } + + // Set custom headers + for _, header := range config.Headers { + parts := strings.SplitN(header, ": ", 2) + if len(parts) == 2 { + w.Header().Set(parts[0], parts[1]) + } + } + + http.ServeFile(w, r, filePath) + return true, nil +} + +func (e *RoutingExtension) handleSPAFallback(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) { + path := r.URL.Path + + // Check exclude patterns + for _, pattern := range config.ExcludePatterns { + if strings.HasPrefix(path, pattern) { + return false, nil + } + } + + root := config.Root + if root == "" { + root = e.staticDir + } + + indexFile := config.IndexFile + if indexFile == "" { + indexFile = "index.html" + } + + filePath := filepath.Join(root, indexFile) + if _, err := os.Stat(filePath); os.IsNotExist(err) { + return false, nil + } + + http.ServeFile(w, r, filePath) + return true, nil +} + +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header first + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + parts := strings.Split(xff, ",") + return strings.TrimSpace(parts[0]) + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // Fall back to RemoteAddr + ip := r.RemoteAddr + if idx := strings.LastIndex(ip, ":"); idx != -1 { + ip = ip[:idx] + } + return ip +} + +// GetMetrics returns routing metrics +func (e *RoutingExtension) GetMetrics() map[string]interface{} { + return map[string]interface{}{ + "exact_routes": len(e.exactRoutes), + "regex_routes": len(e.regexRoutes), + "has_default": e.defaultRoute != nil, + } +} + +func getKeys(m map[string]interface{}) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/go/internal/extension/security.go b/go/internal/extension/security.go new file mode 100644 index 0000000..31ed6c6 --- /dev/null +++ b/go/internal/extension/security.go @@ -0,0 +1,312 @@ +package extension + +import ( + "context" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/konduktor/konduktor/internal/logging" +) + +// SecurityExtension provides security features like IP filtering and security headers +type SecurityExtension struct { + BaseExtension + allowedIPs map[string]bool + blockedIPs map[string]bool + allowedCIDRs []*net.IPNet + blockedCIDRs []*net.IPNet + securityHeaders map[string]string + + // Rate limiting + rateLimitEnabled bool + rateLimitRequests int + rateLimitWindow time.Duration + rateLimitByIP map[string]*rateLimitEntry + rateLimitMu sync.RWMutex +} + +type rateLimitEntry struct { + count int + resetTime time.Time +} + +// SecurityConfig holds security extension configuration +type SecurityConfig struct { + AllowedIPs []string `yaml:"allowed_ips"` + BlockedIPs []string `yaml:"blocked_ips"` + SecurityHeaders map[string]string `yaml:"security_headers"` + RateLimit *RateLimitConfig `yaml:"rate_limit"` +} + +// RateLimitConfig holds rate limiting configuration +type RateLimitConfig struct { + Enabled bool `yaml:"enabled"` + Requests int `yaml:"requests"` + Window string `yaml:"window"` // e.g., "1m", "1h" +} + +// Default security headers +var defaultSecurityHeaders = map[string]string{ + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "1; mode=block", + "Referrer-Policy": "strict-origin-when-cross-origin", +} + +// NewSecurityExtension creates a new security extension +func NewSecurityExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) { + ext := &SecurityExtension{ + BaseExtension: NewBaseExtension("security", 10, logger), // High priority (early execution) + allowedIPs: make(map[string]bool), + blockedIPs: make(map[string]bool), + allowedCIDRs: make([]*net.IPNet, 0), + blockedCIDRs: make([]*net.IPNet, 0), + securityHeaders: make(map[string]string), + rateLimitByIP: make(map[string]*rateLimitEntry), + } + + // Copy default security headers + for k, v := range defaultSecurityHeaders { + ext.securityHeaders[k] = v + } + + // Parse allowed_ips + if allowedIPs, ok := config["allowed_ips"].([]interface{}); ok { + for _, ip := range allowedIPs { + if ipStr, ok := ip.(string); ok { + ext.addAllowedIP(ipStr) + } + } + } + + // Parse blocked_ips + if blockedIPs, ok := config["blocked_ips"].([]interface{}); ok { + for _, ip := range blockedIPs { + if ipStr, ok := ip.(string); ok { + ext.addBlockedIP(ipStr) + } + } + } + + // Parse security_headers + if headers, ok := config["security_headers"].(map[string]interface{}); ok { + for k, v := range headers { + if vStr, ok := v.(string); ok { + ext.securityHeaders[k] = vStr + } + } + } + + // Parse rate_limit + if rateLimit, ok := config["rate_limit"].(map[string]interface{}); ok { + if enabled, ok := rateLimit["enabled"].(bool); ok && enabled { + ext.rateLimitEnabled = true + + if requests, ok := rateLimit["requests"].(int); ok { + ext.rateLimitRequests = requests + } else if requestsFloat, ok := rateLimit["requests"].(float64); ok { + ext.rateLimitRequests = int(requestsFloat) + } else { + ext.rateLimitRequests = 100 // default + } + + if window, ok := rateLimit["window"].(string); ok { + if duration, err := time.ParseDuration(window); err == nil { + ext.rateLimitWindow = duration + } else { + ext.rateLimitWindow = time.Minute // default + } + } else { + ext.rateLimitWindow = time.Minute + } + + logger.Info("Rate limiting enabled", + "requests", ext.rateLimitRequests, + "window", ext.rateLimitWindow.String()) + } + } + + return ext, nil +} + +func (e *SecurityExtension) addAllowedIP(ip string) { + if strings.Contains(ip, "/") { + // CIDR notation + _, cidr, err := net.ParseCIDR(ip) + if err == nil { + e.allowedCIDRs = append(e.allowedCIDRs, cidr) + } + } else { + e.allowedIPs[ip] = true + } +} + +func (e *SecurityExtension) addBlockedIP(ip string) { + if strings.Contains(ip, "/") { + // CIDR notation + _, cidr, err := net.ParseCIDR(ip) + if err == nil { + e.blockedCIDRs = append(e.blockedCIDRs, cidr) + } + } else { + e.blockedIPs[ip] = true + } +} + +// ProcessRequest checks security rules +func (e *SecurityExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) { + clientIP := getClientIP(r) + parsedIP := net.ParseIP(clientIP) + + // Check blocked IPs first + if e.isBlocked(clientIP, parsedIP) { + e.logger.Warn("Blocked request from IP", "ip", clientIP) + http.Error(w, "403 Forbidden", http.StatusForbidden) + return true, nil + } + + // Check allowed IPs (if configured, only these IPs are allowed) + if len(e.allowedIPs) > 0 || len(e.allowedCIDRs) > 0 { + if !e.isAllowed(clientIP, parsedIP) { + e.logger.Warn("Access denied for IP", "ip", clientIP) + http.Error(w, "403 Forbidden", http.StatusForbidden) + return true, nil + } + } + + // Check rate limit + if e.rateLimitEnabled { + if !e.checkRateLimit(clientIP) { + e.logger.Warn("Rate limit exceeded", "ip", clientIP) + w.Header().Set("Retry-After", "60") + http.Error(w, "429 Too Many Requests", http.StatusTooManyRequests) + return true, nil + } + } + + return false, nil +} + +// ProcessResponse adds security headers to the response +func (e *SecurityExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) { + for header, value := range e.securityHeaders { + w.Header().Set(header, value) + } +} + +func (e *SecurityExtension) isBlocked(ip string, parsedIP net.IP) bool { + // Check exact match + if e.blockedIPs[ip] { + return true + } + + // Check CIDR ranges + if parsedIP != nil { + for _, cidr := range e.blockedCIDRs { + if cidr.Contains(parsedIP) { + return true + } + } + } + + return false +} + +func (e *SecurityExtension) isAllowed(ip string, parsedIP net.IP) bool { + // Check exact match + if e.allowedIPs[ip] { + return true + } + + // Check CIDR ranges + if parsedIP != nil { + for _, cidr := range e.allowedCIDRs { + if cidr.Contains(parsedIP) { + return true + } + } + } + + return false +} + +func (e *SecurityExtension) checkRateLimit(ip string) bool { + e.rateLimitMu.Lock() + defer e.rateLimitMu.Unlock() + + now := time.Now() + entry, exists := e.rateLimitByIP[ip] + + if !exists || now.After(entry.resetTime) { + // Create new entry or reset expired one + e.rateLimitByIP[ip] = &rateLimitEntry{ + count: 1, + resetTime: now.Add(e.rateLimitWindow), + } + return true + } + + // Increment counter + entry.count++ + return entry.count <= e.rateLimitRequests +} + +// AddBlockedIP adds an IP to the blocked list at runtime +func (e *SecurityExtension) AddBlockedIP(ip string) { + e.addBlockedIP(ip) +} + +// RemoveBlockedIP removes an IP from the blocked list +func (e *SecurityExtension) RemoveBlockedIP(ip string) { + delete(e.blockedIPs, ip) +} + +// AddAllowedIP adds an IP to the allowed list at runtime +func (e *SecurityExtension) AddAllowedIP(ip string) { + e.addAllowedIP(ip) +} + +// RemoveAllowedIP removes an IP from the allowed list +func (e *SecurityExtension) RemoveAllowedIP(ip string) { + delete(e.allowedIPs, ip) +} + +// SetSecurityHeader sets or updates a security header +func (e *SecurityExtension) SetSecurityHeader(name, value string) { + e.securityHeaders[name] = value +} + +// GetMetrics returns security metrics +func (e *SecurityExtension) GetMetrics() map[string]interface{} { + e.rateLimitMu.RLock() + activeRateLimits := len(e.rateLimitByIP) + e.rateLimitMu.RUnlock() + + return map[string]interface{}{ + "allowed_ips": len(e.allowedIPs), + "allowed_cidrs": len(e.allowedCIDRs), + "blocked_ips": len(e.blockedIPs), + "blocked_cidrs": len(e.blockedCIDRs), + "security_headers": len(e.securityHeaders), + "rate_limit_enabled": e.rateLimitEnabled, + "active_rate_limits": activeRateLimits, + } +} + +// Cleanup cleans up rate limit entries periodically +func (e *SecurityExtension) Cleanup() error { + e.rateLimitMu.Lock() + defer e.rateLimitMu.Unlock() + + now := time.Now() + for ip, entry := range e.rateLimitByIP { + if now.After(entry.resetTime) { + delete(e.rateLimitByIP, ip) + } + } + + return nil +} diff --git a/go/internal/extension/security_test.go b/go/internal/extension/security_test.go new file mode 100644 index 0000000..3f6d7fd --- /dev/null +++ b/go/internal/extension/security_test.go @@ -0,0 +1,213 @@ +package extension + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewSecurityExtension(t *testing.T) { + logger := newTestLogger() + + ext, err := NewSecurityExtension(map[string]interface{}{}, logger) + if err != nil { + t.Fatalf("Failed to create security extension: %v", err) + } + + if ext.Name() != "security" { + t.Errorf("Expected name 'security', got %s", ext.Name()) + } + + if ext.Priority() != 10 { + t.Errorf("Expected priority 10, got %d", ext.Priority()) + } +} + +func TestSecurityExtension_BlockedIP(t *testing.T) { + logger := newTestLogger() + + ext, _ := NewSecurityExtension(map[string]interface{}{ + "blocked_ips": []interface{}{"192.168.1.100"}, + }, logger) + + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + rr := httptest.NewRecorder() + + handled, err := ext.ProcessRequest(context.Background(), rr, req) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if !handled { + t.Error("Expected blocked request to be handled") + } + + if rr.Code != http.StatusForbidden { + t.Errorf("Expected status 403, got %d", rr.Code) + } +} + +func TestSecurityExtension_AllowedIP(t *testing.T) { + logger := newTestLogger() + + ext, _ := NewSecurityExtension(map[string]interface{}{ + "allowed_ips": []interface{}{"192.168.1.50"}, + }, logger) + + // Allowed IP + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.50:12345" + rr := httptest.NewRecorder() + + handled, _ := ext.ProcessRequest(context.Background(), rr, req) + if handled { + t.Error("Expected allowed IP to pass through") + } + + // Not allowed IP + req = httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.51:12345" + rr = httptest.NewRecorder() + + handled, _ = ext.ProcessRequest(context.Background(), rr, req) + if !handled { + t.Error("Expected non-allowed IP to be blocked") + } + + if rr.Code != http.StatusForbidden { + t.Errorf("Expected status 403, got %d", rr.Code) + } +} + +func TestSecurityExtension_CIDR(t *testing.T) { + logger := newTestLogger() + + ext, _ := NewSecurityExtension(map[string]interface{}{ + "blocked_ips": []interface{}{"10.0.0.0/8"}, + }, logger) + + // IP in blocked CIDR + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.1.2.3:12345" + rr := httptest.NewRecorder() + + handled, _ := ext.ProcessRequest(context.Background(), rr, req) + if !handled { + t.Error("Expected IP in blocked CIDR to be blocked") + } + + // IP not in blocked CIDR + req = httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + rr = httptest.NewRecorder() + + handled, _ = ext.ProcessRequest(context.Background(), rr, req) + if handled { + t.Error("Expected IP not in blocked CIDR to pass through") + } +} + +func TestSecurityExtension_SecurityHeaders(t *testing.T) { + logger := newTestLogger() + + ext, _ := NewSecurityExtension(map[string]interface{}{ + "security_headers": map[string]interface{}{ + "X-Custom-Header": "custom-value", + }, + }, logger) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + ext.ProcessResponse(context.Background(), rr, req) + + // Check default headers + if rr.Header().Get("X-Content-Type-Options") != "nosniff" { + t.Error("Expected X-Content-Type-Options header") + } + + // Check custom header + if rr.Header().Get("X-Custom-Header") != "custom-value" { + t.Error("Expected custom header") + } +} + +func TestSecurityExtension_RateLimit(t *testing.T) { + logger := newTestLogger() + + ext, _ := NewSecurityExtension(map[string]interface{}{ + "rate_limit": map[string]interface{}{ + "enabled": true, + "requests": 2, + "window": "1m", + }, + }, logger) + + securityExt := ext.(*SecurityExtension) + clientIP := "192.168.1.1" + + // First request - should pass + if !securityExt.checkRateLimit(clientIP) { + t.Error("First request should pass") + } + + // Second request - should pass + if !securityExt.checkRateLimit(clientIP) { + t.Error("Second request should pass") + } + + // Third request - should be rate limited + if securityExt.checkRateLimit(clientIP) { + t.Error("Third request should be rate limited") + } +} + +func TestSecurityExtension_GetMetrics(t *testing.T) { + logger := newTestLogger() + + ext, _ := NewSecurityExtension(map[string]interface{}{ + "blocked_ips": []interface{}{"192.168.1.1"}, + "allowed_ips": []interface{}{"192.168.1.2"}, + }, logger) + + securityExt := ext.(*SecurityExtension) + metrics := securityExt.GetMetrics() + + if metrics["blocked_ips"].(int) != 1 { + t.Errorf("Expected 1 blocked IP, got %v", metrics["blocked_ips"]) + } + + if metrics["allowed_ips"].(int) != 1 { + t.Errorf("Expected 1 allowed IP, got %v", metrics["allowed_ips"]) + } +} + +func TestSecurityExtension_AddRemoveIPs(t *testing.T) { + logger := newTestLogger() + + ext, _ := NewSecurityExtension(map[string]interface{}{}, logger) + securityExt := ext.(*SecurityExtension) + + // Add blocked IP + securityExt.AddBlockedIP("192.168.1.100") + + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + rr := httptest.NewRecorder() + + handled, _ := ext.ProcessRequest(context.Background(), rr, req) + if !handled { + t.Error("Expected dynamically blocked IP to be blocked") + } + + // Remove blocked IP + securityExt.RemoveBlockedIP("192.168.1.100") + + rr = httptest.NewRecorder() + handled, _ = ext.ProcessRequest(context.Background(), rr, req) + if handled { + t.Error("Expected removed blocked IP to pass through") + } +} diff --git a/go/internal/logging/logger.go b/go/internal/logging/logger.go index fe50509..0a4eab6 100644 --- a/go/internal/logging/logger.go +++ b/go/internal/logging/logger.go @@ -1,136 +1,338 @@ +// Package logging provides structured logging with zap package logging import ( "fmt" "os" - "time" + "path/filepath" + "strings" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/natefinch/lumberjack.v2" "github.com/konduktor/konduktor/internal/config" ) +// Config is a simple configuration for basic logger setup type Config struct { Level string TimestampFormat string } +// Logger wraps zap.SugaredLogger with additional functionality type Logger struct { - level string - timestampFormat string - configFull *config.LoggingConfig + *zap.SugaredLogger + zap *zap.Logger + config *config.LoggingConfig + name string } +// New creates a new Logger with basic configuration func New(cfg Config) (*Logger, error) { + level := parseLevel(cfg.Level) + timestampFormat := cfg.TimestampFormat if timestampFormat == "" { timestampFormat = "2006-01-02 15:04:05" } + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.TimeKey = "timestamp" + encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(timestampFormat) + encoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder + + core := zapcore.NewCore( + zapcore.NewConsoleEncoder(encoderConfig), + zapcore.AddSync(os.Stdout), + level, + ) + + zapLogger := zap.New(core) return &Logger{ - level: cfg.Level, - timestampFormat: timestampFormat, + SugaredLogger: zapLogger.Sugar(), + zap: zapLogger, + name: "konduktor", }, nil } +// NewFromConfig creates a Logger from full LoggingConfig func NewFromConfig(cfg config.LoggingConfig) (*Logger, error) { - timestampFormat := cfg.Format.TimestampFormat + var cores []zapcore.Core + + // Parse main level + mainLevel := parseLevel(cfg.Level) + + // Add console core if enabled + if cfg.ConsoleOutput { + consoleLevel := mainLevel + if cfg.Console != nil && cfg.Console.Level != "" { + consoleLevel = parseLevel(cfg.Console.Level) + } + + var consoleEncoder zapcore.Encoder + formatConfig := cfg.Format + if cfg.Console != nil { + formatConfig = mergeFormatConfig(cfg.Format, cfg.Console.Format) + } + + encoderCfg := createEncoderConfig(formatConfig) + if formatConfig.Type == "json" { + consoleEncoder = zapcore.NewJSONEncoder(encoderCfg) + } else { + if formatConfig.UseColors { + encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder + } + consoleEncoder = zapcore.NewConsoleEncoder(encoderCfg) + } + + consoleSyncer := zapcore.AddSync(os.Stdout) + cores = append(cores, zapcore.NewCore(consoleEncoder, consoleSyncer, consoleLevel)) + } + + // Add file cores + for _, fileConfig := range cfg.Files { + fileCore, err := createFileCore(fileConfig, cfg.Format, mainLevel) + if err != nil { + return nil, fmt.Errorf("failed to create file logger for %s: %w", fileConfig.Path, err) + } + + // If specific loggers are configured, wrap with filter + if len(fileConfig.Loggers) > 0 { + fileCore = &filteredCore{ + Core: fileCore, + loggers: fileConfig.Loggers, + } + } + + cores = append(cores, fileCore) + } + + // If no cores configured, add default console + if len(cores) == 0 { + encoderCfg := zap.NewProductionEncoderConfig() + encoderCfg.EncodeTime = zapcore.TimeEncoderOfLayout("2006-01-02 15:04:05") + encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder + cores = append(cores, zapcore.NewCore( + zapcore.NewConsoleEncoder(encoderCfg), + zapcore.AddSync(os.Stdout), + mainLevel, + )) + } + + // Combine all cores + core := zapcore.NewTee(cores...) + zapLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1)) + + return &Logger{ + SugaredLogger: zapLogger.Sugar(), + zap: zapLogger, + config: &cfg, + name: "konduktor", + }, nil +} + +// Named returns a logger with a specific name (for filtering) +func (l *Logger) Named(name string) *Logger { + return &Logger{ + SugaredLogger: l.SugaredLogger.Named(name), + zap: l.zap.Named(name), + config: l.config, + name: name, + } +} + +// With returns a logger with additional fields +func (l *Logger) With(args ...interface{}) *Logger { + return &Logger{ + SugaredLogger: l.SugaredLogger.With(args...), + zap: l.zap.Sugar().With(args...).Desugar(), + config: l.config, + name: l.name, + } +} + +// Sync flushes any buffered log entries +func (l *Logger) Sync() error { + return l.zap.Sync() +} + +// GetZap returns the underlying zap.Logger +func (l *Logger) GetZap() *zap.Logger { + return l.zap +} + +// Debug logs a debug message +func (l *Logger) Debug(msg string, keysAndValues ...interface{}) { + l.SugaredLogger.Debugw(msg, keysAndValues...) +} + +// Info logs an info message +func (l *Logger) Info(msg string, keysAndValues ...interface{}) { + l.SugaredLogger.Infow(msg, keysAndValues...) +} + +// Warn logs a warning message +func (l *Logger) Warn(msg string, keysAndValues ...interface{}) { + l.SugaredLogger.Warnw(msg, keysAndValues...) +} + +// Error logs an error message +func (l *Logger) Error(msg string, keysAndValues ...interface{}) { + l.SugaredLogger.Errorw(msg, keysAndValues...) +} + +// Fatal logs a fatal message and exits +func (l *Logger) Fatal(msg string, keysAndValues ...interface{}) { + l.SugaredLogger.Fatalw(msg, keysAndValues...) +} + +// --- Helper functions --- + +func parseLevel(level string) zapcore.Level { + switch strings.ToUpper(level) { + case "DEBUG": + return zapcore.DebugLevel + case "INFO": + return zapcore.InfoLevel + case "WARN", "WARNING": + return zapcore.WarnLevel + case "ERROR": + return zapcore.ErrorLevel + case "CRITICAL", "FATAL": + return zapcore.FatalLevel + default: + return zapcore.InfoLevel + } +} + +func createEncoderConfig(format config.LogFormatConfig) zapcore.EncoderConfig { + timestampFormat := format.TimestampFormat if timestampFormat == "" { timestampFormat = "2006-01-02 15:04:05" } - return &Logger{ - level: cfg.Level, - timestampFormat: timestampFormat, - configFull: &cfg, - }, nil + cfg := zapcore.EncoderConfig{ + TimeKey: "timestamp", + LevelKey: "level", + NameKey: "logger", + CallerKey: "caller", + FunctionKey: zapcore.OmitKey, + MessageKey: "msg", + StacktraceKey: "stacktrace", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.TimeEncoderOfLayout(timestampFormat), + EncodeDuration: zapcore.SecondsDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + + if !format.ShowModule { + cfg.NameKey = zapcore.OmitKey + } + + return cfg } -func (l *Logger) formatTime() string { - return time.Now().Format(l.timestampFormat) +func mergeFormatConfig(base, override config.LogFormatConfig) config.LogFormatConfig { + result := base + if override.Type != "" { + result.Type = override.Type + } + if override.TimestampFormat != "" { + result.TimestampFormat = override.TimestampFormat + } + // UseColors and ShowModule are bool - check if override has non-default + result.UseColors = override.UseColors + result.ShowModule = override.ShowModule + return result } -func (l *Logger) log(level string, msg string, fields ...interface{}) { - timestamp := l.formatTime() - - // Simple console output for now - // TODO: Implement proper structured logging with zap - output := timestamp + " [" + level + "] " + msg - - if len(fields) > 0 { - output += " {" - for i := 0; i < len(fields); i += 2 { - if i > 0 { - output += ", " - } - if i+1 < len(fields) { - output += fields[i].(string) + "=" + formatValue(fields[i+1]) - } +func createFileCore(fileConfig config.FileLogConfig, defaultFormat config.LogFormatConfig, defaultLevel zapcore.Level) (zapcore.Core, error) { + // Ensure directory exists + dir := filepath.Dir(fileConfig.Path) + if dir != "" && dir != "." { + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create log directory %s: %w", dir, err) } - output += "}" } - os.Stdout.WriteString(output + "\n") + // Configure log rotation with lumberjack + maxSize := 10 // MB + if fileConfig.MaxBytes > 0 { + maxSize = int(fileConfig.MaxBytes / (1024 * 1024)) + if maxSize < 1 { + maxSize = 1 + } + } + + backupCount := 5 + if fileConfig.BackupCount > 0 { + backupCount = fileConfig.BackupCount + } + + rotator := &lumberjack.Logger{ + Filename: fileConfig.Path, + MaxSize: maxSize, + MaxBackups: backupCount, + MaxAge: 30, // days + Compress: true, + } + + // Determine level + level := defaultLevel + if fileConfig.Level != "" { + level = parseLevel(fileConfig.Level) + } + + // Create encoder + format := defaultFormat + if fileConfig.Format.Type != "" { + format = mergeFormatConfig(defaultFormat, fileConfig.Format) + } + // Files should not use colors + format.UseColors = false + + encoderConfig := createEncoderConfig(format) + var encoder zapcore.Encoder + if format.Type == "json" { + encoder = zapcore.NewJSONEncoder(encoderConfig) + } else { + encoder = zapcore.NewConsoleEncoder(encoderConfig) + } + + return zapcore.NewCore(encoder, zapcore.AddSync(rotator), level), nil } -func formatValue(v interface{}) string { - switch val := v.(type) { - case string: - return val - case int: - return fmt.Sprintf("%d", val) - case int64: - return fmt.Sprintf("%d", val) - case float64: - return fmt.Sprintf("%.2f", val) - case bool: - return fmt.Sprintf("%t", val) - case error: - return val.Error() - default: - return fmt.Sprintf("%v", val) - } +// filteredCore wraps a Core to filter by logger name +type filteredCore struct { + zapcore.Core + loggers []string } -func (l *Logger) Debug(msg string, fields ...interface{}) { - if l.shouldLog("DEBUG") { - l.log("DEBUG", msg, fields...) +func (c *filteredCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry { + if !c.shouldLog(entry.LoggerName) { + return ce } + return c.Core.Check(entry, ce) } -func (l *Logger) Info(msg string, fields ...interface{}) { - if l.shouldLog("INFO") { - l.log("INFO", msg, fields...) +func (c *filteredCore) shouldLog(loggerName string) bool { + if len(c.loggers) == 0 { + return true } + + for _, allowed := range c.loggers { + if loggerName == allowed || strings.HasPrefix(loggerName, allowed+".") { + return true + } + } + return false } -func (l *Logger) Warn(msg string, fields ...interface{}) { - if l.shouldLog("WARN") { - l.log("WARN", msg, fields...) +func (c *filteredCore) With(fields []zapcore.Field) zapcore.Core { + return &filteredCore{ + Core: c.Core.With(fields), + loggers: c.loggers, } } - -func (l *Logger) Error(msg string, fields ...interface{}) { - if l.shouldLog("ERROR") { - l.log("ERROR", msg, fields...) - } -} - -func (l *Logger) shouldLog(level string) bool { - levels := map[string]int{ - "DEBUG": 0, - "INFO": 1, - "WARN": 2, - "ERROR": 3, - } - - currentLevel, ok := levels[l.level] - if !ok { - currentLevel = 1 // Default to INFO - } - - msgLevel, ok := levels[level] - if !ok { - msgLevel = 1 - } - - return msgLevel >= currentLevel -} diff --git a/go/internal/logging/logger_test.go b/go/internal/logging/logger_test.go index 342ac53..0e6c6b6 100644 --- a/go/internal/logging/logger_test.go +++ b/go/internal/logging/logger_test.go @@ -2,6 +2,8 @@ package logging import ( "testing" + + "github.com/konduktor/konduktor/internal/config" ) func TestNew(t *testing.T) { @@ -15,71 +17,84 @@ func TestNew(t *testing.T) { t.Fatal("Expected logger, got nil") } - if logger.level != "INFO" { - t.Errorf("Expected level INFO, got %s", logger.level) + if logger.name != "konduktor" { + t.Errorf("Expected name konduktor, got %s", logger.name) } } func TestNew_DefaultTimestampFormat(t *testing.T) { - logger, _ := New(Config{Level: "DEBUG"}) + logger, err := New(Config{Level: "DEBUG"}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - if logger.timestampFormat != "2006-01-02 15:04:05" { - t.Errorf("Expected default timestamp format, got %s", logger.timestampFormat) + // Logger should be created successfully + if logger == nil { + t.Fatal("Expected logger, got nil") } } func TestNew_CustomTimestampFormat(t *testing.T) { - logger, _ := New(Config{ + logger, err := New(Config{ Level: "DEBUG", TimestampFormat: "15:04:05", }) - if logger.timestampFormat != "15:04:05" { - t.Errorf("Expected custom timestamp format, got %s", logger.timestampFormat) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if logger == nil { + t.Fatal("Expected logger, got nil") } } -func TestLogger_ShouldLog(t *testing.T) { - tests := []struct { - loggerLevel string - msgLevel string - shouldLog bool - }{ - {"DEBUG", "DEBUG", true}, - {"DEBUG", "INFO", true}, - {"DEBUG", "WARN", true}, - {"DEBUG", "ERROR", true}, - {"INFO", "DEBUG", false}, - {"INFO", "INFO", true}, - {"INFO", "WARN", true}, - {"INFO", "ERROR", true}, - {"WARN", "DEBUG", false}, - {"WARN", "INFO", false}, - {"WARN", "WARN", true}, - {"WARN", "ERROR", true}, - {"ERROR", "DEBUG", false}, - {"ERROR", "INFO", false}, - {"ERROR", "WARN", false}, - {"ERROR", "ERROR", true}, +func TestNewFromConfig(t *testing.T) { + cfg := config.LoggingConfig{ + Level: "DEBUG", + ConsoleOutput: true, + Format: config.LogFormatConfig{ + Type: "standard", + UseColors: true, + ShowModule: true, + TimestampFormat: "2006-01-02 15:04:05", + }, } - for _, tt := range tests { - t.Run(tt.loggerLevel+"_"+tt.msgLevel, func(t *testing.T) { - logger, _ := New(Config{Level: tt.loggerLevel}) + logger, err := NewFromConfig(cfg) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - if got := logger.shouldLog(tt.msgLevel); got != tt.shouldLog { - t.Errorf("shouldLog(%s) = %v, want %v", tt.msgLevel, got, tt.shouldLog) - } - }) + if logger == nil { + t.Fatal("Expected logger, got nil") } } -func TestLogger_ShouldLog_InvalidLevel(t *testing.T) { - logger, _ := New(Config{Level: "INVALID"}) +func TestNewFromConfig_WithConsole(t *testing.T) { + cfg := config.LoggingConfig{ + Level: "INFO", + ConsoleOutput: true, + Format: config.LogFormatConfig{ + Type: "standard", + UseColors: true, + }, + Console: &config.ConsoleLogConfig{ + Level: "DEBUG", + Format: config.LogFormatConfig{ + Type: "standard", + UseColors: false, + }, + }, + } - // Should default to INFO level - if !logger.shouldLog("INFO") { - t.Error("Invalid level should default to INFO") + logger, err := NewFromConfig(cfg) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if logger == nil { + t.Fatal("Expected logger, got nil") } } @@ -111,34 +126,65 @@ func TestLogger_Error(t *testing.T) { logger.Error("test message", "key", "value") } -func TestFormatValue(t *testing.T) { +func TestLogger_Named(t *testing.T) { + logger, _ := New(Config{Level: "INFO"}) + named := logger.Named("test.module") + + if named == nil { + t.Fatal("Expected named logger, got nil") + } + + if named.name != "test.module" { + t.Errorf("Expected name 'test.module', got %s", named.name) + } + + // Should not panic + named.Info("test from named logger") +} + +func TestLogger_With(t *testing.T) { + logger, _ := New(Config{Level: "INFO"}) + withFields := logger.With("service", "test") + + if withFields == nil { + t.Fatal("Expected logger with fields, got nil") + } + + // Should not panic + withFields.Info("test with fields") +} + +func TestLogger_Sync(t *testing.T) { + logger, _ := New(Config{Level: "INFO"}) + + // Should not panic + err := logger.Sync() + // Sync may return an error for stdout on some systems, ignore it + _ = err +} + +func TestParseLevel(t *testing.T) { tests := []struct { - input interface{} + input string expected string }{ - {"test", "test"}, - {42, "*"}, // int converts to rune - {nil, ""}, + {"DEBUG", "debug"}, + {"INFO", "info"}, + {"WARN", "warn"}, + {"WARNING", "warn"}, + {"ERROR", "error"}, + {"CRITICAL", "fatal"}, + {"FATAL", "fatal"}, + {"invalid", "info"}, // defaults to INFO } for _, tt := range tests { - got := formatValue(tt.input) - // Just check it doesn't panic - _ = got - } -} - -func TestLogger_FormatTime(t *testing.T) { - logger, _ := New(Config{ - Level: "INFO", - TimestampFormat: "2006-01-02", - }) - - result := logger.formatTime() - - // Should be in expected format (YYYY-MM-DD) - if len(result) != 10 { - t.Errorf("Expected date format YYYY-MM-DD, got %s", result) + t.Run(tt.input, func(t *testing.T) { + level := parseLevel(tt.input) + if level.String() != tt.expected { + t.Errorf("parseLevel(%s) = %s, want %s", tt.input, level.String(), tt.expected) + } + }) } } @@ -161,12 +207,3 @@ func BenchmarkLogger_Debug_Filtered(b *testing.B) { logger.Debug("test message", "key", "value") } } - -func BenchmarkLogger_ShouldLog(b *testing.B) { - logger, _ := New(Config{Level: "INFO"}) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - logger.shouldLog("DEBUG") - } -} diff --git a/go/internal/proxy/proxy.go b/go/internal/proxy/proxy.go index 5fb02b1..ddfa1f2 100644 --- a/go/internal/proxy/proxy.go +++ b/go/internal/proxy/proxy.go @@ -29,6 +29,10 @@ type Config struct { // PreserveHost keeps the original Host header PreserveHost bool + + // IgnoreRequestPath ignores the request path and uses only the target path + // This is useful for exact match routes where target URL should be used as-is + IgnoreRequestPath bool } type ReverseProxy struct { @@ -116,6 +120,13 @@ func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, par func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL { targetURL := *rp.targetURL + // If ignoring request path, use target URL path as-is + if rp.config.IgnoreRequestPath { + // Preserve query string only + targetURL.RawQuery = r.URL.RawQuery + return &targetURL + } + // Strip prefix if configured path := r.URL.Path if rp.config.StripPrefix != "" { @@ -125,10 +136,12 @@ func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL { } } - // If target has a path, append request path to it + // If target URL has a non-empty path, combine it with the request path if rp.targetURL.Path != "" && rp.targetURL.Path != "/" { - targetURL.Path = singleJoiningSlash(rp.targetURL.Path, path) + // Combine target path with request path + targetURL.Path = strings.TrimSuffix(rp.targetURL.Path, "/") + path } else { + // No path in target, use request path as-is targetURL.Path = path } diff --git a/go/internal/routing/router.go b/go/internal/routing/router.go index 6a22db4..55f015b 100644 --- a/go/internal/routing/router.go +++ b/go/internal/routing/router.go @@ -2,6 +2,7 @@ package routing import ( + "fmt" "net/http" "os" "path/filepath" @@ -11,6 +12,7 @@ import ( "github.com/konduktor/konduktor/internal/config" "github.com/konduktor/konduktor/internal/logging" + "github.com/konduktor/konduktor/internal/proxy" ) // RouteMatch represents a matched route with captured parameters @@ -55,6 +57,24 @@ func New(cfg *config.Config, logger *logging.Logger) *Router { regexRoutes: make([]*RegexRoute, 0), } + // Load routes from extensions + if cfg != nil { + for _, ext := range cfg.Extensions { + if ext.Type == "routing" && ext.Config != nil { + if locations, ok := ext.Config["regex_locations"].(map[string]interface{}); ok { + for pattern, routeCfg := range locations { + if rc, ok := routeCfg.(map[string]interface{}); ok { + r.AddRoute(pattern, rc) + if logger != nil { + logger.Debug("Added route", "pattern", pattern) + } + } + } + } + } + } + } + r.setupRoutes() return r } @@ -250,17 +270,27 @@ func (r *Router) defaultHandler(w http.ResponseWriter, req *http.Request) { // Try to match against configured routes match := r.Match(path) + fmt.Printf("DEBUG defaultHandler: path=%q match=%v defaultRoute=%v\n", path, match != nil, r.defaultRoute != nil) if match != nil { + fmt.Printf("DEBUG: matched config: %v\n", match.Config) r.handleRouteMatch(w, req, match) return } // Try to serve static file if r.staticDir != "" { - filePath := filepath.Join(r.staticDir, path) + // Get absolute path for static dir + absStaticDir, err := filepath.Abs(r.staticDir) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } - // Prevent directory traversal - if !strings.HasPrefix(filepath.Clean(filePath), filepath.Clean(r.staticDir)) { + filePath := filepath.Join(absStaticDir, filepath.Clean("/"+path)) + cleanPath := filepath.Clean(filePath) + + // Prevent directory traversal - ensure path is within static dir + if !strings.HasPrefix(cleanPath+string(filepath.Separator), absStaticDir+string(filepath.Separator)) { http.Error(w, "Forbidden", http.StatusForbidden) return } @@ -290,6 +320,12 @@ func (r *Router) defaultHandler(w http.ResponseWriter, req *http.Request) { func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, match *RouteMatch) { cfg := match.Config + // Handle proxy_pass directive + if proxyTarget, ok := cfg["proxy_pass"].(string); ok { + r.handleProxyPass(w, req, proxyTarget, cfg, match.Params) + return + } + // Handle "return" directive if ret, ok := cfg["return"].(string); ok { parts := strings.SplitN(ret, " ", 2) @@ -338,7 +374,26 @@ func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, matc } } - filePath := filepath.Join(root, path) + // Get absolute path for root dir + absRoot, err := filepath.Abs(root) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + filePath := filepath.Join(absRoot, filepath.Clean("/"+path)) + cleanPath := filepath.Clean(filePath) + + // DEBUG + fmt.Printf("DEBUG: path=%q absRoot=%q filePath=%q cleanPath=%q\n", path, absRoot, filePath, cleanPath) + fmt.Printf("DEBUG: check1=%q check2=%q\n", cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) + + // Prevent directory traversal + if !strings.HasPrefix(cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) { + fmt.Printf("DEBUG: FORBIDDEN - path not within root\n") + http.Error(w, "Forbidden", http.StatusForbidden) + return + } if cacheControl, ok := cfg["cache_control"].(string); ok { w.Header().Set("Cache-Control", cacheControl) @@ -379,6 +434,48 @@ func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, matc http.NotFound(w, req) } +// handleProxyPass proxies the request to the target backend +func (r *Router) handleProxyPass(w http.ResponseWriter, req *http.Request, target string, cfg map[string]interface{}, params map[string]string) { + // Substitute params in target URL (e.g., {version} -> actual version) + for key, value := range params { + target = strings.ReplaceAll(target, "{"+key+"}", value) + } + + // Create proxy + proxyConfig := &proxy.Config{ + Target: target, + Headers: make(map[string]string), + } + + // Parse headers from config + if headers, ok := cfg["headers"].([]interface{}); ok { + for _, h := range headers { + if header, ok := h.(string); ok { + parts := strings.SplitN(header, ": ", 2) + if len(parts) == 2 { + // Substitute params in header values + headerValue := parts[1] + for key, value := range params { + headerValue = strings.ReplaceAll(headerValue, "{"+key+"}", value) + } + proxyConfig.Headers[parts[0]] = headerValue + } + } + } + } + + p, err := proxy.New(proxyConfig, r.logger) + if err != nil { + if r.logger != nil { + r.logger.Error("Failed to create proxy", "target", target, "error", err) + } + http.Error(w, "Bad Gateway", http.StatusBadGateway) + return + } + + p.ProxyRequest(w, req, params) +} + // CreateRouterFromConfig creates a router from extension config func CreateRouterFromConfig(cfg map[string]interface{}) *Router { router := NewRouter() diff --git a/go/internal/server/server.go b/go/internal/server/server.go index aaf64e7..5bc4317 100644 --- a/go/internal/server/server.go +++ b/go/internal/server/server.go @@ -11,19 +11,19 @@ import ( "time" "github.com/konduktor/konduktor/internal/config" + "github.com/konduktor/konduktor/internal/extension" "github.com/konduktor/konduktor/internal/logging" "github.com/konduktor/konduktor/internal/middleware" - "github.com/konduktor/konduktor/internal/routing" ) -const Version = "0.1.0" +const Version = "0.2.0" // Server represents the Konduktor HTTP server type Server struct { - config *config.Config - httpServer *http.Server - router *routing.Router - logger *logging.Logger + config *config.Config + httpServer *http.Server + extensionManager *extension.Manager + logger *logging.Logger } // New creates a new server instance @@ -37,12 +37,31 @@ func New(cfg *config.Config) (*Server, error) { return nil, fmt.Errorf("failed to create logger: %w", err) } - router := routing.New(cfg, logger) + // Create extension manager + extManager := extension.NewManager(logger) + + // Load extensions from config + for _, extCfg := range cfg.Extensions { + // Add static_dir to routing config if not present + if extCfg.Type == "routing" { + if extCfg.Config == nil { + extCfg.Config = make(map[string]interface{}) + } + if _, ok := extCfg.Config["static_dir"]; !ok { + extCfg.Config["static_dir"] = cfg.HTTP.StaticDir + } + } + + if err := extManager.LoadExtension(extCfg.Type, extCfg.Config); err != nil { + logger.Error("Failed to load extension", "type", extCfg.Type, "error", err) + // Continue loading other extensions + } + } srv := &Server{ - config: cfg, - router: router, - logger: logger, + config: cfg, + extensionManager: extManager, + logger: logger, } return srv, nil @@ -86,9 +105,15 @@ func (s *Server) Run() error { // buildHandler builds the HTTP handler chain func (s *Server) buildHandler() http.Handler { - var handler http.Handler = s.router + // Create base handler that returns 404 + baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + }) - // Add middleware + // Wrap with extension manager + var handler http.Handler = s.extensionManager.Handler(baseHandler) + + // Add middleware (applied in reverse order) handler = middleware.AccessLog(handler, s.logger) handler = middleware.ServerHeader(handler, Version) handler = middleware.Recovery(handler, s.logger) @@ -115,6 +140,9 @@ func (s *Server) waitForShutdown(errChan <-chan error) error { s.logger.Info("Shutting down server...") + // Cleanup extensions + s.extensionManager.Cleanup() + if err := s.httpServer.Shutdown(ctx); err != nil { s.logger.Error("Error during shutdown", "error", err) return err diff --git a/go/tests/integration/README.md b/go/tests/integration/README.md new file mode 100644 index 0000000..861806a --- /dev/null +++ b/go/tests/integration/README.md @@ -0,0 +1,127 @@ +# Integration Tests + +Интеграционные тесты для Konduktor — полноценное тестирование сервера с реальными HTTP запросами. + +## Отличие от unit-тестов + +| Аспект | Unit-тесты | Интеграционные тесты | +|--------|------------|---------------------| +| Scope | Отдельный модуль в изоляции | Весь сервер целиком | +| Backend | Mock (httptest.Server) | Реальные HTTP серверы | +| Config | Программный | YAML конфигурация | +| Extensions | Не тестируются | Полная цепочка обработки | + +## Структура тестов + +``` +tests/integration/ +├── README.md # Эта документация +├── helpers_test.go # Общие хелперы и утилиты +├── reverse_proxy_test.go # Тесты reverse proxy +├── routing_test.go # Тесты маршрутизации (TODO) +├── security_test.go # Тесты security extension (TODO) +├── caching_test.go # Тесты caching extension (TODO) +└── static_files_test.go # Тесты статических файлов (TODO) +``` + +## Что тестируют интеграционные тесты + +### 1. Reverse Proxy (`reverse_proxy_test.go`) + +- [ ] Базовое проксирование GET/POST/PUT/DELETE +- [ ] Exact match routes (`=/api/version`) +- [ ] Regex routes с параметрами (`~^/api/resource/(?P\d+)$`) +- [ ] Подстановка параметров в target URL (`{id}`, `{tag}`) +- [ ] Подстановка переменных в заголовки (`$remote_addr`) +- [ ] Передача заголовков X-Forwarded-For, X-Real-IP +- [ ] Сохранение query string +- [ ] Обработка ошибок backend (502, 504) +- [ ] Таймауты соединения + +### 2. Routing Extension (`routing_test.go`) + +- [ ] Приоритет маршрутов (exact > regex > default) +- [ ] Case-sensitive regex (`~`) +- [ ] Case-insensitive regex (`~*`) +- [ ] Default route (`__default__`) +- [ ] Return directive (`return 200 "OK"`) +- [ ] Конфликт маршрутов + +### 3. Security Extension (`security_test.go`) + +- [ ] IP whitelist +- [ ] IP blacklist +- [ ] CIDR нотация (10.0.0.0/8) +- [ ] Security headers (X-Frame-Options, X-Content-Type-Options) +- [ ] Rate limiting +- [ ] Комбинация с другими extensions + +### 4. Caching Extension (`caching_test.go`) + +- [ ] Cache hit/miss +- [ ] TTL expiration +- [ ] Pattern-based caching +- [ ] Cache-Control headers +- [ ] Cache invalidation +- [ ] Max cache size и eviction + +### 5. Static Files (`static_files_test.go`) + +- [ ] Serving статических файлов +- [ ] Index file (index.html) +- [ ] MIME types +- [ ] Cache-Control для static +- [ ] SPA fallback +- [ ] Directory traversal protection +- [ ] 404 для несуществующих файлов + +### 6. Extension Chain (`extension_chain_test.go`) + +- [ ] Порядок выполнения extensions (security → caching → routing) +- [ ] Прерывание цепочки при ошибке +- [ ] Совместная работа extensions + +## Запуск тестов + +```bash +# Все интеграционные тесты +go test ./tests/integration/... -v + +# Конкретный файл +go test ./tests/integration/... -v -run TestReverseProxy + +# С таймаутом (интеграционные тесты медленнее) +go test ./tests/integration/... -v -timeout 60s + +# С покрытием +go test ./tests/integration/... -v -coverprofile=coverage.out +``` + +## Требования + +- Свободные порты: тесты используют случайные порты (`:0`) +- Сетевой доступ: для localhost соединений +- Время: интеграционные тесты занимают больше времени (~5-10 сек) + +## Добавление новых тестов + +1. Создайте файл `*_test.go` в `tests/integration/` +2. Используйте хелперы из `helpers_test.go`: + - `startTestServer()` — запуск Konduktor сервера + - `startBackend()` — запуск mock backend + - `makeRequest()` — отправка HTTP запроса +3. Добавьте описание в этот README + +## CI/CD + +Интеграционные тесты запускаются отдельно от unit-тестов: + +```yaml +# .github/workflows/test.yml +jobs: + unit-tests: + run: go test ./internal/... + + integration-tests: + run: go test ./tests/integration/... -timeout 120s +``` diff --git a/go/tests/integration/helpers_test.go b/go/tests/integration/helpers_test.go new file mode 100644 index 0000000..4779c6e --- /dev/null +++ b/go/tests/integration/helpers_test.go @@ -0,0 +1,408 @@ +package integration + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/konduktor/konduktor/internal/extension" + "github.com/konduktor/konduktor/internal/logging" + "github.com/konduktor/konduktor/internal/middleware" +) + +// TestServer represents a running Konduktor server for testing +type TestServer struct { + Server *http.Server + URL string + Port int + listener net.Listener + handler http.Handler + t *testing.T +} + +// TestBackend represents a mock backend server +type TestBackend struct { + server *httptest.Server + requestLog []RequestLogEntry + mu sync.Mutex + requestCount int64 + handler http.HandlerFunc +} + +// RequestLogEntry stores information about a received request +type RequestLogEntry struct { + Method string + Path string + Query string + Headers http.Header + Body string + Timestamp time.Time +} + +// ServerConfig holds configuration for starting a test server +type ServerConfig struct { + Extensions []extension.Extension + StaticDir string + Middleware []func(http.Handler) http.Handler +} + +// ============== Test Server ============== + +// StartTestServer creates and starts a Konduktor server for testing +func StartTestServer(t *testing.T, cfg *ServerConfig) *TestServer { + t.Helper() + + logger, err := logging.New(logging.Config{ + Level: "DEBUG", + }) + if err != nil { + t.Fatalf("Failed to create logger: %v", err) + } + + // Create extension manager + extManager := extension.NewManager(logger) + + // Add extensions if provided + if cfg != nil && len(cfg.Extensions) > 0 { + for _, ext := range cfg.Extensions { + if err := extManager.AddExtension(ext); err != nil { + t.Fatalf("Failed to add extension: %v", err) + } + } + } + + // Create a fallback handler for when no extension handles the request + fallback := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + }) + + // Create handler chain + var handler http.Handler = extManager.Handler(fallback) + + // Add middleware + handler = middleware.AccessLog(handler, logger) + handler = middleware.Recovery(handler, logger) + + // Add custom middleware if provided + if cfg != nil { + for _, mw := range cfg.Middleware { + handler = mw(handler) + } + } + + // Find available port + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to find available port: %v", err) + } + + port := listener.Addr().(*net.TCPAddr).Port + + server := &http.Server{ + Handler: handler, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + + ts := &TestServer{ + Server: server, + URL: fmt.Sprintf("http://127.0.0.1:%d", port), + Port: port, + listener: listener, + handler: handler, + t: t, + } + + // Start server in goroutine + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + // Don't fail test here as server might be intentionally closed + } + }() + + // Wait for server to be ready + ts.waitReady() + + return ts +} + +// waitReady waits for the server to be ready to accept connections +func (ts *TestServer) waitReady() { + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", ts.Port), 100*time.Millisecond) + if err == nil { + conn.Close() + return + } + time.Sleep(10 * time.Millisecond) + } + ts.t.Fatal("Server failed to start within timeout") +} + +// Close shuts down the test server +func (ts *TestServer) Close() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ts.Server.Shutdown(ctx) +} + +// ============== Test Backend ============== + +// StartBackend creates and starts a mock backend server +func StartBackend(handler http.HandlerFunc) *TestBackend { + tb := &TestBackend{ + requestLog: make([]RequestLogEntry, 0), + handler: handler, + } + + if handler == nil { + handler = tb.defaultHandler + } + + tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tb.logRequest(r) + handler(w, r) + })) + + return tb +} + +func (tb *TestBackend) logRequest(r *http.Request) { + tb.mu.Lock() + defer tb.mu.Unlock() + + body, _ := io.ReadAll(r.Body) + r.Body = io.NopCloser(bytes.NewReader(body)) + + tb.requestLog = append(tb.requestLog, RequestLogEntry{ + Method: r.Method, + Path: r.URL.Path, + Query: r.URL.RawQuery, + Headers: r.Header.Clone(), + Body: string(body), + Timestamp: time.Now(), + }) + atomic.AddInt64(&tb.requestCount, 1) +} + +func (tb *TestBackend) defaultHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "backend": "default", + "path": r.URL.Path, + "method": r.Method, + "query": r.URL.RawQuery, + "received": time.Now().Unix(), + }) +} + +// URL returns the backend server URL +func (tb *TestBackend) URL() string { + return tb.server.URL +} + +// Close shuts down the backend server +func (tb *TestBackend) Close() { + tb.server.Close() +} + +// RequestCount returns the number of requests received +func (tb *TestBackend) RequestCount() int64 { + return atomic.LoadInt64(&tb.requestCount) +} + +// LastRequest returns the most recent request +func (tb *TestBackend) LastRequest() *RequestLogEntry { + tb.mu.Lock() + defer tb.mu.Unlock() + if len(tb.requestLog) == 0 { + return nil + } + return &tb.requestLog[len(tb.requestLog)-1] +} + +// AllRequests returns all logged requests +func (tb *TestBackend) AllRequests() []RequestLogEntry { + tb.mu.Lock() + defer tb.mu.Unlock() + result := make([]RequestLogEntry, len(tb.requestLog)) + copy(result, tb.requestLog) + return result +} + +// ============== HTTP Client Helpers ============== + +// HTTPClient is a configured HTTP client for testing +type HTTPClient struct { + client *http.Client + baseURL string +} + +// NewHTTPClient creates a new test HTTP client +func NewHTTPClient(baseURL string) *HTTPClient { + return &HTTPClient{ + client: &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse // Don't follow redirects + }, + }, + baseURL: baseURL, + } +} + +// Get performs a GET request +func (c *HTTPClient) Get(path string, headers map[string]string) (*http.Response, error) { + return c.Do("GET", path, nil, headers) +} + +// Post performs a POST request +func (c *HTTPClient) Post(path string, body []byte, headers map[string]string) (*http.Response, error) { + return c.Do("POST", path, body, headers) +} + +// Do performs an HTTP request +func (c *HTTPClient) Do(method, path string, body []byte, headers map[string]string) (*http.Response, error) { + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + + req, err := http.NewRequest(method, c.baseURL+path, bodyReader) + if err != nil { + return nil, err + } + + for k, v := range headers { + req.Header.Set(k, v) + } + + return c.client.Do(req) +} + +// GetJSON performs GET and decodes JSON response +func (c *HTTPClient) GetJSON(path string, result interface{}) (*http.Response, error) { + resp, err := c.Get(path, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if err := json.NewDecoder(resp.Body).Decode(result); err != nil { + return resp, err + } + + return resp, nil +} + +// ============== File System Helpers ============== + +// CreateTempDir creates a temporary directory for static files +func CreateTempDir(t *testing.T) string { + t.Helper() + dir, err := os.MkdirTemp("", "konduktor-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + t.Cleanup(func() { os.RemoveAll(dir) }) + return dir +} + +// CreateTempFile creates a temporary file with given content +func CreateTempFile(t *testing.T, dir, name, content string) string { + t.Helper() + path := filepath.Join(dir, name) + + // Create parent directories if needed + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + t.Fatalf("Failed to create directories: %v", err) + } + + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("Failed to write file: %v", err) + } + return path +} + +// ============== Assertion Helpers ============== + +// AssertStatus checks if response has expected status code +func AssertStatus(t *testing.T, resp *http.Response, expected int) { + t.Helper() + if resp.StatusCode != expected { + body, _ := io.ReadAll(resp.Body) + t.Errorf("Expected status %d, got %d. Body: %s", expected, resp.StatusCode, string(body)) + } +} + +// AssertHeader checks if response has expected header value +func AssertHeader(t *testing.T, resp *http.Response, header, expected string) { + t.Helper() + actual := resp.Header.Get(header) + if actual != expected { + t.Errorf("Expected header %s=%q, got %q", header, expected, actual) + } +} + +// AssertHeaderContains checks if header contains substring +func AssertHeaderContains(t *testing.T, resp *http.Response, header, substring string) { + t.Helper() + actual := resp.Header.Get(header) + if actual == "" || !contains(actual, substring) { + t.Errorf("Expected header %s to contain %q, got %q", header, substring, actual) + } +} + +// AssertJSONField checks if JSON response has expected field value +func AssertJSONField(t *testing.T, body []byte, field string, expected interface{}) { + t.Helper() + var data map[string]interface{} + if err := json.Unmarshal(body, &data); err != nil { + t.Fatalf("Failed to parse JSON: %v", err) + } + + actual, ok := data[field] + if !ok { + t.Errorf("Field %q not found in JSON", field) + return + } + + if actual != expected { + t.Errorf("Expected %s=%v, got %v", field, expected, actual) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0)) +} + +func containsAt(s, substr string, start int) bool { + for i := start; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// ReadBody reads and returns response body +func ReadBody(t *testing.T, resp *http.Response) []byte { + t.Helper() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read body: %v", err) + } + return body +} diff --git a/go/tests/integration/reverse_proxy_test.go b/go/tests/integration/reverse_proxy_test.go new file mode 100644 index 0000000..c96c573 --- /dev/null +++ b/go/tests/integration/reverse_proxy_test.go @@ -0,0 +1,562 @@ +package integration + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/konduktor/konduktor/internal/extension" + "github.com/konduktor/konduktor/internal/logging" +) + +// createTestLogger creates a logger for tests +func createTestLogger(t *testing.T) *logging.Logger { + t.Helper() + logger, err := logging.New(logging.Config{Level: "DEBUG"}) + if err != nil { + t.Fatalf("Failed to create logger: %v", err) + } + return logger +} + +// ============== Basic Reverse Proxy Tests ============== + +func TestReverseProxy_BasicGET(t *testing.T) { + // Start backend server + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "message": "Hello from backend", + "path": r.URL.Path, + "method": r.Method, + }) + }) + defer backend.Close() + + // Create routing extension with proxy to backend + logger := createTestLogger(t) + routingExt, err := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + if err != nil { + t.Fatalf("Failed to create routing extension: %v", err) + } + + // Start Konduktor server + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + // Make request through Konduktor + client := NewHTTPClient(server.URL) + resp, err := client.Get("/api/test", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + // Verify response + AssertStatus(t, resp, http.StatusOK) + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if result["message"] != "Hello from backend" { + t.Errorf("Unexpected message: %v", result["message"]) + } + + if result["path"] != "/api/test" { + t.Errorf("Expected path /api/test, got %v", result["path"]) + } + + // Verify backend received request + if backend.RequestCount() != 1 { + t.Errorf("Expected 1 backend request, got %d", backend.RequestCount()) + } +} + +func TestReverseProxy_POST(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "received": body, + "method": r.Method, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + body := []byte(`{"name":"test","value":123}`) + resp, err := client.Post("/api/data", body, map[string]string{ + "Content-Type": "application/json", + }) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + AssertStatus(t, resp, http.StatusOK) + + var result map[string]interface{} + json.NewDecoder(resp.Body).Decode(&result) + + if result["method"] != "POST" { + t.Errorf("Expected method POST, got %v", result["method"]) + } + + received := result["received"].(map[string]interface{}) + if received["name"] != "test" { + t.Errorf("Expected name 'test', got %v", received["name"]) + } +} + +// ============== Exact Match Routes ============== + +func TestReverseProxy_ExactMatchRoute(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "endpoint": "version", + "path": r.URL.Path, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + // Exact match - should use backend URL as-is + "=/api/version": map[string]interface{}{ + "proxy_pass": backend.URL() + "/releases/latest", + }, + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Test exact match route + resp, err := client.Get("/api/version", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + AssertStatus(t, resp, http.StatusOK) + + lastReq := backend.LastRequest() + if lastReq == nil { + t.Fatal("No request received by backend") + } + + // For exact match, the target path should be used as-is (IgnoreRequestPath=true) + if lastReq.Path != "/releases/latest" { + t.Errorf("Expected backend path /releases/latest, got %s", lastReq.Path) + } +} + +// ============== Regex Routes with Parameters ============== + +func TestReverseProxy_RegexRouteWithParams(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "path": r.URL.Path, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + // Regex with named group + "~^/api/users/(?P\\d+)$": map[string]interface{}{ + "proxy_pass": backend.URL() + "/v2/users/{id}", + }, + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Test regex route with parameter + resp, err := client.Get("/api/users/42", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + AssertStatus(t, resp, http.StatusOK) + + lastReq := backend.LastRequest() + if lastReq == nil { + t.Fatal("No request received by backend") + } + + // Parameter {id} should be substituted + if lastReq.Path != "/v2/users/42" { + t.Errorf("Expected backend path /v2/users/42, got %s", lastReq.Path) + } +} + +// ============== Header Forwarding ============== + +func TestReverseProxy_HeaderForwarding(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "x-forwarded-for": r.Header.Get("X-Forwarded-For"), + "x-real-ip": r.Header.Get("X-Real-IP"), + "x-custom": r.Header.Get("X-Custom"), + "x-forwarded-host": r.Header.Get("X-Forwarded-Host"), + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + "headers": []interface{}{ + "X-Forwarded-For: $remote_addr", + "X-Real-IP: $remote_addr", + }, + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + resp, err := client.Get("/test", map[string]string{ + "X-Custom": "custom-value", + }) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + var result map[string]string + json.NewDecoder(resp.Body).Decode(&result) + + // X-Custom should be forwarded + if result["x-custom"] != "custom-value" { + t.Errorf("Expected X-Custom header to be forwarded, got %v", result["x-custom"]) + } + + // X-Forwarded-For should be set (will contain 127.0.0.1) + if result["x-forwarded-for"] == "" { + t.Error("Expected X-Forwarded-For header to be set") + } +} + +// ============== Query String ============== + +func TestReverseProxy_QueryStringPreservation(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "query": r.URL.RawQuery, + "foo": r.URL.Query().Get("foo"), + "bar": r.URL.Query().Get("bar"), + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + resp, err := client.Get("/search?foo=hello&bar=world", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + var result map[string]string + json.NewDecoder(resp.Body).Decode(&result) + + if result["foo"] != "hello" { + t.Errorf("Expected foo=hello, got %v", result["foo"]) + } + + if result["bar"] != "world" { + t.Errorf("Expected bar=world, got %v", result["bar"]) + } +} + +// ============== Error Handling ============== + +func TestReverseProxy_BackendUnavailable(t *testing.T) { + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + // Non-existent backend + "proxy_pass": "http://127.0.0.1:59999", + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + resp, err := client.Get("/test", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + // Should return 502 Bad Gateway + AssertStatus(t, resp, http.StatusBadGateway) +} + +func TestReverseProxy_BackendTimeout(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + // Simulate slow backend + time.Sleep(3 * time.Second) + w.Write([]byte("OK")) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + "timeout": 0.5, // 500ms timeout + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + resp, err := client.Get("/slow", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + // Should return 504 Gateway Timeout + AssertStatus(t, resp, http.StatusGatewayTimeout) +} + +// ============== HTTP Methods ============== + +func TestReverseProxy_AllMethods(t *testing.T) { + methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"} + + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "method": r.Method, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + for _, method := range methods { + t.Run(method, func(t *testing.T) { + resp, err := client.Do(method, "/resource", nil, nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + AssertStatus(t, resp, http.StatusOK) + + if method != "HEAD" { + var result map[string]string + json.NewDecoder(resp.Body).Decode(&result) + + if result["method"] != method { + t.Errorf("Expected method %s, got %v", method, result["method"]) + } + } + }) + } +} + +// ============== Large Bodies ============== + +func TestReverseProxy_LargeRequestBody(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.NewEncoder(w).Encode(map[string]int{ + "received": len(body), + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // 1MB body + largeBody := []byte(strings.Repeat("x", 1024*1024)) + resp, err := client.Post("/upload", largeBody, map[string]string{ + "Content-Type": "application/octet-stream", + }) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + AssertStatus(t, resp, http.StatusOK) +} + +// ============== Concurrent Requests ============== + +func TestReverseProxy_ConcurrentRequests(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + // Small delay to simulate work + time.Sleep(10 * time.Millisecond) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + const numRequests = 50 + results := make(chan error, numRequests) + + for i := 0; i < numRequests; i++ { + go func(n int) { + client := NewHTTPClient(server.URL) + resp, err := client.Get(fmt.Sprintf("/concurrent/%d", n), nil) + if err != nil { + results <- err + return + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + results <- fmt.Errorf("unexpected status: %d", resp.StatusCode) + return + } + results <- nil + }(i) + } + + // Collect results + var errors []error + for i := 0; i < numRequests; i++ { + if err := <-results; err != nil { + errors = append(errors, err) + } + } + + if len(errors) > 0 { + t.Errorf("Got %d errors in concurrent requests: %v", len(errors), errors[:min(5, len(errors))]) + } + + // Verify all requests reached backend + if backend.RequestCount() != numRequests { + t.Errorf("Expected %d backend requests, got %d", numRequests, backend.RequestCount()) + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +}