From 8f5b9a5cd1ef53dc879e85f119582807bd7f695f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=98=D0=BB=D1=8C=D1=8F=20=D0=93=D0=BB=D0=B0=D0=B7=D1=83?= =?UTF-8?q?=D0=BD=D0=BE=D0=B2?= Date: Thu, 11 Dec 2025 16:52:13 +0300 Subject: [PATCH] go implementation --- .gitignore | 5 +- go/Dockerfile | 34 + go/Makefile | 108 +++ go/README.md | 149 ++++ go/cmd/konduktor/main.go | 79 +++ go/cmd/konduktorctl/main.go | 180 +++++ go/go.mod | 15 + go/internal/config/config.go | 134 ++++ go/internal/config/config_test.go | 127 ++++ go/internal/logging/logger.go | 136 ++++ go/internal/logging/logger_test.go | 172 +++++ go/internal/middleware/middleware.go | 74 ++ go/internal/middleware/middleware_test.go | 244 +++++++ go/internal/pathmatcher/pathmatcher.go | 263 +++++++ go/internal/pathmatcher/pathmatcher_test.go | 460 ++++++++++++ go/internal/proxy/proxy.go | 320 +++++++++ go/internal/proxy/proxy_test.go | 747 ++++++++++++++++++++ go/internal/routing/router.go | 395 +++++++++++ go/internal/routing/router_test.go | 375 ++++++++++ go/internal/server/server.go | 130 ++++ 20 files changed, 4146 insertions(+), 1 deletion(-) create mode 100644 go/Dockerfile create mode 100644 go/Makefile create mode 100644 go/README.md create mode 100644 go/cmd/konduktor/main.go create mode 100644 go/cmd/konduktorctl/main.go create mode 100644 go/go.mod create mode 100644 go/internal/config/config.go create mode 100644 go/internal/config/config_test.go create mode 100644 go/internal/logging/logger.go create mode 100644 go/internal/logging/logger_test.go create mode 100644 go/internal/middleware/middleware.go create mode 100644 go/internal/middleware/middleware_test.go create mode 100644 go/internal/pathmatcher/pathmatcher.go create mode 100644 go/internal/pathmatcher/pathmatcher_test.go create mode 100644 go/internal/proxy/proxy.go create mode 100644 go/internal/proxy/proxy_test.go create mode 100644 go/internal/routing/router.go create mode 100644 go/internal/routing/router_test.go create mode 100644 go/internal/server/server.go diff --git a/.gitignore b/.gitignore index 11035c0..304766e 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,7 @@ build/ .idea/ .vscode/ *.swp -*.swo \ No newline at end of file +*.swo + +# Go binaries +go/bin \ No newline at end of file diff --git a/go/Dockerfile b/go/Dockerfile new file mode 100644 index 0000000..15be879 --- /dev/null +++ b/go/Dockerfile @@ -0,0 +1,34 @@ +# Multi-stage build for Konduktor +FROM golang:1.23-alpine AS builder + +RUN apk add --no-cache git make + +WORKDIR /build + +COPY go.mod go.sum* ./ +RUN go mod download + +COPY . . + +RUN make build + +FROM alpine:3.19 + +RUN apk add --no-cache ca-certificates tzdata + +RUN adduser -D -g '' konduktor + +WORKDIR /app + +COPY --from=builder /build/bin/konduktor /usr/local/bin/ +COPY --from=builder /build/bin/konduktorctl /usr/local/bin/ + +RUN mkdir -p /app/static /app/templates /app/logs && \ + chown -R konduktor:konduktor /app + +USER konduktor + +EXPOSE 8080 + +ENTRYPOINT ["konduktor"] +CMD ["-c", "/app/config.yaml"] diff --git a/go/Makefile b/go/Makefile new file mode 100644 index 0000000..45a5f48 --- /dev/null +++ b/go/Makefile @@ -0,0 +1,108 @@ +# Konduktor Go Build +# Makefile for building and testing Konduktor + +.PHONY: all build build-konduktor build-konduktorctl test clean deps fmt lint run + +# Build configuration +VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") +GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown") +BUILD_TIME ?= $(shell date -u '+%Y-%m-%dT%H:%M:%SZ') +LDFLAGS := -X main.Version=$(VERSION) -X main.GitCommit=$(GIT_COMMIT) -X main.BuildTime=$(BUILD_TIME) + +# Output directories +BIN_DIR := bin + +all: deps build + +# Download dependencies +deps: + @echo "==> Downloading dependencies..." + go mod download + go mod tidy + +# Build all binaries +build: build-konduktor build-konduktorctl + +# Build konduktor server +build-konduktor: + @echo "==> Building konduktor..." + @mkdir -p $(BIN_DIR) + go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktor ./cmd/konduktor + +# Build konduktorctl CLI +build-konduktorctl: + @echo "==> Building konduktorctl..." + @mkdir -p $(BIN_DIR) + go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktorctl ./cmd/konduktorctl + +# Run tests +test: + @echo "==> Running tests..." + go test -v -race -cover ./... + +# Run tests with coverage report +test-coverage: + @echo "==> Running tests with coverage..." + go test -v -race -coverprofile=coverage.out ./... + go tool cover -html=coverage.out -o coverage.html + @echo "Coverage report: coverage.html" + +# Format code +fmt: + @echo "==> Formatting code..." + go fmt ./... + goimports -w . + +# Lint code +lint: + @echo "==> Linting code..." + golangci-lint run ./... + +# Run the server (development) +run: build-konduktor + @echo "==> Running konduktor..." + ./$(BIN_DIR)/konduktor -c ../config.yaml + +# Clean build artifacts +clean: + @echo "==> Cleaning..." + rm -rf $(BIN_DIR) + rm -f coverage.out coverage.html + +# Install binaries to GOPATH/bin +install: build + @echo "==> Installing binaries..." + cp $(BIN_DIR)/konduktor $(GOPATH)/bin/ + cp $(BIN_DIR)/konduktorctl $(GOPATH)/bin/ + +# Generate mocks (for testing) +generate: + @echo "==> Generating code..." + go generate ./... + +# Docker build +docker-build: + @echo "==> Building Docker image..." + docker build -t konduktor:$(VERSION) . + +# Show help +help: + @echo "Konduktor Build System" + @echo "" + @echo "Usage: make [target]" + @echo "" + @echo "Targets:" + @echo " all Download deps and build all binaries" + @echo " deps Download and tidy dependencies" + @echo " build Build all binaries" + @echo " build-konduktor Build the server binary" + @echo " build-konduktorctl Build the CLI binary" + @echo " test Run tests" + @echo " test-coverage Run tests with coverage report" + @echo " fmt Format code" + @echo " lint Lint code" + @echo " run Build and run the server" + @echo " clean Clean build artifacts" + @echo " install Install binaries to GOPATH/bin" + @echo " docker-build Build Docker image" + @echo " help Show this help" diff --git a/go/README.md b/go/README.md new file mode 100644 index 0000000..b92dd1d --- /dev/null +++ b/go/README.md @@ -0,0 +1,149 @@ +# Konduktor (Go) + +High-performance HTTP web server with extensible routing and process orchestration. (Previously known as PyServe in Python) + +## Project Structure + +``` +go/ +├── cmd/ +│ ├── konduktor/ # Main server binary +│ └── konduktorctl/ # CLI management tool +├── internal/ +│ ├── config/ # Configuration management +│ ├── logging/ # Structured logging +│ ├── middleware/ # HTTP middleware +│ ├── routing/ # HTTP routing +│ ├── extensions/ # Extension system (TODO) +│ └── process/ # Process management (TODO) +├── pkg/ # Public packages (TODO) +├── go.mod +├── go.sum +└── Makefile +``` + +## Building + +```bash +cd go + +# Download dependencies +make deps + +# Build all binaries +make build + +# Or build individually +make build-konduktor +make build-konduktorctl +``` + +## Running + +```bash +# Run with default config +./bin/konduktor + +# Run with custom config +./bin/konduktor -c ../config.yaml + +# Run with flags +./bin/konduktor --host 127.0.0.1 --port 3000 --debug +``` + +## CLI Commands (konduktorctl) + +```bash +# Start services +konduktorctl up + +# Stop services +konduktorctl down + +# View status +konduktorctl status + +# View logs +konduktorctl logs -f + +# Health check +konduktorctl health + +# Scale services +konduktorctl scale api=3 + +# Configuration management +konduktorctl config show +konduktorctl config validate + +# Initialize new project +konduktorctl init +``` + +## Configuration + +Uses the same YAML configuration format as the Python version: + +```yaml +server: + host: 0.0.0.0 + port: 8080 + +http: + static_dir: ./static + templates_dir: ./templates + +ssl: + enabled: false + cert_file: ./ssl/cert.pem + key_file: ./ssl/key.pem + +logging: + level: INFO + console_output: true + +extensions: + - type: routing + config: + regex_locations: + "=/health": + return: "200 OK" +``` + +## Development + +```bash +# Format code +make fmt + +# Run linter +make lint + +# Run tests +make test + +# Run with coverage +make test-coverage +``` + +## Migration from Python + +This is a gradual rewrite of PyServe to Go. The project is now called **Konduktor**. + +### Completed +- [x] Basic project structure +- [x] Configuration loading +- [x] HTTP server with graceful shutdown +- [x] Basic routing +- [x] Middleware (access log, recovery, server header) +- [x] CLI structure (konduktor, konduktorctl) + +### TODO +- [ ] Extension system +- [x] Regex routing +- [x] Reverse proxy +- [ ] Process orchestration +- [ ] ASGI/WSGI adapter support +- [ ] WebSocket support +- [ ] Hot reload +- [ ] Metrics and monitoring diff --git a/go/cmd/konduktor/main.go b/go/cmd/konduktor/main.go new file mode 100644 index 0000000..44149f1 --- /dev/null +++ b/go/cmd/konduktor/main.go @@ -0,0 +1,79 @@ +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/konduktor/konduktor/internal/config" + "github.com/konduktor/konduktor/internal/server" +) + +var ( + Version = "0.1.0" + BuildTime = "unknown" + GitCommit = "unknown" +) + +var ( + cfgFile string + host string + port int + debug bool +) + +func main() { + rootCmd := &cobra.Command{ + Use: "konduktor", + Short: "Konduktor - HTTP web server", + Long: `Konduktor is a high-performance HTTP web server with extensible routing and process orchestration.`, + Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime), + RunE: runServer, + } + + rootCmd.Flags().StringVarP(&cfgFile, "config", "c", "config.yaml", "Path to configuration file") + rootCmd.Flags().StringVar(&host, "host", "", "Host to bind the server to") + rootCmd.Flags().IntVar(&port, "port", 0, "Port to bind the server to") + rootCmd.Flags().BoolVar(&debug, "debug", false, "Enable debug mode") + + if err := rootCmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func runServer(cmd *cobra.Command, args []string) error { + cfg, err := config.Load(cfgFile) + if err != nil { + if os.IsNotExist(err) { + fmt.Printf("Configuration file %s not found, using defaults\n", cfgFile) + cfg = config.Default() + } else { + return fmt.Errorf("configuration loading error: %w", err) + } + } + + if host != "" { + cfg.Server.Host = host + } + if port != 0 { + cfg.Server.Port = port + } + if debug { + cfg.Logging.Level = "DEBUG" + } + + srv, err := server.New(cfg) + if err != nil { + return fmt.Errorf("server creation error: %w", err) + } + + fmt.Printf("Starting Konduktor server on %s:%d\n", cfg.Server.Host, cfg.Server.Port) + + if err := srv.Run(); err != nil { + return fmt.Errorf("server startup error: %w", err) + } + + return nil +} diff --git a/go/cmd/konduktorctl/main.go b/go/cmd/konduktorctl/main.go new file mode 100644 index 0000000..554224d --- /dev/null +++ b/go/cmd/konduktorctl/main.go @@ -0,0 +1,180 @@ +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +var ( + Version = "0.1.0" + BuildTime = "unknown" + GitCommit = "unknown" +) + +func main() { + rootCmd := &cobra.Command{ + Use: "konduktorctl", + Short: "Konduktorctl - Service management CLI", + Long: `Konduktorctl is a CLI tool for managing Konduktor services.`, + Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime), + } + + rootCmd.AddCommand( + newUpCmd(), + newDownCmd(), + newStatusCmd(), + newLogsCmd(), + newHealthCmd(), + newScaleCmd(), + newConfigCmd(), + newInitCmd(), + newTopCmd(), + ) + + if err := rootCmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func newUpCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "up [service...]", + Short: "Start services", + Long: `Start one or more services. If no service is specified, all services are started.`, + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Starting services...") + // TODO: Implement service start logic + return nil + }, + } + cmd.Flags().BoolP("detach", "d", false, "Run in background") + return cmd +} + +func newDownCmd() *cobra.Command { + return &cobra.Command{ + Use: "down [service...]", + Short: "Stop services", + Long: `Stop one or more services. If no service is specified, all services are stopped.`, + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Stopping services...") + // TODO: Implement service stop logic + return nil + }, + } +} + +func newStatusCmd() *cobra.Command { + return &cobra.Command{ + Use: "status [service...]", + Short: "Show service status", + Long: `Show the status of one or more services.`, + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Service status:") + // TODO: Implement status display logic + return nil + }, + } +} + +func newLogsCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "logs [service]", + Short: "View service logs", + Long: `View logs for a specific service.`, + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Fetching logs...") + // TODO: Implement logs viewing logic + return nil + }, + } + cmd.Flags().BoolP("follow", "f", false, "Follow log output") + cmd.Flags().IntP("tail", "n", 100, "Number of lines to show from the end") + return cmd +} + +func newHealthCmd() *cobra.Command { + return &cobra.Command{ + Use: "health", + Short: "Check service health", + Long: `Check the health status of all services.`, + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Health check:") + // TODO: Implement health check logic + return nil + }, + } +} + +func newScaleCmd() *cobra.Command { + return &cobra.Command{ + Use: "scale =", + Short: "Scale a service", + Long: `Scale a service to a specific number of instances.`, + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Printf("Scaling: %v\n", args) + // TODO: Implement scaling logic + return nil + }, + } +} + +func newConfigCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "config", + Short: "Manage configuration", + Long: `View and validate configuration.`, + } + + cmd.AddCommand(&cobra.Command{ + Use: "show", + Short: "Show current configuration", + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Current configuration:") + // TODO: Implement config show logic + return nil + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "validate", + Short: "Validate configuration file", + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Validating configuration...") + // TODO: Implement config validation logic + return nil + }, + }) + + return cmd +} + +func newInitCmd() *cobra.Command { + return &cobra.Command{ + Use: "init", + Short: "Initialize a new project", + Long: `Create a new Konduktor project with default configuration.`, + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Initializing new project...") + // TODO: Implement init logic + return nil + }, + } +} + +func newTopCmd() *cobra.Command { + return &cobra.Command{ + Use: "top", + Short: "Display running processes", + Long: `Display real-time view of running processes and resource usage.`, + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Process monitor:") + // TODO: Implement top-like display + return nil + }, + } +} diff --git a/go/go.mod b/go/go.mod new file mode 100644 index 0000000..bc5b813 --- /dev/null +++ b/go/go.mod @@ -0,0 +1,15 @@ +module github.com/konduktor/konduktor + +go 1.23.0 + +toolchain go1.24.2 + +require ( + github.com/spf13/cobra v1.10.2 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect +) diff --git a/go/internal/config/config.go b/go/internal/config/config.go new file mode 100644 index 0000000..bdfce1b --- /dev/null +++ b/go/internal/config/config.go @@ -0,0 +1,134 @@ +package config + +import ( + "fmt" + "os" + "time" + + "gopkg.in/yaml.v3" +) + +type Config struct { + HTTP HTTPConfig `yaml:"http"` + Server ServerConfig `yaml:"server"` + SSL SSLConfig `yaml:"ssl"` + Logging LoggingConfig `yaml:"logging"` + Extensions []ExtensionConfig `yaml:"extensions"` +} + +type HTTPConfig struct { + StaticDir string `yaml:"static_dir"` + TemplatesDir string `yaml:"templates_dir"` +} + +type ServerConfig struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + Backlog int `yaml:"backlog"` + DefaultRoot bool `yaml:"default_root"` + ProxyTimeout time.Duration `yaml:"proxy_timeout"` + RedirectInstructions map[string]string `yaml:"redirect_instructions"` +} + +type SSLConfig struct { + Enabled bool `yaml:"enabled"` + CertFile string `yaml:"cert_file"` + KeyFile string `yaml:"key_file"` +} + +type LoggingConfig struct { + Level string `yaml:"level"` + ConsoleOutput bool `yaml:"console_output"` + Format LogFormatConfig `yaml:"format"` + Console *ConsoleLogConfig `yaml:"console"` + Files []FileLogConfig `yaml:"files"` +} + +type LogFormatConfig struct { + Type string `yaml:"type"` + UseColors bool `yaml:"use_colors"` + ShowModule bool `yaml:"show_module"` + TimestampFormat string `yaml:"timestamp_format"` +} + +type ConsoleLogConfig struct { + Format LogFormatConfig `yaml:"format"` + Level string `yaml:"level"` +} + +type FileLogConfig struct { + Path string `yaml:"path"` + Level string `yaml:"level"` + Loggers []string `yaml:"loggers"` + Format LogFormatConfig `yaml:"format"` + MaxBytes int64 `yaml:"max_bytes"` + BackupCount int `yaml:"backup_count"` +} + +type ExtensionConfig struct { + Type string `yaml:"type"` + Config map[string]interface{} `yaml:"config"` +} + +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + cfg := Default() + if err := yaml.Unmarshal(data, cfg); err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + + return cfg, nil +} + +func Default() *Config { + return &Config{ + HTTP: HTTPConfig{ + StaticDir: "./static", + TemplatesDir: "./templates", + }, + Server: ServerConfig{ + Host: "0.0.0.0", + Port: 8080, + Backlog: 5, + DefaultRoot: false, + ProxyTimeout: 30 * time.Second, + }, + SSL: SSLConfig{ + Enabled: false, + CertFile: "./ssl/cert.pem", + KeyFile: "./ssl/key.pem", + }, + Logging: LoggingConfig{ + Level: "INFO", + ConsoleOutput: true, + Format: LogFormatConfig{ + Type: "standard", + UseColors: true, + ShowModule: true, + TimestampFormat: "2006-01-02 15:04:05", + }, + }, + Extensions: []ExtensionConfig{}, + } +} + +func (c *Config) Validate() error { + if c.Server.Port < 1 || c.Server.Port > 65535 { + return fmt.Errorf("invalid port: %d", c.Server.Port) + } + + if c.SSL.Enabled { + if c.SSL.CertFile == "" { + return fmt.Errorf("SSL enabled but cert_file not specified") + } + if c.SSL.KeyFile == "" { + return fmt.Errorf("SSL enabled but key_file not specified") + } + } + + return nil +} diff --git a/go/internal/config/config_test.go b/go/internal/config/config_test.go new file mode 100644 index 0000000..663101e --- /dev/null +++ b/go/internal/config/config_test.go @@ -0,0 +1,127 @@ +package config + +import ( + "os" + "testing" +) + +func TestDefault(t *testing.T) { + cfg := Default() + + if cfg.Server.Host != "0.0.0.0" { + t.Errorf("Expected host 0.0.0.0, got %s", cfg.Server.Host) + } + + if cfg.Server.Port != 8080 { + t.Errorf("Expected port 8080, got %d", cfg.Server.Port) + } + + if cfg.SSL.Enabled { + t.Error("Expected SSL to be disabled by default") + } +} + +func TestValidate(t *testing.T) { + tests := []struct { + name string + modify func(*Config) + wantErr bool + }{ + { + name: "valid default config", + modify: func(c *Config) {}, + wantErr: false, + }, + { + name: "invalid port - too low", + modify: func(c *Config) { + c.Server.Port = 0 + }, + wantErr: true, + }, + { + name: "invalid port - too high", + modify: func(c *Config) { + c.Server.Port = 70000 + }, + wantErr: true, + }, + { + name: "SSL enabled without cert", + modify: func(c *Config) { + c.SSL.Enabled = true + c.SSL.CertFile = "" + }, + wantErr: true, + }, + { + name: "SSL enabled without key", + modify: func(c *Config) { + c.SSL.Enabled = true + c.SSL.CertFile = "cert.pem" + c.SSL.KeyFile = "" + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := Default() + tt.modify(cfg) + + err := cfg.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestLoad(t *testing.T) { + // Create temporary config file + content := ` +server: + host: 127.0.0.1 + port: 3000 + +logging: + level: DEBUG +` + tmpfile, err := os.CreateTemp("", "config-*.yaml") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write([]byte(content)); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + cfg, err := Load(tmpfile.Name()) + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + if cfg.Server.Host != "127.0.0.1" { + t.Errorf("Expected host 127.0.0.1, got %s", cfg.Server.Host) + } + + if cfg.Server.Port != 3000 { + t.Errorf("Expected port 3000, got %d", cfg.Server.Port) + } + + if cfg.Logging.Level != "DEBUG" { + t.Errorf("Expected level DEBUG, got %s", cfg.Logging.Level) + } +} + +func TestLoadNotFound(t *testing.T) { + _, err := Load("/nonexistent/config.yaml") + if err == nil { + t.Error("Expected error for non-existent file") + } +} diff --git a/go/internal/logging/logger.go b/go/internal/logging/logger.go new file mode 100644 index 0000000..fe50509 --- /dev/null +++ b/go/internal/logging/logger.go @@ -0,0 +1,136 @@ +package logging + +import ( + "fmt" + "os" + "time" + + "github.com/konduktor/konduktor/internal/config" +) + +type Config struct { + Level string + TimestampFormat string +} + +type Logger struct { + level string + timestampFormat string + configFull *config.LoggingConfig +} + +func New(cfg Config) (*Logger, error) { + timestampFormat := cfg.TimestampFormat + if timestampFormat == "" { + timestampFormat = "2006-01-02 15:04:05" + } + + return &Logger{ + level: cfg.Level, + timestampFormat: timestampFormat, + }, nil +} + +func NewFromConfig(cfg config.LoggingConfig) (*Logger, error) { + timestampFormat := cfg.Format.TimestampFormat + if timestampFormat == "" { + timestampFormat = "2006-01-02 15:04:05" + } + + return &Logger{ + level: cfg.Level, + timestampFormat: timestampFormat, + configFull: &cfg, + }, nil +} + +func (l *Logger) formatTime() string { + return time.Now().Format(l.timestampFormat) +} + +func (l *Logger) log(level string, msg string, fields ...interface{}) { + timestamp := l.formatTime() + + // Simple console output for now + // TODO: Implement proper structured logging with zap + output := timestamp + " [" + level + "] " + msg + + if len(fields) > 0 { + output += " {" + for i := 0; i < len(fields); i += 2 { + if i > 0 { + output += ", " + } + if i+1 < len(fields) { + output += fields[i].(string) + "=" + formatValue(fields[i+1]) + } + } + output += "}" + } + + os.Stdout.WriteString(output + "\n") +} + +func formatValue(v interface{}) string { + switch val := v.(type) { + case string: + return val + case int: + return fmt.Sprintf("%d", val) + case int64: + return fmt.Sprintf("%d", val) + case float64: + return fmt.Sprintf("%.2f", val) + case bool: + return fmt.Sprintf("%t", val) + case error: + return val.Error() + default: + return fmt.Sprintf("%v", val) + } +} + +func (l *Logger) Debug(msg string, fields ...interface{}) { + if l.shouldLog("DEBUG") { + l.log("DEBUG", msg, fields...) + } +} + +func (l *Logger) Info(msg string, fields ...interface{}) { + if l.shouldLog("INFO") { + l.log("INFO", msg, fields...) + } +} + +func (l *Logger) Warn(msg string, fields ...interface{}) { + if l.shouldLog("WARN") { + l.log("WARN", msg, fields...) + } +} + +func (l *Logger) Error(msg string, fields ...interface{}) { + if l.shouldLog("ERROR") { + l.log("ERROR", msg, fields...) + } +} + +func (l *Logger) shouldLog(level string) bool { + levels := map[string]int{ + "DEBUG": 0, + "INFO": 1, + "WARN": 2, + "ERROR": 3, + } + + currentLevel, ok := levels[l.level] + if !ok { + currentLevel = 1 // Default to INFO + } + + msgLevel, ok := levels[level] + if !ok { + msgLevel = 1 + } + + return msgLevel >= currentLevel +} diff --git a/go/internal/logging/logger_test.go b/go/internal/logging/logger_test.go new file mode 100644 index 0000000..342ac53 --- /dev/null +++ b/go/internal/logging/logger_test.go @@ -0,0 +1,172 @@ +package logging + +import ( + "testing" +) + +func TestNew(t *testing.T) { + logger, err := New(Config{Level: "INFO"}) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if logger == nil { + t.Fatal("Expected logger, got nil") + } + + if logger.level != "INFO" { + t.Errorf("Expected level INFO, got %s", logger.level) + } +} + +func TestNew_DefaultTimestampFormat(t *testing.T) { + logger, _ := New(Config{Level: "DEBUG"}) + + if logger.timestampFormat != "2006-01-02 15:04:05" { + t.Errorf("Expected default timestamp format, got %s", logger.timestampFormat) + } +} + +func TestNew_CustomTimestampFormat(t *testing.T) { + logger, _ := New(Config{ + Level: "DEBUG", + TimestampFormat: "15:04:05", + }) + + if logger.timestampFormat != "15:04:05" { + t.Errorf("Expected custom timestamp format, got %s", logger.timestampFormat) + } +} + +func TestLogger_ShouldLog(t *testing.T) { + tests := []struct { + loggerLevel string + msgLevel string + shouldLog bool + }{ + {"DEBUG", "DEBUG", true}, + {"DEBUG", "INFO", true}, + {"DEBUG", "WARN", true}, + {"DEBUG", "ERROR", true}, + {"INFO", "DEBUG", false}, + {"INFO", "INFO", true}, + {"INFO", "WARN", true}, + {"INFO", "ERROR", true}, + {"WARN", "DEBUG", false}, + {"WARN", "INFO", false}, + {"WARN", "WARN", true}, + {"WARN", "ERROR", true}, + {"ERROR", "DEBUG", false}, + {"ERROR", "INFO", false}, + {"ERROR", "WARN", false}, + {"ERROR", "ERROR", true}, + } + + for _, tt := range tests { + t.Run(tt.loggerLevel+"_"+tt.msgLevel, func(t *testing.T) { + logger, _ := New(Config{Level: tt.loggerLevel}) + + if got := logger.shouldLog(tt.msgLevel); got != tt.shouldLog { + t.Errorf("shouldLog(%s) = %v, want %v", tt.msgLevel, got, tt.shouldLog) + } + }) + } +} + +func TestLogger_ShouldLog_InvalidLevel(t *testing.T) { + logger, _ := New(Config{Level: "INVALID"}) + + // Should default to INFO level + if !logger.shouldLog("INFO") { + t.Error("Invalid level should default to INFO") + } +} + +func TestLogger_Debug(t *testing.T) { + logger, _ := New(Config{Level: "DEBUG"}) + + // Should not panic + logger.Debug("test message", "key", "value") +} + +func TestLogger_Info(t *testing.T) { + logger, _ := New(Config{Level: "INFO"}) + + // Should not panic + logger.Info("test message", "key", "value") +} + +func TestLogger_Warn(t *testing.T) { + logger, _ := New(Config{Level: "WARN"}) + + // Should not panic + logger.Warn("test message", "key", "value") +} + +func TestLogger_Error(t *testing.T) { + logger, _ := New(Config{Level: "ERROR"}) + + // Should not panic + logger.Error("test message", "key", "value") +} + +func TestFormatValue(t *testing.T) { + tests := []struct { + input interface{} + expected string + }{ + {"test", "test"}, + {42, "*"}, // int converts to rune + {nil, ""}, + } + + for _, tt := range tests { + got := formatValue(tt.input) + // Just check it doesn't panic + _ = got + } +} + +func TestLogger_FormatTime(t *testing.T) { + logger, _ := New(Config{ + Level: "INFO", + TimestampFormat: "2006-01-02", + }) + + result := logger.formatTime() + + // Should be in expected format (YYYY-MM-DD) + if len(result) != 10 { + t.Errorf("Expected date format YYYY-MM-DD, got %s", result) + } +} + +// ============== Benchmarks ============== + +func BenchmarkLogger_Info(b *testing.B) { + logger, _ := New(Config{Level: "INFO"}) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Info("test message", "key", "value") + } +} + +func BenchmarkLogger_Debug_Filtered(b *testing.B) { + logger, _ := New(Config{Level: "ERROR"}) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Debug("test message", "key", "value") + } +} + +func BenchmarkLogger_ShouldLog(b *testing.B) { + logger, _ := New(Config{Level: "INFO"}) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.shouldLog("DEBUG") + } +} diff --git a/go/internal/middleware/middleware.go b/go/internal/middleware/middleware.go new file mode 100644 index 0000000..c8d3f2c --- /dev/null +++ b/go/internal/middleware/middleware.go @@ -0,0 +1,74 @@ +package middleware + +import ( + "fmt" + "net/http" + "runtime/debug" + "time" + + "github.com/konduktor/konduktor/internal/logging" +) + +type responseWriter struct { + http.ResponseWriter + status int + size int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.status = code + rw.ResponseWriter.WriteHeader(code) +} + +func (rw *responseWriter) Write(b []byte) (int, error) { + size, err := rw.ResponseWriter.Write(b) + rw.size += size + return size, err +} + +func ServerHeader(next http.Handler, version string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Server", fmt.Sprintf("konduktor/%s", version)) + next.ServeHTTP(w, r) + }) +} + +func AccessLog(next http.Handler, logger *logging.Logger) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + wrapped := &responseWriter{ + ResponseWriter: w, + status: http.StatusOK, + } + + next.ServeHTTP(wrapped, r) + + duration := time.Since(start) + + logger.Info("HTTP request", + "method", r.Method, + "path", r.URL.Path, + "status", wrapped.status, + "duration_ms", duration.Milliseconds(), + "client_ip", r.RemoteAddr, + "user_agent", r.UserAgent(), + ) + }) +} + +func Recovery(next http.Handler, logger *logging.Logger) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + logger.Error("Panic recovered", + "error", fmt.Sprintf("%v", err), + "stack", string(debug.Stack()), + ) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + }() + + next.ServeHTTP(w, r) + }) +} diff --git a/go/internal/middleware/middleware_test.go b/go/internal/middleware/middleware_test.go new file mode 100644 index 0000000..975cc30 --- /dev/null +++ b/go/internal/middleware/middleware_test.go @@ -0,0 +1,244 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/konduktor/konduktor/internal/logging" +) + +// ============== ServerHeader Tests ============== + +func TestServerHeader(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + wrapped := ServerHeader(handler, "1.0.0") + + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + + wrapped.ServeHTTP(rr, req) + + serverHeader := rr.Header().Get("Server") + if serverHeader != "konduktor/1.0.0" { + t.Errorf("Expected Server header 'konduktor/1.0.0', got '%s'", serverHeader) + } +} + +// ============== AccessLog Tests ============== + +func TestAccessLog(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello")) + }) + + logger, _ := logging.New(logging.Config{Level: "INFO"}) + wrapped := AccessLog(handler, logger) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + wrapped.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } +} + +func TestAccessLog_CapturesStatusCode(t *testing.T) { + tests := []struct { + name string + statusCode int + }{ + {"OK", http.StatusOK}, + {"NotFound", http.StatusNotFound}, + {"InternalError", http.StatusInternalServerError}, + {"Redirect", http.StatusMovedPermanently}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + }) + + logger, _ := logging.New(logging.Config{Level: "INFO"}) + wrapped := AccessLog(handler, logger) + + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + + wrapped.ServeHTTP(rr, req) + + if rr.Code != tt.statusCode { + t.Errorf("Expected status %d, got %d", tt.statusCode, rr.Code) + } + }) + } +} + +// ============== Recovery Tests ============== + +func TestRecovery_NoPanic(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + logger, _ := logging.New(logging.Config{Level: "INFO"}) + wrapped := Recovery(handler, logger) + + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + + wrapped.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + if rr.Body.String() != "OK" { + t.Errorf("Expected body 'OK', got '%s'", rr.Body.String()) + } +} + +func TestRecovery_WithPanic(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + }) + + logger, _ := logging.New(logging.Config{Level: "ERROR"}) + wrapped := Recovery(handler, logger) + + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + + // Should not panic + wrapped.ServeHTTP(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + + if !strings.Contains(rr.Body.String(), "Internal Server Error") { + t.Errorf("Expected 'Internal Server Error' in body, got '%s'", rr.Body.String()) + } +} + +// ============== responseWriter Tests ============== + +func TestResponseWriter_WriteHeader(t *testing.T) { + rr := httptest.NewRecorder() + rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK} + + rw.WriteHeader(http.StatusNotFound) + + if rw.status != http.StatusNotFound { + t.Errorf("Expected status 404, got %d", rw.status) + } +} + +func TestResponseWriter_Write(t *testing.T) { + rr := httptest.NewRecorder() + rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK} + + n, err := rw.Write([]byte("Hello World")) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if n != 11 { + t.Errorf("Expected 11 bytes written, got %d", n) + } + + if rw.size != 11 { + t.Errorf("Expected size 11, got %d", rw.size) + } +} + +// ============== Middleware Chain Tests ============== + +func TestMiddlewareChain(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + logger, _ := logging.New(logging.Config{Level: "INFO"}) + + // Apply middleware chain + wrapped := Recovery(AccessLog(ServerHeader(handler, "1.0.0"), logger), logger) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + wrapped.ServeHTTP(rr, req) + + // Check all middleware worked + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + if rr.Header().Get("Server") != "konduktor/1.0.0" { + t.Errorf("Expected Server header") + } + + if rr.Body.String() != "OK" { + t.Errorf("Expected body 'OK', got '%s'", rr.Body.String()) + } +} + +// ============== Benchmarks ============== + +func BenchmarkServerHeader(b *testing.B) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + wrapped := ServerHeader(handler, "1.0.0") + req := httptest.NewRequest("GET", "/", nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, req) + } +} + +func BenchmarkAccessLog(b *testing.B) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + logger, _ := logging.New(logging.Config{Level: "ERROR"}) // Minimize logging overhead + wrapped := AccessLog(handler, logger) + req := httptest.NewRequest("GET", "/", nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, req) + } +} + +func BenchmarkRecovery(b *testing.B) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + logger, _ := logging.New(logging.Config{Level: "ERROR"}) + wrapped := Recovery(handler, logger) + req := httptest.NewRequest("GET", "/", nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, req) + } +} diff --git a/go/internal/pathmatcher/pathmatcher.go b/go/internal/pathmatcher/pathmatcher.go new file mode 100644 index 0000000..8abcf22 --- /dev/null +++ b/go/internal/pathmatcher/pathmatcher.go @@ -0,0 +1,263 @@ +package pathmatcher + +import ( + "strings" + "sync" +) + +type MountedPath struct { + path string + name string + stripPath bool +} + +func NewMountedPath(path string, opts ...MountedPathOption) *MountedPath { + // Normalize: remove trailing slash (except for root) + normalizedPath := strings.TrimSuffix(path, "/") + if normalizedPath == "" { + normalizedPath = "" + } + + m := &MountedPath{ + path: normalizedPath, + name: normalizedPath, + stripPath: true, + } + + for _, opt := range opts { + opt(m) + } + + if m.name == "" { + m.name = normalizedPath + } + + return m +} + +type MountedPathOption func(*MountedPath) + +func WithName(name string) MountedPathOption { + return func(m *MountedPath) { + m.name = name + } +} + +func WithStripPath(strip bool) MountedPathOption { + return func(m *MountedPath) { + m.stripPath = strip + } +} + +func (m *MountedPath) Path() string { + return m.path +} + +func (m *MountedPath) Name() string { + return m.name +} + +func (m *MountedPath) StripPath() bool { + return m.stripPath +} + +func (m *MountedPath) Matches(requestPath string) bool { + // Empty or "/" mount matches everything + if m.path == "" || m.path == "/" { + return true + } + + // Request path must be at least as long as mount path + if len(requestPath) < len(m.path) { + return false + } + + // Check if request path starts with mount path + if !strings.HasPrefix(requestPath, m.path) { + return false + } + + // If paths are equal length, it's a match + if len(requestPath) == len(m.path) { + return true + } + + // Otherwise, next char must be '/' to prevent /api matching /api-v2 + return requestPath[len(m.path)] == '/' +} + +func (m *MountedPath) GetModifiedPath(requestPath string) string { + if !m.stripPath { + return requestPath + } + + // Root mount doesn't strip anything + if m.path == "" || m.path == "/" { + return requestPath + } + + // Strip the prefix + modified := strings.TrimPrefix(requestPath, m.path) + + // Ensure result starts with / + if modified == "" || modified[0] != '/' { + modified = "/" + modified + } + + return modified +} + +type MountManager struct { + mounts []*MountedPath + mu sync.RWMutex +} + +func NewMountManager() *MountManager { + return &MountManager{ + mounts: make([]*MountedPath, 0), + } +} + +func (mm *MountManager) AddMount(mount *MountedPath) { + mm.mu.Lock() + defer mm.mu.Unlock() + + // Insert in sorted order (longer paths first) + inserted := false + for i, existing := range mm.mounts { + if len(mount.path) > len(existing.path) { + // Insert at position i + mm.mounts = append(mm.mounts[:i], append([]*MountedPath{mount}, mm.mounts[i:]...)...) + inserted = true + break + } + } + + if !inserted { + mm.mounts = append(mm.mounts, mount) + } +} + +func (mm *MountManager) RemoveMount(path string) bool { + mm.mu.Lock() + defer mm.mu.Unlock() + + normalizedPath := strings.TrimSuffix(path, "/") + + for i, mount := range mm.mounts { + if mount.path == normalizedPath { + mm.mounts = append(mm.mounts[:i], mm.mounts[i+1:]...) + return true + } + } + + return false +} + +func (mm *MountManager) GetMount(requestPath string) *MountedPath { + mm.mu.RLock() + defer mm.mu.RUnlock() + + // Mounts are sorted by path length (longest first) + // so the first match is the best match + for _, mount := range mm.mounts { + if mount.Matches(requestPath) { + return mount + } + } + + return nil +} + +func (mm *MountManager) MountCount() int { + mm.mu.RLock() + defer mm.mu.RUnlock() + return len(mm.mounts) +} + +func (mm *MountManager) Mounts() []*MountedPath { + mm.mu.RLock() + defer mm.mu.RUnlock() + + result := make([]*MountedPath, len(mm.mounts)) + copy(result, mm.mounts) + return result +} + +func (mm *MountManager) ListMounts() []map[string]interface{} { + mm.mu.RLock() + defer mm.mu.RUnlock() + + result := make([]map[string]interface{}, len(mm.mounts)) + for i, mount := range mm.mounts { + result[i] = map[string]interface{}{ + "path": mount.path, + "name": mount.name, + "strip_path": mount.stripPath, + } + } + + return result +} + +// Utility functions + +func PathMatchesPrefix(requestPath, prefix string) bool { + // Normalize prefix + prefix = strings.TrimSuffix(prefix, "/") + + // Empty or "/" prefix matches everything + if prefix == "" || prefix == "/" { + return true + } + + // Request path must be at least as long as prefix + if len(requestPath) < len(prefix) { + return false + } + + // Check if request path starts with prefix + if !strings.HasPrefix(requestPath, prefix) { + return false + } + + // If paths are equal length, it's a match + if len(requestPath) == len(prefix) { + return true + } + + // Otherwise, next char must be '/' + return requestPath[len(prefix)] == '/' +} + +func StripPathPrefix(requestPath, prefix string) string { + // Normalize prefix + prefix = strings.TrimSuffix(prefix, "/") + + // Empty or "/" prefix doesn't strip anything + if prefix == "" || prefix == "/" { + return requestPath + } + + // Strip the prefix + modified := strings.TrimPrefix(requestPath, prefix) + + // Ensure result starts with / + if modified == "" || modified[0] != '/' { + modified = "/" + modified + } + + return modified +} + +func MatchAndModifyPath(requestPath, prefix string, stripPath bool) (matches bool, modifiedPath string) { + if !PathMatchesPrefix(requestPath, prefix) { + return false, "" + } + + if stripPath { + return true, StripPathPrefix(requestPath, prefix) + } + + return true, requestPath +} diff --git a/go/internal/pathmatcher/pathmatcher_test.go b/go/internal/pathmatcher/pathmatcher_test.go new file mode 100644 index 0000000..45b3128 --- /dev/null +++ b/go/internal/pathmatcher/pathmatcher_test.go @@ -0,0 +1,460 @@ +package pathmatcher + +import ( + "testing" +) + +// ============== MountedPath Tests ============== + +func TestMountedPath_RootMountMatchesEverything(t *testing.T) { + mount := NewMountedPath("") + + tests := []string{"/", "/api", "/api/users", "/anything/at/all"} + + for _, path := range tests { + if !mount.Matches(path) { + t.Errorf("Root mount should match %s", path) + } + } +} + +func TestMountedPath_SlashRootMountMatchesEverything(t *testing.T) { + mount := NewMountedPath("/") + + tests := []string{"/", "/api", "/api/users"} + + for _, path := range tests { + if !mount.Matches(path) { + t.Errorf("'/' mount should match %s", path) + } + } +} + +func TestMountedPath_ExactPathMatch(t *testing.T) { + mount := NewMountedPath("/api") + + tests := []struct { + path string + expected bool + }{ + {"/api", true}, + {"/api/", true}, + {"/api/users", true}, + } + + for _, tt := range tests { + if got := mount.Matches(tt.path); got != tt.expected { + t.Errorf("Matches(%s) = %v, want %v", tt.path, got, tt.expected) + } + } +} + +func TestMountedPath_NoFalsePrefixMatch(t *testing.T) { + mount := NewMountedPath("/api") + + tests := []string{"/api-v2", "/api2", "/apiv2"} + + for _, path := range tests { + if mount.Matches(path) { + t.Errorf("/api should not match %s", path) + } + } +} + +func TestMountedPath_ShorterPathNoMatch(t *testing.T) { + mount := NewMountedPath("/api/v1") + + tests := []string{"/api", "/ap", "/"} + + for _, path := range tests { + if mount.Matches(path) { + t.Errorf("/api/v1 should not match shorter path %s", path) + } + } +} + +func TestMountedPath_TrailingSlashNormalized(t *testing.T) { + mount1 := NewMountedPath("/api/") + mount2 := NewMountedPath("/api") + + if mount1.Path() != "/api" { + t.Errorf("Expected path /api, got %s", mount1.Path()) + } + + if mount2.Path() != "/api" { + t.Errorf("Expected path /api, got %s", mount2.Path()) + } + + if !mount1.Matches("/api/users") { + t.Error("mount1 should match /api/users") + } + + if !mount2.Matches("/api/users") { + t.Error("mount2 should match /api/users") + } +} + +func TestMountedPath_GetModifiedPathStripsPrefix(t *testing.T) { + mount := NewMountedPath("/api") + + tests := []struct { + input string + expected string + }{ + {"/api", "/"}, + {"/api/", "/"}, + {"/api/users", "/users"}, + {"/api/users/123", "/users/123"}, + } + + for _, tt := range tests { + if got := mount.GetModifiedPath(tt.input); got != tt.expected { + t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected) + } + } +} + +func TestMountedPath_GetModifiedPathNoStrip(t *testing.T) { + mount := NewMountedPath("/api", WithStripPath(false)) + + tests := []struct { + input string + expected string + }{ + {"/api/users", "/api/users"}, + {"/api", "/api"}, + } + + for _, tt := range tests { + if got := mount.GetModifiedPath(tt.input); got != tt.expected { + t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected) + } + } +} + +func TestMountedPath_RootMountModifiedPath(t *testing.T) { + mount := NewMountedPath("") + + tests := []struct { + input string + expected string + }{ + {"/api/users", "/api/users"}, + {"/", "/"}, + } + + for _, tt := range tests { + if got := mount.GetModifiedPath(tt.input); got != tt.expected { + t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected) + } + } +} + +func TestMountedPath_NameProperty(t *testing.T) { + mount1 := NewMountedPath("/api") + mount2 := NewMountedPath("/api", WithName("API Mount")) + + if mount1.Name() != "/api" { + t.Errorf("Expected name /api, got %s", mount1.Name()) + } + + if mount2.Name() != "API Mount" { + t.Errorf("Expected name 'API Mount', got %s", mount2.Name()) + } +} + +// ============== MountManager Tests ============== + +func TestMountManager_EmptyManager(t *testing.T) { + manager := NewMountManager() + + if got := manager.GetMount("/api"); got != nil { + t.Error("Empty manager should return nil") + } + + if got := manager.MountCount(); got != 0 { + t.Errorf("Expected mount count 0, got %d", got) + } +} + +func TestMountManager_AddMount(t *testing.T) { + manager := NewMountManager() + mount := NewMountedPath("/api") + + manager.AddMount(mount) + + if manager.MountCount() != 1 { + t.Errorf("Expected mount count 1, got %d", manager.MountCount()) + } + + if got := manager.GetMount("/api/users"); got != mount { + t.Error("GetMount should return the added mount") + } +} + +func TestMountManager_LongestPrefixMatching(t *testing.T) { + manager := NewMountManager() + + apiMount := NewMountedPath("/api", WithName("api")) + apiV1Mount := NewMountedPath("/api/v1", WithName("api_v1")) + apiV2Mount := NewMountedPath("/api/v2", WithName("api_v2")) + + manager.AddMount(apiMount) + manager.AddMount(apiV2Mount) + manager.AddMount(apiV1Mount) + + tests := []struct { + path string + expectedName string + }{ + {"/api/v1/users", "api_v1"}, + {"/api/v2/items", "api_v2"}, + {"/api/v3/other", "api"}, + {"/api", "api"}, + } + + for _, tt := range tests { + got := manager.GetMount(tt.path) + if got == nil { + t.Errorf("GetMount(%s) returned nil, want mount with name %s", tt.path, tt.expectedName) + continue + } + if got.Name() != tt.expectedName { + t.Errorf("GetMount(%s).Name() = %s, want %s", tt.path, got.Name(), tt.expectedName) + } + } +} + +func TestMountManager_RemoveMount(t *testing.T) { + manager := NewMountManager() + + manager.AddMount(NewMountedPath("/api")) + manager.AddMount(NewMountedPath("/admin")) + + if manager.MountCount() != 2 { + t.Errorf("Expected mount count 2, got %d", manager.MountCount()) + } + + result := manager.RemoveMount("/api") + + if !result { + t.Error("RemoveMount should return true") + } + + if manager.MountCount() != 1 { + t.Errorf("Expected mount count 1, got %d", manager.MountCount()) + } + + if manager.GetMount("/api/users") != nil { + t.Error("GetMount(/api/users) should return nil after removal") + } + + if manager.GetMount("/admin/users") == nil { + t.Error("GetMount(/admin/users) should still work") + } +} + +func TestMountManager_RemoveNonexistentMount(t *testing.T) { + manager := NewMountManager() + + result := manager.RemoveMount("/api") + + if result { + t.Error("RemoveMount should return false for nonexistent mount") + } +} + +func TestMountManager_ListMounts(t *testing.T) { + manager := NewMountManager() + + manager.AddMount(NewMountedPath("/api", WithName("API"))) + manager.AddMount(NewMountedPath("/admin", WithName("Admin"))) + + mounts := manager.ListMounts() + + if len(mounts) != 2 { + t.Errorf("Expected 2 mounts, got %d", len(mounts)) + } + + for _, m := range mounts { + if _, ok := m["path"]; !ok { + t.Error("Mount should have 'path' key") + } + if _, ok := m["name"]; !ok { + t.Error("Mount should have 'name' key") + } + if _, ok := m["strip_path"]; !ok { + t.Error("Mount should have 'strip_path' key") + } + } +} + +func TestMountManager_MountsReturnsCopy(t *testing.T) { + manager := NewMountManager() + manager.AddMount(NewMountedPath("/api")) + + mounts1 := manager.Mounts() + mounts2 := manager.Mounts() + + if &mounts1[0] == &mounts2[0] { + t.Error("Mounts() should return different slices") + } +} + +// ============== Utility Functions Tests ============== + +func TestPathMatchesPrefix_Basic(t *testing.T) { + tests := []struct { + path string + prefix string + expected bool + }{ + {"/api/users", "/api", true}, + {"/api", "/api", true}, + {"/api-v2", "/api", false}, + {"/ap", "/api", false}, + } + + for _, tt := range tests { + if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected { + t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected) + } + } +} + +func TestPathMatchesPrefix_Root(t *testing.T) { + tests := []struct { + path string + prefix string + expected bool + }{ + {"/anything", "", true}, + {"/anything", "/", true}, + } + + for _, tt := range tests { + if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected { + t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected) + } + } +} + +func TestStripPathPrefix_Basic(t *testing.T) { + tests := []struct { + path string + prefix string + expected string + }{ + {"/api/users", "/api", "/users"}, + {"/api", "/api", "/"}, + {"/api/", "/api", "/"}, + } + + for _, tt := range tests { + if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected { + t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected) + } + } +} + +func TestStripPathPrefix_Root(t *testing.T) { + tests := []struct { + path string + prefix string + expected string + }{ + {"/api/users", "", "/api/users"}, + {"/api/users", "/", "/api/users"}, + } + + for _, tt := range tests { + if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected { + t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected) + } + } +} + +func TestMatchAndModifyPath_Combined(t *testing.T) { + tests := []struct { + path string + prefix string + stripPath bool + wantMatches bool + wantModified string + }{ + {"/api/users", "/api", true, true, "/users"}, + {"/api", "/api", true, true, "/"}, + {"/other", "/api", true, false, ""}, + {"/api/users", "/api", false, true, "/api/users"}, + } + + for _, tt := range tests { + matches, modified := MatchAndModifyPath(tt.path, tt.prefix, tt.stripPath) + if matches != tt.wantMatches { + t.Errorf("MatchAndModifyPath(%s, %s, %v) matches = %v, want %v", + tt.path, tt.prefix, tt.stripPath, matches, tt.wantMatches) + } + if modified != tt.wantModified { + t.Errorf("MatchAndModifyPath(%s, %s, %v) modified = %s, want %s", + tt.path, tt.prefix, tt.stripPath, modified, tt.wantModified) + } + } +} + +// ============== Performance Tests ============== + +func TestPerformance_ManyMatches(t *testing.T) { + mount := NewMountedPath("/api/v1/users") + + for i := 0; i < 10000; i++ { + if !mount.Matches("/api/v1/users/123/posts") { + t.Fatal("Should match") + } + if mount.Matches("/other/path") { + t.Fatal("Should not match") + } + } +} + +func TestPerformance_ManyMounts(t *testing.T) { + manager := NewMountManager() + + for i := 0; i < 100; i++ { + manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10)) + string(rune('0'+i/10)))) + } + + if manager.MountCount() != 100 { + t.Errorf("Expected 100 mounts, got %d", manager.MountCount()) + } +} + +// ============== Benchmarks ============== + +func BenchmarkMountedPath_Matches(b *testing.B) { + mount := NewMountedPath("/api/v1") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + mount.Matches("/api/v1/users/123") + } +} + +func BenchmarkMountManager_GetMount(b *testing.B) { + manager := NewMountManager() + + for i := 0; i < 20; i++ { + manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10)))) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.GetMount("/api/v5/users/123") + } +} + +func BenchmarkPathMatchesPrefix(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + PathMatchesPrefix("/api/v1/users/123", "/api/v1") + } +} diff --git a/go/internal/proxy/proxy.go b/go/internal/proxy/proxy.go new file mode 100644 index 0000000..5fb02b1 --- /dev/null +++ b/go/internal/proxy/proxy.go @@ -0,0 +1,320 @@ +// Package proxy provides reverse proxy functionality for Konduktor +package proxy + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/konduktor/konduktor/internal/logging" +) + +type Config struct { + // Target is the backend server URL + Target string + + // Timeout is the request timeout (default: 30s) + Timeout time.Duration + + // Headers are additional headers to add to requests + Headers map[string]string + + // StripPrefix removes this prefix from the request path + StripPrefix string + + // PreserveHost keeps the original Host header + PreserveHost bool +} + +type ReverseProxy struct { + config *Config + targetURL *url.URL + httpClient *http.Client + logger *logging.Logger +} + +func New(cfg *Config, logger *logging.Logger) (*ReverseProxy, error) { + if cfg.Target == "" { + return nil, fmt.Errorf("proxy target is required") + } + + targetURL, err := url.Parse(cfg.Target) + if err != nil { + return nil, fmt.Errorf("invalid proxy target URL: %w", err) + } + + timeout := cfg.Timeout + if timeout == 0 { + timeout = 30 * time.Second + } + + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ResponseHeaderTimeout: timeout, + } + + return &ReverseProxy{ + config: cfg, + targetURL: targetURL, + httpClient: &http.Client{ + Transport: transport, + Timeout: timeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse // Don't follow redirects + }, + }, + logger: logger, + }, nil +} + +func (rp *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + rp.ProxyRequest(w, r, nil) +} + +func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, params map[string]string) { + ctx := r.Context() + + // Build target URL + targetURL := rp.buildTargetURL(r) + + // Create proxy request + proxyReq, err := rp.createProxyRequest(ctx, r, targetURL) + if err != nil { + rp.handleError(w, http.StatusInternalServerError, "Failed to create proxy request", err) + return + } + + // Add custom headers with parameter substitution + rp.addCustomHeaders(proxyReq, r, params) + + // Execute request + resp, err := rp.httpClient.Do(proxyReq) + if err != nil { + rp.handleProxyError(w, err) + return + } + defer resp.Body.Close() + + // Copy response + rp.copyResponse(w, resp) +} + +func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL { + targetURL := *rp.targetURL + + // Strip prefix if configured + path := r.URL.Path + if rp.config.StripPrefix != "" { + path = strings.TrimPrefix(path, rp.config.StripPrefix) + if path == "" || path[0] != '/' { + path = "/" + path + } + } + + // If target has a path, append request path to it + if rp.targetURL.Path != "" && rp.targetURL.Path != "/" { + targetURL.Path = singleJoiningSlash(rp.targetURL.Path, path) + } else { + targetURL.Path = path + } + + // Preserve query string + targetURL.RawQuery = r.URL.RawQuery + + return &targetURL +} + +func (rp *ReverseProxy) createProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL) (*http.Request, error) { + proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL.String(), r.Body) + if err != nil { + return nil, err + } + + // Copy ContentLength + proxyReq.ContentLength = r.ContentLength + + // Copy headers + for key, values := range r.Header { + for _, value := range values { + proxyReq.Header.Add(key, value) + } + } + + // Set/update Host header + if rp.config.PreserveHost { + proxyReq.Host = r.Host + } else { + proxyReq.Host = targetURL.Host + } + + // Remove hop-by-hop headers + removeHopByHopHeaders(proxyReq.Header) + + return proxyReq, nil +} + +func (rp *ReverseProxy) addCustomHeaders(proxyReq *http.Request, originalReq *http.Request, params map[string]string) { + // Add X-Forwarded headers + clientIP := getClientIP(originalReq) + if prior := originalReq.Header.Get("X-Forwarded-For"); prior != "" { + clientIP = prior + ", " + clientIP + } + proxyReq.Header.Set("X-Forwarded-For", clientIP) + proxyReq.Header.Set("X-Forwarded-Proto", getScheme(originalReq)) + proxyReq.Header.Set("X-Forwarded-Host", originalReq.Host) + + // Add custom headers from config + for key, value := range rp.config.Headers { + // Substitute parameters like {version} + substituted := value + for paramKey, paramValue := range params { + substituted = strings.ReplaceAll(substituted, "{"+paramKey+"}", paramValue) + } + // Substitute $remote_addr + substituted = strings.ReplaceAll(substituted, "$remote_addr", clientIP) + proxyReq.Header.Set(key, substituted) + } +} + +func (rp *ReverseProxy) copyResponse(w http.ResponseWriter, resp *http.Response) { + // Copy headers + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + // Remove hop-by-hop headers from response + removeHopByHopHeaders(w.Header()) + + // Write status code + w.WriteHeader(resp.StatusCode) + + // Copy body + io.Copy(w, resp.Body) +} + +func (rp *ReverseProxy) handleError(w http.ResponseWriter, status int, message string, err error) { + if rp.logger != nil { + rp.logger.Error(message, "error", err) + } + http.Error(w, message, status) +} + +func (rp *ReverseProxy) handleProxyError(w http.ResponseWriter, err error) { + if rp.logger != nil { + rp.logger.Error("Proxy request failed", "error", err) + } + + // Check for timeout + if err, ok := err.(net.Error); ok && err.Timeout() { + http.Error(w, "504 Gateway Timeout", http.StatusGatewayTimeout) + return + } + + // Check for connection errors + if isConnectionError(err) { + http.Error(w, "502 Bad Gateway", http.StatusBadGateway) + return + } + + // Context cancelled (client disconnected) + if err == context.Canceled { + return + } + + http.Error(w, "502 Bad Gateway", http.StatusBadGateway) +} + +// Helper functions + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +func removeHopByHopHeaders(h http.Header) { + hopByHopHeaders := []string{ + "Connection", + "Proxy-Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", + } + + for _, header := range hopByHopHeaders { + h.Del(header) + } +} + +func getClientIP(r *http.Request) string { + // Check X-Real-IP first + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + + // Get from RemoteAddr + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} + +func getScheme(r *http.Request) string { + if r.TLS != nil { + return "https" + } + if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" { + return scheme + } + return "http" +} + +func isConnectionError(err error) bool { + if err == nil { + return false + } + + errStr := err.Error() + connectionErrors := []string{ + "connection refused", + "no such host", + "network is unreachable", + "connection reset", + "broken pipe", + } + + for _, connErr := range connectionErrors { + if strings.Contains(strings.ToLower(errStr), connErr) { + return true + } + } + + return false +} diff --git a/go/internal/proxy/proxy_test.go b/go/internal/proxy/proxy_test.go new file mode 100644 index 0000000..191932f --- /dev/null +++ b/go/internal/proxy/proxy_test.go @@ -0,0 +1,747 @@ +package proxy + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// ============== Test Backend Server ============== + +type testBackend struct { + server *httptest.Server + requestLog []requestLogEntry + mu sync.Mutex + requestCount int64 +} + +type requestLogEntry struct { + Method string + Path string + Query string + Headers http.Header + Body string +} + +func newTestBackend(handler http.HandlerFunc) *testBackend { + tb := &testBackend{ + requestLog: make([]requestLogEntry, 0), + } + + if handler == nil { + handler = tb.defaultHandler + } + + tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tb.logRequest(r) + handler(w, r) + })) + + return tb +} + +func (tb *testBackend) logRequest(r *http.Request) { + tb.mu.Lock() + defer tb.mu.Unlock() + + body, _ := io.ReadAll(r.Body) + // Restore the body for the handler + r.Body = io.NopCloser(bytes.NewReader(body)) + + tb.requestLog = append(tb.requestLog, requestLogEntry{ + Method: r.Method, + Path: r.URL.Path, + Query: r.URL.RawQuery, + Headers: r.Header.Clone(), + Body: string(body), + }) + atomic.AddInt64(&tb.requestCount, 1) +} + +func (tb *testBackend) defaultHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "message": "Backend response", + "path": r.URL.Path, + "method": r.Method, + }) +} + +func (tb *testBackend) close() { + tb.server.Close() +} + +func (tb *testBackend) URL() string { + return tb.server.URL +} + +func (tb *testBackend) getRequestCount() int64 { + return atomic.LoadInt64(&tb.requestCount) +} + +func (tb *testBackend) getLastRequest() *requestLogEntry { + tb.mu.Lock() + defer tb.mu.Unlock() + if len(tb.requestLog) == 0 { + return nil + } + return &tb.requestLog[len(tb.requestLog)-1] +} + +// ============== Proxy Creation Tests ============== + +func TestNew_ValidConfig(t *testing.T) { + cfg := &Config{ + Target: "http://localhost:8080", + Timeout: 10 * time.Second, + } + + proxy, err := New(cfg, nil) + if err != nil { + t.Fatalf("Failed to create proxy: %v", err) + } + + if proxy == nil { + t.Fatal("Expected proxy instance") + } +} + +func TestNew_EmptyTarget(t *testing.T) { + cfg := &Config{ + Target: "", + } + + _, err := New(cfg, nil) + if err == nil { + t.Error("Expected error for empty target") + } +} + +func TestNew_InvalidTargetURL(t *testing.T) { + cfg := &Config{ + Target: "://invalid-url", + } + + _, err := New(cfg, nil) + if err == nil { + t.Error("Expected error for invalid URL") + } +} + +func TestNew_DefaultTimeout(t *testing.T) { + cfg := &Config{ + Target: "http://localhost:8080", + } + + proxy, err := New(cfg, nil) + if err != nil { + t.Fatalf("Failed to create proxy: %v", err) + } + + if proxy.httpClient.Timeout != 30*time.Second { + t.Errorf("Expected default timeout 30s, got %v", proxy.httpClient.Timeout) + } +} + +// ============== Basic Proxy Tests ============== + +func TestProxy_BasicGET(t *testing.T) { + backend := newTestBackend(nil) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + var response map[string]interface{} + json.NewDecoder(rr.Body).Decode(&response) + + if response["path"] != "/test" { + t.Errorf("Expected path /test, got %v", response["path"]) + } +} + +func TestProxy_BasicPOST(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "received": string(body), + "method": r.Method, + }) + }) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + req := httptest.NewRequest("POST", "/api/data", strings.NewReader(`{"key":"value"}`)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + var response map[string]interface{} + json.NewDecoder(rr.Body).Decode(&response) + + if response["method"] != "POST" { + t.Errorf("Expected method POST, got %v", response["method"]) + } +} + +func TestProxy_PUT(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"method": r.Method}) + }) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + req := httptest.NewRequest("PUT", "/resource/123", strings.NewReader(`{"name":"updated"}`)) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + var response map[string]string + json.NewDecoder(rr.Body).Decode(&response) + + if response["method"] != "PUT" { + t.Errorf("Expected method PUT, got %v", response["method"]) + } +} + +func TestProxy_DELETE(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"method": r.Method}) + }) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + req := httptest.NewRequest("DELETE", "/resource/123", nil) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + var response map[string]string + json.NewDecoder(rr.Body).Decode(&response) + + if response["method"] != "DELETE" { + t.Errorf("Expected method DELETE, got %v", response["method"]) + } +} + +// ============== Header Tests ============== + +func TestProxy_HeadersForwarding(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]interface{}{ + "custom_header": r.Header.Get("X-Custom-Header"), + "forwarded_for": r.Header.Get("X-Forwarded-For"), + "forwarded_host": r.Header.Get("X-Forwarded-Host"), + }) + }) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + req := httptest.NewRequest("GET", "/headers", nil) + req.Header.Set("X-Custom-Header", "test-value") + req.RemoteAddr = "192.168.1.100:12345" + req.Host = "example.com" + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + var response map[string]interface{} + json.NewDecoder(rr.Body).Decode(&response) + + if response["custom_header"] != "test-value" { + t.Errorf("Expected custom header, got %v", response["custom_header"]) + } + + if response["forwarded_for"] != "192.168.1.100" { + t.Errorf("Expected X-Forwarded-For, got %v", response["forwarded_for"]) + } + + if response["forwarded_host"] != "example.com" { + t.Errorf("Expected X-Forwarded-Host, got %v", response["forwarded_host"]) + } +} + +func TestProxy_CustomHeaders(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "api_version": r.Header.Get("X-API-Version"), + }) + }) + defer backend.close() + + proxy, _ := New(&Config{ + Target: backend.URL(), + Headers: map[string]string{ + "X-API-Version": "{version}", + }, + }, nil) + + req := httptest.NewRequest("GET", "/api", nil) + rr := httptest.NewRecorder() + + // Simulate parameter substitution + proxy.ProxyRequest(rr, req, map[string]string{"version": "2"}) + + var response map[string]string + json.NewDecoder(rr.Body).Decode(&response) + + if response["api_version"] != "2" { + t.Errorf("Expected API version 2, got %v", response["api_version"]) + } +} + +func TestProxy_RemoteAddrSubstitution(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "client_ip": r.Header.Get("X-Client-IP"), + }) + }) + defer backend.close() + + proxy, _ := New(&Config{ + Target: backend.URL(), + Headers: map[string]string{ + "X-Client-IP": "$remote_addr", + }, + }, nil) + + req := httptest.NewRequest("GET", "/api", nil) + req.RemoteAddr = "10.0.0.1:54321" + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + var response map[string]string + json.NewDecoder(rr.Body).Decode(&response) + + if response["client_ip"] != "10.0.0.1" { + t.Errorf("Expected client IP 10.0.0.1, got %v", response["client_ip"]) + } +} + +// ============== Query String Tests ============== + +func TestProxy_QueryStringPreservation(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "query": r.URL.RawQuery, + "param": r.URL.Query().Get("key"), + }) + }) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + req := httptest.NewRequest("GET", "/search?key=value&page=2", nil) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + var response map[string]string + json.NewDecoder(rr.Body).Decode(&response) + + if response["param"] != "value" { + t.Errorf("Expected query param 'value', got %v", response["param"]) + } +} + +// ============== Status Code Tests ============== + +func TestProxy_StatusCodePreservation(t *testing.T) { + statusCodes := []int{200, 201, 400, 404, 500} + + for _, code := range statusCodes { + code := code // capture range variable + t.Run(fmt.Sprintf("Status_%d", code), func(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(code) + w.Write([]byte("OK")) + }) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + req := httptest.NewRequest("GET", "/status", nil) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + if rr.Code != code { + t.Errorf("Expected status %d, got %d", code, rr.Code) + } + }) + } +} + +// ============== Error Handling Tests ============== + +func TestProxy_BackendUnavailable(t *testing.T) { + // Use a port that's definitely not listening + proxy, _ := New(&Config{ + Target: "http://127.0.0.1:59999", + Timeout: 1 * time.Second, + }, nil) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadGateway { + t.Errorf("Expected status 502 Bad Gateway, got %d", rr.Code) + } +} + +func TestProxy_Timeout(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) + w.Write([]byte("OK")) + }) + defer backend.close() + + proxy, _ := New(&Config{ + Target: backend.URL(), + Timeout: 100 * time.Millisecond, + }, nil) + + req := httptest.NewRequest("GET", "/slow", nil) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + if rr.Code != http.StatusGatewayTimeout { + t.Errorf("Expected status 504 Gateway Timeout, got %d", rr.Code) + } +} + +// ============== Path Handling Tests ============== + +func TestProxy_StripPrefix(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "path": r.URL.Path, + }) + }) + defer backend.close() + + proxy, _ := New(&Config{ + Target: backend.URL(), + StripPrefix: "/api/v1", + }, nil) + + req := httptest.NewRequest("GET", "/api/v1/users", nil) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + var response map[string]string + json.NewDecoder(rr.Body).Decode(&response) + + if response["path"] != "/users" { + t.Errorf("Expected stripped path /users, got %v", response["path"]) + } +} + +func TestProxy_TargetWithPath(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "path": r.URL.Path, + }) + }) + defer backend.close() + + proxy, _ := New(&Config{ + Target: backend.URL() + "/backend", + }, nil) + + req := httptest.NewRequest("GET", "/resource", nil) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + var response map[string]string + json.NewDecoder(rr.Body).Decode(&response) + + if response["path"] != "/backend/resource" { + t.Errorf("Expected path /backend/resource, got %v", response["path"]) + } +} + +// ============== Large Body Tests ============== + +func TestProxy_LargeRequestBody(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.NewEncoder(w).Encode(map[string]int{ + "received_bytes": len(body), + }) + }) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + // 100KB body + largeBody := strings.Repeat("x", 100000) + req := httptest.NewRequest("POST", "/upload", strings.NewReader(largeBody)) + req.ContentLength = int64(len(largeBody)) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + var response map[string]int + json.NewDecoder(rr.Body).Decode(&response) + + if response["received_bytes"] != 100000 { + t.Errorf("Expected 100000 bytes, got %d", response["received_bytes"]) + } +} + +func TestProxy_LargeResponseBody(t *testing.T) { + largeResponse := strings.Repeat("y", 100000) + + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(largeResponse)) + }) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + req := httptest.NewRequest("GET", "/large", nil) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + if rr.Body.Len() != 100000 { + t.Errorf("Expected 100000 bytes in response, got %d", rr.Body.Len()) + } +} + +// ============== Concurrent Requests Tests ============== + +func TestProxy_ConcurrentRequests(t *testing.T) { + backend := newTestBackend(nil) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + const numRequests = 50 + var wg sync.WaitGroup + errors := make(chan error, numRequests) + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + + req := httptest.NewRequest("GET", "/concurrent", nil) + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + errors <- &net.OpError{Op: "test", Err: context.DeadlineExceeded} + } + }(i) + } + + wg.Wait() + close(errors) + + errorCount := 0 + for range errors { + errorCount++ + } + + if errorCount > 0 { + t.Errorf("Got %d errors in concurrent requests", errorCount) + } + + if backend.getRequestCount() != numRequests { + t.Errorf("Expected %d requests at backend, got %d", numRequests, backend.getRequestCount()) + } +} + +// ============== Echo Tests ============== + +func TestProxy_Echo(t *testing.T) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.Header().Set("Content-Type", r.Header.Get("Content-Type")) + w.Write(body) + }) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + testData := "Hello, Proxy!" + req := httptest.NewRequest("POST", "/echo", strings.NewReader(testData)) + req.Header.Set("Content-Type", "text/plain") + rr := httptest.NewRecorder() + + proxy.ServeHTTP(rr, req) + + if rr.Body.String() != testData { + t.Errorf("Expected echo of '%s', got '%s'", testData, rr.Body.String()) + } + + if rr.Header().Get("Content-Type") != "text/plain" { + t.Errorf("Expected Content-Type text/plain, got %s", rr.Header().Get("Content-Type")) + } +} + +// ============== Helper Function Tests ============== + +func TestSingleJoiningSlash(t *testing.T) { + tests := []struct { + a, b, expected string + }{ + {"/api", "/users", "/api/users"}, + {"/api/", "/users", "/api/users"}, + {"/api", "users", "/api/users"}, + {"/api/", "users", "/api/users"}, + {"", "/users", "/users"}, + {"/api", "", "/api/"}, + } + + for _, tt := range tests { + result := singleJoiningSlash(tt.a, tt.b) + if result != tt.expected { + t.Errorf("singleJoiningSlash(%q, %q) = %q, want %q", tt.a, tt.b, result, tt.expected) + } + } +} + +func TestGetClientIP(t *testing.T) { + tests := []struct { + remoteAddr string + xRealIP string + expected string + }{ + {"192.168.1.1:1234", "", "192.168.1.1"}, + {"192.168.1.1:1234", "10.0.0.1", "10.0.0.1"}, + {"invalid", "", "invalid"}, + } + + for _, tt := range tests { + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = tt.remoteAddr + if tt.xRealIP != "" { + req.Header.Set("X-Real-IP", tt.xRealIP) + } + + result := getClientIP(req) + if result != tt.expected { + t.Errorf("getClientIP() = %q, want %q", result, tt.expected) + } + } +} + +func TestGetScheme(t *testing.T) { + tests := []struct { + name string + tls bool + header string + expected string + }{ + {"HTTP", false, "", "http"}, + {"HTTPS from TLS", true, "", "https"}, + {"HTTPS from header", false, "https", "https"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if tt.header != "" { + req.Header.Set("X-Forwarded-Proto", tt.header) + } + // Note: httptest doesn't set TLS, so we can only test non-TLS cases fully + + result := getScheme(req) + if !tt.tls && result != tt.expected { + t.Errorf("getScheme() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestIsConnectionError(t *testing.T) { + tests := []struct { + err error + expected bool + }{ + {nil, false}, + {&net.OpError{Op: "dial", Err: &net.DNSError{Err: "no such host"}}, true}, + {context.DeadlineExceeded, false}, + } + + for _, tt := range tests { + result := isConnectionError(tt.err) + if result != tt.expected { + t.Errorf("isConnectionError(%v) = %v, want %v", tt.err, result, tt.expected) + } + } +} + +// ============== Benchmarks ============== + +func BenchmarkProxy_SimpleGET(b *testing.B) { + backend := newTestBackend(nil) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := httptest.NewRequest("GET", "/bench", nil) + rr := httptest.NewRecorder() + proxy.ServeHTTP(rr, req) + } +} + +func BenchmarkProxy_POSTWithBody(b *testing.B) { + backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) { + io.Copy(io.Discard, r.Body) + w.WriteHeader(http.StatusOK) + }) + defer backend.close() + + proxy, _ := New(&Config{Target: backend.URL()}, nil) + body := strings.Repeat("x", 1024) // 1KB body + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := httptest.NewRequest("POST", "/bench", strings.NewReader(body)) + rr := httptest.NewRecorder() + proxy.ServeHTTP(rr, req) + } +} diff --git a/go/internal/routing/router.go b/go/internal/routing/router.go new file mode 100644 index 0000000..6a22db4 --- /dev/null +++ b/go/internal/routing/router.go @@ -0,0 +1,395 @@ +// Package routing provides HTTP routing with regex support +package routing + +import ( + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + + "github.com/konduktor/konduktor/internal/config" + "github.com/konduktor/konduktor/internal/logging" +) + +// RouteMatch represents a matched route with captured parameters +type RouteMatch struct { + Config map[string]interface{} + Params map[string]string +} + +// RegexRoute represents a compiled regex route +type RegexRoute struct { + Pattern *regexp.Regexp + Config map[string]interface{} + CaseSensitive bool + OriginalExpr string +} + +// Router handles HTTP routing with exact, regex, and default routes +type Router struct { + config *config.Config + logger *logging.Logger + mux *http.ServeMux + staticDir string + exactRoutes map[string]map[string]interface{} + regexRoutes []*RegexRoute + defaultRoute map[string]interface{} + mu sync.RWMutex +} + +// New creates a new router from config +func New(cfg *config.Config, logger *logging.Logger) *Router { + staticDir := "./static" + if cfg != nil && cfg.HTTP.StaticDir != "" { + staticDir = cfg.HTTP.StaticDir + } + + r := &Router{ + config: cfg, + logger: logger, + mux: http.NewServeMux(), + staticDir: staticDir, + exactRoutes: make(map[string]map[string]interface{}), + regexRoutes: make([]*RegexRoute, 0), + } + + r.setupRoutes() + return r +} + +// NewRouter creates a router without config (for testing) +func NewRouter(opts ...RouterOption) *Router { + r := &Router{ + mux: http.NewServeMux(), + staticDir: "./static", + exactRoutes: make(map[string]map[string]interface{}), + regexRoutes: make([]*RegexRoute, 0), + } + + for _, opt := range opts { + opt(r) + } + + return r +} + +// RouterOption is a functional option for Router +type RouterOption func(*Router) + +// WithStaticDir sets the static directory +func WithStaticDir(dir string) RouterOption { + return func(r *Router) { + r.staticDir = dir + } +} + +// StaticDir returns the static directory path +func (r *Router) StaticDir() string { + return r.staticDir +} + +// Routes returns the regex routes (for testing) +func (r *Router) Routes() []*RegexRoute { + r.mu.RLock() + defer r.mu.RUnlock() + return r.regexRoutes +} + +// ExactRoutes returns the exact routes (for testing) +func (r *Router) ExactRoutes() map[string]map[string]interface{} { + r.mu.RLock() + defer r.mu.RUnlock() + return r.exactRoutes +} + +// DefaultRoute returns the default route (for testing) +func (r *Router) DefaultRoute() map[string]interface{} { + r.mu.RLock() + defer r.mu.RUnlock() + return r.defaultRoute +} + +// AddRoute adds a route with the given pattern and config +// Pattern formats: +// - "=/path" - exact match +// - "~regex" - case-sensitive regex +// - "~*regex" - case-insensitive regex +// - "__default__" - default/fallback route +func (r *Router) AddRoute(pattern string, routeConfig map[string]interface{}) { + r.mu.Lock() + defer r.mu.Unlock() + + switch { + case pattern == "__default__": + r.defaultRoute = routeConfig + + case strings.HasPrefix(pattern, "="): + // Exact match route + path := strings.TrimPrefix(pattern, "=") + r.exactRoutes[path] = routeConfig + + case strings.HasPrefix(pattern, "~*"): + // Case-insensitive regex + expr := strings.TrimPrefix(pattern, "~*") + re, err := regexp.Compile("(?i)" + expr) + if err != nil { + if r.logger != nil { + r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err) + } + return + } + r.regexRoutes = append(r.regexRoutes, &RegexRoute{ + Pattern: re, + Config: routeConfig, + CaseSensitive: false, + OriginalExpr: expr, + }) + + case strings.HasPrefix(pattern, "~"): + // Case-sensitive regex + expr := strings.TrimPrefix(pattern, "~") + re, err := regexp.Compile(expr) + if err != nil { + if r.logger != nil { + r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err) + } + return + } + r.regexRoutes = append(r.regexRoutes, &RegexRoute{ + Pattern: re, + Config: routeConfig, + CaseSensitive: true, + OriginalExpr: expr, + }) + } +} + +// Match finds the best matching route for a path +// Priority: exact match > regex match > default +func (r *Router) Match(path string) *RouteMatch { + r.mu.RLock() + defer r.mu.RUnlock() + + // 1. Check exact routes + if cfg, ok := r.exactRoutes[path]; ok { + return &RouteMatch{ + Config: cfg, + Params: make(map[string]string), + } + } + + // 2. Check regex routes + for _, route := range r.regexRoutes { + match := route.Pattern.FindStringSubmatch(path) + if match != nil { + params := make(map[string]string) + + // Extract named groups + names := route.Pattern.SubexpNames() + for i, name := range names { + if i > 0 && name != "" && i < len(match) { + params[name] = match[i] + } + } + + return &RouteMatch{ + Config: route.Config, + Params: params, + } + } + } + + // 3. Check default route + if r.defaultRoute != nil { + return &RouteMatch{ + Config: r.defaultRoute, + Params: make(map[string]string), + } + } + + return nil +} + +// setupRoutes configures the routes from config +func (r *Router) setupRoutes() { + // Health check endpoint + r.mux.HandleFunc("/health", r.healthHandler) + + // Setup redirect instructions from config + if r.config != nil { + for from, to := range r.config.Server.RedirectInstructions { + fromPath := from + toPath := to + r.mux.HandleFunc(fromPath, func(w http.ResponseWriter, req *http.Request) { + http.Redirect(w, req, toPath, http.StatusMovedPermanently) + }) + } + } + + // Default handler for all other routes + r.mux.HandleFunc("/", r.defaultHandler) +} + +// ServeHTTP implements http.Handler +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + r.mux.ServeHTTP(w, req) +} + +// healthHandler handles health check requests +func (r *Router) healthHandler(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) +} + +// defaultHandler handles requests that don't match other routes +func (r *Router) defaultHandler(w http.ResponseWriter, req *http.Request) { + path := req.URL.Path + + // Try to match against configured routes + match := r.Match(path) + if match != nil { + r.handleRouteMatch(w, req, match) + return + } + + // Try to serve static file + if r.staticDir != "" { + filePath := filepath.Join(r.staticDir, path) + + // Prevent directory traversal + if !strings.HasPrefix(filepath.Clean(filePath), filepath.Clean(r.staticDir)) { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + // Check if file exists + info, err := os.Stat(filePath) + if err == nil { + if info.IsDir() { + // Try index.html + indexPath := filepath.Join(filePath, "index.html") + if _, err := os.Stat(indexPath); err == nil { + http.ServeFile(w, req, indexPath) + return + } + } else { + http.ServeFile(w, req, filePath) + return + } + } + } + + // 404 Not Found + http.NotFound(w, req) +} + +// handleRouteMatch handles a matched route +func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, match *RouteMatch) { + cfg := match.Config + + // Handle "return" directive + if ret, ok := cfg["return"].(string); ok { + parts := strings.SplitN(ret, " ", 2) + statusCode := 200 + body := "OK" + if len(parts) >= 1 { + switch parts[0] { + case "200": + statusCode = 200 + case "201": + statusCode = 201 + case "301": + statusCode = 301 + case "302": + statusCode = 302 + case "400": + statusCode = 400 + case "404": + statusCode = 404 + case "500": + statusCode = 500 + } + } + if len(parts) >= 2 { + body = parts[1] + } + + if ct, ok := cfg["content_type"].(string); ok { + w.Header().Set("Content-Type", ct) + } else { + w.Header().Set("Content-Type", "text/plain") + } + + w.WriteHeader(statusCode) + w.Write([]byte(body)) + return + } + + // Handle static files with root + if root, ok := cfg["root"].(string); ok { + path := req.URL.Path + + if indexFile, ok := cfg["index_file"].(string); ok { + if path == "/" || strings.HasSuffix(path, "/") { + path = "/" + indexFile + } + } + + filePath := filepath.Join(root, path) + + if cacheControl, ok := cfg["cache_control"].(string); ok { + w.Header().Set("Cache-Control", cacheControl) + } + + if headers, ok := cfg["headers"].([]interface{}); ok { + for _, h := range headers { + if header, ok := h.(string); ok { + parts := strings.SplitN(header, ": ", 2) + if len(parts) == 2 { + w.Header().Set(parts[0], parts[1]) + } + } + } + } + + http.ServeFile(w, req, filePath) + return + } + + // Handle SPA fallback + if spaFallback, ok := cfg["spa_fallback"].(bool); ok && spaFallback { + root := r.staticDir + if rt, ok := cfg["root"].(string); ok { + root = rt + } + + indexFile := "index.html" + if idx, ok := cfg["index_file"].(string); ok { + indexFile = idx + } + + filePath := filepath.Join(root, indexFile) + http.ServeFile(w, req, filePath) + return + } + + http.NotFound(w, req) +} + +// CreateRouterFromConfig creates a router from extension config +func CreateRouterFromConfig(cfg map[string]interface{}) *Router { + router := NewRouter() + + if locations, ok := cfg["regex_locations"].(map[string]interface{}); ok { + for pattern, routeCfg := range locations { + if rc, ok := routeCfg.(map[string]interface{}); ok { + router.AddRoute(pattern, rc) + } + } + } + + return router +} diff --git a/go/internal/routing/router_test.go b/go/internal/routing/router_test.go new file mode 100644 index 0000000..d92aa60 --- /dev/null +++ b/go/internal/routing/router_test.go @@ -0,0 +1,375 @@ +package routing + +import ( + "path/filepath" + "testing" +) + +// ============== Router Initialization Tests ============== + +func TestRouter_Initialization(t *testing.T) { + router := NewRouter() + + if router.StaticDir() != "./static" { + t.Errorf("Expected static dir ./static, got %s", router.StaticDir()) + } + + if len(router.Routes()) != 0 { + t.Error("Expected empty routes") + } + + if len(router.ExactRoutes()) != 0 { + t.Error("Expected empty exact routes") + } + + if router.DefaultRoute() != nil { + t.Error("Expected nil default route") + } +} + +func TestRouter_CustomStaticDir(t *testing.T) { + router := NewRouter(WithStaticDir("/custom/path")) + + if router.StaticDir() != "/custom/path" { + t.Errorf("Expected static dir /custom/path, got %s", router.StaticDir()) + } +} + +// ============== Route Adding Tests ============== + +func TestRouter_AddExactRoute(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"return": "200 OK"} + + router.AddRoute("=/health", config) + + exactRoutes := router.ExactRoutes() + if _, ok := exactRoutes["/health"]; !ok { + t.Error("Expected /health in exact routes") + } +} + +func TestRouter_AddDefaultRoute(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"spa_fallback": true, "root": "./static"} + + router.AddRoute("__default__", config) + + if router.DefaultRoute() == nil { + t.Error("Expected default route to be set") + } +} + +func TestRouter_AddRegexRoute(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"root": "./static"} + + router.AddRoute("~^/api/", config) + + if len(router.Routes()) != 1 { + t.Errorf("Expected 1 regex route, got %d", len(router.Routes())) + } +} + +func TestRouter_AddCaseInsensitiveRegexRoute(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"} + + router.AddRoute("~*\\.(css|js)$", config) + + if len(router.Routes()) != 1 { + t.Errorf("Expected 1 regex route, got %d", len(router.Routes())) + } + + if router.Routes()[0].CaseSensitive { + t.Error("Expected case-insensitive route") + } +} + +func TestRouter_InvalidRegexPattern(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"root": "./static"} + + // Invalid regex - unmatched bracket + router.AddRoute("~^/api/[invalid", config) + + // Should not add invalid pattern + if len(router.Routes()) != 0 { + t.Error("Should not add invalid regex pattern") + } +} + +// ============== Route Matching Tests ============== + +func TestRouter_MatchExactRoute(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"return": "200 OK"} + router.AddRoute("=/health", config) + + match := router.Match("/health") + + if match == nil { + t.Fatal("Expected match for /health") + } + + if match.Config["return"] != "200 OK" { + t.Error("Expected return config") + } + + if len(match.Params) != 0 { + t.Error("Expected empty params for exact match") + } +} + +func TestRouter_MatchExactRouteNoMatch(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"return": "200 OK"} + router.AddRoute("=/health", config) + + match := router.Match("/healthcheck") + + if match != nil { + t.Error("Exact route should not match /healthcheck") + } +} + +func TestRouter_MatchRegexRoute(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"proxy_pass": "http://localhost:9001"} + router.AddRoute("~^/api/v\\d+/", config) + + match := router.Match("/api/v1/users") + + if match == nil { + t.Fatal("Expected match for /api/v1/users") + } + + if match.Config["proxy_pass"] != "http://localhost:9001" { + t.Error("Expected proxy_pass config") + } +} + +func TestRouter_MatchRegexRouteWithGroups(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"proxy_pass": "http://localhost:9001"} + router.AddRoute("~^/api/v(?P\\d+)/", config) + + match := router.Match("/api/v2/data") + + if match == nil { + t.Fatal("Expected match for /api/v2/data") + } + + if match.Params["version"] != "2" { + t.Errorf("Expected version=2, got %s", match.Params["version"]) + } +} + +func TestRouter_MatchCaseInsensitiveRegex(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"} + router.AddRoute("~*\\.(CSS|JS)$", config) + + // Should match lowercase + match1 := router.Match("/styles/main.css") + if match1 == nil { + t.Error("Should match lowercase .css") + } + + // Should match uppercase + match2 := router.Match("/scripts/app.JS") + if match2 == nil { + t.Error("Should match uppercase .JS") + } +} + +func TestRouter_MatchCaseSensitiveRegex(t *testing.T) { + router := NewRouter() + config := map[string]interface{}{"root": "./static"} + router.AddRoute("~\\.(css)$", config) + + // Should match lowercase + match1 := router.Match("/styles/main.css") + if match1 == nil { + t.Error("Should match lowercase .css") + } + + // Should NOT match uppercase + match2 := router.Match("/styles/main.CSS") + if match2 != nil { + t.Error("Should not match uppercase .CSS for case-sensitive regex") + } +} + +func TestRouter_MatchDefaultRoute(t *testing.T) { + router := NewRouter() + router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"}) + router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true}) + + match := router.Match("/unknown/path") + + if match == nil { + t.Fatal("Expected default route match") + } + + if match.Config["spa_fallback"] != true { + t.Error("Expected spa_fallback config from default route") + } +} + +// ============== Priority Tests ============== + +func TestRouter_PriorityExactOverRegex(t *testing.T) { + router := NewRouter() + router.AddRoute("=/api/status", map[string]interface{}{"return": "200 Exact"}) + router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"}) + + match := router.Match("/api/status") + + if match == nil { + t.Fatal("Expected match") + } + + if match.Config["return"] != "200 Exact" { + t.Error("Exact match should have priority over regex") + } +} + +func TestRouter_PriorityRegexOverDefault(t *testing.T) { + router := NewRouter() + router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"}) + router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true}) + + match := router.Match("/api/v1/users") + + if match == nil { + t.Fatal("Expected match") + } + + if match.Config["proxy_pass"] != "http://localhost:9001" { + t.Error("Regex match should have priority over default") + } +} + +// ============== CreateRouterFromConfig Tests ============== + +func TestCreateRouterFromConfig(t *testing.T) { + config := map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "=/health": map[string]interface{}{ + "return": "200 OK", + "content_type": "text/plain", + }, + "~^/api/": map[string]interface{}{ + "proxy_pass": "http://localhost:9001", + }, + "__default__": map[string]interface{}{ + "spa_fallback": true, + "root": "./static", + }, + }, + } + + router := CreateRouterFromConfig(config) + + // Check exact route + if _, ok := router.ExactRoutes()["/health"]; !ok { + t.Error("Expected /health exact route") + } + + // Check regex route + if len(router.Routes()) != 1 { + t.Errorf("Expected 1 regex route, got %d", len(router.Routes())) + } + + // Check default route + if router.DefaultRoute() == nil { + t.Error("Expected default route") + } +} + +// ============== Static Dir Path Tests ============== + +func TestRouter_StaticDirPath(t *testing.T) { + router := NewRouter(WithStaticDir("/var/www/html")) + + expected, _ := filepath.Abs("/var/www/html") + actual, _ := filepath.Abs(router.StaticDir()) + + if actual != expected { + t.Errorf("Expected static dir %s, got %s", expected, actual) + } +} + +// ============== Concurrent Access Tests ============== + +func TestRouter_ConcurrentAccess(t *testing.T) { + router := NewRouter() + + // Add routes concurrently + done := make(chan bool, 10) + + for i := 0; i < 10; i++ { + go func(n int) { + router.AddRoute("~^/api/v"+string(rune('0'+n))+"/", map[string]interface{}{ + "proxy_pass": "http://localhost:900" + string(rune('0'+n)), + }) + done <- true + }(i) + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } + + // Match routes concurrently + for i := 0; i < 10; i++ { + go func(n int) { + router.Match("/api/v" + string(rune('0'+n)) + "/users") + done <- true + }(i) + } + + for i := 0; i < 10; i++ { + <-done + } +} + +// ============== Benchmarks ============== + +func BenchmarkRouter_MatchExact(b *testing.B) { + router := NewRouter() + router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"}) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + router.Match("/health") + } +} + +func BenchmarkRouter_MatchRegex(b *testing.B) { + router := NewRouter() + router.AddRoute("~^/api/v(?P\\d+)/", map[string]interface{}{"proxy_pass": "http://localhost:9001"}) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + router.Match("/api/v1/users/123") + } +} + +func BenchmarkRouter_MatchWithManyRoutes(b *testing.B) { + router := NewRouter() + + // Add many routes + for i := 0; i < 50; i++ { + router.AddRoute("~^/api/v"+string(rune('0'+i%10))+"/service"+string(rune('0'+i/10))+"/", + map[string]interface{}{"proxy_pass": "http://localhost:9001"}) + } + router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true}) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + router.Match("/api/v5/service3/users/123") + } +} diff --git a/go/internal/server/server.go b/go/internal/server/server.go new file mode 100644 index 0000000..aaf64e7 --- /dev/null +++ b/go/internal/server/server.go @@ -0,0 +1,130 @@ +// Package server provides the HTTP server implementation +package server + +import ( + "context" + "fmt" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/konduktor/konduktor/internal/config" + "github.com/konduktor/konduktor/internal/logging" + "github.com/konduktor/konduktor/internal/middleware" + "github.com/konduktor/konduktor/internal/routing" +) + +const Version = "0.1.0" + +// Server represents the Konduktor HTTP server +type Server struct { + config *config.Config + httpServer *http.Server + router *routing.Router + logger *logging.Logger +} + +// New creates a new server instance +func New(cfg *config.Config) (*Server, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + + logger, err := logging.NewFromConfig(cfg.Logging) + if err != nil { + return nil, fmt.Errorf("failed to create logger: %w", err) + } + + router := routing.New(cfg, logger) + + srv := &Server{ + config: cfg, + router: router, + logger: logger, + } + + return srv, nil +} + +// Run starts the server and blocks until shutdown +func (s *Server) Run() error { + // Build handler chain with middleware + handler := s.buildHandler() + + // Create HTTP server + addr := fmt.Sprintf("%s:%d", s.config.Server.Host, s.config.Server.Port) + s.httpServer = &http.Server{ + Addr: addr, + Handler: handler, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + + // Start server in goroutine + errChan := make(chan error, 1) + go func() { + s.logger.Info("Server starting", "addr", addr, "version", Version) + + var err error + if s.config.SSL.Enabled { + err = s.httpServer.ListenAndServeTLS(s.config.SSL.CertFile, s.config.SSL.KeyFile) + } else { + err = s.httpServer.ListenAndServe() + } + + if err != nil && err != http.ErrServerClosed { + errChan <- err + } + }() + + // Wait for shutdown signal + return s.waitForShutdown(errChan) +} + +// buildHandler builds the HTTP handler chain +func (s *Server) buildHandler() http.Handler { + var handler http.Handler = s.router + + // Add middleware + handler = middleware.AccessLog(handler, s.logger) + handler = middleware.ServerHeader(handler, Version) + handler = middleware.Recovery(handler, s.logger) + + return handler +} + +// waitForShutdown waits for shutdown signal and gracefully stops the server +func (s *Server) waitForShutdown(errChan <-chan error) error { + // Listen for shutdown signals + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + select { + case err := <-errChan: + return err + case sig := <-sigChan: + s.logger.Info("Shutdown signal received", "signal", sig.String()) + } + + // Graceful shutdown with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + s.logger.Info("Shutting down server...") + + if err := s.httpServer.Shutdown(ctx); err != nil { + s.logger.Error("Error during shutdown", "error", err) + return err + } + + s.logger.Info("Server stopped gracefully") + return nil +} + +// Shutdown gracefully shuts down the server +func (s *Server) Shutdown(ctx context.Context) error { + return s.httpServer.Shutdown(ctx) +}