Илья Глазунов 881028c1e6 feat: Add reverse proxy functionality with enhanced routing capabilities
- Introduced IgnoreRequestPath option in proxy configuration to allow exact match routing.
- Implemented proxy_pass directive in routing extension to handle backend requests.
- Enhanced error handling for backend unavailability and timeouts.
- Added integration tests for reverse proxy, including basic requests, exact match routes, regex routes, header forwarding, and query string preservation.
- Created helper functions for setting up test servers and backends, along with assertion utilities for response validation.
- Updated server initialization to support extension management and middleware chaining.
- Improved logging for debugging purposes during request handling.
2025-12-12 00:38:30 +03:00

313 lines
7.9 KiB
Go

package extension
import (
"context"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/konduktor/konduktor/internal/logging"
)
// SecurityExtension provides security features like IP filtering and security headers
type SecurityExtension struct {
BaseExtension
allowedIPs map[string]bool
blockedIPs map[string]bool
allowedCIDRs []*net.IPNet
blockedCIDRs []*net.IPNet
securityHeaders map[string]string
// Rate limiting
rateLimitEnabled bool
rateLimitRequests int
rateLimitWindow time.Duration
rateLimitByIP map[string]*rateLimitEntry
rateLimitMu sync.RWMutex
}
type rateLimitEntry struct {
count int
resetTime time.Time
}
// SecurityConfig holds security extension configuration
type SecurityConfig struct {
AllowedIPs []string `yaml:"allowed_ips"`
BlockedIPs []string `yaml:"blocked_ips"`
SecurityHeaders map[string]string `yaml:"security_headers"`
RateLimit *RateLimitConfig `yaml:"rate_limit"`
}
// RateLimitConfig holds rate limiting configuration
type RateLimitConfig struct {
Enabled bool `yaml:"enabled"`
Requests int `yaml:"requests"`
Window string `yaml:"window"` // e.g., "1m", "1h"
}
// Default security headers
var defaultSecurityHeaders = map[string]string{
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Referrer-Policy": "strict-origin-when-cross-origin",
}
// NewSecurityExtension creates a new security extension
func NewSecurityExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
ext := &SecurityExtension{
BaseExtension: NewBaseExtension("security", 10, logger), // High priority (early execution)
allowedIPs: make(map[string]bool),
blockedIPs: make(map[string]bool),
allowedCIDRs: make([]*net.IPNet, 0),
blockedCIDRs: make([]*net.IPNet, 0),
securityHeaders: make(map[string]string),
rateLimitByIP: make(map[string]*rateLimitEntry),
}
// Copy default security headers
for k, v := range defaultSecurityHeaders {
ext.securityHeaders[k] = v
}
// Parse allowed_ips
if allowedIPs, ok := config["allowed_ips"].([]interface{}); ok {
for _, ip := range allowedIPs {
if ipStr, ok := ip.(string); ok {
ext.addAllowedIP(ipStr)
}
}
}
// Parse blocked_ips
if blockedIPs, ok := config["blocked_ips"].([]interface{}); ok {
for _, ip := range blockedIPs {
if ipStr, ok := ip.(string); ok {
ext.addBlockedIP(ipStr)
}
}
}
// Parse security_headers
if headers, ok := config["security_headers"].(map[string]interface{}); ok {
for k, v := range headers {
if vStr, ok := v.(string); ok {
ext.securityHeaders[k] = vStr
}
}
}
// Parse rate_limit
if rateLimit, ok := config["rate_limit"].(map[string]interface{}); ok {
if enabled, ok := rateLimit["enabled"].(bool); ok && enabled {
ext.rateLimitEnabled = true
if requests, ok := rateLimit["requests"].(int); ok {
ext.rateLimitRequests = requests
} else if requestsFloat, ok := rateLimit["requests"].(float64); ok {
ext.rateLimitRequests = int(requestsFloat)
} else {
ext.rateLimitRequests = 100 // default
}
if window, ok := rateLimit["window"].(string); ok {
if duration, err := time.ParseDuration(window); err == nil {
ext.rateLimitWindow = duration
} else {
ext.rateLimitWindow = time.Minute // default
}
} else {
ext.rateLimitWindow = time.Minute
}
logger.Info("Rate limiting enabled",
"requests", ext.rateLimitRequests,
"window", ext.rateLimitWindow.String())
}
}
return ext, nil
}
func (e *SecurityExtension) addAllowedIP(ip string) {
if strings.Contains(ip, "/") {
// CIDR notation
_, cidr, err := net.ParseCIDR(ip)
if err == nil {
e.allowedCIDRs = append(e.allowedCIDRs, cidr)
}
} else {
e.allowedIPs[ip] = true
}
}
func (e *SecurityExtension) addBlockedIP(ip string) {
if strings.Contains(ip, "/") {
// CIDR notation
_, cidr, err := net.ParseCIDR(ip)
if err == nil {
e.blockedCIDRs = append(e.blockedCIDRs, cidr)
}
} else {
e.blockedIPs[ip] = true
}
}
// ProcessRequest checks security rules
func (e *SecurityExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
clientIP := getClientIP(r)
parsedIP := net.ParseIP(clientIP)
// Check blocked IPs first
if e.isBlocked(clientIP, parsedIP) {
e.logger.Warn("Blocked request from IP", "ip", clientIP)
http.Error(w, "403 Forbidden", http.StatusForbidden)
return true, nil
}
// Check allowed IPs (if configured, only these IPs are allowed)
if len(e.allowedIPs) > 0 || len(e.allowedCIDRs) > 0 {
if !e.isAllowed(clientIP, parsedIP) {
e.logger.Warn("Access denied for IP", "ip", clientIP)
http.Error(w, "403 Forbidden", http.StatusForbidden)
return true, nil
}
}
// Check rate limit
if e.rateLimitEnabled {
if !e.checkRateLimit(clientIP) {
e.logger.Warn("Rate limit exceeded", "ip", clientIP)
w.Header().Set("Retry-After", "60")
http.Error(w, "429 Too Many Requests", http.StatusTooManyRequests)
return true, nil
}
}
return false, nil
}
// ProcessResponse adds security headers to the response
func (e *SecurityExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
for header, value := range e.securityHeaders {
w.Header().Set(header, value)
}
}
func (e *SecurityExtension) isBlocked(ip string, parsedIP net.IP) bool {
// Check exact match
if e.blockedIPs[ip] {
return true
}
// Check CIDR ranges
if parsedIP != nil {
for _, cidr := range e.blockedCIDRs {
if cidr.Contains(parsedIP) {
return true
}
}
}
return false
}
func (e *SecurityExtension) isAllowed(ip string, parsedIP net.IP) bool {
// Check exact match
if e.allowedIPs[ip] {
return true
}
// Check CIDR ranges
if parsedIP != nil {
for _, cidr := range e.allowedCIDRs {
if cidr.Contains(parsedIP) {
return true
}
}
}
return false
}
func (e *SecurityExtension) checkRateLimit(ip string) bool {
e.rateLimitMu.Lock()
defer e.rateLimitMu.Unlock()
now := time.Now()
entry, exists := e.rateLimitByIP[ip]
if !exists || now.After(entry.resetTime) {
// Create new entry or reset expired one
e.rateLimitByIP[ip] = &rateLimitEntry{
count: 1,
resetTime: now.Add(e.rateLimitWindow),
}
return true
}
// Increment counter
entry.count++
return entry.count <= e.rateLimitRequests
}
// AddBlockedIP adds an IP to the blocked list at runtime
func (e *SecurityExtension) AddBlockedIP(ip string) {
e.addBlockedIP(ip)
}
// RemoveBlockedIP removes an IP from the blocked list
func (e *SecurityExtension) RemoveBlockedIP(ip string) {
delete(e.blockedIPs, ip)
}
// AddAllowedIP adds an IP to the allowed list at runtime
func (e *SecurityExtension) AddAllowedIP(ip string) {
e.addAllowedIP(ip)
}
// RemoveAllowedIP removes an IP from the allowed list
func (e *SecurityExtension) RemoveAllowedIP(ip string) {
delete(e.allowedIPs, ip)
}
// SetSecurityHeader sets or updates a security header
func (e *SecurityExtension) SetSecurityHeader(name, value string) {
e.securityHeaders[name] = value
}
// GetMetrics returns security metrics
func (e *SecurityExtension) GetMetrics() map[string]interface{} {
e.rateLimitMu.RLock()
activeRateLimits := len(e.rateLimitByIP)
e.rateLimitMu.RUnlock()
return map[string]interface{}{
"allowed_ips": len(e.allowedIPs),
"allowed_cidrs": len(e.allowedCIDRs),
"blocked_ips": len(e.blockedIPs),
"blocked_cidrs": len(e.blockedCIDRs),
"security_headers": len(e.securityHeaders),
"rate_limit_enabled": e.rateLimitEnabled,
"active_rate_limits": activeRateLimits,
}
}
// Cleanup cleans up rate limit entries periodically
func (e *SecurityExtension) Cleanup() error {
e.rateLimitMu.Lock()
defer e.rateLimitMu.Unlock()
now := time.Now()
for ip, entry := range e.rateLimitByIP {
if now.After(entry.resetTime) {
delete(e.rateLimitByIP, ip)
}
}
return nil
}