forked from aegis/pyserveX
go implementation
This commit is contained in:
parent
c04ab283a6
commit
8f5b9a5cd1
5
.gitignore
vendored
5
.gitignore
vendored
@ -27,4 +27,7 @@ build/
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*.swo
|
||||
|
||||
# Go binaries
|
||||
go/bin
|
||||
34
go/Dockerfile
Normal file
34
go/Dockerfile
Normal file
@ -0,0 +1,34 @@
|
||||
# Multi-stage build for Konduktor
|
||||
FROM golang:1.23-alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git make
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY go.mod go.sum* ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN make build
|
||||
|
||||
FROM alpine:3.19
|
||||
|
||||
RUN apk add --no-cache ca-certificates tzdata
|
||||
|
||||
RUN adduser -D -g '' konduktor
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /build/bin/konduktor /usr/local/bin/
|
||||
COPY --from=builder /build/bin/konduktorctl /usr/local/bin/
|
||||
|
||||
RUN mkdir -p /app/static /app/templates /app/logs && \
|
||||
chown -R konduktor:konduktor /app
|
||||
|
||||
USER konduktor
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
ENTRYPOINT ["konduktor"]
|
||||
CMD ["-c", "/app/config.yaml"]
|
||||
108
go/Makefile
Normal file
108
go/Makefile
Normal file
@ -0,0 +1,108 @@
|
||||
# Konduktor Go Build
|
||||
# Makefile for building and testing Konduktor
|
||||
|
||||
.PHONY: all build build-konduktor build-konduktorctl test clean deps fmt lint run
|
||||
|
||||
# Build configuration
|
||||
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
BUILD_TIME ?= $(shell date -u '+%Y-%m-%dT%H:%M:%SZ')
|
||||
LDFLAGS := -X main.Version=$(VERSION) -X main.GitCommit=$(GIT_COMMIT) -X main.BuildTime=$(BUILD_TIME)
|
||||
|
||||
# Output directories
|
||||
BIN_DIR := bin
|
||||
|
||||
all: deps build
|
||||
|
||||
# Download dependencies
|
||||
deps:
|
||||
@echo "==> Downloading dependencies..."
|
||||
go mod download
|
||||
go mod tidy
|
||||
|
||||
# Build all binaries
|
||||
build: build-konduktor build-konduktorctl
|
||||
|
||||
# Build konduktor server
|
||||
build-konduktor:
|
||||
@echo "==> Building konduktor..."
|
||||
@mkdir -p $(BIN_DIR)
|
||||
go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktor ./cmd/konduktor
|
||||
|
||||
# Build konduktorctl CLI
|
||||
build-konduktorctl:
|
||||
@echo "==> Building konduktorctl..."
|
||||
@mkdir -p $(BIN_DIR)
|
||||
go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktorctl ./cmd/konduktorctl
|
||||
|
||||
# Run tests
|
||||
test:
|
||||
@echo "==> Running tests..."
|
||||
go test -v -race -cover ./...
|
||||
|
||||
# Run tests with coverage report
|
||||
test-coverage:
|
||||
@echo "==> Running tests with coverage..."
|
||||
go test -v -race -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report: coverage.html"
|
||||
|
||||
# Format code
|
||||
fmt:
|
||||
@echo "==> Formatting code..."
|
||||
go fmt ./...
|
||||
goimports -w .
|
||||
|
||||
# Lint code
|
||||
lint:
|
||||
@echo "==> Linting code..."
|
||||
golangci-lint run ./...
|
||||
|
||||
# Run the server (development)
|
||||
run: build-konduktor
|
||||
@echo "==> Running konduktor..."
|
||||
./$(BIN_DIR)/konduktor -c ../config.yaml
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
@echo "==> Cleaning..."
|
||||
rm -rf $(BIN_DIR)
|
||||
rm -f coverage.out coverage.html
|
||||
|
||||
# Install binaries to GOPATH/bin
|
||||
install: build
|
||||
@echo "==> Installing binaries..."
|
||||
cp $(BIN_DIR)/konduktor $(GOPATH)/bin/
|
||||
cp $(BIN_DIR)/konduktorctl $(GOPATH)/bin/
|
||||
|
||||
# Generate mocks (for testing)
|
||||
generate:
|
||||
@echo "==> Generating code..."
|
||||
go generate ./...
|
||||
|
||||
# Docker build
|
||||
docker-build:
|
||||
@echo "==> Building Docker image..."
|
||||
docker build -t konduktor:$(VERSION) .
|
||||
|
||||
# Show help
|
||||
help:
|
||||
@echo "Konduktor Build System"
|
||||
@echo ""
|
||||
@echo "Usage: make [target]"
|
||||
@echo ""
|
||||
@echo "Targets:"
|
||||
@echo " all Download deps and build all binaries"
|
||||
@echo " deps Download and tidy dependencies"
|
||||
@echo " build Build all binaries"
|
||||
@echo " build-konduktor Build the server binary"
|
||||
@echo " build-konduktorctl Build the CLI binary"
|
||||
@echo " test Run tests"
|
||||
@echo " test-coverage Run tests with coverage report"
|
||||
@echo " fmt Format code"
|
||||
@echo " lint Lint code"
|
||||
@echo " run Build and run the server"
|
||||
@echo " clean Clean build artifacts"
|
||||
@echo " install Install binaries to GOPATH/bin"
|
||||
@echo " docker-build Build Docker image"
|
||||
@echo " help Show this help"
|
||||
149
go/README.md
Normal file
149
go/README.md
Normal file
@ -0,0 +1,149 @@
|
||||
# Konduktor (Go)
|
||||
|
||||
High-performance HTTP web server with extensible routing and process orchestration. (Previously known as PyServe in Python)
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
go/
|
||||
├── cmd/
|
||||
│ ├── konduktor/ # Main server binary
|
||||
│ └── konduktorctl/ # CLI management tool
|
||||
├── internal/
|
||||
│ ├── config/ # Configuration management
|
||||
│ ├── logging/ # Structured logging
|
||||
│ ├── middleware/ # HTTP middleware
|
||||
│ ├── routing/ # HTTP routing
|
||||
│ ├── extensions/ # Extension system (TODO)
|
||||
│ └── process/ # Process management (TODO)
|
||||
├── pkg/ # Public packages (TODO)
|
||||
├── go.mod
|
||||
├── go.sum
|
||||
└── Makefile
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
```bash
|
||||
cd go
|
||||
|
||||
# Download dependencies
|
||||
make deps
|
||||
|
||||
# Build all binaries
|
||||
make build
|
||||
|
||||
# Or build individually
|
||||
make build-konduktor
|
||||
make build-konduktorctl
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# Run with default config
|
||||
./bin/konduktor
|
||||
|
||||
# Run with custom config
|
||||
./bin/konduktor -c ../config.yaml
|
||||
|
||||
# Run with flags
|
||||
./bin/konduktor --host 127.0.0.1 --port 3000 --debug
|
||||
```
|
||||
|
||||
## CLI Commands (konduktorctl)
|
||||
|
||||
```bash
|
||||
# Start services
|
||||
konduktorctl up
|
||||
|
||||
# Stop services
|
||||
konduktorctl down
|
||||
|
||||
# View status
|
||||
konduktorctl status
|
||||
|
||||
# View logs
|
||||
konduktorctl logs -f
|
||||
|
||||
# Health check
|
||||
konduktorctl health
|
||||
|
||||
# Scale services
|
||||
konduktorctl scale api=3
|
||||
|
||||
# Configuration management
|
||||
konduktorctl config show
|
||||
konduktorctl config validate
|
||||
|
||||
# Initialize new project
|
||||
konduktorctl init
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Uses the same YAML configuration format as the Python version:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
host: 0.0.0.0
|
||||
port: 8080
|
||||
|
||||
http:
|
||||
static_dir: ./static
|
||||
templates_dir: ./templates
|
||||
|
||||
ssl:
|
||||
enabled: false
|
||||
cert_file: ./ssl/cert.pem
|
||||
key_file: ./ssl/key.pem
|
||||
|
||||
logging:
|
||||
level: INFO
|
||||
console_output: true
|
||||
|
||||
extensions:
|
||||
- type: routing
|
||||
config:
|
||||
regex_locations:
|
||||
"=/health":
|
||||
return: "200 OK"
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
make fmt
|
||||
|
||||
# Run linter
|
||||
make lint
|
||||
|
||||
# Run tests
|
||||
make test
|
||||
|
||||
# Run with coverage
|
||||
make test-coverage
|
||||
```
|
||||
|
||||
## Migration from Python
|
||||
|
||||
This is a gradual rewrite of PyServe to Go. The project is now called **Konduktor**.
|
||||
|
||||
### Completed
|
||||
- [x] Basic project structure
|
||||
- [x] Configuration loading
|
||||
- [x] HTTP server with graceful shutdown
|
||||
- [x] Basic routing
|
||||
- [x] Middleware (access log, recovery, server header)
|
||||
- [x] CLI structure (konduktor, konduktorctl)
|
||||
|
||||
### TODO
|
||||
- [ ] Extension system
|
||||
- [x] Regex routing
|
||||
- [x] Reverse proxy
|
||||
- [ ] Process orchestration
|
||||
- [ ] ASGI/WSGI adapter support
|
||||
- [ ] WebSocket support
|
||||
- [ ] Hot reload
|
||||
- [ ] Metrics and monitoring
|
||||
79
go/cmd/konduktor/main.go
Normal file
79
go/cmd/konduktor/main.go
Normal file
@ -0,0 +1,79 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/konduktor/konduktor/internal/config"
|
||||
"github.com/konduktor/konduktor/internal/server"
|
||||
)
|
||||
|
||||
var (
|
||||
Version = "0.1.0"
|
||||
BuildTime = "unknown"
|
||||
GitCommit = "unknown"
|
||||
)
|
||||
|
||||
var (
|
||||
cfgFile string
|
||||
host string
|
||||
port int
|
||||
debug bool
|
||||
)
|
||||
|
||||
func main() {
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "konduktor",
|
||||
Short: "Konduktor - HTTP web server",
|
||||
Long: `Konduktor is a high-performance HTTP web server with extensible routing and process orchestration.`,
|
||||
Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime),
|
||||
RunE: runServer,
|
||||
}
|
||||
|
||||
rootCmd.Flags().StringVarP(&cfgFile, "config", "c", "config.yaml", "Path to configuration file")
|
||||
rootCmd.Flags().StringVar(&host, "host", "", "Host to bind the server to")
|
||||
rootCmd.Flags().IntVar(&port, "port", 0, "Port to bind the server to")
|
||||
rootCmd.Flags().BoolVar(&debug, "debug", false, "Enable debug mode")
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runServer(cmd *cobra.Command, args []string) error {
|
||||
cfg, err := config.Load(cfgFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
fmt.Printf("Configuration file %s not found, using defaults\n", cfgFile)
|
||||
cfg = config.Default()
|
||||
} else {
|
||||
return fmt.Errorf("configuration loading error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if host != "" {
|
||||
cfg.Server.Host = host
|
||||
}
|
||||
if port != 0 {
|
||||
cfg.Server.Port = port
|
||||
}
|
||||
if debug {
|
||||
cfg.Logging.Level = "DEBUG"
|
||||
}
|
||||
|
||||
srv, err := server.New(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server creation error: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Starting Konduktor server on %s:%d\n", cfg.Server.Host, cfg.Server.Port)
|
||||
|
||||
if err := srv.Run(); err != nil {
|
||||
return fmt.Errorf("server startup error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
180
go/cmd/konduktorctl/main.go
Normal file
180
go/cmd/konduktorctl/main.go
Normal file
@ -0,0 +1,180 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
Version = "0.1.0"
|
||||
BuildTime = "unknown"
|
||||
GitCommit = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "konduktorctl",
|
||||
Short: "Konduktorctl - Service management CLI",
|
||||
Long: `Konduktorctl is a CLI tool for managing Konduktor services.`,
|
||||
Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime),
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(
|
||||
newUpCmd(),
|
||||
newDownCmd(),
|
||||
newStatusCmd(),
|
||||
newLogsCmd(),
|
||||
newHealthCmd(),
|
||||
newScaleCmd(),
|
||||
newConfigCmd(),
|
||||
newInitCmd(),
|
||||
newTopCmd(),
|
||||
)
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func newUpCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "up [service...]",
|
||||
Short: "Start services",
|
||||
Long: `Start one or more services. If no service is specified, all services are started.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("Starting services...")
|
||||
// TODO: Implement service start logic
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cmd.Flags().BoolP("detach", "d", false, "Run in background")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newDownCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "down [service...]",
|
||||
Short: "Stop services",
|
||||
Long: `Stop one or more services. If no service is specified, all services are stopped.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("Stopping services...")
|
||||
// TODO: Implement service stop logic
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newStatusCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "status [service...]",
|
||||
Short: "Show service status",
|
||||
Long: `Show the status of one or more services.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("Service status:")
|
||||
// TODO: Implement status display logic
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newLogsCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "logs [service]",
|
||||
Short: "View service logs",
|
||||
Long: `View logs for a specific service.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("Fetching logs...")
|
||||
// TODO: Implement logs viewing logic
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cmd.Flags().BoolP("follow", "f", false, "Follow log output")
|
||||
cmd.Flags().IntP("tail", "n", 100, "Number of lines to show from the end")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newHealthCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "health",
|
||||
Short: "Check service health",
|
||||
Long: `Check the health status of all services.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("Health check:")
|
||||
// TODO: Implement health check logic
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newScaleCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "scale <service>=<count>",
|
||||
Short: "Scale a service",
|
||||
Long: `Scale a service to a specific number of instances.`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Printf("Scaling: %v\n", args)
|
||||
// TODO: Implement scaling logic
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newConfigCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "config",
|
||||
Short: "Manage configuration",
|
||||
Long: `View and validate configuration.`,
|
||||
}
|
||||
|
||||
cmd.AddCommand(&cobra.Command{
|
||||
Use: "show",
|
||||
Short: "Show current configuration",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("Current configuration:")
|
||||
// TODO: Implement config show logic
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
cmd.AddCommand(&cobra.Command{
|
||||
Use: "validate",
|
||||
Short: "Validate configuration file",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("Validating configuration...")
|
||||
// TODO: Implement config validation logic
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newInitCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "init",
|
||||
Short: "Initialize a new project",
|
||||
Long: `Create a new Konduktor project with default configuration.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("Initializing new project...")
|
||||
// TODO: Implement init logic
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newTopCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "top",
|
||||
Short: "Display running processes",
|
||||
Long: `Display real-time view of running processes and resource usage.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("Process monitor:")
|
||||
// TODO: Implement top-like display
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
15
go/go.mod
Normal file
15
go/go.mod
Normal file
@ -0,0 +1,15 @@
|
||||
module github.com/konduktor/konduktor
|
||||
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.2
|
||||
|
||||
require (
|
||||
github.com/spf13/cobra v1.10.2
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
)
|
||||
134
go/internal/config/config.go
Normal file
134
go/internal/config/config.go
Normal file
@ -0,0 +1,134 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
HTTP HTTPConfig `yaml:"http"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
SSL SSLConfig `yaml:"ssl"`
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
Extensions []ExtensionConfig `yaml:"extensions"`
|
||||
}
|
||||
|
||||
type HTTPConfig struct {
|
||||
StaticDir string `yaml:"static_dir"`
|
||||
TemplatesDir string `yaml:"templates_dir"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
Backlog int `yaml:"backlog"`
|
||||
DefaultRoot bool `yaml:"default_root"`
|
||||
ProxyTimeout time.Duration `yaml:"proxy_timeout"`
|
||||
RedirectInstructions map[string]string `yaml:"redirect_instructions"`
|
||||
}
|
||||
|
||||
type SSLConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
CertFile string `yaml:"cert_file"`
|
||||
KeyFile string `yaml:"key_file"`
|
||||
}
|
||||
|
||||
type LoggingConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
ConsoleOutput bool `yaml:"console_output"`
|
||||
Format LogFormatConfig `yaml:"format"`
|
||||
Console *ConsoleLogConfig `yaml:"console"`
|
||||
Files []FileLogConfig `yaml:"files"`
|
||||
}
|
||||
|
||||
type LogFormatConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
UseColors bool `yaml:"use_colors"`
|
||||
ShowModule bool `yaml:"show_module"`
|
||||
TimestampFormat string `yaml:"timestamp_format"`
|
||||
}
|
||||
|
||||
type ConsoleLogConfig struct {
|
||||
Format LogFormatConfig `yaml:"format"`
|
||||
Level string `yaml:"level"`
|
||||
}
|
||||
|
||||
type FileLogConfig struct {
|
||||
Path string `yaml:"path"`
|
||||
Level string `yaml:"level"`
|
||||
Loggers []string `yaml:"loggers"`
|
||||
Format LogFormatConfig `yaml:"format"`
|
||||
MaxBytes int64 `yaml:"max_bytes"`
|
||||
BackupCount int `yaml:"backup_count"`
|
||||
}
|
||||
|
||||
type ExtensionConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
Config map[string]interface{} `yaml:"config"`
|
||||
}
|
||||
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := Default()
|
||||
if err := yaml.Unmarshal(data, cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config: %w", err)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func Default() *Config {
|
||||
return &Config{
|
||||
HTTP: HTTPConfig{
|
||||
StaticDir: "./static",
|
||||
TemplatesDir: "./templates",
|
||||
},
|
||||
Server: ServerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
Backlog: 5,
|
||||
DefaultRoot: false,
|
||||
ProxyTimeout: 30 * time.Second,
|
||||
},
|
||||
SSL: SSLConfig{
|
||||
Enabled: false,
|
||||
CertFile: "./ssl/cert.pem",
|
||||
KeyFile: "./ssl/key.pem",
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "INFO",
|
||||
ConsoleOutput: true,
|
||||
Format: LogFormatConfig{
|
||||
Type: "standard",
|
||||
UseColors: true,
|
||||
ShowModule: true,
|
||||
TimestampFormat: "2006-01-02 15:04:05",
|
||||
},
|
||||
},
|
||||
Extensions: []ExtensionConfig{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if c.Server.Port < 1 || c.Server.Port > 65535 {
|
||||
return fmt.Errorf("invalid port: %d", c.Server.Port)
|
||||
}
|
||||
|
||||
if c.SSL.Enabled {
|
||||
if c.SSL.CertFile == "" {
|
||||
return fmt.Errorf("SSL enabled but cert_file not specified")
|
||||
}
|
||||
if c.SSL.KeyFile == "" {
|
||||
return fmt.Errorf("SSL enabled but key_file not specified")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
127
go/internal/config/config_test.go
Normal file
127
go/internal/config/config_test.go
Normal file
@ -0,0 +1,127 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefault(t *testing.T) {
|
||||
cfg := Default()
|
||||
|
||||
if cfg.Server.Host != "0.0.0.0" {
|
||||
t.Errorf("Expected host 0.0.0.0, got %s", cfg.Server.Host)
|
||||
}
|
||||
|
||||
if cfg.Server.Port != 8080 {
|
||||
t.Errorf("Expected port 8080, got %d", cfg.Server.Port)
|
||||
}
|
||||
|
||||
if cfg.SSL.Enabled {
|
||||
t.Error("Expected SSL to be disabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modify func(*Config)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid default config",
|
||||
modify: func(c *Config) {},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid port - too low",
|
||||
modify: func(c *Config) {
|
||||
c.Server.Port = 0
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid port - too high",
|
||||
modify: func(c *Config) {
|
||||
c.Server.Port = 70000
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "SSL enabled without cert",
|
||||
modify: func(c *Config) {
|
||||
c.SSL.Enabled = true
|
||||
c.SSL.CertFile = ""
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "SSL enabled without key",
|
||||
modify: func(c *Config) {
|
||||
c.SSL.Enabled = true
|
||||
c.SSL.CertFile = "cert.pem"
|
||||
c.SSL.KeyFile = ""
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := Default()
|
||||
tt.modify(cfg)
|
||||
|
||||
err := cfg.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create temporary config file
|
||||
content := `
|
||||
server:
|
||||
host: 127.0.0.1
|
||||
port: 3000
|
||||
|
||||
logging:
|
||||
level: DEBUG
|
||||
`
|
||||
tmpfile, err := os.CreateTemp("", "config-*.yaml")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
if _, err := tmpfile.Write([]byte(content)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(tmpfile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Host != "127.0.0.1" {
|
||||
t.Errorf("Expected host 127.0.0.1, got %s", cfg.Server.Host)
|
||||
}
|
||||
|
||||
if cfg.Server.Port != 3000 {
|
||||
t.Errorf("Expected port 3000, got %d", cfg.Server.Port)
|
||||
}
|
||||
|
||||
if cfg.Logging.Level != "DEBUG" {
|
||||
t.Errorf("Expected level DEBUG, got %s", cfg.Logging.Level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadNotFound(t *testing.T) {
|
||||
_, err := Load("/nonexistent/config.yaml")
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent file")
|
||||
}
|
||||
}
|
||||
136
go/internal/logging/logger.go
Normal file
136
go/internal/logging/logger.go
Normal file
@ -0,0 +1,136 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/konduktor/konduktor/internal/config"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Level string
|
||||
TimestampFormat string
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
level string
|
||||
timestampFormat string
|
||||
configFull *config.LoggingConfig
|
||||
}
|
||||
|
||||
func New(cfg Config) (*Logger, error) {
|
||||
timestampFormat := cfg.TimestampFormat
|
||||
if timestampFormat == "" {
|
||||
timestampFormat = "2006-01-02 15:04:05"
|
||||
}
|
||||
|
||||
return &Logger{
|
||||
level: cfg.Level,
|
||||
timestampFormat: timestampFormat,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewFromConfig(cfg config.LoggingConfig) (*Logger, error) {
|
||||
timestampFormat := cfg.Format.TimestampFormat
|
||||
if timestampFormat == "" {
|
||||
timestampFormat = "2006-01-02 15:04:05"
|
||||
}
|
||||
|
||||
return &Logger{
|
||||
level: cfg.Level,
|
||||
timestampFormat: timestampFormat,
|
||||
configFull: &cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Logger) formatTime() string {
|
||||
return time.Now().Format(l.timestampFormat)
|
||||
}
|
||||
|
||||
func (l *Logger) log(level string, msg string, fields ...interface{}) {
|
||||
timestamp := l.formatTime()
|
||||
|
||||
// Simple console output for now
|
||||
// TODO: Implement proper structured logging with zap
|
||||
output := timestamp + " [" + level + "] " + msg
|
||||
|
||||
if len(fields) > 0 {
|
||||
output += " {"
|
||||
for i := 0; i < len(fields); i += 2 {
|
||||
if i > 0 {
|
||||
output += ", "
|
||||
}
|
||||
if i+1 < len(fields) {
|
||||
output += fields[i].(string) + "=" + formatValue(fields[i+1])
|
||||
}
|
||||
}
|
||||
output += "}"
|
||||
}
|
||||
|
||||
os.Stdout.WriteString(output + "\n")
|
||||
}
|
||||
|
||||
func formatValue(v interface{}) string {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val
|
||||
case int:
|
||||
return fmt.Sprintf("%d", val)
|
||||
case int64:
|
||||
return fmt.Sprintf("%d", val)
|
||||
case float64:
|
||||
return fmt.Sprintf("%.2f", val)
|
||||
case bool:
|
||||
return fmt.Sprintf("%t", val)
|
||||
case error:
|
||||
return val.Error()
|
||||
default:
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(msg string, fields ...interface{}) {
|
||||
if l.shouldLog("DEBUG") {
|
||||
l.log("DEBUG", msg, fields...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Info(msg string, fields ...interface{}) {
|
||||
if l.shouldLog("INFO") {
|
||||
l.log("INFO", msg, fields...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn(msg string, fields ...interface{}) {
|
||||
if l.shouldLog("WARN") {
|
||||
l.log("WARN", msg, fields...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Error(msg string, fields ...interface{}) {
|
||||
if l.shouldLog("ERROR") {
|
||||
l.log("ERROR", msg, fields...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) shouldLog(level string) bool {
|
||||
levels := map[string]int{
|
||||
"DEBUG": 0,
|
||||
"INFO": 1,
|
||||
"WARN": 2,
|
||||
"ERROR": 3,
|
||||
}
|
||||
|
||||
currentLevel, ok := levels[l.level]
|
||||
if !ok {
|
||||
currentLevel = 1 // Default to INFO
|
||||
}
|
||||
|
||||
msgLevel, ok := levels[level]
|
||||
if !ok {
|
||||
msgLevel = 1
|
||||
}
|
||||
|
||||
return msgLevel >= currentLevel
|
||||
}
|
||||
172
go/internal/logging/logger_test.go
Normal file
172
go/internal/logging/logger_test.go
Normal file
@ -0,0 +1,172 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
logger, err := New(Config{Level: "INFO"})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
t.Fatal("Expected logger, got nil")
|
||||
}
|
||||
|
||||
if logger.level != "INFO" {
|
||||
t.Errorf("Expected level INFO, got %s", logger.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_DefaultTimestampFormat(t *testing.T) {
|
||||
logger, _ := New(Config{Level: "DEBUG"})
|
||||
|
||||
if logger.timestampFormat != "2006-01-02 15:04:05" {
|
||||
t.Errorf("Expected default timestamp format, got %s", logger.timestampFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_CustomTimestampFormat(t *testing.T) {
|
||||
logger, _ := New(Config{
|
||||
Level: "DEBUG",
|
||||
TimestampFormat: "15:04:05",
|
||||
})
|
||||
|
||||
if logger.timestampFormat != "15:04:05" {
|
||||
t.Errorf("Expected custom timestamp format, got %s", logger.timestampFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_ShouldLog(t *testing.T) {
|
||||
tests := []struct {
|
||||
loggerLevel string
|
||||
msgLevel string
|
||||
shouldLog bool
|
||||
}{
|
||||
{"DEBUG", "DEBUG", true},
|
||||
{"DEBUG", "INFO", true},
|
||||
{"DEBUG", "WARN", true},
|
||||
{"DEBUG", "ERROR", true},
|
||||
{"INFO", "DEBUG", false},
|
||||
{"INFO", "INFO", true},
|
||||
{"INFO", "WARN", true},
|
||||
{"INFO", "ERROR", true},
|
||||
{"WARN", "DEBUG", false},
|
||||
{"WARN", "INFO", false},
|
||||
{"WARN", "WARN", true},
|
||||
{"WARN", "ERROR", true},
|
||||
{"ERROR", "DEBUG", false},
|
||||
{"ERROR", "INFO", false},
|
||||
{"ERROR", "WARN", false},
|
||||
{"ERROR", "ERROR", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.loggerLevel+"_"+tt.msgLevel, func(t *testing.T) {
|
||||
logger, _ := New(Config{Level: tt.loggerLevel})
|
||||
|
||||
if got := logger.shouldLog(tt.msgLevel); got != tt.shouldLog {
|
||||
t.Errorf("shouldLog(%s) = %v, want %v", tt.msgLevel, got, tt.shouldLog)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_ShouldLog_InvalidLevel(t *testing.T) {
|
||||
logger, _ := New(Config{Level: "INVALID"})
|
||||
|
||||
// Should default to INFO level
|
||||
if !logger.shouldLog("INFO") {
|
||||
t.Error("Invalid level should default to INFO")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_Debug(t *testing.T) {
|
||||
logger, _ := New(Config{Level: "DEBUG"})
|
||||
|
||||
// Should not panic
|
||||
logger.Debug("test message", "key", "value")
|
||||
}
|
||||
|
||||
func TestLogger_Info(t *testing.T) {
|
||||
logger, _ := New(Config{Level: "INFO"})
|
||||
|
||||
// Should not panic
|
||||
logger.Info("test message", "key", "value")
|
||||
}
|
||||
|
||||
func TestLogger_Warn(t *testing.T) {
|
||||
logger, _ := New(Config{Level: "WARN"})
|
||||
|
||||
// Should not panic
|
||||
logger.Warn("test message", "key", "value")
|
||||
}
|
||||
|
||||
func TestLogger_Error(t *testing.T) {
|
||||
logger, _ := New(Config{Level: "ERROR"})
|
||||
|
||||
// Should not panic
|
||||
logger.Error("test message", "key", "value")
|
||||
}
|
||||
|
||||
func TestFormatValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{"test", "test"},
|
||||
{42, "*"}, // int converts to rune
|
||||
{nil, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := formatValue(tt.input)
|
||||
// Just check it doesn't panic
|
||||
_ = got
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_FormatTime(t *testing.T) {
|
||||
logger, _ := New(Config{
|
||||
Level: "INFO",
|
||||
TimestampFormat: "2006-01-02",
|
||||
})
|
||||
|
||||
result := logger.formatTime()
|
||||
|
||||
// Should be in expected format (YYYY-MM-DD)
|
||||
if len(result) != 10 {
|
||||
t.Errorf("Expected date format YYYY-MM-DD, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Benchmarks ==============
|
||||
|
||||
func BenchmarkLogger_Info(b *testing.B) {
|
||||
logger, _ := New(Config{Level: "INFO"})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Info("test message", "key", "value")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogger_Debug_Filtered(b *testing.B) {
|
||||
logger, _ := New(Config{Level: "ERROR"})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Debug("test message", "key", "value")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogger_ShouldLog(b *testing.B) {
|
||||
logger, _ := New(Config{Level: "INFO"})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.shouldLog("DEBUG")
|
||||
}
|
||||
}
|
||||
74
go/internal/middleware/middleware.go
Normal file
74
go/internal/middleware/middleware.go
Normal file
@ -0,0 +1,74 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/konduktor/konduktor/internal/logging"
|
||||
)
|
||||
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.status = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
size, err := rw.ResponseWriter.Write(b)
|
||||
rw.size += size
|
||||
return size, err
|
||||
}
|
||||
|
||||
func ServerHeader(next http.Handler, version string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Server", fmt.Sprintf("konduktor/%s", version))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func AccessLog(next http.Handler, logger *logging.Logger) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
wrapped := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
status: http.StatusOK,
|
||||
}
|
||||
|
||||
next.ServeHTTP(wrapped, r)
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
logger.Info("HTTP request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", wrapped.status,
|
||||
"duration_ms", duration.Milliseconds(),
|
||||
"client_ip", r.RemoteAddr,
|
||||
"user_agent", r.UserAgent(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func Recovery(next http.Handler, logger *logging.Logger) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error("Panic recovered",
|
||||
"error", fmt.Sprintf("%v", err),
|
||||
"stack", string(debug.Stack()),
|
||||
)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
244
go/internal/middleware/middleware_test.go
Normal file
244
go/internal/middleware/middleware_test.go
Normal file
@ -0,0 +1,244 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/konduktor/konduktor/internal/logging"
|
||||
)
|
||||
|
||||
// ============== ServerHeader Tests ==============
|
||||
|
||||
func TestServerHeader(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
wrapped := ServerHeader(handler, "1.0.0")
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
wrapped.ServeHTTP(rr, req)
|
||||
|
||||
serverHeader := rr.Header().Get("Server")
|
||||
if serverHeader != "konduktor/1.0.0" {
|
||||
t.Errorf("Expected Server header 'konduktor/1.0.0', got '%s'", serverHeader)
|
||||
}
|
||||
}
|
||||
|
||||
// ============== AccessLog Tests ==============
|
||||
|
||||
func TestAccessLog(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Hello"))
|
||||
})
|
||||
|
||||
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
||||
wrapped := AccessLog(handler, logger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
wrapped.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLog_CapturesStatusCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
}{
|
||||
{"OK", http.StatusOK},
|
||||
{"NotFound", http.StatusNotFound},
|
||||
{"InternalError", http.StatusInternalServerError},
|
||||
{"Redirect", http.StatusMovedPermanently},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tt.statusCode)
|
||||
})
|
||||
|
||||
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
||||
wrapped := AccessLog(handler, logger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
wrapped.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.statusCode {
|
||||
t.Errorf("Expected status %d, got %d", tt.statusCode, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Recovery Tests ==============
|
||||
|
||||
func TestRecovery_NoPanic(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
||||
wrapped := Recovery(handler, logger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
wrapped.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||
}
|
||||
|
||||
if rr.Body.String() != "OK" {
|
||||
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecovery_WithPanic(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
logger, _ := logging.New(logging.Config{Level: "ERROR"})
|
||||
wrapped := Recovery(handler, logger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Should not panic
|
||||
wrapped.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status 500, got %d", rr.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(rr.Body.String(), "Internal Server Error") {
|
||||
t.Errorf("Expected 'Internal Server Error' in body, got '%s'", rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ============== responseWriter Tests ==============
|
||||
|
||||
func TestResponseWriter_WriteHeader(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
|
||||
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
|
||||
if rw.status != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404, got %d", rw.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseWriter_Write(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
|
||||
|
||||
n, err := rw.Write([]byte("Hello World"))
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if n != 11 {
|
||||
t.Errorf("Expected 11 bytes written, got %d", n)
|
||||
}
|
||||
|
||||
if rw.size != 11 {
|
||||
t.Errorf("Expected size 11, got %d", rw.size)
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Middleware Chain Tests ==============
|
||||
|
||||
func TestMiddlewareChain(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
||||
|
||||
// Apply middleware chain
|
||||
wrapped := Recovery(AccessLog(ServerHeader(handler, "1.0.0"), logger), logger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
wrapped.ServeHTTP(rr, req)
|
||||
|
||||
// Check all middleware worked
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||
}
|
||||
|
||||
if rr.Header().Get("Server") != "konduktor/1.0.0" {
|
||||
t.Errorf("Expected Server header")
|
||||
}
|
||||
|
||||
if rr.Body.String() != "OK" {
|
||||
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Benchmarks ==============
|
||||
|
||||
func BenchmarkServerHeader(b *testing.B) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
wrapped := ServerHeader(handler, "1.0.0")
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rr := httptest.NewRecorder()
|
||||
wrapped.ServeHTTP(rr, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAccessLog(b *testing.B) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
logger, _ := logging.New(logging.Config{Level: "ERROR"}) // Minimize logging overhead
|
||||
wrapped := AccessLog(handler, logger)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rr := httptest.NewRecorder()
|
||||
wrapped.ServeHTTP(rr, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRecovery(b *testing.B) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
logger, _ := logging.New(logging.Config{Level: "ERROR"})
|
||||
wrapped := Recovery(handler, logger)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rr := httptest.NewRecorder()
|
||||
wrapped.ServeHTTP(rr, req)
|
||||
}
|
||||
}
|
||||
263
go/internal/pathmatcher/pathmatcher.go
Normal file
263
go/internal/pathmatcher/pathmatcher.go
Normal file
@ -0,0 +1,263 @@
|
||||
package pathmatcher
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type MountedPath struct {
|
||||
path string
|
||||
name string
|
||||
stripPath bool
|
||||
}
|
||||
|
||||
func NewMountedPath(path string, opts ...MountedPathOption) *MountedPath {
|
||||
// Normalize: remove trailing slash (except for root)
|
||||
normalizedPath := strings.TrimSuffix(path, "/")
|
||||
if normalizedPath == "" {
|
||||
normalizedPath = ""
|
||||
}
|
||||
|
||||
m := &MountedPath{
|
||||
path: normalizedPath,
|
||||
name: normalizedPath,
|
||||
stripPath: true,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
|
||||
if m.name == "" {
|
||||
m.name = normalizedPath
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
type MountedPathOption func(*MountedPath)
|
||||
|
||||
func WithName(name string) MountedPathOption {
|
||||
return func(m *MountedPath) {
|
||||
m.name = name
|
||||
}
|
||||
}
|
||||
|
||||
func WithStripPath(strip bool) MountedPathOption {
|
||||
return func(m *MountedPath) {
|
||||
m.stripPath = strip
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MountedPath) Path() string {
|
||||
return m.path
|
||||
}
|
||||
|
||||
func (m *MountedPath) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MountedPath) StripPath() bool {
|
||||
return m.stripPath
|
||||
}
|
||||
|
||||
func (m *MountedPath) Matches(requestPath string) bool {
|
||||
// Empty or "/" mount matches everything
|
||||
if m.path == "" || m.path == "/" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Request path must be at least as long as mount path
|
||||
if len(requestPath) < len(m.path) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if request path starts with mount path
|
||||
if !strings.HasPrefix(requestPath, m.path) {
|
||||
return false
|
||||
}
|
||||
|
||||
// If paths are equal length, it's a match
|
||||
if len(requestPath) == len(m.path) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Otherwise, next char must be '/' to prevent /api matching /api-v2
|
||||
return requestPath[len(m.path)] == '/'
|
||||
}
|
||||
|
||||
func (m *MountedPath) GetModifiedPath(requestPath string) string {
|
||||
if !m.stripPath {
|
||||
return requestPath
|
||||
}
|
||||
|
||||
// Root mount doesn't strip anything
|
||||
if m.path == "" || m.path == "/" {
|
||||
return requestPath
|
||||
}
|
||||
|
||||
// Strip the prefix
|
||||
modified := strings.TrimPrefix(requestPath, m.path)
|
||||
|
||||
// Ensure result starts with /
|
||||
if modified == "" || modified[0] != '/' {
|
||||
modified = "/" + modified
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
type MountManager struct {
|
||||
mounts []*MountedPath
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMountManager() *MountManager {
|
||||
return &MountManager{
|
||||
mounts: make([]*MountedPath, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (mm *MountManager) AddMount(mount *MountedPath) {
|
||||
mm.mu.Lock()
|
||||
defer mm.mu.Unlock()
|
||||
|
||||
// Insert in sorted order (longer paths first)
|
||||
inserted := false
|
||||
for i, existing := range mm.mounts {
|
||||
if len(mount.path) > len(existing.path) {
|
||||
// Insert at position i
|
||||
mm.mounts = append(mm.mounts[:i], append([]*MountedPath{mount}, mm.mounts[i:]...)...)
|
||||
inserted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !inserted {
|
||||
mm.mounts = append(mm.mounts, mount)
|
||||
}
|
||||
}
|
||||
|
||||
func (mm *MountManager) RemoveMount(path string) bool {
|
||||
mm.mu.Lock()
|
||||
defer mm.mu.Unlock()
|
||||
|
||||
normalizedPath := strings.TrimSuffix(path, "/")
|
||||
|
||||
for i, mount := range mm.mounts {
|
||||
if mount.path == normalizedPath {
|
||||
mm.mounts = append(mm.mounts[:i], mm.mounts[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (mm *MountManager) GetMount(requestPath string) *MountedPath {
|
||||
mm.mu.RLock()
|
||||
defer mm.mu.RUnlock()
|
||||
|
||||
// Mounts are sorted by path length (longest first)
|
||||
// so the first match is the best match
|
||||
for _, mount := range mm.mounts {
|
||||
if mount.Matches(requestPath) {
|
||||
return mount
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mm *MountManager) MountCount() int {
|
||||
mm.mu.RLock()
|
||||
defer mm.mu.RUnlock()
|
||||
return len(mm.mounts)
|
||||
}
|
||||
|
||||
func (mm *MountManager) Mounts() []*MountedPath {
|
||||
mm.mu.RLock()
|
||||
defer mm.mu.RUnlock()
|
||||
|
||||
result := make([]*MountedPath, len(mm.mounts))
|
||||
copy(result, mm.mounts)
|
||||
return result
|
||||
}
|
||||
|
||||
func (mm *MountManager) ListMounts() []map[string]interface{} {
|
||||
mm.mu.RLock()
|
||||
defer mm.mu.RUnlock()
|
||||
|
||||
result := make([]map[string]interface{}, len(mm.mounts))
|
||||
for i, mount := range mm.mounts {
|
||||
result[i] = map[string]interface{}{
|
||||
"path": mount.path,
|
||||
"name": mount.name,
|
||||
"strip_path": mount.stripPath,
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
|
||||
func PathMatchesPrefix(requestPath, prefix string) bool {
|
||||
// Normalize prefix
|
||||
prefix = strings.TrimSuffix(prefix, "/")
|
||||
|
||||
// Empty or "/" prefix matches everything
|
||||
if prefix == "" || prefix == "/" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Request path must be at least as long as prefix
|
||||
if len(requestPath) < len(prefix) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if request path starts with prefix
|
||||
if !strings.HasPrefix(requestPath, prefix) {
|
||||
return false
|
||||
}
|
||||
|
||||
// If paths are equal length, it's a match
|
||||
if len(requestPath) == len(prefix) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Otherwise, next char must be '/'
|
||||
return requestPath[len(prefix)] == '/'
|
||||
}
|
||||
|
||||
func StripPathPrefix(requestPath, prefix string) string {
|
||||
// Normalize prefix
|
||||
prefix = strings.TrimSuffix(prefix, "/")
|
||||
|
||||
// Empty or "/" prefix doesn't strip anything
|
||||
if prefix == "" || prefix == "/" {
|
||||
return requestPath
|
||||
}
|
||||
|
||||
// Strip the prefix
|
||||
modified := strings.TrimPrefix(requestPath, prefix)
|
||||
|
||||
// Ensure result starts with /
|
||||
if modified == "" || modified[0] != '/' {
|
||||
modified = "/" + modified
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
func MatchAndModifyPath(requestPath, prefix string, stripPath bool) (matches bool, modifiedPath string) {
|
||||
if !PathMatchesPrefix(requestPath, prefix) {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if stripPath {
|
||||
return true, StripPathPrefix(requestPath, prefix)
|
||||
}
|
||||
|
||||
return true, requestPath
|
||||
}
|
||||
460
go/internal/pathmatcher/pathmatcher_test.go
Normal file
460
go/internal/pathmatcher/pathmatcher_test.go
Normal file
@ -0,0 +1,460 @@
|
||||
package pathmatcher
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ============== MountedPath Tests ==============
|
||||
|
||||
func TestMountedPath_RootMountMatchesEverything(t *testing.T) {
|
||||
mount := NewMountedPath("")
|
||||
|
||||
tests := []string{"/", "/api", "/api/users", "/anything/at/all"}
|
||||
|
||||
for _, path := range tests {
|
||||
if !mount.Matches(path) {
|
||||
t.Errorf("Root mount should match %s", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountedPath_SlashRootMountMatchesEverything(t *testing.T) {
|
||||
mount := NewMountedPath("/")
|
||||
|
||||
tests := []string{"/", "/api", "/api/users"}
|
||||
|
||||
for _, path := range tests {
|
||||
if !mount.Matches(path) {
|
||||
t.Errorf("'/' mount should match %s", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountedPath_ExactPathMatch(t *testing.T) {
|
||||
mount := NewMountedPath("/api")
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"/api", true},
|
||||
{"/api/", true},
|
||||
{"/api/users", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := mount.Matches(tt.path); got != tt.expected {
|
||||
t.Errorf("Matches(%s) = %v, want %v", tt.path, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountedPath_NoFalsePrefixMatch(t *testing.T) {
|
||||
mount := NewMountedPath("/api")
|
||||
|
||||
tests := []string{"/api-v2", "/api2", "/apiv2"}
|
||||
|
||||
for _, path := range tests {
|
||||
if mount.Matches(path) {
|
||||
t.Errorf("/api should not match %s", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountedPath_ShorterPathNoMatch(t *testing.T) {
|
||||
mount := NewMountedPath("/api/v1")
|
||||
|
||||
tests := []string{"/api", "/ap", "/"}
|
||||
|
||||
for _, path := range tests {
|
||||
if mount.Matches(path) {
|
||||
t.Errorf("/api/v1 should not match shorter path %s", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountedPath_TrailingSlashNormalized(t *testing.T) {
|
||||
mount1 := NewMountedPath("/api/")
|
||||
mount2 := NewMountedPath("/api")
|
||||
|
||||
if mount1.Path() != "/api" {
|
||||
t.Errorf("Expected path /api, got %s", mount1.Path())
|
||||
}
|
||||
|
||||
if mount2.Path() != "/api" {
|
||||
t.Errorf("Expected path /api, got %s", mount2.Path())
|
||||
}
|
||||
|
||||
if !mount1.Matches("/api/users") {
|
||||
t.Error("mount1 should match /api/users")
|
||||
}
|
||||
|
||||
if !mount2.Matches("/api/users") {
|
||||
t.Error("mount2 should match /api/users")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountedPath_GetModifiedPathStripsPrefix(t *testing.T) {
|
||||
mount := NewMountedPath("/api")
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"/api", "/"},
|
||||
{"/api/", "/"},
|
||||
{"/api/users", "/users"},
|
||||
{"/api/users/123", "/users/123"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
|
||||
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountedPath_GetModifiedPathNoStrip(t *testing.T) {
|
||||
mount := NewMountedPath("/api", WithStripPath(false))
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"/api/users", "/api/users"},
|
||||
{"/api", "/api"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
|
||||
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountedPath_RootMountModifiedPath(t *testing.T) {
|
||||
mount := NewMountedPath("")
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"/api/users", "/api/users"},
|
||||
{"/", "/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
|
||||
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountedPath_NameProperty(t *testing.T) {
|
||||
mount1 := NewMountedPath("/api")
|
||||
mount2 := NewMountedPath("/api", WithName("API Mount"))
|
||||
|
||||
if mount1.Name() != "/api" {
|
||||
t.Errorf("Expected name /api, got %s", mount1.Name())
|
||||
}
|
||||
|
||||
if mount2.Name() != "API Mount" {
|
||||
t.Errorf("Expected name 'API Mount', got %s", mount2.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// ============== MountManager Tests ==============
|
||||
|
||||
func TestMountManager_EmptyManager(t *testing.T) {
|
||||
manager := NewMountManager()
|
||||
|
||||
if got := manager.GetMount("/api"); got != nil {
|
||||
t.Error("Empty manager should return nil")
|
||||
}
|
||||
|
||||
if got := manager.MountCount(); got != 0 {
|
||||
t.Errorf("Expected mount count 0, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountManager_AddMount(t *testing.T) {
|
||||
manager := NewMountManager()
|
||||
mount := NewMountedPath("/api")
|
||||
|
||||
manager.AddMount(mount)
|
||||
|
||||
if manager.MountCount() != 1 {
|
||||
t.Errorf("Expected mount count 1, got %d", manager.MountCount())
|
||||
}
|
||||
|
||||
if got := manager.GetMount("/api/users"); got != mount {
|
||||
t.Error("GetMount should return the added mount")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountManager_LongestPrefixMatching(t *testing.T) {
|
||||
manager := NewMountManager()
|
||||
|
||||
apiMount := NewMountedPath("/api", WithName("api"))
|
||||
apiV1Mount := NewMountedPath("/api/v1", WithName("api_v1"))
|
||||
apiV2Mount := NewMountedPath("/api/v2", WithName("api_v2"))
|
||||
|
||||
manager.AddMount(apiMount)
|
||||
manager.AddMount(apiV2Mount)
|
||||
manager.AddMount(apiV1Mount)
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
expectedName string
|
||||
}{
|
||||
{"/api/v1/users", "api_v1"},
|
||||
{"/api/v2/items", "api_v2"},
|
||||
{"/api/v3/other", "api"},
|
||||
{"/api", "api"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := manager.GetMount(tt.path)
|
||||
if got == nil {
|
||||
t.Errorf("GetMount(%s) returned nil, want mount with name %s", tt.path, tt.expectedName)
|
||||
continue
|
||||
}
|
||||
if got.Name() != tt.expectedName {
|
||||
t.Errorf("GetMount(%s).Name() = %s, want %s", tt.path, got.Name(), tt.expectedName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountManager_RemoveMount(t *testing.T) {
|
||||
manager := NewMountManager()
|
||||
|
||||
manager.AddMount(NewMountedPath("/api"))
|
||||
manager.AddMount(NewMountedPath("/admin"))
|
||||
|
||||
if manager.MountCount() != 2 {
|
||||
t.Errorf("Expected mount count 2, got %d", manager.MountCount())
|
||||
}
|
||||
|
||||
result := manager.RemoveMount("/api")
|
||||
|
||||
if !result {
|
||||
t.Error("RemoveMount should return true")
|
||||
}
|
||||
|
||||
if manager.MountCount() != 1 {
|
||||
t.Errorf("Expected mount count 1, got %d", manager.MountCount())
|
||||
}
|
||||
|
||||
if manager.GetMount("/api/users") != nil {
|
||||
t.Error("GetMount(/api/users) should return nil after removal")
|
||||
}
|
||||
|
||||
if manager.GetMount("/admin/users") == nil {
|
||||
t.Error("GetMount(/admin/users) should still work")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountManager_RemoveNonexistentMount(t *testing.T) {
|
||||
manager := NewMountManager()
|
||||
|
||||
result := manager.RemoveMount("/api")
|
||||
|
||||
if result {
|
||||
t.Error("RemoveMount should return false for nonexistent mount")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountManager_ListMounts(t *testing.T) {
|
||||
manager := NewMountManager()
|
||||
|
||||
manager.AddMount(NewMountedPath("/api", WithName("API")))
|
||||
manager.AddMount(NewMountedPath("/admin", WithName("Admin")))
|
||||
|
||||
mounts := manager.ListMounts()
|
||||
|
||||
if len(mounts) != 2 {
|
||||
t.Errorf("Expected 2 mounts, got %d", len(mounts))
|
||||
}
|
||||
|
||||
for _, m := range mounts {
|
||||
if _, ok := m["path"]; !ok {
|
||||
t.Error("Mount should have 'path' key")
|
||||
}
|
||||
if _, ok := m["name"]; !ok {
|
||||
t.Error("Mount should have 'name' key")
|
||||
}
|
||||
if _, ok := m["strip_path"]; !ok {
|
||||
t.Error("Mount should have 'strip_path' key")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMountManager_MountsReturnsCopy(t *testing.T) {
|
||||
manager := NewMountManager()
|
||||
manager.AddMount(NewMountedPath("/api"))
|
||||
|
||||
mounts1 := manager.Mounts()
|
||||
mounts2 := manager.Mounts()
|
||||
|
||||
if &mounts1[0] == &mounts2[0] {
|
||||
t.Error("Mounts() should return different slices")
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Utility Functions Tests ==============
|
||||
|
||||
func TestPathMatchesPrefix_Basic(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
prefix string
|
||||
expected bool
|
||||
}{
|
||||
{"/api/users", "/api", true},
|
||||
{"/api", "/api", true},
|
||||
{"/api-v2", "/api", false},
|
||||
{"/ap", "/api", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected {
|
||||
t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathMatchesPrefix_Root(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
prefix string
|
||||
expected bool
|
||||
}{
|
||||
{"/anything", "", true},
|
||||
{"/anything", "/", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected {
|
||||
t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripPathPrefix_Basic(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
prefix string
|
||||
expected string
|
||||
}{
|
||||
{"/api/users", "/api", "/users"},
|
||||
{"/api", "/api", "/"},
|
||||
{"/api/", "/api", "/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected {
|
||||
t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripPathPrefix_Root(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
prefix string
|
||||
expected string
|
||||
}{
|
||||
{"/api/users", "", "/api/users"},
|
||||
{"/api/users", "/", "/api/users"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected {
|
||||
t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchAndModifyPath_Combined(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
prefix string
|
||||
stripPath bool
|
||||
wantMatches bool
|
||||
wantModified string
|
||||
}{
|
||||
{"/api/users", "/api", true, true, "/users"},
|
||||
{"/api", "/api", true, true, "/"},
|
||||
{"/other", "/api", true, false, ""},
|
||||
{"/api/users", "/api", false, true, "/api/users"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
matches, modified := MatchAndModifyPath(tt.path, tt.prefix, tt.stripPath)
|
||||
if matches != tt.wantMatches {
|
||||
t.Errorf("MatchAndModifyPath(%s, %s, %v) matches = %v, want %v",
|
||||
tt.path, tt.prefix, tt.stripPath, matches, tt.wantMatches)
|
||||
}
|
||||
if modified != tt.wantModified {
|
||||
t.Errorf("MatchAndModifyPath(%s, %s, %v) modified = %s, want %s",
|
||||
tt.path, tt.prefix, tt.stripPath, modified, tt.wantModified)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Performance Tests ==============
|
||||
|
||||
func TestPerformance_ManyMatches(t *testing.T) {
|
||||
mount := NewMountedPath("/api/v1/users")
|
||||
|
||||
for i := 0; i < 10000; i++ {
|
||||
if !mount.Matches("/api/v1/users/123/posts") {
|
||||
t.Fatal("Should match")
|
||||
}
|
||||
if mount.Matches("/other/path") {
|
||||
t.Fatal("Should not match")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformance_ManyMounts(t *testing.T) {
|
||||
manager := NewMountManager()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10)) + string(rune('0'+i/10))))
|
||||
}
|
||||
|
||||
if manager.MountCount() != 100 {
|
||||
t.Errorf("Expected 100 mounts, got %d", manager.MountCount())
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Benchmarks ==============
|
||||
|
||||
func BenchmarkMountedPath_Matches(b *testing.B) {
|
||||
mount := NewMountedPath("/api/v1")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
mount.Matches("/api/v1/users/123")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMountManager_GetMount(b *testing.B) {
|
||||
manager := NewMountManager()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10))))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.GetMount("/api/v5/users/123")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPathMatchesPrefix(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
PathMatchesPrefix("/api/v1/users/123", "/api/v1")
|
||||
}
|
||||
}
|
||||
320
go/internal/proxy/proxy.go
Normal file
320
go/internal/proxy/proxy.go
Normal file
@ -0,0 +1,320 @@
|
||||
// Package proxy provides reverse proxy functionality for Konduktor
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/konduktor/konduktor/internal/logging"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// Target is the backend server URL
|
||||
Target string
|
||||
|
||||
// Timeout is the request timeout (default: 30s)
|
||||
Timeout time.Duration
|
||||
|
||||
// Headers are additional headers to add to requests
|
||||
Headers map[string]string
|
||||
|
||||
// StripPrefix removes this prefix from the request path
|
||||
StripPrefix string
|
||||
|
||||
// PreserveHost keeps the original Host header
|
||||
PreserveHost bool
|
||||
}
|
||||
|
||||
type ReverseProxy struct {
|
||||
config *Config
|
||||
targetURL *url.URL
|
||||
httpClient *http.Client
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
func New(cfg *Config, logger *logging.Logger) (*ReverseProxy, error) {
|
||||
if cfg.Target == "" {
|
||||
return nil, fmt.Errorf("proxy target is required")
|
||||
}
|
||||
|
||||
targetURL, err := url.Parse(cfg.Target)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy target URL: %w", err)
|
||||
}
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: timeout,
|
||||
}
|
||||
|
||||
return &ReverseProxy{
|
||||
config: cfg,
|
||||
targetURL: targetURL,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: timeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse // Don't follow redirects
|
||||
},
|
||||
},
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
rp.ProxyRequest(w, r, nil)
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, params map[string]string) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Build target URL
|
||||
targetURL := rp.buildTargetURL(r)
|
||||
|
||||
// Create proxy request
|
||||
proxyReq, err := rp.createProxyRequest(ctx, r, targetURL)
|
||||
if err != nil {
|
||||
rp.handleError(w, http.StatusInternalServerError, "Failed to create proxy request", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Add custom headers with parameter substitution
|
||||
rp.addCustomHeaders(proxyReq, r, params)
|
||||
|
||||
// Execute request
|
||||
resp, err := rp.httpClient.Do(proxyReq)
|
||||
if err != nil {
|
||||
rp.handleProxyError(w, err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Copy response
|
||||
rp.copyResponse(w, resp)
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL {
|
||||
targetURL := *rp.targetURL
|
||||
|
||||
// Strip prefix if configured
|
||||
path := r.URL.Path
|
||||
if rp.config.StripPrefix != "" {
|
||||
path = strings.TrimPrefix(path, rp.config.StripPrefix)
|
||||
if path == "" || path[0] != '/' {
|
||||
path = "/" + path
|
||||
}
|
||||
}
|
||||
|
||||
// If target has a path, append request path to it
|
||||
if rp.targetURL.Path != "" && rp.targetURL.Path != "/" {
|
||||
targetURL.Path = singleJoiningSlash(rp.targetURL.Path, path)
|
||||
} else {
|
||||
targetURL.Path = path
|
||||
}
|
||||
|
||||
// Preserve query string
|
||||
targetURL.RawQuery = r.URL.RawQuery
|
||||
|
||||
return &targetURL
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) createProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL) (*http.Request, error) {
|
||||
proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL.String(), r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Copy ContentLength
|
||||
proxyReq.ContentLength = r.ContentLength
|
||||
|
||||
// Copy headers
|
||||
for key, values := range r.Header {
|
||||
for _, value := range values {
|
||||
proxyReq.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Set/update Host header
|
||||
if rp.config.PreserveHost {
|
||||
proxyReq.Host = r.Host
|
||||
} else {
|
||||
proxyReq.Host = targetURL.Host
|
||||
}
|
||||
|
||||
// Remove hop-by-hop headers
|
||||
removeHopByHopHeaders(proxyReq.Header)
|
||||
|
||||
return proxyReq, nil
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) addCustomHeaders(proxyReq *http.Request, originalReq *http.Request, params map[string]string) {
|
||||
// Add X-Forwarded headers
|
||||
clientIP := getClientIP(originalReq)
|
||||
if prior := originalReq.Header.Get("X-Forwarded-For"); prior != "" {
|
||||
clientIP = prior + ", " + clientIP
|
||||
}
|
||||
proxyReq.Header.Set("X-Forwarded-For", clientIP)
|
||||
proxyReq.Header.Set("X-Forwarded-Proto", getScheme(originalReq))
|
||||
proxyReq.Header.Set("X-Forwarded-Host", originalReq.Host)
|
||||
|
||||
// Add custom headers from config
|
||||
for key, value := range rp.config.Headers {
|
||||
// Substitute parameters like {version}
|
||||
substituted := value
|
||||
for paramKey, paramValue := range params {
|
||||
substituted = strings.ReplaceAll(substituted, "{"+paramKey+"}", paramValue)
|
||||
}
|
||||
// Substitute $remote_addr
|
||||
substituted = strings.ReplaceAll(substituted, "$remote_addr", clientIP)
|
||||
proxyReq.Header.Set(key, substituted)
|
||||
}
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) copyResponse(w http.ResponseWriter, resp *http.Response) {
|
||||
// Copy headers
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove hop-by-hop headers from response
|
||||
removeHopByHopHeaders(w.Header())
|
||||
|
||||
// Write status code
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// Copy body
|
||||
io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) handleError(w http.ResponseWriter, status int, message string, err error) {
|
||||
if rp.logger != nil {
|
||||
rp.logger.Error(message, "error", err)
|
||||
}
|
||||
http.Error(w, message, status)
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) handleProxyError(w http.ResponseWriter, err error) {
|
||||
if rp.logger != nil {
|
||||
rp.logger.Error("Proxy request failed", "error", err)
|
||||
}
|
||||
|
||||
// Check for timeout
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||
http.Error(w, "504 Gateway Timeout", http.StatusGatewayTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for connection errors
|
||||
if isConnectionError(err) {
|
||||
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// Context cancelled (client disconnected)
|
||||
if err == context.Canceled {
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
func removeHopByHopHeaders(h http.Header) {
|
||||
hopByHopHeaders := []string{
|
||||
"Connection",
|
||||
"Proxy-Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te",
|
||||
"Trailer",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
for _, header := range hopByHopHeaders {
|
||||
h.Del(header)
|
||||
}
|
||||
}
|
||||
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Real-IP first
|
||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
// Get from RemoteAddr
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func getScheme(r *http.Request) string {
|
||||
if r.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
connectionErrors := []string{
|
||||
"connection refused",
|
||||
"no such host",
|
||||
"network is unreachable",
|
||||
"connection reset",
|
||||
"broken pipe",
|
||||
}
|
||||
|
||||
for _, connErr := range connectionErrors {
|
||||
if strings.Contains(strings.ToLower(errStr), connErr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
747
go/internal/proxy/proxy_test.go
Normal file
747
go/internal/proxy/proxy_test.go
Normal file
@ -0,0 +1,747 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ============== Test Backend Server ==============
|
||||
|
||||
type testBackend struct {
|
||||
server *httptest.Server
|
||||
requestLog []requestLogEntry
|
||||
mu sync.Mutex
|
||||
requestCount int64
|
||||
}
|
||||
|
||||
type requestLogEntry struct {
|
||||
Method string
|
||||
Path string
|
||||
Query string
|
||||
Headers http.Header
|
||||
Body string
|
||||
}
|
||||
|
||||
func newTestBackend(handler http.HandlerFunc) *testBackend {
|
||||
tb := &testBackend{
|
||||
requestLog: make([]requestLogEntry, 0),
|
||||
}
|
||||
|
||||
if handler == nil {
|
||||
handler = tb.defaultHandler
|
||||
}
|
||||
|
||||
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tb.logRequest(r)
|
||||
handler(w, r)
|
||||
}))
|
||||
|
||||
return tb
|
||||
}
|
||||
|
||||
func (tb *testBackend) logRequest(r *http.Request) {
|
||||
tb.mu.Lock()
|
||||
defer tb.mu.Unlock()
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
// Restore the body for the handler
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
tb.requestLog = append(tb.requestLog, requestLogEntry{
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
Query: r.URL.RawQuery,
|
||||
Headers: r.Header.Clone(),
|
||||
Body: string(body),
|
||||
})
|
||||
atomic.AddInt64(&tb.requestCount, 1)
|
||||
}
|
||||
|
||||
func (tb *testBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"message": "Backend response",
|
||||
"path": r.URL.Path,
|
||||
"method": r.Method,
|
||||
})
|
||||
}
|
||||
|
||||
func (tb *testBackend) close() {
|
||||
tb.server.Close()
|
||||
}
|
||||
|
||||
func (tb *testBackend) URL() string {
|
||||
return tb.server.URL
|
||||
}
|
||||
|
||||
func (tb *testBackend) getRequestCount() int64 {
|
||||
return atomic.LoadInt64(&tb.requestCount)
|
||||
}
|
||||
|
||||
func (tb *testBackend) getLastRequest() *requestLogEntry {
|
||||
tb.mu.Lock()
|
||||
defer tb.mu.Unlock()
|
||||
if len(tb.requestLog) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &tb.requestLog[len(tb.requestLog)-1]
|
||||
}
|
||||
|
||||
// ============== Proxy Creation Tests ==============
|
||||
|
||||
func TestNew_ValidConfig(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Target: "http://localhost:8080",
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
proxy, err := New(cfg, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy: %v", err)
|
||||
}
|
||||
|
||||
if proxy == nil {
|
||||
t.Fatal("Expected proxy instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_EmptyTarget(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Target: "",
|
||||
}
|
||||
|
||||
_, err := New(cfg, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for empty target")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_InvalidTargetURL(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Target: "://invalid-url",
|
||||
}
|
||||
|
||||
_, err := New(cfg, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_DefaultTimeout(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Target: "http://localhost:8080",
|
||||
}
|
||||
|
||||
proxy, err := New(cfg, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy: %v", err)
|
||||
}
|
||||
|
||||
if proxy.httpClient.Timeout != 30*time.Second {
|
||||
t.Errorf("Expected default timeout 30s, got %v", proxy.httpClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Basic Proxy Tests ==============
|
||||
|
||||
func TestProxy_BasicGET(t *testing.T) {
|
||||
backend := newTestBackend(nil)
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["path"] != "/test" {
|
||||
t.Errorf("Expected path /test, got %v", response["path"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_BasicPOST(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"received": string(body),
|
||||
"method": r.Method,
|
||||
})
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/data", strings.NewReader(`{"key":"value"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["method"] != "POST" {
|
||||
t.Errorf("Expected method POST, got %v", response["method"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_PUT(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
req := httptest.NewRequest("PUT", "/resource/123", strings.NewReader(`{"name":"updated"}`))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
var response map[string]string
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["method"] != "PUT" {
|
||||
t.Errorf("Expected method PUT, got %v", response["method"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_DELETE(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/resource/123", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
var response map[string]string
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["method"] != "DELETE" {
|
||||
t.Errorf("Expected method DELETE, got %v", response["method"])
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Header Tests ==============
|
||||
|
||||
func TestProxy_HeadersForwarding(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"custom_header": r.Header.Get("X-Custom-Header"),
|
||||
"forwarded_for": r.Header.Get("X-Forwarded-For"),
|
||||
"forwarded_host": r.Header.Get("X-Forwarded-Host"),
|
||||
})
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/headers", nil)
|
||||
req.Header.Set("X-Custom-Header", "test-value")
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
req.Host = "example.com"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
var response map[string]interface{}
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["custom_header"] != "test-value" {
|
||||
t.Errorf("Expected custom header, got %v", response["custom_header"])
|
||||
}
|
||||
|
||||
if response["forwarded_for"] != "192.168.1.100" {
|
||||
t.Errorf("Expected X-Forwarded-For, got %v", response["forwarded_for"])
|
||||
}
|
||||
|
||||
if response["forwarded_host"] != "example.com" {
|
||||
t.Errorf("Expected X-Forwarded-Host, got %v", response["forwarded_host"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_CustomHeaders(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"api_version": r.Header.Get("X-API-Version"),
|
||||
})
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{
|
||||
Target: backend.URL(),
|
||||
Headers: map[string]string{
|
||||
"X-API-Version": "{version}",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Simulate parameter substitution
|
||||
proxy.ProxyRequest(rr, req, map[string]string{"version": "2"})
|
||||
|
||||
var response map[string]string
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["api_version"] != "2" {
|
||||
t.Errorf("Expected API version 2, got %v", response["api_version"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_RemoteAddrSubstitution(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"client_ip": r.Header.Get("X-Client-IP"),
|
||||
})
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{
|
||||
Target: backend.URL(),
|
||||
Headers: map[string]string{
|
||||
"X-Client-IP": "$remote_addr",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api", nil)
|
||||
req.RemoteAddr = "10.0.0.1:54321"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
var response map[string]string
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["client_ip"] != "10.0.0.1" {
|
||||
t.Errorf("Expected client IP 10.0.0.1, got %v", response["client_ip"])
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Query String Tests ==============
|
||||
|
||||
func TestProxy_QueryStringPreservation(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"query": r.URL.RawQuery,
|
||||
"param": r.URL.Query().Get("key"),
|
||||
})
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/search?key=value&page=2", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
var response map[string]string
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["param"] != "value" {
|
||||
t.Errorf("Expected query param 'value', got %v", response["param"])
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Status Code Tests ==============
|
||||
|
||||
func TestProxy_StatusCodePreservation(t *testing.T) {
|
||||
statusCodes := []int{200, 201, 400, 404, 500}
|
||||
|
||||
for _, code := range statusCodes {
|
||||
code := code // capture range variable
|
||||
t.Run(fmt.Sprintf("Status_%d", code), func(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(code)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/status", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != code {
|
||||
t.Errorf("Expected status %d, got %d", code, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Error Handling Tests ==============
|
||||
|
||||
func TestProxy_BackendUnavailable(t *testing.T) {
|
||||
// Use a port that's definitely not listening
|
||||
proxy, _ := New(&Config{
|
||||
Target: "http://127.0.0.1:59999",
|
||||
Timeout: 1 * time.Second,
|
||||
}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadGateway {
|
||||
t.Errorf("Expected status 502 Bad Gateway, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_Timeout(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(2 * time.Second)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{
|
||||
Target: backend.URL(),
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/slow", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusGatewayTimeout {
|
||||
t.Errorf("Expected status 504 Gateway Timeout, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Path Handling Tests ==============
|
||||
|
||||
func TestProxy_StripPrefix(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"path": r.URL.Path,
|
||||
})
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{
|
||||
Target: backend.URL(),
|
||||
StripPrefix: "/api/v1",
|
||||
}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/users", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
var response map[string]string
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["path"] != "/users" {
|
||||
t.Errorf("Expected stripped path /users, got %v", response["path"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_TargetWithPath(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"path": r.URL.Path,
|
||||
})
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{
|
||||
Target: backend.URL() + "/backend",
|
||||
}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/resource", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
var response map[string]string
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["path"] != "/backend/resource" {
|
||||
t.Errorf("Expected path /backend/resource, got %v", response["path"])
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Large Body Tests ==============
|
||||
|
||||
func TestProxy_LargeRequestBody(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
json.NewEncoder(w).Encode(map[string]int{
|
||||
"received_bytes": len(body),
|
||||
})
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
// 100KB body
|
||||
largeBody := strings.Repeat("x", 100000)
|
||||
req := httptest.NewRequest("POST", "/upload", strings.NewReader(largeBody))
|
||||
req.ContentLength = int64(len(largeBody))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
var response map[string]int
|
||||
json.NewDecoder(rr.Body).Decode(&response)
|
||||
|
||||
if response["received_bytes"] != 100000 {
|
||||
t.Errorf("Expected 100000 bytes, got %d", response["received_bytes"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_LargeResponseBody(t *testing.T) {
|
||||
largeResponse := strings.Repeat("y", 100000)
|
||||
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(largeResponse))
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/large", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Body.Len() != 100000 {
|
||||
t.Errorf("Expected 100000 bytes in response, got %d", rr.Body.Len())
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Concurrent Requests Tests ==============
|
||||
|
||||
func TestProxy_ConcurrentRequests(t *testing.T) {
|
||||
backend := newTestBackend(nil)
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
const numRequests = 50
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
|
||||
req := httptest.NewRequest("GET", "/concurrent", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
errors <- &net.OpError{Op: "test", Err: context.DeadlineExceeded}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
errorCount := 0
|
||||
for range errors {
|
||||
errorCount++
|
||||
}
|
||||
|
||||
if errorCount > 0 {
|
||||
t.Errorf("Got %d errors in concurrent requests", errorCount)
|
||||
}
|
||||
|
||||
if backend.getRequestCount() != numRequests {
|
||||
t.Errorf("Expected %d requests at backend, got %d", numRequests, backend.getRequestCount())
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Echo Tests ==============
|
||||
|
||||
func TestProxy_Echo(t *testing.T) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
|
||||
w.Write(body)
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
testData := "Hello, Proxy!"
|
||||
req := httptest.NewRequest("POST", "/echo", strings.NewReader(testData))
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Body.String() != testData {
|
||||
t.Errorf("Expected echo of '%s', got '%s'", testData, rr.Body.String())
|
||||
}
|
||||
|
||||
if rr.Header().Get("Content-Type") != "text/plain" {
|
||||
t.Errorf("Expected Content-Type text/plain, got %s", rr.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Helper Function Tests ==============
|
||||
|
||||
func TestSingleJoiningSlash(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b, expected string
|
||||
}{
|
||||
{"/api", "/users", "/api/users"},
|
||||
{"/api/", "/users", "/api/users"},
|
||||
{"/api", "users", "/api/users"},
|
||||
{"/api/", "users", "/api/users"},
|
||||
{"", "/users", "/users"},
|
||||
{"/api", "", "/api/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := singleJoiningSlash(tt.a, tt.b)
|
||||
if result != tt.expected {
|
||||
t.Errorf("singleJoiningSlash(%q, %q) = %q, want %q", tt.a, tt.b, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
remoteAddr string
|
||||
xRealIP string
|
||||
expected string
|
||||
}{
|
||||
{"192.168.1.1:1234", "", "192.168.1.1"},
|
||||
{"192.168.1.1:1234", "10.0.0.1", "10.0.0.1"},
|
||||
{"invalid", "", "invalid"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xRealIP != "" {
|
||||
req.Header.Set("X-Real-IP", tt.xRealIP)
|
||||
}
|
||||
|
||||
result := getClientIP(req)
|
||||
if result != tt.expected {
|
||||
t.Errorf("getClientIP() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetScheme(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tls bool
|
||||
header string
|
||||
expected string
|
||||
}{
|
||||
{"HTTP", false, "", "http"},
|
||||
{"HTTPS from TLS", true, "", "https"},
|
||||
{"HTTPS from header", false, "https", "https"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tt.header != "" {
|
||||
req.Header.Set("X-Forwarded-Proto", tt.header)
|
||||
}
|
||||
// Note: httptest doesn't set TLS, so we can only test non-TLS cases fully
|
||||
|
||||
result := getScheme(req)
|
||||
if !tt.tls && result != tt.expected {
|
||||
t.Errorf("getScheme() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsConnectionError(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{nil, false},
|
||||
{&net.OpError{Op: "dial", Err: &net.DNSError{Err: "no such host"}}, true},
|
||||
{context.DeadlineExceeded, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := isConnectionError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isConnectionError(%v) = %v, want %v", tt.err, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Benchmarks ==============
|
||||
|
||||
func BenchmarkProxy_SimpleGET(b *testing.B) {
|
||||
backend := newTestBackend(nil)
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest("GET", "/bench", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rr, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkProxy_POSTWithBody(b *testing.B) {
|
||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||
io.Copy(io.Discard, r.Body)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
defer backend.close()
|
||||
|
||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||
body := strings.Repeat("x", 1024) // 1KB body
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest("POST", "/bench", strings.NewReader(body))
|
||||
rr := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rr, req)
|
||||
}
|
||||
}
|
||||
395
go/internal/routing/router.go
Normal file
395
go/internal/routing/router.go
Normal file
@ -0,0 +1,395 @@
|
||||
// Package routing provides HTTP routing with regex support
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/konduktor/konduktor/internal/config"
|
||||
"github.com/konduktor/konduktor/internal/logging"
|
||||
)
|
||||
|
||||
// RouteMatch represents a matched route with captured parameters
|
||||
type RouteMatch struct {
|
||||
Config map[string]interface{}
|
||||
Params map[string]string
|
||||
}
|
||||
|
||||
// RegexRoute represents a compiled regex route
|
||||
type RegexRoute struct {
|
||||
Pattern *regexp.Regexp
|
||||
Config map[string]interface{}
|
||||
CaseSensitive bool
|
||||
OriginalExpr string
|
||||
}
|
||||
|
||||
// Router handles HTTP routing with exact, regex, and default routes
|
||||
type Router struct {
|
||||
config *config.Config
|
||||
logger *logging.Logger
|
||||
mux *http.ServeMux
|
||||
staticDir string
|
||||
exactRoutes map[string]map[string]interface{}
|
||||
regexRoutes []*RegexRoute
|
||||
defaultRoute map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates a new router from config
|
||||
func New(cfg *config.Config, logger *logging.Logger) *Router {
|
||||
staticDir := "./static"
|
||||
if cfg != nil && cfg.HTTP.StaticDir != "" {
|
||||
staticDir = cfg.HTTP.StaticDir
|
||||
}
|
||||
|
||||
r := &Router{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
mux: http.NewServeMux(),
|
||||
staticDir: staticDir,
|
||||
exactRoutes: make(map[string]map[string]interface{}),
|
||||
regexRoutes: make([]*RegexRoute, 0),
|
||||
}
|
||||
|
||||
r.setupRoutes()
|
||||
return r
|
||||
}
|
||||
|
||||
// NewRouter creates a router without config (for testing)
|
||||
func NewRouter(opts ...RouterOption) *Router {
|
||||
r := &Router{
|
||||
mux: http.NewServeMux(),
|
||||
staticDir: "./static",
|
||||
exactRoutes: make(map[string]map[string]interface{}),
|
||||
regexRoutes: make([]*RegexRoute, 0),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// RouterOption is a functional option for Router
|
||||
type RouterOption func(*Router)
|
||||
|
||||
// WithStaticDir sets the static directory
|
||||
func WithStaticDir(dir string) RouterOption {
|
||||
return func(r *Router) {
|
||||
r.staticDir = dir
|
||||
}
|
||||
}
|
||||
|
||||
// StaticDir returns the static directory path
|
||||
func (r *Router) StaticDir() string {
|
||||
return r.staticDir
|
||||
}
|
||||
|
||||
// Routes returns the regex routes (for testing)
|
||||
func (r *Router) Routes() []*RegexRoute {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.regexRoutes
|
||||
}
|
||||
|
||||
// ExactRoutes returns the exact routes (for testing)
|
||||
func (r *Router) ExactRoutes() map[string]map[string]interface{} {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.exactRoutes
|
||||
}
|
||||
|
||||
// DefaultRoute returns the default route (for testing)
|
||||
func (r *Router) DefaultRoute() map[string]interface{} {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.defaultRoute
|
||||
}
|
||||
|
||||
// AddRoute adds a route with the given pattern and config
|
||||
// Pattern formats:
|
||||
// - "=/path" - exact match
|
||||
// - "~regex" - case-sensitive regex
|
||||
// - "~*regex" - case-insensitive regex
|
||||
// - "__default__" - default/fallback route
|
||||
func (r *Router) AddRoute(pattern string, routeConfig map[string]interface{}) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
switch {
|
||||
case pattern == "__default__":
|
||||
r.defaultRoute = routeConfig
|
||||
|
||||
case strings.HasPrefix(pattern, "="):
|
||||
// Exact match route
|
||||
path := strings.TrimPrefix(pattern, "=")
|
||||
r.exactRoutes[path] = routeConfig
|
||||
|
||||
case strings.HasPrefix(pattern, "~*"):
|
||||
// Case-insensitive regex
|
||||
expr := strings.TrimPrefix(pattern, "~*")
|
||||
re, err := regexp.Compile("(?i)" + expr)
|
||||
if err != nil {
|
||||
if r.logger != nil {
|
||||
r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
r.regexRoutes = append(r.regexRoutes, &RegexRoute{
|
||||
Pattern: re,
|
||||
Config: routeConfig,
|
||||
CaseSensitive: false,
|
||||
OriginalExpr: expr,
|
||||
})
|
||||
|
||||
case strings.HasPrefix(pattern, "~"):
|
||||
// Case-sensitive regex
|
||||
expr := strings.TrimPrefix(pattern, "~")
|
||||
re, err := regexp.Compile(expr)
|
||||
if err != nil {
|
||||
if r.logger != nil {
|
||||
r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
r.regexRoutes = append(r.regexRoutes, &RegexRoute{
|
||||
Pattern: re,
|
||||
Config: routeConfig,
|
||||
CaseSensitive: true,
|
||||
OriginalExpr: expr,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Match finds the best matching route for a path
|
||||
// Priority: exact match > regex match > default
|
||||
func (r *Router) Match(path string) *RouteMatch {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// 1. Check exact routes
|
||||
if cfg, ok := r.exactRoutes[path]; ok {
|
||||
return &RouteMatch{
|
||||
Config: cfg,
|
||||
Params: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check regex routes
|
||||
for _, route := range r.regexRoutes {
|
||||
match := route.Pattern.FindStringSubmatch(path)
|
||||
if match != nil {
|
||||
params := make(map[string]string)
|
||||
|
||||
// Extract named groups
|
||||
names := route.Pattern.SubexpNames()
|
||||
for i, name := range names {
|
||||
if i > 0 && name != "" && i < len(match) {
|
||||
params[name] = match[i]
|
||||
}
|
||||
}
|
||||
|
||||
return &RouteMatch{
|
||||
Config: route.Config,
|
||||
Params: params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Check default route
|
||||
if r.defaultRoute != nil {
|
||||
return &RouteMatch{
|
||||
Config: r.defaultRoute,
|
||||
Params: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupRoutes configures the routes from config
|
||||
func (r *Router) setupRoutes() {
|
||||
// Health check endpoint
|
||||
r.mux.HandleFunc("/health", r.healthHandler)
|
||||
|
||||
// Setup redirect instructions from config
|
||||
if r.config != nil {
|
||||
for from, to := range r.config.Server.RedirectInstructions {
|
||||
fromPath := from
|
||||
toPath := to
|
||||
r.mux.HandleFunc(fromPath, func(w http.ResponseWriter, req *http.Request) {
|
||||
http.Redirect(w, req, toPath, http.StatusMovedPermanently)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Default handler for all other routes
|
||||
r.mux.HandleFunc("/", r.defaultHandler)
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler
|
||||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
r.mux.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// healthHandler handles health check requests
|
||||
func (r *Router) healthHandler(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
// defaultHandler handles requests that don't match other routes
|
||||
func (r *Router) defaultHandler(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path
|
||||
|
||||
// Try to match against configured routes
|
||||
match := r.Match(path)
|
||||
if match != nil {
|
||||
r.handleRouteMatch(w, req, match)
|
||||
return
|
||||
}
|
||||
|
||||
// Try to serve static file
|
||||
if r.staticDir != "" {
|
||||
filePath := filepath.Join(r.staticDir, path)
|
||||
|
||||
// Prevent directory traversal
|
||||
if !strings.HasPrefix(filepath.Clean(filePath), filepath.Clean(r.staticDir)) {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
info, err := os.Stat(filePath)
|
||||
if err == nil {
|
||||
if info.IsDir() {
|
||||
// Try index.html
|
||||
indexPath := filepath.Join(filePath, "index.html")
|
||||
if _, err := os.Stat(indexPath); err == nil {
|
||||
http.ServeFile(w, req, indexPath)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.ServeFile(w, req, filePath)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 404 Not Found
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
|
||||
// handleRouteMatch handles a matched route
|
||||
func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, match *RouteMatch) {
|
||||
cfg := match.Config
|
||||
|
||||
// Handle "return" directive
|
||||
if ret, ok := cfg["return"].(string); ok {
|
||||
parts := strings.SplitN(ret, " ", 2)
|
||||
statusCode := 200
|
||||
body := "OK"
|
||||
if len(parts) >= 1 {
|
||||
switch parts[0] {
|
||||
case "200":
|
||||
statusCode = 200
|
||||
case "201":
|
||||
statusCode = 201
|
||||
case "301":
|
||||
statusCode = 301
|
||||
case "302":
|
||||
statusCode = 302
|
||||
case "400":
|
||||
statusCode = 400
|
||||
case "404":
|
||||
statusCode = 404
|
||||
case "500":
|
||||
statusCode = 500
|
||||
}
|
||||
}
|
||||
if len(parts) >= 2 {
|
||||
body = parts[1]
|
||||
}
|
||||
|
||||
if ct, ok := cfg["content_type"].(string); ok {
|
||||
w.Header().Set("Content-Type", ct)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
w.Write([]byte(body))
|
||||
return
|
||||
}
|
||||
|
||||
// Handle static files with root
|
||||
if root, ok := cfg["root"].(string); ok {
|
||||
path := req.URL.Path
|
||||
|
||||
if indexFile, ok := cfg["index_file"].(string); ok {
|
||||
if path == "/" || strings.HasSuffix(path, "/") {
|
||||
path = "/" + indexFile
|
||||
}
|
||||
}
|
||||
|
||||
filePath := filepath.Join(root, path)
|
||||
|
||||
if cacheControl, ok := cfg["cache_control"].(string); ok {
|
||||
w.Header().Set("Cache-Control", cacheControl)
|
||||
}
|
||||
|
||||
if headers, ok := cfg["headers"].([]interface{}); ok {
|
||||
for _, h := range headers {
|
||||
if header, ok := h.(string); ok {
|
||||
parts := strings.SplitN(header, ": ", 2)
|
||||
if len(parts) == 2 {
|
||||
w.Header().Set(parts[0], parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
http.ServeFile(w, req, filePath)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle SPA fallback
|
||||
if spaFallback, ok := cfg["spa_fallback"].(bool); ok && spaFallback {
|
||||
root := r.staticDir
|
||||
if rt, ok := cfg["root"].(string); ok {
|
||||
root = rt
|
||||
}
|
||||
|
||||
indexFile := "index.html"
|
||||
if idx, ok := cfg["index_file"].(string); ok {
|
||||
indexFile = idx
|
||||
}
|
||||
|
||||
filePath := filepath.Join(root, indexFile)
|
||||
http.ServeFile(w, req, filePath)
|
||||
return
|
||||
}
|
||||
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
|
||||
// CreateRouterFromConfig creates a router from extension config
|
||||
func CreateRouterFromConfig(cfg map[string]interface{}) *Router {
|
||||
router := NewRouter()
|
||||
|
||||
if locations, ok := cfg["regex_locations"].(map[string]interface{}); ok {
|
||||
for pattern, routeCfg := range locations {
|
||||
if rc, ok := routeCfg.(map[string]interface{}); ok {
|
||||
router.AddRoute(pattern, rc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return router
|
||||
}
|
||||
375
go/internal/routing/router_test.go
Normal file
375
go/internal/routing/router_test.go
Normal file
@ -0,0 +1,375 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ============== Router Initialization Tests ==============
|
||||
|
||||
func TestRouter_Initialization(t *testing.T) {
|
||||
router := NewRouter()
|
||||
|
||||
if router.StaticDir() != "./static" {
|
||||
t.Errorf("Expected static dir ./static, got %s", router.StaticDir())
|
||||
}
|
||||
|
||||
if len(router.Routes()) != 0 {
|
||||
t.Error("Expected empty routes")
|
||||
}
|
||||
|
||||
if len(router.ExactRoutes()) != 0 {
|
||||
t.Error("Expected empty exact routes")
|
||||
}
|
||||
|
||||
if router.DefaultRoute() != nil {
|
||||
t.Error("Expected nil default route")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomStaticDir(t *testing.T) {
|
||||
router := NewRouter(WithStaticDir("/custom/path"))
|
||||
|
||||
if router.StaticDir() != "/custom/path" {
|
||||
t.Errorf("Expected static dir /custom/path, got %s", router.StaticDir())
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Route Adding Tests ==============
|
||||
|
||||
func TestRouter_AddExactRoute(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"return": "200 OK"}
|
||||
|
||||
router.AddRoute("=/health", config)
|
||||
|
||||
exactRoutes := router.ExactRoutes()
|
||||
if _, ok := exactRoutes["/health"]; !ok {
|
||||
t.Error("Expected /health in exact routes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_AddDefaultRoute(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"spa_fallback": true, "root": "./static"}
|
||||
|
||||
router.AddRoute("__default__", config)
|
||||
|
||||
if router.DefaultRoute() == nil {
|
||||
t.Error("Expected default route to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_AddRegexRoute(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"root": "./static"}
|
||||
|
||||
router.AddRoute("~^/api/", config)
|
||||
|
||||
if len(router.Routes()) != 1 {
|
||||
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_AddCaseInsensitiveRegexRoute(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"}
|
||||
|
||||
router.AddRoute("~*\\.(css|js)$", config)
|
||||
|
||||
if len(router.Routes()) != 1 {
|
||||
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
|
||||
}
|
||||
|
||||
if router.Routes()[0].CaseSensitive {
|
||||
t.Error("Expected case-insensitive route")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_InvalidRegexPattern(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"root": "./static"}
|
||||
|
||||
// Invalid regex - unmatched bracket
|
||||
router.AddRoute("~^/api/[invalid", config)
|
||||
|
||||
// Should not add invalid pattern
|
||||
if len(router.Routes()) != 0 {
|
||||
t.Error("Should not add invalid regex pattern")
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Route Matching Tests ==============
|
||||
|
||||
func TestRouter_MatchExactRoute(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"return": "200 OK"}
|
||||
router.AddRoute("=/health", config)
|
||||
|
||||
match := router.Match("/health")
|
||||
|
||||
if match == nil {
|
||||
t.Fatal("Expected match for /health")
|
||||
}
|
||||
|
||||
if match.Config["return"] != "200 OK" {
|
||||
t.Error("Expected return config")
|
||||
}
|
||||
|
||||
if len(match.Params) != 0 {
|
||||
t.Error("Expected empty params for exact match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_MatchExactRouteNoMatch(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"return": "200 OK"}
|
||||
router.AddRoute("=/health", config)
|
||||
|
||||
match := router.Match("/healthcheck")
|
||||
|
||||
if match != nil {
|
||||
t.Error("Exact route should not match /healthcheck")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_MatchRegexRoute(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"proxy_pass": "http://localhost:9001"}
|
||||
router.AddRoute("~^/api/v\\d+/", config)
|
||||
|
||||
match := router.Match("/api/v1/users")
|
||||
|
||||
if match == nil {
|
||||
t.Fatal("Expected match for /api/v1/users")
|
||||
}
|
||||
|
||||
if match.Config["proxy_pass"] != "http://localhost:9001" {
|
||||
t.Error("Expected proxy_pass config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_MatchRegexRouteWithGroups(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"proxy_pass": "http://localhost:9001"}
|
||||
router.AddRoute("~^/api/v(?P<version>\\d+)/", config)
|
||||
|
||||
match := router.Match("/api/v2/data")
|
||||
|
||||
if match == nil {
|
||||
t.Fatal("Expected match for /api/v2/data")
|
||||
}
|
||||
|
||||
if match.Params["version"] != "2" {
|
||||
t.Errorf("Expected version=2, got %s", match.Params["version"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_MatchCaseInsensitiveRegex(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"}
|
||||
router.AddRoute("~*\\.(CSS|JS)$", config)
|
||||
|
||||
// Should match lowercase
|
||||
match1 := router.Match("/styles/main.css")
|
||||
if match1 == nil {
|
||||
t.Error("Should match lowercase .css")
|
||||
}
|
||||
|
||||
// Should match uppercase
|
||||
match2 := router.Match("/scripts/app.JS")
|
||||
if match2 == nil {
|
||||
t.Error("Should match uppercase .JS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_MatchCaseSensitiveRegex(t *testing.T) {
|
||||
router := NewRouter()
|
||||
config := map[string]interface{}{"root": "./static"}
|
||||
router.AddRoute("~\\.(css)$", config)
|
||||
|
||||
// Should match lowercase
|
||||
match1 := router.Match("/styles/main.css")
|
||||
if match1 == nil {
|
||||
t.Error("Should match lowercase .css")
|
||||
}
|
||||
|
||||
// Should NOT match uppercase
|
||||
match2 := router.Match("/styles/main.CSS")
|
||||
if match2 != nil {
|
||||
t.Error("Should not match uppercase .CSS for case-sensitive regex")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_MatchDefaultRoute(t *testing.T) {
|
||||
router := NewRouter()
|
||||
router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"})
|
||||
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
|
||||
|
||||
match := router.Match("/unknown/path")
|
||||
|
||||
if match == nil {
|
||||
t.Fatal("Expected default route match")
|
||||
}
|
||||
|
||||
if match.Config["spa_fallback"] != true {
|
||||
t.Error("Expected spa_fallback config from default route")
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Priority Tests ==============
|
||||
|
||||
func TestRouter_PriorityExactOverRegex(t *testing.T) {
|
||||
router := NewRouter()
|
||||
router.AddRoute("=/api/status", map[string]interface{}{"return": "200 Exact"})
|
||||
router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
||||
|
||||
match := router.Match("/api/status")
|
||||
|
||||
if match == nil {
|
||||
t.Fatal("Expected match")
|
||||
}
|
||||
|
||||
if match.Config["return"] != "200 Exact" {
|
||||
t.Error("Exact match should have priority over regex")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_PriorityRegexOverDefault(t *testing.T) {
|
||||
router := NewRouter()
|
||||
router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
||||
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
|
||||
|
||||
match := router.Match("/api/v1/users")
|
||||
|
||||
if match == nil {
|
||||
t.Fatal("Expected match")
|
||||
}
|
||||
|
||||
if match.Config["proxy_pass"] != "http://localhost:9001" {
|
||||
t.Error("Regex match should have priority over default")
|
||||
}
|
||||
}
|
||||
|
||||
// ============== CreateRouterFromConfig Tests ==============
|
||||
|
||||
func TestCreateRouterFromConfig(t *testing.T) {
|
||||
config := map[string]interface{}{
|
||||
"regex_locations": map[string]interface{}{
|
||||
"=/health": map[string]interface{}{
|
||||
"return": "200 OK",
|
||||
"content_type": "text/plain",
|
||||
},
|
||||
"~^/api/": map[string]interface{}{
|
||||
"proxy_pass": "http://localhost:9001",
|
||||
},
|
||||
"__default__": map[string]interface{}{
|
||||
"spa_fallback": true,
|
||||
"root": "./static",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router := CreateRouterFromConfig(config)
|
||||
|
||||
// Check exact route
|
||||
if _, ok := router.ExactRoutes()["/health"]; !ok {
|
||||
t.Error("Expected /health exact route")
|
||||
}
|
||||
|
||||
// Check regex route
|
||||
if len(router.Routes()) != 1 {
|
||||
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
|
||||
}
|
||||
|
||||
// Check default route
|
||||
if router.DefaultRoute() == nil {
|
||||
t.Error("Expected default route")
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Static Dir Path Tests ==============
|
||||
|
||||
func TestRouter_StaticDirPath(t *testing.T) {
|
||||
router := NewRouter(WithStaticDir("/var/www/html"))
|
||||
|
||||
expected, _ := filepath.Abs("/var/www/html")
|
||||
actual, _ := filepath.Abs(router.StaticDir())
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf("Expected static dir %s, got %s", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Concurrent Access Tests ==============
|
||||
|
||||
func TestRouter_ConcurrentAccess(t *testing.T) {
|
||||
router := NewRouter()
|
||||
|
||||
// Add routes concurrently
|
||||
done := make(chan bool, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(n int) {
|
||||
router.AddRoute("~^/api/v"+string(rune('0'+n))+"/", map[string]interface{}{
|
||||
"proxy_pass": "http://localhost:900" + string(rune('0'+n)),
|
||||
})
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Match routes concurrently
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(n int) {
|
||||
router.Match("/api/v" + string(rune('0'+n)) + "/users")
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Benchmarks ==============
|
||||
|
||||
func BenchmarkRouter_MatchExact(b *testing.B) {
|
||||
router := NewRouter()
|
||||
router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
router.Match("/health")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRouter_MatchRegex(b *testing.B) {
|
||||
router := NewRouter()
|
||||
router.AddRoute("~^/api/v(?P<version>\\d+)/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
router.Match("/api/v1/users/123")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRouter_MatchWithManyRoutes(b *testing.B) {
|
||||
router := NewRouter()
|
||||
|
||||
// Add many routes
|
||||
for i := 0; i < 50; i++ {
|
||||
router.AddRoute("~^/api/v"+string(rune('0'+i%10))+"/service"+string(rune('0'+i/10))+"/",
|
||||
map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
||||
}
|
||||
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
router.Match("/api/v5/service3/users/123")
|
||||
}
|
||||
}
|
||||
130
go/internal/server/server.go
Normal file
130
go/internal/server/server.go
Normal file
@ -0,0 +1,130 @@
|
||||
// Package server provides the HTTP server implementation
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/konduktor/konduktor/internal/config"
|
||||
"github.com/konduktor/konduktor/internal/logging"
|
||||
"github.com/konduktor/konduktor/internal/middleware"
|
||||
"github.com/konduktor/konduktor/internal/routing"
|
||||
)
|
||||
|
||||
const Version = "0.1.0"
|
||||
|
||||
// Server represents the Konduktor HTTP server
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
httpServer *http.Server
|
||||
router *routing.Router
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// New creates a new server instance
|
||||
func New(cfg *config.Config) (*Server, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
logger, err := logging.NewFromConfig(cfg.Logging)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create logger: %w", err)
|
||||
}
|
||||
|
||||
router := routing.New(cfg, logger)
|
||||
|
||||
srv := &Server{
|
||||
config: cfg,
|
||||
router: router,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// Run starts the server and blocks until shutdown
|
||||
func (s *Server) Run() error {
|
||||
// Build handler chain with middleware
|
||||
handler := s.buildHandler()
|
||||
|
||||
// Create HTTP server
|
||||
addr := fmt.Sprintf("%s:%d", s.config.Server.Host, s.config.Server.Port)
|
||||
s.httpServer = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
// Start server in goroutine
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
s.logger.Info("Server starting", "addr", addr, "version", Version)
|
||||
|
||||
var err error
|
||||
if s.config.SSL.Enabled {
|
||||
err = s.httpServer.ListenAndServeTLS(s.config.SSL.CertFile, s.config.SSL.KeyFile)
|
||||
} else {
|
||||
err = s.httpServer.ListenAndServe()
|
||||
}
|
||||
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for shutdown signal
|
||||
return s.waitForShutdown(errChan)
|
||||
}
|
||||
|
||||
// buildHandler builds the HTTP handler chain
|
||||
func (s *Server) buildHandler() http.Handler {
|
||||
var handler http.Handler = s.router
|
||||
|
||||
// Add middleware
|
||||
handler = middleware.AccessLog(handler, s.logger)
|
||||
handler = middleware.ServerHeader(handler, Version)
|
||||
handler = middleware.Recovery(handler, s.logger)
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
// waitForShutdown waits for shutdown signal and gracefully stops the server
|
||||
func (s *Server) waitForShutdown(errChan <-chan error) error {
|
||||
// Listen for shutdown signals
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return err
|
||||
case sig := <-sigChan:
|
||||
s.logger.Info("Shutdown signal received", "signal", sig.String())
|
||||
}
|
||||
|
||||
// Graceful shutdown with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
s.logger.Info("Shutting down server...")
|
||||
|
||||
if err := s.httpServer.Shutdown(ctx); err != nil {
|
||||
s.logger.Error("Error during shutdown", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("Server stopped gracefully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
return s.httpServer.Shutdown(ctx)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user