package extension import ( "context" "net" "net/http" "strings" "sync" "time" "github.com/konduktor/konduktor/internal/logging" ) // SecurityExtension provides security features like IP filtering and security headers type SecurityExtension struct { BaseExtension allowedIPs map[string]bool blockedIPs map[string]bool allowedCIDRs []*net.IPNet blockedCIDRs []*net.IPNet securityHeaders map[string]string // Rate limiting rateLimitEnabled bool rateLimitRequests int rateLimitWindow time.Duration rateLimitByIP map[string]*rateLimitEntry rateLimitMu sync.RWMutex } type rateLimitEntry struct { count int resetTime time.Time } // SecurityConfig holds security extension configuration type SecurityConfig struct { AllowedIPs []string `yaml:"allowed_ips"` BlockedIPs []string `yaml:"blocked_ips"` SecurityHeaders map[string]string `yaml:"security_headers"` RateLimit *RateLimitConfig `yaml:"rate_limit"` } // RateLimitConfig holds rate limiting configuration type RateLimitConfig struct { Enabled bool `yaml:"enabled"` Requests int `yaml:"requests"` Window string `yaml:"window"` // e.g., "1m", "1h" } // Default security headers var defaultSecurityHeaders = map[string]string{ "X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", "X-XSS-Protection": "1; mode=block", "Referrer-Policy": "strict-origin-when-cross-origin", } // NewSecurityExtension creates a new security extension func NewSecurityExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) { ext := &SecurityExtension{ BaseExtension: NewBaseExtension("security", 10, logger), // High priority (early execution) allowedIPs: make(map[string]bool), blockedIPs: make(map[string]bool), allowedCIDRs: make([]*net.IPNet, 0), blockedCIDRs: make([]*net.IPNet, 0), securityHeaders: make(map[string]string), rateLimitByIP: make(map[string]*rateLimitEntry), } // Copy default security headers for k, v := range defaultSecurityHeaders { ext.securityHeaders[k] = v } // Parse allowed_ips if allowedIPs, ok := config["allowed_ips"].([]interface{}); ok { for _, ip := range allowedIPs { if ipStr, ok := ip.(string); ok { ext.addAllowedIP(ipStr) } } } // Parse blocked_ips if blockedIPs, ok := config["blocked_ips"].([]interface{}); ok { for _, ip := range blockedIPs { if ipStr, ok := ip.(string); ok { ext.addBlockedIP(ipStr) } } } // Parse security_headers if headers, ok := config["security_headers"].(map[string]interface{}); ok { for k, v := range headers { if vStr, ok := v.(string); ok { ext.securityHeaders[k] = vStr } } } // Parse rate_limit if rateLimit, ok := config["rate_limit"].(map[string]interface{}); ok { if enabled, ok := rateLimit["enabled"].(bool); ok && enabled { ext.rateLimitEnabled = true if requests, ok := rateLimit["requests"].(int); ok { ext.rateLimitRequests = requests } else if requestsFloat, ok := rateLimit["requests"].(float64); ok { ext.rateLimitRequests = int(requestsFloat) } else { ext.rateLimitRequests = 100 // default } if window, ok := rateLimit["window"].(string); ok { if duration, err := time.ParseDuration(window); err == nil { ext.rateLimitWindow = duration } else { ext.rateLimitWindow = time.Minute // default } } else { ext.rateLimitWindow = time.Minute } logger.Info("Rate limiting enabled", "requests", ext.rateLimitRequests, "window", ext.rateLimitWindow.String()) } } return ext, nil } func (e *SecurityExtension) addAllowedIP(ip string) { if strings.Contains(ip, "/") { // CIDR notation _, cidr, err := net.ParseCIDR(ip) if err == nil { e.allowedCIDRs = append(e.allowedCIDRs, cidr) } } else { e.allowedIPs[ip] = true } } func (e *SecurityExtension) addBlockedIP(ip string) { if strings.Contains(ip, "/") { // CIDR notation _, cidr, err := net.ParseCIDR(ip) if err == nil { e.blockedCIDRs = append(e.blockedCIDRs, cidr) } } else { e.blockedIPs[ip] = true } } // ProcessRequest checks security rules func (e *SecurityExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) { clientIP := getClientIP(r) parsedIP := net.ParseIP(clientIP) // Check blocked IPs first if e.isBlocked(clientIP, parsedIP) { e.logger.Warn("Blocked request from IP", "ip", clientIP) http.Error(w, "403 Forbidden", http.StatusForbidden) return true, nil } // Check allowed IPs (if configured, only these IPs are allowed) if len(e.allowedIPs) > 0 || len(e.allowedCIDRs) > 0 { if !e.isAllowed(clientIP, parsedIP) { e.logger.Warn("Access denied for IP", "ip", clientIP) http.Error(w, "403 Forbidden", http.StatusForbidden) return true, nil } } // Check rate limit if e.rateLimitEnabled { if !e.checkRateLimit(clientIP) { e.logger.Warn("Rate limit exceeded", "ip", clientIP) w.Header().Set("Retry-After", "60") http.Error(w, "429 Too Many Requests", http.StatusTooManyRequests) return true, nil } } return false, nil } // ProcessResponse adds security headers to the response func (e *SecurityExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) { for header, value := range e.securityHeaders { w.Header().Set(header, value) } } func (e *SecurityExtension) isBlocked(ip string, parsedIP net.IP) bool { // Check exact match if e.blockedIPs[ip] { return true } // Check CIDR ranges if parsedIP != nil { for _, cidr := range e.blockedCIDRs { if cidr.Contains(parsedIP) { return true } } } return false } func (e *SecurityExtension) isAllowed(ip string, parsedIP net.IP) bool { // Check exact match if e.allowedIPs[ip] { return true } // Check CIDR ranges if parsedIP != nil { for _, cidr := range e.allowedCIDRs { if cidr.Contains(parsedIP) { return true } } } return false } func (e *SecurityExtension) checkRateLimit(ip string) bool { e.rateLimitMu.Lock() defer e.rateLimitMu.Unlock() now := time.Now() entry, exists := e.rateLimitByIP[ip] if !exists || now.After(entry.resetTime) { // Create new entry or reset expired one e.rateLimitByIP[ip] = &rateLimitEntry{ count: 1, resetTime: now.Add(e.rateLimitWindow), } return true } // Increment counter entry.count++ return entry.count <= e.rateLimitRequests } // AddBlockedIP adds an IP to the blocked list at runtime func (e *SecurityExtension) AddBlockedIP(ip string) { e.addBlockedIP(ip) } // RemoveBlockedIP removes an IP from the blocked list func (e *SecurityExtension) RemoveBlockedIP(ip string) { delete(e.blockedIPs, ip) } // AddAllowedIP adds an IP to the allowed list at runtime func (e *SecurityExtension) AddAllowedIP(ip string) { e.addAllowedIP(ip) } // RemoveAllowedIP removes an IP from the allowed list func (e *SecurityExtension) RemoveAllowedIP(ip string) { delete(e.allowedIPs, ip) } // SetSecurityHeader sets or updates a security header func (e *SecurityExtension) SetSecurityHeader(name, value string) { e.securityHeaders[name] = value } // GetMetrics returns security metrics func (e *SecurityExtension) GetMetrics() map[string]interface{} { e.rateLimitMu.RLock() activeRateLimits := len(e.rateLimitByIP) e.rateLimitMu.RUnlock() return map[string]interface{}{ "allowed_ips": len(e.allowedIPs), "allowed_cidrs": len(e.allowedCIDRs), "blocked_ips": len(e.blockedIPs), "blocked_cidrs": len(e.blockedCIDRs), "security_headers": len(e.securityHeaders), "rate_limit_enabled": e.rateLimitEnabled, "active_rate_limits": activeRateLimits, } } // Cleanup cleans up rate limit entries periodically func (e *SecurityExtension) Cleanup() error { e.rateLimitMu.Lock() defer e.rateLimitMu.Unlock() now := time.Now() for ip, entry := range e.rateLimitByIP { if now.After(entry.resetTime) { delete(e.rateLimitByIP, ip) } } return nil }