// Package proxy provides reverse proxy functionality for Konduktor package proxy import ( "context" "fmt" "io" "net" "net/http" "net/url" "strings" "time" "github.com/konduktor/konduktor/internal/logging" ) type Config struct { // Target is the backend server URL Target string // Timeout is the request timeout (default: 30s) Timeout time.Duration // Headers are additional headers to add to requests Headers map[string]string // StripPrefix removes this prefix from the request path StripPrefix string // PreserveHost keeps the original Host header PreserveHost bool // IgnoreRequestPath ignores the request path and uses only the target path // This is useful for exact match routes where target URL should be used as-is IgnoreRequestPath bool } type ReverseProxy struct { config *Config targetURL *url.URL httpClient *http.Client logger *logging.Logger } func New(cfg *Config, logger *logging.Logger) (*ReverseProxy, error) { if cfg.Target == "" { return nil, fmt.Errorf("proxy target is required") } targetURL, err := url.Parse(cfg.Target) if err != nil { return nil, fmt.Errorf("invalid proxy target URL: %w", err) } timeout := cfg.Timeout if timeout == 0 { timeout = 30 * time.Second } transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, MaxIdleConns: 100, MaxIdleConnsPerHost: 10, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: timeout, } return &ReverseProxy{ config: cfg, targetURL: targetURL, httpClient: &http.Client{ Transport: transport, Timeout: timeout, CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse // Don't follow redirects }, }, logger: logger, }, nil } func (rp *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { rp.ProxyRequest(w, r, nil) } func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, params map[string]string) { ctx := r.Context() // Build target URL targetURL := rp.buildTargetURL(r) // Create proxy request proxyReq, err := rp.createProxyRequest(ctx, r, targetURL) if err != nil { rp.handleError(w, http.StatusInternalServerError, "Failed to create proxy request", err) return } // Add custom headers with parameter substitution rp.addCustomHeaders(proxyReq, r, params) // Execute request resp, err := rp.httpClient.Do(proxyReq) if err != nil { rp.handleProxyError(w, err) return } defer resp.Body.Close() // Copy response rp.copyResponse(w, resp) } func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL { targetURL := *rp.targetURL // If ignoring request path, use target URL path as-is if rp.config.IgnoreRequestPath { // Preserve query string only targetURL.RawQuery = r.URL.RawQuery return &targetURL } // Strip prefix if configured path := r.URL.Path if rp.config.StripPrefix != "" { path = strings.TrimPrefix(path, rp.config.StripPrefix) if path == "" || path[0] != '/' { path = "/" + path } } // If target URL has a non-empty path, combine it with the request path if rp.targetURL.Path != "" && rp.targetURL.Path != "/" { // Combine target path with request path targetURL.Path = strings.TrimSuffix(rp.targetURL.Path, "/") + path } else { // No path in target, use request path as-is targetURL.Path = path } // Preserve query string targetURL.RawQuery = r.URL.RawQuery return &targetURL } func (rp *ReverseProxy) createProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL) (*http.Request, error) { proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL.String(), r.Body) if err != nil { return nil, err } // Copy ContentLength proxyReq.ContentLength = r.ContentLength // Copy headers for key, values := range r.Header { for _, value := range values { proxyReq.Header.Add(key, value) } } // Set/update Host header if rp.config.PreserveHost { proxyReq.Host = r.Host } else { proxyReq.Host = targetURL.Host } // Remove hop-by-hop headers removeHopByHopHeaders(proxyReq.Header) return proxyReq, nil } func (rp *ReverseProxy) addCustomHeaders(proxyReq *http.Request, originalReq *http.Request, params map[string]string) { // Add X-Forwarded headers clientIP := getClientIP(originalReq) if prior := originalReq.Header.Get("X-Forwarded-For"); prior != "" { clientIP = prior + ", " + clientIP } proxyReq.Header.Set("X-Forwarded-For", clientIP) proxyReq.Header.Set("X-Forwarded-Proto", getScheme(originalReq)) proxyReq.Header.Set("X-Forwarded-Host", originalReq.Host) // Add custom headers from config for key, value := range rp.config.Headers { // Substitute parameters like {version} substituted := value for paramKey, paramValue := range params { substituted = strings.ReplaceAll(substituted, "{"+paramKey+"}", paramValue) } // Substitute $remote_addr substituted = strings.ReplaceAll(substituted, "$remote_addr", clientIP) proxyReq.Header.Set(key, substituted) } } func (rp *ReverseProxy) copyResponse(w http.ResponseWriter, resp *http.Response) { // Copy headers for key, values := range resp.Header { for _, value := range values { w.Header().Add(key, value) } } // Remove hop-by-hop headers from response removeHopByHopHeaders(w.Header()) // Write status code w.WriteHeader(resp.StatusCode) // Copy body io.Copy(w, resp.Body) } func (rp *ReverseProxy) handleError(w http.ResponseWriter, status int, message string, err error) { if rp.logger != nil { rp.logger.Error(message, "error", err) } http.Error(w, message, status) } func (rp *ReverseProxy) handleProxyError(w http.ResponseWriter, err error) { if rp.logger != nil { rp.logger.Error("Proxy request failed", "error", err) } // Check for timeout if err, ok := err.(net.Error); ok && err.Timeout() { http.Error(w, "504 Gateway Timeout", http.StatusGatewayTimeout) return } // Check for connection errors if isConnectionError(err) { http.Error(w, "502 Bad Gateway", http.StatusBadGateway) return } // Context cancelled (client disconnected) if err == context.Canceled { return } http.Error(w, "502 Bad Gateway", http.StatusBadGateway) } // Helper functions func singleJoiningSlash(a, b string) string { aslash := strings.HasSuffix(a, "/") bslash := strings.HasPrefix(b, "/") switch { case aslash && bslash: return a + b[1:] case !aslash && !bslash: return a + "/" + b } return a + b } func removeHopByHopHeaders(h http.Header) { hopByHopHeaders := []string{ "Connection", "Proxy-Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", "Te", "Trailer", "Transfer-Encoding", "Upgrade", } for _, header := range hopByHopHeaders { h.Del(header) } } func getClientIP(r *http.Request) string { // Check X-Real-IP first if ip := r.Header.Get("X-Real-IP"); ip != "" { return ip } // Get from RemoteAddr host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host } func getScheme(r *http.Request) string { if r.TLS != nil { return "https" } if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" { return scheme } return "http" } func isConnectionError(err error) bool { if err == nil { return false } errStr := err.Error() connectionErrors := []string{ "connection refused", "no such host", "network is unreachable", "connection reset", "broken pipe", } for _, connErr := range connectionErrors { if strings.Contains(strings.ToLower(errStr), connErr) { return true } } return false }