- Introduced IgnoreRequestPath option in proxy configuration to allow exact match routing. - Implemented proxy_pass directive in routing extension to handle backend requests. - Enhanced error handling for backend unavailability and timeouts. - Added integration tests for reverse proxy, including basic requests, exact match routes, regex routes, header forwarding, and query string preservation. - Created helper functions for setting up test servers and backends, along with assertion utilities for response validation. - Updated server initialization to support extension management and middleware chaining. - Improved logging for debugging purposes during request handling.
429 lines
10 KiB
Go
429 lines
10 KiB
Go
package extension
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/konduktor/konduktor/internal/logging"
|
|
"github.com/konduktor/konduktor/internal/proxy"
|
|
)
|
|
|
|
// RoutingExtension handles request routing based on patterns
|
|
type RoutingExtension struct {
|
|
BaseExtension
|
|
exactRoutes map[string]RouteConfig
|
|
regexRoutes []*regexRoute
|
|
defaultRoute *RouteConfig
|
|
staticDir string
|
|
}
|
|
|
|
// RouteConfig holds configuration for a route
|
|
type RouteConfig struct {
|
|
ProxyPass string
|
|
Root string
|
|
IndexFile string
|
|
Return string
|
|
ContentType string
|
|
Headers []string
|
|
CacheControl string
|
|
SPAFallback bool
|
|
ExcludePatterns []string
|
|
Timeout float64
|
|
}
|
|
|
|
type regexRoute struct {
|
|
pattern *regexp.Regexp
|
|
config RouteConfig
|
|
caseSensitive bool
|
|
originalExpr string
|
|
}
|
|
|
|
// NewRoutingExtension creates a new routing extension
|
|
func NewRoutingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
|
|
ext := &RoutingExtension{
|
|
BaseExtension: NewBaseExtension("routing", 50, logger), // Middle priority
|
|
exactRoutes: make(map[string]RouteConfig),
|
|
regexRoutes: make([]*regexRoute, 0),
|
|
staticDir: "./static",
|
|
}
|
|
|
|
logger.Debug("Routing extension config", "config", config)
|
|
|
|
// Parse regex_locations from config
|
|
if locations, ok := config["regex_locations"].(map[string]interface{}); ok {
|
|
logger.Debug("Found regex_locations", "count", len(locations))
|
|
for pattern, routeCfg := range locations {
|
|
logger.Debug("Adding route", "pattern", pattern)
|
|
if rc, ok := routeCfg.(map[string]interface{}); ok {
|
|
ext.addRoute(pattern, parseRouteConfig(rc))
|
|
}
|
|
}
|
|
} else {
|
|
logger.Warn("No regex_locations found in config", "config_keys", getKeys(config))
|
|
}
|
|
|
|
// Parse static_dir if provided
|
|
if staticDir, ok := config["static_dir"].(string); ok {
|
|
ext.staticDir = staticDir
|
|
}
|
|
|
|
return ext, nil
|
|
}
|
|
|
|
func parseRouteConfig(cfg map[string]interface{}) RouteConfig {
|
|
rc := RouteConfig{
|
|
IndexFile: "index.html",
|
|
}
|
|
|
|
if v, ok := cfg["proxy_pass"].(string); ok {
|
|
rc.ProxyPass = v
|
|
}
|
|
if v, ok := cfg["root"].(string); ok {
|
|
rc.Root = v
|
|
}
|
|
if v, ok := cfg["index_file"].(string); ok {
|
|
rc.IndexFile = v
|
|
}
|
|
if v, ok := cfg["return"].(string); ok {
|
|
rc.Return = v
|
|
}
|
|
if v, ok := cfg["content_type"].(string); ok {
|
|
rc.ContentType = v
|
|
}
|
|
if v, ok := cfg["cache_control"].(string); ok {
|
|
rc.CacheControl = v
|
|
}
|
|
if v, ok := cfg["spa_fallback"].(bool); ok {
|
|
rc.SPAFallback = v
|
|
}
|
|
if v, ok := cfg["timeout"].(float64); ok {
|
|
rc.Timeout = v
|
|
}
|
|
|
|
// Parse headers
|
|
if headers, ok := cfg["headers"].([]interface{}); ok {
|
|
for _, h := range headers {
|
|
if header, ok := h.(string); ok {
|
|
rc.Headers = append(rc.Headers, header)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse exclude_patterns
|
|
if patterns, ok := cfg["exclude_patterns"].([]interface{}); ok {
|
|
for _, p := range patterns {
|
|
if pattern, ok := p.(string); ok {
|
|
rc.ExcludePatterns = append(rc.ExcludePatterns, pattern)
|
|
}
|
|
}
|
|
}
|
|
|
|
return rc
|
|
}
|
|
|
|
func (e *RoutingExtension) addRoute(pattern string, config RouteConfig) {
|
|
switch {
|
|
case pattern == "__default__":
|
|
e.defaultRoute = &config
|
|
|
|
case strings.HasPrefix(pattern, "="):
|
|
// Exact match
|
|
path := strings.TrimPrefix(pattern, "=")
|
|
e.exactRoutes[path] = config
|
|
|
|
case strings.HasPrefix(pattern, "~*"):
|
|
// Case-insensitive regex
|
|
expr := strings.TrimPrefix(pattern, "~*")
|
|
re, err := regexp.Compile("(?i)" + expr)
|
|
if err != nil {
|
|
e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
|
return
|
|
}
|
|
e.regexRoutes = append(e.regexRoutes, ®exRoute{
|
|
pattern: re,
|
|
config: config,
|
|
caseSensitive: false,
|
|
originalExpr: expr,
|
|
})
|
|
|
|
case strings.HasPrefix(pattern, "~"):
|
|
// Case-sensitive regex
|
|
expr := strings.TrimPrefix(pattern, "~")
|
|
re, err := regexp.Compile(expr)
|
|
if err != nil {
|
|
e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
|
return
|
|
}
|
|
e.regexRoutes = append(e.regexRoutes, ®exRoute{
|
|
pattern: re,
|
|
config: config,
|
|
caseSensitive: true,
|
|
originalExpr: expr,
|
|
})
|
|
}
|
|
}
|
|
|
|
// ProcessRequest handles the request routing
|
|
func (e *RoutingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
|
path := r.URL.Path
|
|
|
|
// 1. Check exact routes (ignore request path for proxy)
|
|
if config, ok := e.exactRoutes[path]; ok {
|
|
return e.handleRoute(w, r, config, nil, true)
|
|
}
|
|
|
|
// 2. Check regex routes
|
|
for _, route := range e.regexRoutes {
|
|
match := route.pattern.FindStringSubmatch(path)
|
|
if match != nil {
|
|
params := make(map[string]string)
|
|
names := route.pattern.SubexpNames()
|
|
for i, name := range names {
|
|
if i > 0 && name != "" && i < len(match) {
|
|
params[name] = match[i]
|
|
}
|
|
}
|
|
return e.handleRoute(w, r, route.config, params, false)
|
|
}
|
|
}
|
|
|
|
// 3. Check default route
|
|
if e.defaultRoute != nil {
|
|
return e.handleRoute(w, r, *e.defaultRoute, nil, false)
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func (e *RoutingExtension) handleRoute(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) {
|
|
// Handle "return" directive
|
|
if config.Return != "" {
|
|
return e.handleReturn(w, config)
|
|
}
|
|
|
|
// Handle proxy_pass
|
|
if config.ProxyPass != "" {
|
|
return e.handleProxy(w, r, config, params, exactMatch)
|
|
}
|
|
|
|
// Handle static files with root
|
|
if config.Root != "" {
|
|
return e.handleStatic(w, r, config)
|
|
}
|
|
|
|
// Handle SPA fallback
|
|
if config.SPAFallback {
|
|
return e.handleSPAFallback(w, r, config)
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func (e *RoutingExtension) handleReturn(w http.ResponseWriter, config RouteConfig) (bool, error) {
|
|
parts := strings.SplitN(config.Return, " ", 2)
|
|
statusCode := 200
|
|
body := "OK"
|
|
|
|
if len(parts) >= 1 {
|
|
switch parts[0] {
|
|
case "200":
|
|
statusCode = 200
|
|
case "201":
|
|
statusCode = 201
|
|
case "301":
|
|
statusCode = 301
|
|
case "302":
|
|
statusCode = 302
|
|
case "400":
|
|
statusCode = 400
|
|
case "404":
|
|
statusCode = 404
|
|
case "500":
|
|
statusCode = 500
|
|
}
|
|
}
|
|
if len(parts) >= 2 {
|
|
body = parts[1]
|
|
}
|
|
|
|
contentType := "text/plain"
|
|
if config.ContentType != "" {
|
|
contentType = config.ContentType
|
|
}
|
|
|
|
w.Header().Set("Content-Type", contentType)
|
|
w.WriteHeader(statusCode)
|
|
w.Write([]byte(body))
|
|
return true, nil
|
|
}
|
|
|
|
func (e *RoutingExtension) handleProxy(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) {
|
|
target := config.ProxyPass
|
|
|
|
// Check if target URL contains parameter placeholders
|
|
hasParams := strings.Contains(target, "{") && strings.Contains(target, "}")
|
|
|
|
// Substitute params in target URL
|
|
for key, value := range params {
|
|
target = strings.ReplaceAll(target, "{"+key+"}", value)
|
|
}
|
|
|
|
// Create proxy config
|
|
// IgnoreRequestPath=true when:
|
|
// - exact match route (=/path)
|
|
// - target URL had parameter substitutions (the target path is fully specified)
|
|
proxyConfig := &proxy.Config{
|
|
Target: target,
|
|
Headers: make(map[string]string),
|
|
IgnoreRequestPath: exactMatch || hasParams,
|
|
}
|
|
|
|
// Set timeout if specified
|
|
if config.Timeout > 0 {
|
|
proxyConfig.Timeout = time.Duration(config.Timeout * float64(time.Second))
|
|
}
|
|
|
|
// Parse headers
|
|
clientIP := getClientIP(r)
|
|
for _, header := range config.Headers {
|
|
parts := strings.SplitN(header, ": ", 2)
|
|
if len(parts) == 2 {
|
|
value := parts[1]
|
|
// Substitute params
|
|
for key, pValue := range params {
|
|
value = strings.ReplaceAll(value, "{"+key+"}", pValue)
|
|
}
|
|
// Substitute special variables
|
|
value = strings.ReplaceAll(value, "$remote_addr", clientIP)
|
|
proxyConfig.Headers[parts[0]] = value
|
|
}
|
|
}
|
|
|
|
p, err := proxy.New(proxyConfig, e.logger)
|
|
if err != nil {
|
|
e.logger.Error("Failed to create proxy", "target", target, "error", err)
|
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
return true, nil
|
|
}
|
|
|
|
p.ProxyRequest(w, r, params)
|
|
return true, nil
|
|
}
|
|
|
|
func (e *RoutingExtension) handleStatic(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) {
|
|
path := r.URL.Path
|
|
|
|
// Handle index file for root or directory paths
|
|
if path == "/" || strings.HasSuffix(path, "/") {
|
|
path = "/" + config.IndexFile
|
|
}
|
|
|
|
// Get absolute path for root dir
|
|
absRoot, err := filepath.Abs(config.Root)
|
|
if err != nil {
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return true, nil
|
|
}
|
|
|
|
filePath := filepath.Join(absRoot, filepath.Clean("/"+path))
|
|
cleanPath := filepath.Clean(filePath)
|
|
|
|
// Prevent directory traversal
|
|
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) {
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return true, nil
|
|
}
|
|
|
|
// Check if file exists
|
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
|
return false, nil // Let other handlers try
|
|
}
|
|
|
|
// Set cache control header
|
|
if config.CacheControl != "" {
|
|
w.Header().Set("Cache-Control", config.CacheControl)
|
|
}
|
|
|
|
// Set custom headers
|
|
for _, header := range config.Headers {
|
|
parts := strings.SplitN(header, ": ", 2)
|
|
if len(parts) == 2 {
|
|
w.Header().Set(parts[0], parts[1])
|
|
}
|
|
}
|
|
|
|
http.ServeFile(w, r, filePath)
|
|
return true, nil
|
|
}
|
|
|
|
func (e *RoutingExtension) handleSPAFallback(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) {
|
|
path := r.URL.Path
|
|
|
|
// Check exclude patterns
|
|
for _, pattern := range config.ExcludePatterns {
|
|
if strings.HasPrefix(path, pattern) {
|
|
return false, nil
|
|
}
|
|
}
|
|
|
|
root := config.Root
|
|
if root == "" {
|
|
root = e.staticDir
|
|
}
|
|
|
|
indexFile := config.IndexFile
|
|
if indexFile == "" {
|
|
indexFile = "index.html"
|
|
}
|
|
|
|
filePath := filepath.Join(root, indexFile)
|
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
|
return false, nil
|
|
}
|
|
|
|
http.ServeFile(w, r, filePath)
|
|
return true, nil
|
|
}
|
|
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header first
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
parts := strings.Split(xff, ",")
|
|
return strings.TrimSpace(parts[0])
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return xri
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
ip := r.RemoteAddr
|
|
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
|
ip = ip[:idx]
|
|
}
|
|
return ip
|
|
}
|
|
|
|
// GetMetrics returns routing metrics
|
|
func (e *RoutingExtension) GetMetrics() map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"exact_routes": len(e.exactRoutes),
|
|
"regex_routes": len(e.regexRoutes),
|
|
"has_default": e.defaultRoute != nil,
|
|
}
|
|
}
|
|
|
|
func getKeys(m map[string]interface{}) []string {
|
|
keys := make([]string, 0, len(m))
|
|
for k := range m {
|
|
keys = append(keys, k)
|
|
}
|
|
return keys
|
|
}
|