forked from aegis/pyserveX
321 lines
7.1 KiB
Go
321 lines
7.1 KiB
Go
// Package proxy provides reverse proxy functionality for Konduktor
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/konduktor/konduktor/internal/logging"
|
|
)
|
|
|
|
type Config struct {
|
|
// Target is the backend server URL
|
|
Target string
|
|
|
|
// Timeout is the request timeout (default: 30s)
|
|
Timeout time.Duration
|
|
|
|
// Headers are additional headers to add to requests
|
|
Headers map[string]string
|
|
|
|
// StripPrefix removes this prefix from the request path
|
|
StripPrefix string
|
|
|
|
// PreserveHost keeps the original Host header
|
|
PreserveHost bool
|
|
}
|
|
|
|
type ReverseProxy struct {
|
|
config *Config
|
|
targetURL *url.URL
|
|
httpClient *http.Client
|
|
logger *logging.Logger
|
|
}
|
|
|
|
func New(cfg *Config, logger *logging.Logger) (*ReverseProxy, error) {
|
|
if cfg.Target == "" {
|
|
return nil, fmt.Errorf("proxy target is required")
|
|
}
|
|
|
|
targetURL, err := url.Parse(cfg.Target)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid proxy target URL: %w", err)
|
|
}
|
|
|
|
timeout := cfg.Timeout
|
|
if timeout == 0 {
|
|
timeout = 30 * time.Second
|
|
}
|
|
|
|
transport := &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
DialContext: (&net.Dialer{
|
|
Timeout: 10 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}).DialContext,
|
|
MaxIdleConns: 100,
|
|
MaxIdleConnsPerHost: 10,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
ResponseHeaderTimeout: timeout,
|
|
}
|
|
|
|
return &ReverseProxy{
|
|
config: cfg,
|
|
targetURL: targetURL,
|
|
httpClient: &http.Client{
|
|
Transport: transport,
|
|
Timeout: timeout,
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse // Don't follow redirects
|
|
},
|
|
},
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
func (rp *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
rp.ProxyRequest(w, r, nil)
|
|
}
|
|
|
|
func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, params map[string]string) {
|
|
ctx := r.Context()
|
|
|
|
// Build target URL
|
|
targetURL := rp.buildTargetURL(r)
|
|
|
|
// Create proxy request
|
|
proxyReq, err := rp.createProxyRequest(ctx, r, targetURL)
|
|
if err != nil {
|
|
rp.handleError(w, http.StatusInternalServerError, "Failed to create proxy request", err)
|
|
return
|
|
}
|
|
|
|
// Add custom headers with parameter substitution
|
|
rp.addCustomHeaders(proxyReq, r, params)
|
|
|
|
// Execute request
|
|
resp, err := rp.httpClient.Do(proxyReq)
|
|
if err != nil {
|
|
rp.handleProxyError(w, err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Copy response
|
|
rp.copyResponse(w, resp)
|
|
}
|
|
|
|
func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL {
|
|
targetURL := *rp.targetURL
|
|
|
|
// Strip prefix if configured
|
|
path := r.URL.Path
|
|
if rp.config.StripPrefix != "" {
|
|
path = strings.TrimPrefix(path, rp.config.StripPrefix)
|
|
if path == "" || path[0] != '/' {
|
|
path = "/" + path
|
|
}
|
|
}
|
|
|
|
// If target has a path, append request path to it
|
|
if rp.targetURL.Path != "" && rp.targetURL.Path != "/" {
|
|
targetURL.Path = singleJoiningSlash(rp.targetURL.Path, path)
|
|
} else {
|
|
targetURL.Path = path
|
|
}
|
|
|
|
// Preserve query string
|
|
targetURL.RawQuery = r.URL.RawQuery
|
|
|
|
return &targetURL
|
|
}
|
|
|
|
func (rp *ReverseProxy) createProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL) (*http.Request, error) {
|
|
proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL.String(), r.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Copy ContentLength
|
|
proxyReq.ContentLength = r.ContentLength
|
|
|
|
// Copy headers
|
|
for key, values := range r.Header {
|
|
for _, value := range values {
|
|
proxyReq.Header.Add(key, value)
|
|
}
|
|
}
|
|
|
|
// Set/update Host header
|
|
if rp.config.PreserveHost {
|
|
proxyReq.Host = r.Host
|
|
} else {
|
|
proxyReq.Host = targetURL.Host
|
|
}
|
|
|
|
// Remove hop-by-hop headers
|
|
removeHopByHopHeaders(proxyReq.Header)
|
|
|
|
return proxyReq, nil
|
|
}
|
|
|
|
func (rp *ReverseProxy) addCustomHeaders(proxyReq *http.Request, originalReq *http.Request, params map[string]string) {
|
|
// Add X-Forwarded headers
|
|
clientIP := getClientIP(originalReq)
|
|
if prior := originalReq.Header.Get("X-Forwarded-For"); prior != "" {
|
|
clientIP = prior + ", " + clientIP
|
|
}
|
|
proxyReq.Header.Set("X-Forwarded-For", clientIP)
|
|
proxyReq.Header.Set("X-Forwarded-Proto", getScheme(originalReq))
|
|
proxyReq.Header.Set("X-Forwarded-Host", originalReq.Host)
|
|
|
|
// Add custom headers from config
|
|
for key, value := range rp.config.Headers {
|
|
// Substitute parameters like {version}
|
|
substituted := value
|
|
for paramKey, paramValue := range params {
|
|
substituted = strings.ReplaceAll(substituted, "{"+paramKey+"}", paramValue)
|
|
}
|
|
// Substitute $remote_addr
|
|
substituted = strings.ReplaceAll(substituted, "$remote_addr", clientIP)
|
|
proxyReq.Header.Set(key, substituted)
|
|
}
|
|
}
|
|
|
|
func (rp *ReverseProxy) copyResponse(w http.ResponseWriter, resp *http.Response) {
|
|
// Copy headers
|
|
for key, values := range resp.Header {
|
|
for _, value := range values {
|
|
w.Header().Add(key, value)
|
|
}
|
|
}
|
|
|
|
// Remove hop-by-hop headers from response
|
|
removeHopByHopHeaders(w.Header())
|
|
|
|
// Write status code
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
// Copy body
|
|
io.Copy(w, resp.Body)
|
|
}
|
|
|
|
func (rp *ReverseProxy) handleError(w http.ResponseWriter, status int, message string, err error) {
|
|
if rp.logger != nil {
|
|
rp.logger.Error(message, "error", err)
|
|
}
|
|
http.Error(w, message, status)
|
|
}
|
|
|
|
func (rp *ReverseProxy) handleProxyError(w http.ResponseWriter, err error) {
|
|
if rp.logger != nil {
|
|
rp.logger.Error("Proxy request failed", "error", err)
|
|
}
|
|
|
|
// Check for timeout
|
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
|
http.Error(w, "504 Gateway Timeout", http.StatusGatewayTimeout)
|
|
return
|
|
}
|
|
|
|
// Check for connection errors
|
|
if isConnectionError(err) {
|
|
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// Context cancelled (client disconnected)
|
|
if err == context.Canceled {
|
|
return
|
|
}
|
|
|
|
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
func singleJoiningSlash(a, b string) string {
|
|
aslash := strings.HasSuffix(a, "/")
|
|
bslash := strings.HasPrefix(b, "/")
|
|
switch {
|
|
case aslash && bslash:
|
|
return a + b[1:]
|
|
case !aslash && !bslash:
|
|
return a + "/" + b
|
|
}
|
|
return a + b
|
|
}
|
|
|
|
func removeHopByHopHeaders(h http.Header) {
|
|
hopByHopHeaders := []string{
|
|
"Connection",
|
|
"Proxy-Connection",
|
|
"Keep-Alive",
|
|
"Proxy-Authenticate",
|
|
"Proxy-Authorization",
|
|
"Te",
|
|
"Trailer",
|
|
"Transfer-Encoding",
|
|
"Upgrade",
|
|
}
|
|
|
|
for _, header := range hopByHopHeaders {
|
|
h.Del(header)
|
|
}
|
|
}
|
|
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Real-IP first
|
|
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
|
return ip
|
|
}
|
|
|
|
// Get from RemoteAddr
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return host
|
|
}
|
|
|
|
func getScheme(r *http.Request) string {
|
|
if r.TLS != nil {
|
|
return "https"
|
|
}
|
|
if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
|
return scheme
|
|
}
|
|
return "http"
|
|
}
|
|
|
|
func isConnectionError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
|
|
errStr := err.Error()
|
|
connectionErrors := []string{
|
|
"connection refused",
|
|
"no such host",
|
|
"network is unreachable",
|
|
"connection reset",
|
|
"broken pipe",
|
|
}
|
|
|
|
for _, connErr := range connectionErrors {
|
|
if strings.Contains(strings.ToLower(errStr), connErr) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|