forked from aegis/pyserveX
- Introduced IgnoreRequestPath option in proxy configuration to allow exact match routing. - Implemented proxy_pass directive in routing extension to handle backend requests. - Enhanced error handling for backend unavailability and timeouts. - Added integration tests for reverse proxy, including basic requests, exact match routes, regex routes, header forwarding, and query string preservation. - Created helper functions for setting up test servers and backends, along with assertion utilities for response validation. - Updated server initialization to support extension management and middleware chaining. - Improved logging for debugging purposes during request handling.
313 lines
7.9 KiB
Go
313 lines
7.9 KiB
Go
package extension
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/konduktor/konduktor/internal/logging"
|
|
)
|
|
|
|
// SecurityExtension provides security features like IP filtering and security headers
|
|
type SecurityExtension struct {
|
|
BaseExtension
|
|
allowedIPs map[string]bool
|
|
blockedIPs map[string]bool
|
|
allowedCIDRs []*net.IPNet
|
|
blockedCIDRs []*net.IPNet
|
|
securityHeaders map[string]string
|
|
|
|
// Rate limiting
|
|
rateLimitEnabled bool
|
|
rateLimitRequests int
|
|
rateLimitWindow time.Duration
|
|
rateLimitByIP map[string]*rateLimitEntry
|
|
rateLimitMu sync.RWMutex
|
|
}
|
|
|
|
type rateLimitEntry struct {
|
|
count int
|
|
resetTime time.Time
|
|
}
|
|
|
|
// SecurityConfig holds security extension configuration
|
|
type SecurityConfig struct {
|
|
AllowedIPs []string `yaml:"allowed_ips"`
|
|
BlockedIPs []string `yaml:"blocked_ips"`
|
|
SecurityHeaders map[string]string `yaml:"security_headers"`
|
|
RateLimit *RateLimitConfig `yaml:"rate_limit"`
|
|
}
|
|
|
|
// RateLimitConfig holds rate limiting configuration
|
|
type RateLimitConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
Requests int `yaml:"requests"`
|
|
Window string `yaml:"window"` // e.g., "1m", "1h"
|
|
}
|
|
|
|
// Default security headers
|
|
var defaultSecurityHeaders = map[string]string{
|
|
"X-Content-Type-Options": "nosniff",
|
|
"X-Frame-Options": "DENY",
|
|
"X-XSS-Protection": "1; mode=block",
|
|
"Referrer-Policy": "strict-origin-when-cross-origin",
|
|
}
|
|
|
|
// NewSecurityExtension creates a new security extension
|
|
func NewSecurityExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
|
|
ext := &SecurityExtension{
|
|
BaseExtension: NewBaseExtension("security", 10, logger), // High priority (early execution)
|
|
allowedIPs: make(map[string]bool),
|
|
blockedIPs: make(map[string]bool),
|
|
allowedCIDRs: make([]*net.IPNet, 0),
|
|
blockedCIDRs: make([]*net.IPNet, 0),
|
|
securityHeaders: make(map[string]string),
|
|
rateLimitByIP: make(map[string]*rateLimitEntry),
|
|
}
|
|
|
|
// Copy default security headers
|
|
for k, v := range defaultSecurityHeaders {
|
|
ext.securityHeaders[k] = v
|
|
}
|
|
|
|
// Parse allowed_ips
|
|
if allowedIPs, ok := config["allowed_ips"].([]interface{}); ok {
|
|
for _, ip := range allowedIPs {
|
|
if ipStr, ok := ip.(string); ok {
|
|
ext.addAllowedIP(ipStr)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse blocked_ips
|
|
if blockedIPs, ok := config["blocked_ips"].([]interface{}); ok {
|
|
for _, ip := range blockedIPs {
|
|
if ipStr, ok := ip.(string); ok {
|
|
ext.addBlockedIP(ipStr)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse security_headers
|
|
if headers, ok := config["security_headers"].(map[string]interface{}); ok {
|
|
for k, v := range headers {
|
|
if vStr, ok := v.(string); ok {
|
|
ext.securityHeaders[k] = vStr
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse rate_limit
|
|
if rateLimit, ok := config["rate_limit"].(map[string]interface{}); ok {
|
|
if enabled, ok := rateLimit["enabled"].(bool); ok && enabled {
|
|
ext.rateLimitEnabled = true
|
|
|
|
if requests, ok := rateLimit["requests"].(int); ok {
|
|
ext.rateLimitRequests = requests
|
|
} else if requestsFloat, ok := rateLimit["requests"].(float64); ok {
|
|
ext.rateLimitRequests = int(requestsFloat)
|
|
} else {
|
|
ext.rateLimitRequests = 100 // default
|
|
}
|
|
|
|
if window, ok := rateLimit["window"].(string); ok {
|
|
if duration, err := time.ParseDuration(window); err == nil {
|
|
ext.rateLimitWindow = duration
|
|
} else {
|
|
ext.rateLimitWindow = time.Minute // default
|
|
}
|
|
} else {
|
|
ext.rateLimitWindow = time.Minute
|
|
}
|
|
|
|
logger.Info("Rate limiting enabled",
|
|
"requests", ext.rateLimitRequests,
|
|
"window", ext.rateLimitWindow.String())
|
|
}
|
|
}
|
|
|
|
return ext, nil
|
|
}
|
|
|
|
func (e *SecurityExtension) addAllowedIP(ip string) {
|
|
if strings.Contains(ip, "/") {
|
|
// CIDR notation
|
|
_, cidr, err := net.ParseCIDR(ip)
|
|
if err == nil {
|
|
e.allowedCIDRs = append(e.allowedCIDRs, cidr)
|
|
}
|
|
} else {
|
|
e.allowedIPs[ip] = true
|
|
}
|
|
}
|
|
|
|
func (e *SecurityExtension) addBlockedIP(ip string) {
|
|
if strings.Contains(ip, "/") {
|
|
// CIDR notation
|
|
_, cidr, err := net.ParseCIDR(ip)
|
|
if err == nil {
|
|
e.blockedCIDRs = append(e.blockedCIDRs, cidr)
|
|
}
|
|
} else {
|
|
e.blockedIPs[ip] = true
|
|
}
|
|
}
|
|
|
|
// ProcessRequest checks security rules
|
|
func (e *SecurityExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
|
clientIP := getClientIP(r)
|
|
parsedIP := net.ParseIP(clientIP)
|
|
|
|
// Check blocked IPs first
|
|
if e.isBlocked(clientIP, parsedIP) {
|
|
e.logger.Warn("Blocked request from IP", "ip", clientIP)
|
|
http.Error(w, "403 Forbidden", http.StatusForbidden)
|
|
return true, nil
|
|
}
|
|
|
|
// Check allowed IPs (if configured, only these IPs are allowed)
|
|
if len(e.allowedIPs) > 0 || len(e.allowedCIDRs) > 0 {
|
|
if !e.isAllowed(clientIP, parsedIP) {
|
|
e.logger.Warn("Access denied for IP", "ip", clientIP)
|
|
http.Error(w, "403 Forbidden", http.StatusForbidden)
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
// Check rate limit
|
|
if e.rateLimitEnabled {
|
|
if !e.checkRateLimit(clientIP) {
|
|
e.logger.Warn("Rate limit exceeded", "ip", clientIP)
|
|
w.Header().Set("Retry-After", "60")
|
|
http.Error(w, "429 Too Many Requests", http.StatusTooManyRequests)
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// ProcessResponse adds security headers to the response
|
|
func (e *SecurityExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
for header, value := range e.securityHeaders {
|
|
w.Header().Set(header, value)
|
|
}
|
|
}
|
|
|
|
func (e *SecurityExtension) isBlocked(ip string, parsedIP net.IP) bool {
|
|
// Check exact match
|
|
if e.blockedIPs[ip] {
|
|
return true
|
|
}
|
|
|
|
// Check CIDR ranges
|
|
if parsedIP != nil {
|
|
for _, cidr := range e.blockedCIDRs {
|
|
if cidr.Contains(parsedIP) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (e *SecurityExtension) isAllowed(ip string, parsedIP net.IP) bool {
|
|
// Check exact match
|
|
if e.allowedIPs[ip] {
|
|
return true
|
|
}
|
|
|
|
// Check CIDR ranges
|
|
if parsedIP != nil {
|
|
for _, cidr := range e.allowedCIDRs {
|
|
if cidr.Contains(parsedIP) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (e *SecurityExtension) checkRateLimit(ip string) bool {
|
|
e.rateLimitMu.Lock()
|
|
defer e.rateLimitMu.Unlock()
|
|
|
|
now := time.Now()
|
|
entry, exists := e.rateLimitByIP[ip]
|
|
|
|
if !exists || now.After(entry.resetTime) {
|
|
// Create new entry or reset expired one
|
|
e.rateLimitByIP[ip] = &rateLimitEntry{
|
|
count: 1,
|
|
resetTime: now.Add(e.rateLimitWindow),
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Increment counter
|
|
entry.count++
|
|
return entry.count <= e.rateLimitRequests
|
|
}
|
|
|
|
// AddBlockedIP adds an IP to the blocked list at runtime
|
|
func (e *SecurityExtension) AddBlockedIP(ip string) {
|
|
e.addBlockedIP(ip)
|
|
}
|
|
|
|
// RemoveBlockedIP removes an IP from the blocked list
|
|
func (e *SecurityExtension) RemoveBlockedIP(ip string) {
|
|
delete(e.blockedIPs, ip)
|
|
}
|
|
|
|
// AddAllowedIP adds an IP to the allowed list at runtime
|
|
func (e *SecurityExtension) AddAllowedIP(ip string) {
|
|
e.addAllowedIP(ip)
|
|
}
|
|
|
|
// RemoveAllowedIP removes an IP from the allowed list
|
|
func (e *SecurityExtension) RemoveAllowedIP(ip string) {
|
|
delete(e.allowedIPs, ip)
|
|
}
|
|
|
|
// SetSecurityHeader sets or updates a security header
|
|
func (e *SecurityExtension) SetSecurityHeader(name, value string) {
|
|
e.securityHeaders[name] = value
|
|
}
|
|
|
|
// GetMetrics returns security metrics
|
|
func (e *SecurityExtension) GetMetrics() map[string]interface{} {
|
|
e.rateLimitMu.RLock()
|
|
activeRateLimits := len(e.rateLimitByIP)
|
|
e.rateLimitMu.RUnlock()
|
|
|
|
return map[string]interface{}{
|
|
"allowed_ips": len(e.allowedIPs),
|
|
"allowed_cidrs": len(e.allowedCIDRs),
|
|
"blocked_ips": len(e.blockedIPs),
|
|
"blocked_cidrs": len(e.blockedCIDRs),
|
|
"security_headers": len(e.securityHeaders),
|
|
"rate_limit_enabled": e.rateLimitEnabled,
|
|
"active_rate_limits": activeRateLimits,
|
|
}
|
|
}
|
|
|
|
// Cleanup cleans up rate limit entries periodically
|
|
func (e *SecurityExtension) Cleanup() error {
|
|
e.rateLimitMu.Lock()
|
|
defer e.rateLimitMu.Unlock()
|
|
|
|
now := time.Now()
|
|
for ip, entry := range e.rateLimitByIP {
|
|
if now.After(entry.resetTime) {
|
|
delete(e.rateLimitByIP, ip)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|