Илья Глазунов 881028c1e6 feat: Add reverse proxy functionality with enhanced routing capabilities
- Introduced IgnoreRequestPath option in proxy configuration to allow exact match routing.
- Implemented proxy_pass directive in routing extension to handle backend requests.
- Enhanced error handling for backend unavailability and timeouts.
- Added integration tests for reverse proxy, including basic requests, exact match routes, regex routes, header forwarding, and query string preservation.
- Created helper functions for setting up test servers and backends, along with assertion utilities for response validation.
- Updated server initialization to support extension management and middleware chaining.
- Improved logging for debugging purposes during request handling.
2025-12-12 00:38:30 +03:00

429 lines
10 KiB
Go

package extension
import (
"context"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/konduktor/konduktor/internal/logging"
"github.com/konduktor/konduktor/internal/proxy"
)
// RoutingExtension handles request routing based on patterns
type RoutingExtension struct {
BaseExtension
exactRoutes map[string]RouteConfig
regexRoutes []*regexRoute
defaultRoute *RouteConfig
staticDir string
}
// RouteConfig holds configuration for a route
type RouteConfig struct {
ProxyPass string
Root string
IndexFile string
Return string
ContentType string
Headers []string
CacheControl string
SPAFallback bool
ExcludePatterns []string
Timeout float64
}
type regexRoute struct {
pattern *regexp.Regexp
config RouteConfig
caseSensitive bool
originalExpr string
}
// NewRoutingExtension creates a new routing extension
func NewRoutingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
ext := &RoutingExtension{
BaseExtension: NewBaseExtension("routing", 50, logger), // Middle priority
exactRoutes: make(map[string]RouteConfig),
regexRoutes: make([]*regexRoute, 0),
staticDir: "./static",
}
logger.Debug("Routing extension config", "config", config)
// Parse regex_locations from config
if locations, ok := config["regex_locations"].(map[string]interface{}); ok {
logger.Debug("Found regex_locations", "count", len(locations))
for pattern, routeCfg := range locations {
logger.Debug("Adding route", "pattern", pattern)
if rc, ok := routeCfg.(map[string]interface{}); ok {
ext.addRoute(pattern, parseRouteConfig(rc))
}
}
} else {
logger.Warn("No regex_locations found in config", "config_keys", getKeys(config))
}
// Parse static_dir if provided
if staticDir, ok := config["static_dir"].(string); ok {
ext.staticDir = staticDir
}
return ext, nil
}
func parseRouteConfig(cfg map[string]interface{}) RouteConfig {
rc := RouteConfig{
IndexFile: "index.html",
}
if v, ok := cfg["proxy_pass"].(string); ok {
rc.ProxyPass = v
}
if v, ok := cfg["root"].(string); ok {
rc.Root = v
}
if v, ok := cfg["index_file"].(string); ok {
rc.IndexFile = v
}
if v, ok := cfg["return"].(string); ok {
rc.Return = v
}
if v, ok := cfg["content_type"].(string); ok {
rc.ContentType = v
}
if v, ok := cfg["cache_control"].(string); ok {
rc.CacheControl = v
}
if v, ok := cfg["spa_fallback"].(bool); ok {
rc.SPAFallback = v
}
if v, ok := cfg["timeout"].(float64); ok {
rc.Timeout = v
}
// Parse headers
if headers, ok := cfg["headers"].([]interface{}); ok {
for _, h := range headers {
if header, ok := h.(string); ok {
rc.Headers = append(rc.Headers, header)
}
}
}
// Parse exclude_patterns
if patterns, ok := cfg["exclude_patterns"].([]interface{}); ok {
for _, p := range patterns {
if pattern, ok := p.(string); ok {
rc.ExcludePatterns = append(rc.ExcludePatterns, pattern)
}
}
}
return rc
}
func (e *RoutingExtension) addRoute(pattern string, config RouteConfig) {
switch {
case pattern == "__default__":
e.defaultRoute = &config
case strings.HasPrefix(pattern, "="):
// Exact match
path := strings.TrimPrefix(pattern, "=")
e.exactRoutes[path] = config
case strings.HasPrefix(pattern, "~*"):
// Case-insensitive regex
expr := strings.TrimPrefix(pattern, "~*")
re, err := regexp.Compile("(?i)" + expr)
if err != nil {
e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
return
}
e.regexRoutes = append(e.regexRoutes, &regexRoute{
pattern: re,
config: config,
caseSensitive: false,
originalExpr: expr,
})
case strings.HasPrefix(pattern, "~"):
// Case-sensitive regex
expr := strings.TrimPrefix(pattern, "~")
re, err := regexp.Compile(expr)
if err != nil {
e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
return
}
e.regexRoutes = append(e.regexRoutes, &regexRoute{
pattern: re,
config: config,
caseSensitive: true,
originalExpr: expr,
})
}
}
// ProcessRequest handles the request routing
func (e *RoutingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
path := r.URL.Path
// 1. Check exact routes (ignore request path for proxy)
if config, ok := e.exactRoutes[path]; ok {
return e.handleRoute(w, r, config, nil, true)
}
// 2. Check regex routes
for _, route := range e.regexRoutes {
match := route.pattern.FindStringSubmatch(path)
if match != nil {
params := make(map[string]string)
names := route.pattern.SubexpNames()
for i, name := range names {
if i > 0 && name != "" && i < len(match) {
params[name] = match[i]
}
}
return e.handleRoute(w, r, route.config, params, false)
}
}
// 3. Check default route
if e.defaultRoute != nil {
return e.handleRoute(w, r, *e.defaultRoute, nil, false)
}
return false, nil
}
func (e *RoutingExtension) handleRoute(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) {
// Handle "return" directive
if config.Return != "" {
return e.handleReturn(w, config)
}
// Handle proxy_pass
if config.ProxyPass != "" {
return e.handleProxy(w, r, config, params, exactMatch)
}
// Handle static files with root
if config.Root != "" {
return e.handleStatic(w, r, config)
}
// Handle SPA fallback
if config.SPAFallback {
return e.handleSPAFallback(w, r, config)
}
return false, nil
}
func (e *RoutingExtension) handleReturn(w http.ResponseWriter, config RouteConfig) (bool, error) {
parts := strings.SplitN(config.Return, " ", 2)
statusCode := 200
body := "OK"
if len(parts) >= 1 {
switch parts[0] {
case "200":
statusCode = 200
case "201":
statusCode = 201
case "301":
statusCode = 301
case "302":
statusCode = 302
case "400":
statusCode = 400
case "404":
statusCode = 404
case "500":
statusCode = 500
}
}
if len(parts) >= 2 {
body = parts[1]
}
contentType := "text/plain"
if config.ContentType != "" {
contentType = config.ContentType
}
w.Header().Set("Content-Type", contentType)
w.WriteHeader(statusCode)
w.Write([]byte(body))
return true, nil
}
func (e *RoutingExtension) handleProxy(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) {
target := config.ProxyPass
// Check if target URL contains parameter placeholders
hasParams := strings.Contains(target, "{") && strings.Contains(target, "}")
// Substitute params in target URL
for key, value := range params {
target = strings.ReplaceAll(target, "{"+key+"}", value)
}
// Create proxy config
// IgnoreRequestPath=true when:
// - exact match route (=/path)
// - target URL had parameter substitutions (the target path is fully specified)
proxyConfig := &proxy.Config{
Target: target,
Headers: make(map[string]string),
IgnoreRequestPath: exactMatch || hasParams,
}
// Set timeout if specified
if config.Timeout > 0 {
proxyConfig.Timeout = time.Duration(config.Timeout * float64(time.Second))
}
// Parse headers
clientIP := getClientIP(r)
for _, header := range config.Headers {
parts := strings.SplitN(header, ": ", 2)
if len(parts) == 2 {
value := parts[1]
// Substitute params
for key, pValue := range params {
value = strings.ReplaceAll(value, "{"+key+"}", pValue)
}
// Substitute special variables
value = strings.ReplaceAll(value, "$remote_addr", clientIP)
proxyConfig.Headers[parts[0]] = value
}
}
p, err := proxy.New(proxyConfig, e.logger)
if err != nil {
e.logger.Error("Failed to create proxy", "target", target, "error", err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return true, nil
}
p.ProxyRequest(w, r, params)
return true, nil
}
func (e *RoutingExtension) handleStatic(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) {
path := r.URL.Path
// Handle index file for root or directory paths
if path == "/" || strings.HasSuffix(path, "/") {
path = "/" + config.IndexFile
}
// Get absolute path for root dir
absRoot, err := filepath.Abs(config.Root)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return true, nil
}
filePath := filepath.Join(absRoot, filepath.Clean("/"+path))
cleanPath := filepath.Clean(filePath)
// Prevent directory traversal
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) {
http.Error(w, "Forbidden", http.StatusForbidden)
return true, nil
}
// Check if file exists
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return false, nil // Let other handlers try
}
// Set cache control header
if config.CacheControl != "" {
w.Header().Set("Cache-Control", config.CacheControl)
}
// Set custom headers
for _, header := range config.Headers {
parts := strings.SplitN(header, ": ", 2)
if len(parts) == 2 {
w.Header().Set(parts[0], parts[1])
}
}
http.ServeFile(w, r, filePath)
return true, nil
}
func (e *RoutingExtension) handleSPAFallback(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) {
path := r.URL.Path
// Check exclude patterns
for _, pattern := range config.ExcludePatterns {
if strings.HasPrefix(path, pattern) {
return false, nil
}
}
root := config.Root
if root == "" {
root = e.staticDir
}
indexFile := config.IndexFile
if indexFile == "" {
indexFile = "index.html"
}
filePath := filepath.Join(root, indexFile)
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return false, nil
}
http.ServeFile(w, r, filePath)
return true, nil
}
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0])
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// Fall back to RemoteAddr
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}
// GetMetrics returns routing metrics
func (e *RoutingExtension) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"exact_routes": len(e.exactRoutes),
"regex_routes": len(e.regexRoutes),
"has_default": e.defaultRoute != nil,
}
}
func getKeys(m map[string]interface{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}