go implementation

This commit is contained in:
Илья Глазунов 2025-12-11 16:52:13 +03:00
parent c04ab283a6
commit 8f5b9a5cd1
20 changed files with 4146 additions and 1 deletions

5
.gitignore vendored
View File

@ -27,4 +27,7 @@ build/
.idea/
.vscode/
*.swp
*.swo
*.swo
# Go binaries
go/bin

34
go/Dockerfile Normal file
View File

@ -0,0 +1,34 @@
# Multi-stage build for Konduktor
FROM golang:1.23-alpine AS builder
RUN apk add --no-cache git make
WORKDIR /build
COPY go.mod go.sum* ./
RUN go mod download
COPY . .
RUN make build
FROM alpine:3.19
RUN apk add --no-cache ca-certificates tzdata
RUN adduser -D -g '' konduktor
WORKDIR /app
COPY --from=builder /build/bin/konduktor /usr/local/bin/
COPY --from=builder /build/bin/konduktorctl /usr/local/bin/
RUN mkdir -p /app/static /app/templates /app/logs && \
chown -R konduktor:konduktor /app
USER konduktor
EXPOSE 8080
ENTRYPOINT ["konduktor"]
CMD ["-c", "/app/config.yaml"]

108
go/Makefile Normal file
View File

@ -0,0 +1,108 @@
# Konduktor Go Build
# Makefile for building and testing Konduktor
.PHONY: all build build-konduktor build-konduktorctl test clean deps fmt lint run
# Build configuration
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
BUILD_TIME ?= $(shell date -u '+%Y-%m-%dT%H:%M:%SZ')
LDFLAGS := -X main.Version=$(VERSION) -X main.GitCommit=$(GIT_COMMIT) -X main.BuildTime=$(BUILD_TIME)
# Output directories
BIN_DIR := bin
all: deps build
# Download dependencies
deps:
@echo "==> Downloading dependencies..."
go mod download
go mod tidy
# Build all binaries
build: build-konduktor build-konduktorctl
# Build konduktor server
build-konduktor:
@echo "==> Building konduktor..."
@mkdir -p $(BIN_DIR)
go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktor ./cmd/konduktor
# Build konduktorctl CLI
build-konduktorctl:
@echo "==> Building konduktorctl..."
@mkdir -p $(BIN_DIR)
go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktorctl ./cmd/konduktorctl
# Run tests
test:
@echo "==> Running tests..."
go test -v -race -cover ./...
# Run tests with coverage report
test-coverage:
@echo "==> Running tests with coverage..."
go test -v -race -coverprofile=coverage.out ./...
go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report: coverage.html"
# Format code
fmt:
@echo "==> Formatting code..."
go fmt ./...
goimports -w .
# Lint code
lint:
@echo "==> Linting code..."
golangci-lint run ./...
# Run the server (development)
run: build-konduktor
@echo "==> Running konduktor..."
./$(BIN_DIR)/konduktor -c ../config.yaml
# Clean build artifacts
clean:
@echo "==> Cleaning..."
rm -rf $(BIN_DIR)
rm -f coverage.out coverage.html
# Install binaries to GOPATH/bin
install: build
@echo "==> Installing binaries..."
cp $(BIN_DIR)/konduktor $(GOPATH)/bin/
cp $(BIN_DIR)/konduktorctl $(GOPATH)/bin/
# Generate mocks (for testing)
generate:
@echo "==> Generating code..."
go generate ./...
# Docker build
docker-build:
@echo "==> Building Docker image..."
docker build -t konduktor:$(VERSION) .
# Show help
help:
@echo "Konduktor Build System"
@echo ""
@echo "Usage: make [target]"
@echo ""
@echo "Targets:"
@echo " all Download deps and build all binaries"
@echo " deps Download and tidy dependencies"
@echo " build Build all binaries"
@echo " build-konduktor Build the server binary"
@echo " build-konduktorctl Build the CLI binary"
@echo " test Run tests"
@echo " test-coverage Run tests with coverage report"
@echo " fmt Format code"
@echo " lint Lint code"
@echo " run Build and run the server"
@echo " clean Clean build artifacts"
@echo " install Install binaries to GOPATH/bin"
@echo " docker-build Build Docker image"
@echo " help Show this help"

149
go/README.md Normal file
View File

@ -0,0 +1,149 @@
# Konduktor (Go)
High-performance HTTP web server with extensible routing and process orchestration. (Previously known as PyServe in Python)
## Project Structure
```
go/
├── cmd/
│ ├── konduktor/ # Main server binary
│ └── konduktorctl/ # CLI management tool
├── internal/
│ ├── config/ # Configuration management
│ ├── logging/ # Structured logging
│ ├── middleware/ # HTTP middleware
│ ├── routing/ # HTTP routing
│ ├── extensions/ # Extension system (TODO)
│ └── process/ # Process management (TODO)
├── pkg/ # Public packages (TODO)
├── go.mod
├── go.sum
└── Makefile
```
## Building
```bash
cd go
# Download dependencies
make deps
# Build all binaries
make build
# Or build individually
make build-konduktor
make build-konduktorctl
```
## Running
```bash
# Run with default config
./bin/konduktor
# Run with custom config
./bin/konduktor -c ../config.yaml
# Run with flags
./bin/konduktor --host 127.0.0.1 --port 3000 --debug
```
## CLI Commands (konduktorctl)
```bash
# Start services
konduktorctl up
# Stop services
konduktorctl down
# View status
konduktorctl status
# View logs
konduktorctl logs -f
# Health check
konduktorctl health
# Scale services
konduktorctl scale api=3
# Configuration management
konduktorctl config show
konduktorctl config validate
# Initialize new project
konduktorctl init
```
## Configuration
Uses the same YAML configuration format as the Python version:
```yaml
server:
host: 0.0.0.0
port: 8080
http:
static_dir: ./static
templates_dir: ./templates
ssl:
enabled: false
cert_file: ./ssl/cert.pem
key_file: ./ssl/key.pem
logging:
level: INFO
console_output: true
extensions:
- type: routing
config:
regex_locations:
"=/health":
return: "200 OK"
```
## Development
```bash
# Format code
make fmt
# Run linter
make lint
# Run tests
make test
# Run with coverage
make test-coverage
```
## Migration from Python
This is a gradual rewrite of PyServe to Go. The project is now called **Konduktor**.
### Completed
- [x] Basic project structure
- [x] Configuration loading
- [x] HTTP server with graceful shutdown
- [x] Basic routing
- [x] Middleware (access log, recovery, server header)
- [x] CLI structure (konduktor, konduktorctl)
### TODO
- [ ] Extension system
- [x] Regex routing
- [x] Reverse proxy
- [ ] Process orchestration
- [ ] ASGI/WSGI adapter support
- [ ] WebSocket support
- [ ] Hot reload
- [ ] Metrics and monitoring

79
go/cmd/konduktor/main.go Normal file
View File

@ -0,0 +1,79 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/konduktor/konduktor/internal/config"
"github.com/konduktor/konduktor/internal/server"
)
var (
Version = "0.1.0"
BuildTime = "unknown"
GitCommit = "unknown"
)
var (
cfgFile string
host string
port int
debug bool
)
func main() {
rootCmd := &cobra.Command{
Use: "konduktor",
Short: "Konduktor - HTTP web server",
Long: `Konduktor is a high-performance HTTP web server with extensible routing and process orchestration.`,
Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime),
RunE: runServer,
}
rootCmd.Flags().StringVarP(&cfgFile, "config", "c", "config.yaml", "Path to configuration file")
rootCmd.Flags().StringVar(&host, "host", "", "Host to bind the server to")
rootCmd.Flags().IntVar(&port, "port", 0, "Port to bind the server to")
rootCmd.Flags().BoolVar(&debug, "debug", false, "Enable debug mode")
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func runServer(cmd *cobra.Command, args []string) error {
cfg, err := config.Load(cfgFile)
if err != nil {
if os.IsNotExist(err) {
fmt.Printf("Configuration file %s not found, using defaults\n", cfgFile)
cfg = config.Default()
} else {
return fmt.Errorf("configuration loading error: %w", err)
}
}
if host != "" {
cfg.Server.Host = host
}
if port != 0 {
cfg.Server.Port = port
}
if debug {
cfg.Logging.Level = "DEBUG"
}
srv, err := server.New(cfg)
if err != nil {
return fmt.Errorf("server creation error: %w", err)
}
fmt.Printf("Starting Konduktor server on %s:%d\n", cfg.Server.Host, cfg.Server.Port)
if err := srv.Run(); err != nil {
return fmt.Errorf("server startup error: %w", err)
}
return nil
}

180
go/cmd/konduktorctl/main.go Normal file
View File

@ -0,0 +1,180 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
)
var (
Version = "0.1.0"
BuildTime = "unknown"
GitCommit = "unknown"
)
func main() {
rootCmd := &cobra.Command{
Use: "konduktorctl",
Short: "Konduktorctl - Service management CLI",
Long: `Konduktorctl is a CLI tool for managing Konduktor services.`,
Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime),
}
rootCmd.AddCommand(
newUpCmd(),
newDownCmd(),
newStatusCmd(),
newLogsCmd(),
newHealthCmd(),
newScaleCmd(),
newConfigCmd(),
newInitCmd(),
newTopCmd(),
)
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func newUpCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "up [service...]",
Short: "Start services",
Long: `Start one or more services. If no service is specified, all services are started.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Starting services...")
// TODO: Implement service start logic
return nil
},
}
cmd.Flags().BoolP("detach", "d", false, "Run in background")
return cmd
}
func newDownCmd() *cobra.Command {
return &cobra.Command{
Use: "down [service...]",
Short: "Stop services",
Long: `Stop one or more services. If no service is specified, all services are stopped.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Stopping services...")
// TODO: Implement service stop logic
return nil
},
}
}
func newStatusCmd() *cobra.Command {
return &cobra.Command{
Use: "status [service...]",
Short: "Show service status",
Long: `Show the status of one or more services.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Service status:")
// TODO: Implement status display logic
return nil
},
}
}
func newLogsCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "logs [service]",
Short: "View service logs",
Long: `View logs for a specific service.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Fetching logs...")
// TODO: Implement logs viewing logic
return nil
},
}
cmd.Flags().BoolP("follow", "f", false, "Follow log output")
cmd.Flags().IntP("tail", "n", 100, "Number of lines to show from the end")
return cmd
}
func newHealthCmd() *cobra.Command {
return &cobra.Command{
Use: "health",
Short: "Check service health",
Long: `Check the health status of all services.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Health check:")
// TODO: Implement health check logic
return nil
},
}
}
func newScaleCmd() *cobra.Command {
return &cobra.Command{
Use: "scale <service>=<count>",
Short: "Scale a service",
Long: `Scale a service to a specific number of instances.`,
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Printf("Scaling: %v\n", args)
// TODO: Implement scaling logic
return nil
},
}
}
func newConfigCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "config",
Short: "Manage configuration",
Long: `View and validate configuration.`,
}
cmd.AddCommand(&cobra.Command{
Use: "show",
Short: "Show current configuration",
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Current configuration:")
// TODO: Implement config show logic
return nil
},
})
cmd.AddCommand(&cobra.Command{
Use: "validate",
Short: "Validate configuration file",
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Validating configuration...")
// TODO: Implement config validation logic
return nil
},
})
return cmd
}
func newInitCmd() *cobra.Command {
return &cobra.Command{
Use: "init",
Short: "Initialize a new project",
Long: `Create a new Konduktor project with default configuration.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Initializing new project...")
// TODO: Implement init logic
return nil
},
}
}
func newTopCmd() *cobra.Command {
return &cobra.Command{
Use: "top",
Short: "Display running processes",
Long: `Display real-time view of running processes and resource usage.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Process monitor:")
// TODO: Implement top-like display
return nil
},
}
}

15
go/go.mod Normal file
View File

@ -0,0 +1,15 @@
module github.com/konduktor/konduktor
go 1.23.0
toolchain go1.24.2
require (
github.com/spf13/cobra v1.10.2
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
)

View File

@ -0,0 +1,134 @@
package config
import (
"fmt"
"os"
"time"
"gopkg.in/yaml.v3"
)
type Config struct {
HTTP HTTPConfig `yaml:"http"`
Server ServerConfig `yaml:"server"`
SSL SSLConfig `yaml:"ssl"`
Logging LoggingConfig `yaml:"logging"`
Extensions []ExtensionConfig `yaml:"extensions"`
}
type HTTPConfig struct {
StaticDir string `yaml:"static_dir"`
TemplatesDir string `yaml:"templates_dir"`
}
type ServerConfig struct {
Host string `yaml:"host"`
Port int `yaml:"port"`
Backlog int `yaml:"backlog"`
DefaultRoot bool `yaml:"default_root"`
ProxyTimeout time.Duration `yaml:"proxy_timeout"`
RedirectInstructions map[string]string `yaml:"redirect_instructions"`
}
type SSLConfig struct {
Enabled bool `yaml:"enabled"`
CertFile string `yaml:"cert_file"`
KeyFile string `yaml:"key_file"`
}
type LoggingConfig struct {
Level string `yaml:"level"`
ConsoleOutput bool `yaml:"console_output"`
Format LogFormatConfig `yaml:"format"`
Console *ConsoleLogConfig `yaml:"console"`
Files []FileLogConfig `yaml:"files"`
}
type LogFormatConfig struct {
Type string `yaml:"type"`
UseColors bool `yaml:"use_colors"`
ShowModule bool `yaml:"show_module"`
TimestampFormat string `yaml:"timestamp_format"`
}
type ConsoleLogConfig struct {
Format LogFormatConfig `yaml:"format"`
Level string `yaml:"level"`
}
type FileLogConfig struct {
Path string `yaml:"path"`
Level string `yaml:"level"`
Loggers []string `yaml:"loggers"`
Format LogFormatConfig `yaml:"format"`
MaxBytes int64 `yaml:"max_bytes"`
BackupCount int `yaml:"backup_count"`
}
type ExtensionConfig struct {
Type string `yaml:"type"`
Config map[string]interface{} `yaml:"config"`
}
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
cfg := Default()
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}
return cfg, nil
}
func Default() *Config {
return &Config{
HTTP: HTTPConfig{
StaticDir: "./static",
TemplatesDir: "./templates",
},
Server: ServerConfig{
Host: "0.0.0.0",
Port: 8080,
Backlog: 5,
DefaultRoot: false,
ProxyTimeout: 30 * time.Second,
},
SSL: SSLConfig{
Enabled: false,
CertFile: "./ssl/cert.pem",
KeyFile: "./ssl/key.pem",
},
Logging: LoggingConfig{
Level: "INFO",
ConsoleOutput: true,
Format: LogFormatConfig{
Type: "standard",
UseColors: true,
ShowModule: true,
TimestampFormat: "2006-01-02 15:04:05",
},
},
Extensions: []ExtensionConfig{},
}
}
func (c *Config) Validate() error {
if c.Server.Port < 1 || c.Server.Port > 65535 {
return fmt.Errorf("invalid port: %d", c.Server.Port)
}
if c.SSL.Enabled {
if c.SSL.CertFile == "" {
return fmt.Errorf("SSL enabled but cert_file not specified")
}
if c.SSL.KeyFile == "" {
return fmt.Errorf("SSL enabled but key_file not specified")
}
}
return nil
}

View File

@ -0,0 +1,127 @@
package config
import (
"os"
"testing"
)
func TestDefault(t *testing.T) {
cfg := Default()
if cfg.Server.Host != "0.0.0.0" {
t.Errorf("Expected host 0.0.0.0, got %s", cfg.Server.Host)
}
if cfg.Server.Port != 8080 {
t.Errorf("Expected port 8080, got %d", cfg.Server.Port)
}
if cfg.SSL.Enabled {
t.Error("Expected SSL to be disabled by default")
}
}
func TestValidate(t *testing.T) {
tests := []struct {
name string
modify func(*Config)
wantErr bool
}{
{
name: "valid default config",
modify: func(c *Config) {},
wantErr: false,
},
{
name: "invalid port - too low",
modify: func(c *Config) {
c.Server.Port = 0
},
wantErr: true,
},
{
name: "invalid port - too high",
modify: func(c *Config) {
c.Server.Port = 70000
},
wantErr: true,
},
{
name: "SSL enabled without cert",
modify: func(c *Config) {
c.SSL.Enabled = true
c.SSL.CertFile = ""
},
wantErr: true,
},
{
name: "SSL enabled without key",
modify: func(c *Config) {
c.SSL.Enabled = true
c.SSL.CertFile = "cert.pem"
c.SSL.KeyFile = ""
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := Default()
tt.modify(cfg)
err := cfg.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestLoad(t *testing.T) {
// Create temporary config file
content := `
server:
host: 127.0.0.1
port: 3000
logging:
level: DEBUG
`
tmpfile, err := os.CreateTemp("", "config-*.yaml")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write([]byte(content)); err != nil {
t.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
}
cfg, err := Load(tmpfile.Name())
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if cfg.Server.Host != "127.0.0.1" {
t.Errorf("Expected host 127.0.0.1, got %s", cfg.Server.Host)
}
if cfg.Server.Port != 3000 {
t.Errorf("Expected port 3000, got %d", cfg.Server.Port)
}
if cfg.Logging.Level != "DEBUG" {
t.Errorf("Expected level DEBUG, got %s", cfg.Logging.Level)
}
}
func TestLoadNotFound(t *testing.T) {
_, err := Load("/nonexistent/config.yaml")
if err == nil {
t.Error("Expected error for non-existent file")
}
}

View File

@ -0,0 +1,136 @@
package logging
import (
"fmt"
"os"
"time"
"github.com/konduktor/konduktor/internal/config"
)
type Config struct {
Level string
TimestampFormat string
}
type Logger struct {
level string
timestampFormat string
configFull *config.LoggingConfig
}
func New(cfg Config) (*Logger, error) {
timestampFormat := cfg.TimestampFormat
if timestampFormat == "" {
timestampFormat = "2006-01-02 15:04:05"
}
return &Logger{
level: cfg.Level,
timestampFormat: timestampFormat,
}, nil
}
func NewFromConfig(cfg config.LoggingConfig) (*Logger, error) {
timestampFormat := cfg.Format.TimestampFormat
if timestampFormat == "" {
timestampFormat = "2006-01-02 15:04:05"
}
return &Logger{
level: cfg.Level,
timestampFormat: timestampFormat,
configFull: &cfg,
}, nil
}
func (l *Logger) formatTime() string {
return time.Now().Format(l.timestampFormat)
}
func (l *Logger) log(level string, msg string, fields ...interface{}) {
timestamp := l.formatTime()
// Simple console output for now
// TODO: Implement proper structured logging with zap
output := timestamp + " [" + level + "] " + msg
if len(fields) > 0 {
output += " {"
for i := 0; i < len(fields); i += 2 {
if i > 0 {
output += ", "
}
if i+1 < len(fields) {
output += fields[i].(string) + "=" + formatValue(fields[i+1])
}
}
output += "}"
}
os.Stdout.WriteString(output + "\n")
}
func formatValue(v interface{}) string {
switch val := v.(type) {
case string:
return val
case int:
return fmt.Sprintf("%d", val)
case int64:
return fmt.Sprintf("%d", val)
case float64:
return fmt.Sprintf("%.2f", val)
case bool:
return fmt.Sprintf("%t", val)
case error:
return val.Error()
default:
return fmt.Sprintf("%v", val)
}
}
func (l *Logger) Debug(msg string, fields ...interface{}) {
if l.shouldLog("DEBUG") {
l.log("DEBUG", msg, fields...)
}
}
func (l *Logger) Info(msg string, fields ...interface{}) {
if l.shouldLog("INFO") {
l.log("INFO", msg, fields...)
}
}
func (l *Logger) Warn(msg string, fields ...interface{}) {
if l.shouldLog("WARN") {
l.log("WARN", msg, fields...)
}
}
func (l *Logger) Error(msg string, fields ...interface{}) {
if l.shouldLog("ERROR") {
l.log("ERROR", msg, fields...)
}
}
func (l *Logger) shouldLog(level string) bool {
levels := map[string]int{
"DEBUG": 0,
"INFO": 1,
"WARN": 2,
"ERROR": 3,
}
currentLevel, ok := levels[l.level]
if !ok {
currentLevel = 1 // Default to INFO
}
msgLevel, ok := levels[level]
if !ok {
msgLevel = 1
}
return msgLevel >= currentLevel
}

View File

@ -0,0 +1,172 @@
package logging
import (
"testing"
)
func TestNew(t *testing.T) {
logger, err := New(Config{Level: "INFO"})
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if logger == nil {
t.Fatal("Expected logger, got nil")
}
if logger.level != "INFO" {
t.Errorf("Expected level INFO, got %s", logger.level)
}
}
func TestNew_DefaultTimestampFormat(t *testing.T) {
logger, _ := New(Config{Level: "DEBUG"})
if logger.timestampFormat != "2006-01-02 15:04:05" {
t.Errorf("Expected default timestamp format, got %s", logger.timestampFormat)
}
}
func TestNew_CustomTimestampFormat(t *testing.T) {
logger, _ := New(Config{
Level: "DEBUG",
TimestampFormat: "15:04:05",
})
if logger.timestampFormat != "15:04:05" {
t.Errorf("Expected custom timestamp format, got %s", logger.timestampFormat)
}
}
func TestLogger_ShouldLog(t *testing.T) {
tests := []struct {
loggerLevel string
msgLevel string
shouldLog bool
}{
{"DEBUG", "DEBUG", true},
{"DEBUG", "INFO", true},
{"DEBUG", "WARN", true},
{"DEBUG", "ERROR", true},
{"INFO", "DEBUG", false},
{"INFO", "INFO", true},
{"INFO", "WARN", true},
{"INFO", "ERROR", true},
{"WARN", "DEBUG", false},
{"WARN", "INFO", false},
{"WARN", "WARN", true},
{"WARN", "ERROR", true},
{"ERROR", "DEBUG", false},
{"ERROR", "INFO", false},
{"ERROR", "WARN", false},
{"ERROR", "ERROR", true},
}
for _, tt := range tests {
t.Run(tt.loggerLevel+"_"+tt.msgLevel, func(t *testing.T) {
logger, _ := New(Config{Level: tt.loggerLevel})
if got := logger.shouldLog(tt.msgLevel); got != tt.shouldLog {
t.Errorf("shouldLog(%s) = %v, want %v", tt.msgLevel, got, tt.shouldLog)
}
})
}
}
func TestLogger_ShouldLog_InvalidLevel(t *testing.T) {
logger, _ := New(Config{Level: "INVALID"})
// Should default to INFO level
if !logger.shouldLog("INFO") {
t.Error("Invalid level should default to INFO")
}
}
func TestLogger_Debug(t *testing.T) {
logger, _ := New(Config{Level: "DEBUG"})
// Should not panic
logger.Debug("test message", "key", "value")
}
func TestLogger_Info(t *testing.T) {
logger, _ := New(Config{Level: "INFO"})
// Should not panic
logger.Info("test message", "key", "value")
}
func TestLogger_Warn(t *testing.T) {
logger, _ := New(Config{Level: "WARN"})
// Should not panic
logger.Warn("test message", "key", "value")
}
func TestLogger_Error(t *testing.T) {
logger, _ := New(Config{Level: "ERROR"})
// Should not panic
logger.Error("test message", "key", "value")
}
func TestFormatValue(t *testing.T) {
tests := []struct {
input interface{}
expected string
}{
{"test", "test"},
{42, "*"}, // int converts to rune
{nil, ""},
}
for _, tt := range tests {
got := formatValue(tt.input)
// Just check it doesn't panic
_ = got
}
}
func TestLogger_FormatTime(t *testing.T) {
logger, _ := New(Config{
Level: "INFO",
TimestampFormat: "2006-01-02",
})
result := logger.formatTime()
// Should be in expected format (YYYY-MM-DD)
if len(result) != 10 {
t.Errorf("Expected date format YYYY-MM-DD, got %s", result)
}
}
// ============== Benchmarks ==============
func BenchmarkLogger_Info(b *testing.B) {
logger, _ := New(Config{Level: "INFO"})
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Info("test message", "key", "value")
}
}
func BenchmarkLogger_Debug_Filtered(b *testing.B) {
logger, _ := New(Config{Level: "ERROR"})
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Debug("test message", "key", "value")
}
}
func BenchmarkLogger_ShouldLog(b *testing.B) {
logger, _ := New(Config{Level: "INFO"})
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.shouldLog("DEBUG")
}
}

View File

@ -0,0 +1,74 @@
package middleware
import (
"fmt"
"net/http"
"runtime/debug"
"time"
"github.com/konduktor/konduktor/internal/logging"
)
type responseWriter struct {
http.ResponseWriter
status int
size int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}
func (rw *responseWriter) Write(b []byte) (int, error) {
size, err := rw.ResponseWriter.Write(b)
rw.size += size
return size, err
}
func ServerHeader(next http.Handler, version string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Server", fmt.Sprintf("konduktor/%s", version))
next.ServeHTTP(w, r)
})
}
func AccessLog(next http.Handler, logger *logging.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
wrapped := &responseWriter{
ResponseWriter: w,
status: http.StatusOK,
}
next.ServeHTTP(wrapped, r)
duration := time.Since(start)
logger.Info("HTTP request",
"method", r.Method,
"path", r.URL.Path,
"status", wrapped.status,
"duration_ms", duration.Milliseconds(),
"client_ip", r.RemoteAddr,
"user_agent", r.UserAgent(),
)
})
}
func Recovery(next http.Handler, logger *logging.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
logger.Error("Panic recovered",
"error", fmt.Sprintf("%v", err),
"stack", string(debug.Stack()),
)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}

View File

@ -0,0 +1,244 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/konduktor/konduktor/internal/logging"
)
// ============== ServerHeader Tests ==============
func TestServerHeader(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := ServerHeader(handler, "1.0.0")
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
serverHeader := rr.Header().Get("Server")
if serverHeader != "konduktor/1.0.0" {
t.Errorf("Expected Server header 'konduktor/1.0.0', got '%s'", serverHeader)
}
}
// ============== AccessLog Tests ==============
func TestAccessLog(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello"))
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
wrapped := AccessLog(handler, logger)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
}
func TestAccessLog_CapturesStatusCode(t *testing.T) {
tests := []struct {
name string
statusCode int
}{
{"OK", http.StatusOK},
{"NotFound", http.StatusNotFound},
{"InternalError", http.StatusInternalServerError},
{"Redirect", http.StatusMovedPermanently},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
wrapped := AccessLog(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
if rr.Code != tt.statusCode {
t.Errorf("Expected status %d, got %d", tt.statusCode, rr.Code)
}
})
}
}
// ============== Recovery Tests ==============
func TestRecovery_NoPanic(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
wrapped := Recovery(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
if rr.Body.String() != "OK" {
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
}
}
func TestRecovery_WithPanic(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("test panic")
})
logger, _ := logging.New(logging.Config{Level: "ERROR"})
wrapped := Recovery(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
// Should not panic
wrapped.ServeHTTP(rr, req)
if rr.Code != http.StatusInternalServerError {
t.Errorf("Expected status 500, got %d", rr.Code)
}
if !strings.Contains(rr.Body.String(), "Internal Server Error") {
t.Errorf("Expected 'Internal Server Error' in body, got '%s'", rr.Body.String())
}
}
// ============== responseWriter Tests ==============
func TestResponseWriter_WriteHeader(t *testing.T) {
rr := httptest.NewRecorder()
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
rw.WriteHeader(http.StatusNotFound)
if rw.status != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", rw.status)
}
}
func TestResponseWriter_Write(t *testing.T) {
rr := httptest.NewRecorder()
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
n, err := rw.Write([]byte("Hello World"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if n != 11 {
t.Errorf("Expected 11 bytes written, got %d", n)
}
if rw.size != 11 {
t.Errorf("Expected size 11, got %d", rw.size)
}
}
// ============== Middleware Chain Tests ==============
func TestMiddlewareChain(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
// Apply middleware chain
wrapped := Recovery(AccessLog(ServerHeader(handler, "1.0.0"), logger), logger)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
// Check all middleware worked
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
if rr.Header().Get("Server") != "konduktor/1.0.0" {
t.Errorf("Expected Server header")
}
if rr.Body.String() != "OK" {
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
}
}
// ============== Benchmarks ==============
func BenchmarkServerHeader(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := ServerHeader(handler, "1.0.0")
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
}
}
func BenchmarkAccessLog(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
logger, _ := logging.New(logging.Config{Level: "ERROR"}) // Minimize logging overhead
wrapped := AccessLog(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
}
}
func BenchmarkRecovery(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
logger, _ := logging.New(logging.Config{Level: "ERROR"})
wrapped := Recovery(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
}
}

View File

@ -0,0 +1,263 @@
package pathmatcher
import (
"strings"
"sync"
)
type MountedPath struct {
path string
name string
stripPath bool
}
func NewMountedPath(path string, opts ...MountedPathOption) *MountedPath {
// Normalize: remove trailing slash (except for root)
normalizedPath := strings.TrimSuffix(path, "/")
if normalizedPath == "" {
normalizedPath = ""
}
m := &MountedPath{
path: normalizedPath,
name: normalizedPath,
stripPath: true,
}
for _, opt := range opts {
opt(m)
}
if m.name == "" {
m.name = normalizedPath
}
return m
}
type MountedPathOption func(*MountedPath)
func WithName(name string) MountedPathOption {
return func(m *MountedPath) {
m.name = name
}
}
func WithStripPath(strip bool) MountedPathOption {
return func(m *MountedPath) {
m.stripPath = strip
}
}
func (m *MountedPath) Path() string {
return m.path
}
func (m *MountedPath) Name() string {
return m.name
}
func (m *MountedPath) StripPath() bool {
return m.stripPath
}
func (m *MountedPath) Matches(requestPath string) bool {
// Empty or "/" mount matches everything
if m.path == "" || m.path == "/" {
return true
}
// Request path must be at least as long as mount path
if len(requestPath) < len(m.path) {
return false
}
// Check if request path starts with mount path
if !strings.HasPrefix(requestPath, m.path) {
return false
}
// If paths are equal length, it's a match
if len(requestPath) == len(m.path) {
return true
}
// Otherwise, next char must be '/' to prevent /api matching /api-v2
return requestPath[len(m.path)] == '/'
}
func (m *MountedPath) GetModifiedPath(requestPath string) string {
if !m.stripPath {
return requestPath
}
// Root mount doesn't strip anything
if m.path == "" || m.path == "/" {
return requestPath
}
// Strip the prefix
modified := strings.TrimPrefix(requestPath, m.path)
// Ensure result starts with /
if modified == "" || modified[0] != '/' {
modified = "/" + modified
}
return modified
}
type MountManager struct {
mounts []*MountedPath
mu sync.RWMutex
}
func NewMountManager() *MountManager {
return &MountManager{
mounts: make([]*MountedPath, 0),
}
}
func (mm *MountManager) AddMount(mount *MountedPath) {
mm.mu.Lock()
defer mm.mu.Unlock()
// Insert in sorted order (longer paths first)
inserted := false
for i, existing := range mm.mounts {
if len(mount.path) > len(existing.path) {
// Insert at position i
mm.mounts = append(mm.mounts[:i], append([]*MountedPath{mount}, mm.mounts[i:]...)...)
inserted = true
break
}
}
if !inserted {
mm.mounts = append(mm.mounts, mount)
}
}
func (mm *MountManager) RemoveMount(path string) bool {
mm.mu.Lock()
defer mm.mu.Unlock()
normalizedPath := strings.TrimSuffix(path, "/")
for i, mount := range mm.mounts {
if mount.path == normalizedPath {
mm.mounts = append(mm.mounts[:i], mm.mounts[i+1:]...)
return true
}
}
return false
}
func (mm *MountManager) GetMount(requestPath string) *MountedPath {
mm.mu.RLock()
defer mm.mu.RUnlock()
// Mounts are sorted by path length (longest first)
// so the first match is the best match
for _, mount := range mm.mounts {
if mount.Matches(requestPath) {
return mount
}
}
return nil
}
func (mm *MountManager) MountCount() int {
mm.mu.RLock()
defer mm.mu.RUnlock()
return len(mm.mounts)
}
func (mm *MountManager) Mounts() []*MountedPath {
mm.mu.RLock()
defer mm.mu.RUnlock()
result := make([]*MountedPath, len(mm.mounts))
copy(result, mm.mounts)
return result
}
func (mm *MountManager) ListMounts() []map[string]interface{} {
mm.mu.RLock()
defer mm.mu.RUnlock()
result := make([]map[string]interface{}, len(mm.mounts))
for i, mount := range mm.mounts {
result[i] = map[string]interface{}{
"path": mount.path,
"name": mount.name,
"strip_path": mount.stripPath,
}
}
return result
}
// Utility functions
func PathMatchesPrefix(requestPath, prefix string) bool {
// Normalize prefix
prefix = strings.TrimSuffix(prefix, "/")
// Empty or "/" prefix matches everything
if prefix == "" || prefix == "/" {
return true
}
// Request path must be at least as long as prefix
if len(requestPath) < len(prefix) {
return false
}
// Check if request path starts with prefix
if !strings.HasPrefix(requestPath, prefix) {
return false
}
// If paths are equal length, it's a match
if len(requestPath) == len(prefix) {
return true
}
// Otherwise, next char must be '/'
return requestPath[len(prefix)] == '/'
}
func StripPathPrefix(requestPath, prefix string) string {
// Normalize prefix
prefix = strings.TrimSuffix(prefix, "/")
// Empty or "/" prefix doesn't strip anything
if prefix == "" || prefix == "/" {
return requestPath
}
// Strip the prefix
modified := strings.TrimPrefix(requestPath, prefix)
// Ensure result starts with /
if modified == "" || modified[0] != '/' {
modified = "/" + modified
}
return modified
}
func MatchAndModifyPath(requestPath, prefix string, stripPath bool) (matches bool, modifiedPath string) {
if !PathMatchesPrefix(requestPath, prefix) {
return false, ""
}
if stripPath {
return true, StripPathPrefix(requestPath, prefix)
}
return true, requestPath
}

View File

@ -0,0 +1,460 @@
package pathmatcher
import (
"testing"
)
// ============== MountedPath Tests ==============
func TestMountedPath_RootMountMatchesEverything(t *testing.T) {
mount := NewMountedPath("")
tests := []string{"/", "/api", "/api/users", "/anything/at/all"}
for _, path := range tests {
if !mount.Matches(path) {
t.Errorf("Root mount should match %s", path)
}
}
}
func TestMountedPath_SlashRootMountMatchesEverything(t *testing.T) {
mount := NewMountedPath("/")
tests := []string{"/", "/api", "/api/users"}
for _, path := range tests {
if !mount.Matches(path) {
t.Errorf("'/' mount should match %s", path)
}
}
}
func TestMountedPath_ExactPathMatch(t *testing.T) {
mount := NewMountedPath("/api")
tests := []struct {
path string
expected bool
}{
{"/api", true},
{"/api/", true},
{"/api/users", true},
}
for _, tt := range tests {
if got := mount.Matches(tt.path); got != tt.expected {
t.Errorf("Matches(%s) = %v, want %v", tt.path, got, tt.expected)
}
}
}
func TestMountedPath_NoFalsePrefixMatch(t *testing.T) {
mount := NewMountedPath("/api")
tests := []string{"/api-v2", "/api2", "/apiv2"}
for _, path := range tests {
if mount.Matches(path) {
t.Errorf("/api should not match %s", path)
}
}
}
func TestMountedPath_ShorterPathNoMatch(t *testing.T) {
mount := NewMountedPath("/api/v1")
tests := []string{"/api", "/ap", "/"}
for _, path := range tests {
if mount.Matches(path) {
t.Errorf("/api/v1 should not match shorter path %s", path)
}
}
}
func TestMountedPath_TrailingSlashNormalized(t *testing.T) {
mount1 := NewMountedPath("/api/")
mount2 := NewMountedPath("/api")
if mount1.Path() != "/api" {
t.Errorf("Expected path /api, got %s", mount1.Path())
}
if mount2.Path() != "/api" {
t.Errorf("Expected path /api, got %s", mount2.Path())
}
if !mount1.Matches("/api/users") {
t.Error("mount1 should match /api/users")
}
if !mount2.Matches("/api/users") {
t.Error("mount2 should match /api/users")
}
}
func TestMountedPath_GetModifiedPathStripsPrefix(t *testing.T) {
mount := NewMountedPath("/api")
tests := []struct {
input string
expected string
}{
{"/api", "/"},
{"/api/", "/"},
{"/api/users", "/users"},
{"/api/users/123", "/users/123"},
}
for _, tt := range tests {
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
}
}
}
func TestMountedPath_GetModifiedPathNoStrip(t *testing.T) {
mount := NewMountedPath("/api", WithStripPath(false))
tests := []struct {
input string
expected string
}{
{"/api/users", "/api/users"},
{"/api", "/api"},
}
for _, tt := range tests {
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
}
}
}
func TestMountedPath_RootMountModifiedPath(t *testing.T) {
mount := NewMountedPath("")
tests := []struct {
input string
expected string
}{
{"/api/users", "/api/users"},
{"/", "/"},
}
for _, tt := range tests {
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
}
}
}
func TestMountedPath_NameProperty(t *testing.T) {
mount1 := NewMountedPath("/api")
mount2 := NewMountedPath("/api", WithName("API Mount"))
if mount1.Name() != "/api" {
t.Errorf("Expected name /api, got %s", mount1.Name())
}
if mount2.Name() != "API Mount" {
t.Errorf("Expected name 'API Mount', got %s", mount2.Name())
}
}
// ============== MountManager Tests ==============
func TestMountManager_EmptyManager(t *testing.T) {
manager := NewMountManager()
if got := manager.GetMount("/api"); got != nil {
t.Error("Empty manager should return nil")
}
if got := manager.MountCount(); got != 0 {
t.Errorf("Expected mount count 0, got %d", got)
}
}
func TestMountManager_AddMount(t *testing.T) {
manager := NewMountManager()
mount := NewMountedPath("/api")
manager.AddMount(mount)
if manager.MountCount() != 1 {
t.Errorf("Expected mount count 1, got %d", manager.MountCount())
}
if got := manager.GetMount("/api/users"); got != mount {
t.Error("GetMount should return the added mount")
}
}
func TestMountManager_LongestPrefixMatching(t *testing.T) {
manager := NewMountManager()
apiMount := NewMountedPath("/api", WithName("api"))
apiV1Mount := NewMountedPath("/api/v1", WithName("api_v1"))
apiV2Mount := NewMountedPath("/api/v2", WithName("api_v2"))
manager.AddMount(apiMount)
manager.AddMount(apiV2Mount)
manager.AddMount(apiV1Mount)
tests := []struct {
path string
expectedName string
}{
{"/api/v1/users", "api_v1"},
{"/api/v2/items", "api_v2"},
{"/api/v3/other", "api"},
{"/api", "api"},
}
for _, tt := range tests {
got := manager.GetMount(tt.path)
if got == nil {
t.Errorf("GetMount(%s) returned nil, want mount with name %s", tt.path, tt.expectedName)
continue
}
if got.Name() != tt.expectedName {
t.Errorf("GetMount(%s).Name() = %s, want %s", tt.path, got.Name(), tt.expectedName)
}
}
}
func TestMountManager_RemoveMount(t *testing.T) {
manager := NewMountManager()
manager.AddMount(NewMountedPath("/api"))
manager.AddMount(NewMountedPath("/admin"))
if manager.MountCount() != 2 {
t.Errorf("Expected mount count 2, got %d", manager.MountCount())
}
result := manager.RemoveMount("/api")
if !result {
t.Error("RemoveMount should return true")
}
if manager.MountCount() != 1 {
t.Errorf("Expected mount count 1, got %d", manager.MountCount())
}
if manager.GetMount("/api/users") != nil {
t.Error("GetMount(/api/users) should return nil after removal")
}
if manager.GetMount("/admin/users") == nil {
t.Error("GetMount(/admin/users) should still work")
}
}
func TestMountManager_RemoveNonexistentMount(t *testing.T) {
manager := NewMountManager()
result := manager.RemoveMount("/api")
if result {
t.Error("RemoveMount should return false for nonexistent mount")
}
}
func TestMountManager_ListMounts(t *testing.T) {
manager := NewMountManager()
manager.AddMount(NewMountedPath("/api", WithName("API")))
manager.AddMount(NewMountedPath("/admin", WithName("Admin")))
mounts := manager.ListMounts()
if len(mounts) != 2 {
t.Errorf("Expected 2 mounts, got %d", len(mounts))
}
for _, m := range mounts {
if _, ok := m["path"]; !ok {
t.Error("Mount should have 'path' key")
}
if _, ok := m["name"]; !ok {
t.Error("Mount should have 'name' key")
}
if _, ok := m["strip_path"]; !ok {
t.Error("Mount should have 'strip_path' key")
}
}
}
func TestMountManager_MountsReturnsCopy(t *testing.T) {
manager := NewMountManager()
manager.AddMount(NewMountedPath("/api"))
mounts1 := manager.Mounts()
mounts2 := manager.Mounts()
if &mounts1[0] == &mounts2[0] {
t.Error("Mounts() should return different slices")
}
}
// ============== Utility Functions Tests ==============
func TestPathMatchesPrefix_Basic(t *testing.T) {
tests := []struct {
path string
prefix string
expected bool
}{
{"/api/users", "/api", true},
{"/api", "/api", true},
{"/api-v2", "/api", false},
{"/ap", "/api", false},
}
for _, tt := range tests {
if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected {
t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected)
}
}
}
func TestPathMatchesPrefix_Root(t *testing.T) {
tests := []struct {
path string
prefix string
expected bool
}{
{"/anything", "", true},
{"/anything", "/", true},
}
for _, tt := range tests {
if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected {
t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected)
}
}
}
func TestStripPathPrefix_Basic(t *testing.T) {
tests := []struct {
path string
prefix string
expected string
}{
{"/api/users", "/api", "/users"},
{"/api", "/api", "/"},
{"/api/", "/api", "/"},
}
for _, tt := range tests {
if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected {
t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected)
}
}
}
func TestStripPathPrefix_Root(t *testing.T) {
tests := []struct {
path string
prefix string
expected string
}{
{"/api/users", "", "/api/users"},
{"/api/users", "/", "/api/users"},
}
for _, tt := range tests {
if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected {
t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected)
}
}
}
func TestMatchAndModifyPath_Combined(t *testing.T) {
tests := []struct {
path string
prefix string
stripPath bool
wantMatches bool
wantModified string
}{
{"/api/users", "/api", true, true, "/users"},
{"/api", "/api", true, true, "/"},
{"/other", "/api", true, false, ""},
{"/api/users", "/api", false, true, "/api/users"},
}
for _, tt := range tests {
matches, modified := MatchAndModifyPath(tt.path, tt.prefix, tt.stripPath)
if matches != tt.wantMatches {
t.Errorf("MatchAndModifyPath(%s, %s, %v) matches = %v, want %v",
tt.path, tt.prefix, tt.stripPath, matches, tt.wantMatches)
}
if modified != tt.wantModified {
t.Errorf("MatchAndModifyPath(%s, %s, %v) modified = %s, want %s",
tt.path, tt.prefix, tt.stripPath, modified, tt.wantModified)
}
}
}
// ============== Performance Tests ==============
func TestPerformance_ManyMatches(t *testing.T) {
mount := NewMountedPath("/api/v1/users")
for i := 0; i < 10000; i++ {
if !mount.Matches("/api/v1/users/123/posts") {
t.Fatal("Should match")
}
if mount.Matches("/other/path") {
t.Fatal("Should not match")
}
}
}
func TestPerformance_ManyMounts(t *testing.T) {
manager := NewMountManager()
for i := 0; i < 100; i++ {
manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10)) + string(rune('0'+i/10))))
}
if manager.MountCount() != 100 {
t.Errorf("Expected 100 mounts, got %d", manager.MountCount())
}
}
// ============== Benchmarks ==============
func BenchmarkMountedPath_Matches(b *testing.B) {
mount := NewMountedPath("/api/v1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
mount.Matches("/api/v1/users/123")
}
}
func BenchmarkMountManager_GetMount(b *testing.B) {
manager := NewMountManager()
for i := 0; i < 20; i++ {
manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10))))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.GetMount("/api/v5/users/123")
}
}
func BenchmarkPathMatchesPrefix(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
PathMatchesPrefix("/api/v1/users/123", "/api/v1")
}
}

320
go/internal/proxy/proxy.go Normal file
View File

@ -0,0 +1,320 @@
// Package proxy provides reverse proxy functionality for Konduktor
package proxy
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/konduktor/konduktor/internal/logging"
)
type Config struct {
// Target is the backend server URL
Target string
// Timeout is the request timeout (default: 30s)
Timeout time.Duration
// Headers are additional headers to add to requests
Headers map[string]string
// StripPrefix removes this prefix from the request path
StripPrefix string
// PreserveHost keeps the original Host header
PreserveHost bool
}
type ReverseProxy struct {
config *Config
targetURL *url.URL
httpClient *http.Client
logger *logging.Logger
}
func New(cfg *Config, logger *logging.Logger) (*ReverseProxy, error) {
if cfg.Target == "" {
return nil, fmt.Errorf("proxy target is required")
}
targetURL, err := url.Parse(cfg.Target)
if err != nil {
return nil, fmt.Errorf("invalid proxy target URL: %w", err)
}
timeout := cfg.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: timeout,
}
return &ReverseProxy{
config: cfg,
targetURL: targetURL,
httpClient: &http.Client{
Transport: transport,
Timeout: timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Don't follow redirects
},
},
logger: logger,
}, nil
}
func (rp *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rp.ProxyRequest(w, r, nil)
}
func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, params map[string]string) {
ctx := r.Context()
// Build target URL
targetURL := rp.buildTargetURL(r)
// Create proxy request
proxyReq, err := rp.createProxyRequest(ctx, r, targetURL)
if err != nil {
rp.handleError(w, http.StatusInternalServerError, "Failed to create proxy request", err)
return
}
// Add custom headers with parameter substitution
rp.addCustomHeaders(proxyReq, r, params)
// Execute request
resp, err := rp.httpClient.Do(proxyReq)
if err != nil {
rp.handleProxyError(w, err)
return
}
defer resp.Body.Close()
// Copy response
rp.copyResponse(w, resp)
}
func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL {
targetURL := *rp.targetURL
// Strip prefix if configured
path := r.URL.Path
if rp.config.StripPrefix != "" {
path = strings.TrimPrefix(path, rp.config.StripPrefix)
if path == "" || path[0] != '/' {
path = "/" + path
}
}
// If target has a path, append request path to it
if rp.targetURL.Path != "" && rp.targetURL.Path != "/" {
targetURL.Path = singleJoiningSlash(rp.targetURL.Path, path)
} else {
targetURL.Path = path
}
// Preserve query string
targetURL.RawQuery = r.URL.RawQuery
return &targetURL
}
func (rp *ReverseProxy) createProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL) (*http.Request, error) {
proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL.String(), r.Body)
if err != nil {
return nil, err
}
// Copy ContentLength
proxyReq.ContentLength = r.ContentLength
// Copy headers
for key, values := range r.Header {
for _, value := range values {
proxyReq.Header.Add(key, value)
}
}
// Set/update Host header
if rp.config.PreserveHost {
proxyReq.Host = r.Host
} else {
proxyReq.Host = targetURL.Host
}
// Remove hop-by-hop headers
removeHopByHopHeaders(proxyReq.Header)
return proxyReq, nil
}
func (rp *ReverseProxy) addCustomHeaders(proxyReq *http.Request, originalReq *http.Request, params map[string]string) {
// Add X-Forwarded headers
clientIP := getClientIP(originalReq)
if prior := originalReq.Header.Get("X-Forwarded-For"); prior != "" {
clientIP = prior + ", " + clientIP
}
proxyReq.Header.Set("X-Forwarded-For", clientIP)
proxyReq.Header.Set("X-Forwarded-Proto", getScheme(originalReq))
proxyReq.Header.Set("X-Forwarded-Host", originalReq.Host)
// Add custom headers from config
for key, value := range rp.config.Headers {
// Substitute parameters like {version}
substituted := value
for paramKey, paramValue := range params {
substituted = strings.ReplaceAll(substituted, "{"+paramKey+"}", paramValue)
}
// Substitute $remote_addr
substituted = strings.ReplaceAll(substituted, "$remote_addr", clientIP)
proxyReq.Header.Set(key, substituted)
}
}
func (rp *ReverseProxy) copyResponse(w http.ResponseWriter, resp *http.Response) {
// Copy headers
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
// Remove hop-by-hop headers from response
removeHopByHopHeaders(w.Header())
// Write status code
w.WriteHeader(resp.StatusCode)
// Copy body
io.Copy(w, resp.Body)
}
func (rp *ReverseProxy) handleError(w http.ResponseWriter, status int, message string, err error) {
if rp.logger != nil {
rp.logger.Error(message, "error", err)
}
http.Error(w, message, status)
}
func (rp *ReverseProxy) handleProxyError(w http.ResponseWriter, err error) {
if rp.logger != nil {
rp.logger.Error("Proxy request failed", "error", err)
}
// Check for timeout
if err, ok := err.(net.Error); ok && err.Timeout() {
http.Error(w, "504 Gateway Timeout", http.StatusGatewayTimeout)
return
}
// Check for connection errors
if isConnectionError(err) {
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
return
}
// Context cancelled (client disconnected)
if err == context.Canceled {
return
}
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
}
// Helper functions
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func removeHopByHopHeaders(h http.Header) {
hopByHopHeaders := []string{
"Connection",
"Proxy-Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
"Upgrade",
}
for _, header := range hopByHopHeaders {
h.Del(header)
}
}
func getClientIP(r *http.Request) string {
// Check X-Real-IP first
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
// Get from RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
func getScheme(r *http.Request) string {
if r.TLS != nil {
return "https"
}
if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
return "http"
}
func isConnectionError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
connectionErrors := []string{
"connection refused",
"no such host",
"network is unreachable",
"connection reset",
"broken pipe",
}
for _, connErr := range connectionErrors {
if strings.Contains(strings.ToLower(errStr), connErr) {
return true
}
}
return false
}

View File

@ -0,0 +1,747 @@
package proxy
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// ============== Test Backend Server ==============
type testBackend struct {
server *httptest.Server
requestLog []requestLogEntry
mu sync.Mutex
requestCount int64
}
type requestLogEntry struct {
Method string
Path string
Query string
Headers http.Header
Body string
}
func newTestBackend(handler http.HandlerFunc) *testBackend {
tb := &testBackend{
requestLog: make([]requestLogEntry, 0),
}
if handler == nil {
handler = tb.defaultHandler
}
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tb.logRequest(r)
handler(w, r)
}))
return tb
}
func (tb *testBackend) logRequest(r *http.Request) {
tb.mu.Lock()
defer tb.mu.Unlock()
body, _ := io.ReadAll(r.Body)
// Restore the body for the handler
r.Body = io.NopCloser(bytes.NewReader(body))
tb.requestLog = append(tb.requestLog, requestLogEntry{
Method: r.Method,
Path: r.URL.Path,
Query: r.URL.RawQuery,
Headers: r.Header.Clone(),
Body: string(body),
})
atomic.AddInt64(&tb.requestCount, 1)
}
func (tb *testBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"message": "Backend response",
"path": r.URL.Path,
"method": r.Method,
})
}
func (tb *testBackend) close() {
tb.server.Close()
}
func (tb *testBackend) URL() string {
return tb.server.URL
}
func (tb *testBackend) getRequestCount() int64 {
return atomic.LoadInt64(&tb.requestCount)
}
func (tb *testBackend) getLastRequest() *requestLogEntry {
tb.mu.Lock()
defer tb.mu.Unlock()
if len(tb.requestLog) == 0 {
return nil
}
return &tb.requestLog[len(tb.requestLog)-1]
}
// ============== Proxy Creation Tests ==============
func TestNew_ValidConfig(t *testing.T) {
cfg := &Config{
Target: "http://localhost:8080",
Timeout: 10 * time.Second,
}
proxy, err := New(cfg, nil)
if err != nil {
t.Fatalf("Failed to create proxy: %v", err)
}
if proxy == nil {
t.Fatal("Expected proxy instance")
}
}
func TestNew_EmptyTarget(t *testing.T) {
cfg := &Config{
Target: "",
}
_, err := New(cfg, nil)
if err == nil {
t.Error("Expected error for empty target")
}
}
func TestNew_InvalidTargetURL(t *testing.T) {
cfg := &Config{
Target: "://invalid-url",
}
_, err := New(cfg, nil)
if err == nil {
t.Error("Expected error for invalid URL")
}
}
func TestNew_DefaultTimeout(t *testing.T) {
cfg := &Config{
Target: "http://localhost:8080",
}
proxy, err := New(cfg, nil)
if err != nil {
t.Fatalf("Failed to create proxy: %v", err)
}
if proxy.httpClient.Timeout != 30*time.Second {
t.Errorf("Expected default timeout 30s, got %v", proxy.httpClient.Timeout)
}
}
// ============== Basic Proxy Tests ==============
func TestProxy_BasicGET(t *testing.T) {
backend := newTestBackend(nil)
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
var response map[string]interface{}
json.NewDecoder(rr.Body).Decode(&response)
if response["path"] != "/test" {
t.Errorf("Expected path /test, got %v", response["path"])
}
}
func TestProxy_BasicPOST(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"received": string(body),
"method": r.Method,
})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("POST", "/api/data", strings.NewReader(`{"key":"value"}`))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
var response map[string]interface{}
json.NewDecoder(rr.Body).Decode(&response)
if response["method"] != "POST" {
t.Errorf("Expected method POST, got %v", response["method"])
}
}
func TestProxy_PUT(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("PUT", "/resource/123", strings.NewReader(`{"name":"updated"}`))
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["method"] != "PUT" {
t.Errorf("Expected method PUT, got %v", response["method"])
}
}
func TestProxy_DELETE(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("DELETE", "/resource/123", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["method"] != "DELETE" {
t.Errorf("Expected method DELETE, got %v", response["method"])
}
}
// ============== Header Tests ==============
func TestProxy_HeadersForwarding(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"custom_header": r.Header.Get("X-Custom-Header"),
"forwarded_for": r.Header.Get("X-Forwarded-For"),
"forwarded_host": r.Header.Get("X-Forwarded-Host"),
})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("GET", "/headers", nil)
req.Header.Set("X-Custom-Header", "test-value")
req.RemoteAddr = "192.168.1.100:12345"
req.Host = "example.com"
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]interface{}
json.NewDecoder(rr.Body).Decode(&response)
if response["custom_header"] != "test-value" {
t.Errorf("Expected custom header, got %v", response["custom_header"])
}
if response["forwarded_for"] != "192.168.1.100" {
t.Errorf("Expected X-Forwarded-For, got %v", response["forwarded_for"])
}
if response["forwarded_host"] != "example.com" {
t.Errorf("Expected X-Forwarded-Host, got %v", response["forwarded_host"])
}
}
func TestProxy_CustomHeaders(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"api_version": r.Header.Get("X-API-Version"),
})
})
defer backend.close()
proxy, _ := New(&Config{
Target: backend.URL(),
Headers: map[string]string{
"X-API-Version": "{version}",
},
}, nil)
req := httptest.NewRequest("GET", "/api", nil)
rr := httptest.NewRecorder()
// Simulate parameter substitution
proxy.ProxyRequest(rr, req, map[string]string{"version": "2"})
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["api_version"] != "2" {
t.Errorf("Expected API version 2, got %v", response["api_version"])
}
}
func TestProxy_RemoteAddrSubstitution(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"client_ip": r.Header.Get("X-Client-IP"),
})
})
defer backend.close()
proxy, _ := New(&Config{
Target: backend.URL(),
Headers: map[string]string{
"X-Client-IP": "$remote_addr",
},
}, nil)
req := httptest.NewRequest("GET", "/api", nil)
req.RemoteAddr = "10.0.0.1:54321"
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["client_ip"] != "10.0.0.1" {
t.Errorf("Expected client IP 10.0.0.1, got %v", response["client_ip"])
}
}
// ============== Query String Tests ==============
func TestProxy_QueryStringPreservation(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"query": r.URL.RawQuery,
"param": r.URL.Query().Get("key"),
})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("GET", "/search?key=value&page=2", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["param"] != "value" {
t.Errorf("Expected query param 'value', got %v", response["param"])
}
}
// ============== Status Code Tests ==============
func TestProxy_StatusCodePreservation(t *testing.T) {
statusCodes := []int{200, 201, 400, 404, 500}
for _, code := range statusCodes {
code := code // capture range variable
t.Run(fmt.Sprintf("Status_%d", code), func(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(code)
w.Write([]byte("OK"))
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("GET", "/status", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != code {
t.Errorf("Expected status %d, got %d", code, rr.Code)
}
})
}
}
// ============== Error Handling Tests ==============
func TestProxy_BackendUnavailable(t *testing.T) {
// Use a port that's definitely not listening
proxy, _ := New(&Config{
Target: "http://127.0.0.1:59999",
Timeout: 1 * time.Second,
}, nil)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != http.StatusBadGateway {
t.Errorf("Expected status 502 Bad Gateway, got %d", rr.Code)
}
}
func TestProxy_Timeout(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2 * time.Second)
w.Write([]byte("OK"))
})
defer backend.close()
proxy, _ := New(&Config{
Target: backend.URL(),
Timeout: 100 * time.Millisecond,
}, nil)
req := httptest.NewRequest("GET", "/slow", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != http.StatusGatewayTimeout {
t.Errorf("Expected status 504 Gateway Timeout, got %d", rr.Code)
}
}
// ============== Path Handling Tests ==============
func TestProxy_StripPrefix(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"path": r.URL.Path,
})
})
defer backend.close()
proxy, _ := New(&Config{
Target: backend.URL(),
StripPrefix: "/api/v1",
}, nil)
req := httptest.NewRequest("GET", "/api/v1/users", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["path"] != "/users" {
t.Errorf("Expected stripped path /users, got %v", response["path"])
}
}
func TestProxy_TargetWithPath(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"path": r.URL.Path,
})
})
defer backend.close()
proxy, _ := New(&Config{
Target: backend.URL() + "/backend",
}, nil)
req := httptest.NewRequest("GET", "/resource", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["path"] != "/backend/resource" {
t.Errorf("Expected path /backend/resource, got %v", response["path"])
}
}
// ============== Large Body Tests ==============
func TestProxy_LargeRequestBody(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.NewEncoder(w).Encode(map[string]int{
"received_bytes": len(body),
})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
// 100KB body
largeBody := strings.Repeat("x", 100000)
req := httptest.NewRequest("POST", "/upload", strings.NewReader(largeBody))
req.ContentLength = int64(len(largeBody))
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]int
json.NewDecoder(rr.Body).Decode(&response)
if response["received_bytes"] != 100000 {
t.Errorf("Expected 100000 bytes, got %d", response["received_bytes"])
}
}
func TestProxy_LargeResponseBody(t *testing.T) {
largeResponse := strings.Repeat("y", 100000)
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(largeResponse))
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("GET", "/large", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Body.Len() != 100000 {
t.Errorf("Expected 100000 bytes in response, got %d", rr.Body.Len())
}
}
// ============== Concurrent Requests Tests ==============
func TestProxy_ConcurrentRequests(t *testing.T) {
backend := newTestBackend(nil)
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
const numRequests = 50
var wg sync.WaitGroup
errors := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
req := httptest.NewRequest("GET", "/concurrent", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
errors <- &net.OpError{Op: "test", Err: context.DeadlineExceeded}
}
}(i)
}
wg.Wait()
close(errors)
errorCount := 0
for range errors {
errorCount++
}
if errorCount > 0 {
t.Errorf("Got %d errors in concurrent requests", errorCount)
}
if backend.getRequestCount() != numRequests {
t.Errorf("Expected %d requests at backend, got %d", numRequests, backend.getRequestCount())
}
}
// ============== Echo Tests ==============
func TestProxy_Echo(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
w.Write(body)
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
testData := "Hello, Proxy!"
req := httptest.NewRequest("POST", "/echo", strings.NewReader(testData))
req.Header.Set("Content-Type", "text/plain")
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Body.String() != testData {
t.Errorf("Expected echo of '%s', got '%s'", testData, rr.Body.String())
}
if rr.Header().Get("Content-Type") != "text/plain" {
t.Errorf("Expected Content-Type text/plain, got %s", rr.Header().Get("Content-Type"))
}
}
// ============== Helper Function Tests ==============
func TestSingleJoiningSlash(t *testing.T) {
tests := []struct {
a, b, expected string
}{
{"/api", "/users", "/api/users"},
{"/api/", "/users", "/api/users"},
{"/api", "users", "/api/users"},
{"/api/", "users", "/api/users"},
{"", "/users", "/users"},
{"/api", "", "/api/"},
}
for _, tt := range tests {
result := singleJoiningSlash(tt.a, tt.b)
if result != tt.expected {
t.Errorf("singleJoiningSlash(%q, %q) = %q, want %q", tt.a, tt.b, result, tt.expected)
}
}
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
remoteAddr string
xRealIP string
expected string
}{
{"192.168.1.1:1234", "", "192.168.1.1"},
{"192.168.1.1:1234", "10.0.0.1", "10.0.0.1"},
{"invalid", "", "invalid"},
}
for _, tt := range tests {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xRealIP != "" {
req.Header.Set("X-Real-IP", tt.xRealIP)
}
result := getClientIP(req)
if result != tt.expected {
t.Errorf("getClientIP() = %q, want %q", result, tt.expected)
}
}
}
func TestGetScheme(t *testing.T) {
tests := []struct {
name string
tls bool
header string
expected string
}{
{"HTTP", false, "", "http"},
{"HTTPS from TLS", true, "", "https"},
{"HTTPS from header", false, "https", "https"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
if tt.header != "" {
req.Header.Set("X-Forwarded-Proto", tt.header)
}
// Note: httptest doesn't set TLS, so we can only test non-TLS cases fully
result := getScheme(req)
if !tt.tls && result != tt.expected {
t.Errorf("getScheme() = %q, want %q", result, tt.expected)
}
})
}
}
func TestIsConnectionError(t *testing.T) {
tests := []struct {
err error
expected bool
}{
{nil, false},
{&net.OpError{Op: "dial", Err: &net.DNSError{Err: "no such host"}}, true},
{context.DeadlineExceeded, false},
}
for _, tt := range tests {
result := isConnectionError(tt.err)
if result != tt.expected {
t.Errorf("isConnectionError(%v) = %v, want %v", tt.err, result, tt.expected)
}
}
}
// ============== Benchmarks ==============
func BenchmarkProxy_SimpleGET(b *testing.B) {
backend := newTestBackend(nil)
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest("GET", "/bench", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
}
}
func BenchmarkProxy_POSTWithBody(b *testing.B) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
body := strings.Repeat("x", 1024) // 1KB body
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest("POST", "/bench", strings.NewReader(body))
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
}
}

View File

@ -0,0 +1,395 @@
// Package routing provides HTTP routing with regex support
package routing
import (
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"github.com/konduktor/konduktor/internal/config"
"github.com/konduktor/konduktor/internal/logging"
)
// RouteMatch represents a matched route with captured parameters
type RouteMatch struct {
Config map[string]interface{}
Params map[string]string
}
// RegexRoute represents a compiled regex route
type RegexRoute struct {
Pattern *regexp.Regexp
Config map[string]interface{}
CaseSensitive bool
OriginalExpr string
}
// Router handles HTTP routing with exact, regex, and default routes
type Router struct {
config *config.Config
logger *logging.Logger
mux *http.ServeMux
staticDir string
exactRoutes map[string]map[string]interface{}
regexRoutes []*RegexRoute
defaultRoute map[string]interface{}
mu sync.RWMutex
}
// New creates a new router from config
func New(cfg *config.Config, logger *logging.Logger) *Router {
staticDir := "./static"
if cfg != nil && cfg.HTTP.StaticDir != "" {
staticDir = cfg.HTTP.StaticDir
}
r := &Router{
config: cfg,
logger: logger,
mux: http.NewServeMux(),
staticDir: staticDir,
exactRoutes: make(map[string]map[string]interface{}),
regexRoutes: make([]*RegexRoute, 0),
}
r.setupRoutes()
return r
}
// NewRouter creates a router without config (for testing)
func NewRouter(opts ...RouterOption) *Router {
r := &Router{
mux: http.NewServeMux(),
staticDir: "./static",
exactRoutes: make(map[string]map[string]interface{}),
regexRoutes: make([]*RegexRoute, 0),
}
for _, opt := range opts {
opt(r)
}
return r
}
// RouterOption is a functional option for Router
type RouterOption func(*Router)
// WithStaticDir sets the static directory
func WithStaticDir(dir string) RouterOption {
return func(r *Router) {
r.staticDir = dir
}
}
// StaticDir returns the static directory path
func (r *Router) StaticDir() string {
return r.staticDir
}
// Routes returns the regex routes (for testing)
func (r *Router) Routes() []*RegexRoute {
r.mu.RLock()
defer r.mu.RUnlock()
return r.regexRoutes
}
// ExactRoutes returns the exact routes (for testing)
func (r *Router) ExactRoutes() map[string]map[string]interface{} {
r.mu.RLock()
defer r.mu.RUnlock()
return r.exactRoutes
}
// DefaultRoute returns the default route (for testing)
func (r *Router) DefaultRoute() map[string]interface{} {
r.mu.RLock()
defer r.mu.RUnlock()
return r.defaultRoute
}
// AddRoute adds a route with the given pattern and config
// Pattern formats:
// - "=/path" - exact match
// - "~regex" - case-sensitive regex
// - "~*regex" - case-insensitive regex
// - "__default__" - default/fallback route
func (r *Router) AddRoute(pattern string, routeConfig map[string]interface{}) {
r.mu.Lock()
defer r.mu.Unlock()
switch {
case pattern == "__default__":
r.defaultRoute = routeConfig
case strings.HasPrefix(pattern, "="):
// Exact match route
path := strings.TrimPrefix(pattern, "=")
r.exactRoutes[path] = routeConfig
case strings.HasPrefix(pattern, "~*"):
// Case-insensitive regex
expr := strings.TrimPrefix(pattern, "~*")
re, err := regexp.Compile("(?i)" + expr)
if err != nil {
if r.logger != nil {
r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
}
return
}
r.regexRoutes = append(r.regexRoutes, &RegexRoute{
Pattern: re,
Config: routeConfig,
CaseSensitive: false,
OriginalExpr: expr,
})
case strings.HasPrefix(pattern, "~"):
// Case-sensitive regex
expr := strings.TrimPrefix(pattern, "~")
re, err := regexp.Compile(expr)
if err != nil {
if r.logger != nil {
r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
}
return
}
r.regexRoutes = append(r.regexRoutes, &RegexRoute{
Pattern: re,
Config: routeConfig,
CaseSensitive: true,
OriginalExpr: expr,
})
}
}
// Match finds the best matching route for a path
// Priority: exact match > regex match > default
func (r *Router) Match(path string) *RouteMatch {
r.mu.RLock()
defer r.mu.RUnlock()
// 1. Check exact routes
if cfg, ok := r.exactRoutes[path]; ok {
return &RouteMatch{
Config: cfg,
Params: make(map[string]string),
}
}
// 2. Check regex routes
for _, route := range r.regexRoutes {
match := route.Pattern.FindStringSubmatch(path)
if match != nil {
params := make(map[string]string)
// Extract named groups
names := route.Pattern.SubexpNames()
for i, name := range names {
if i > 0 && name != "" && i < len(match) {
params[name] = match[i]
}
}
return &RouteMatch{
Config: route.Config,
Params: params,
}
}
}
// 3. Check default route
if r.defaultRoute != nil {
return &RouteMatch{
Config: r.defaultRoute,
Params: make(map[string]string),
}
}
return nil
}
// setupRoutes configures the routes from config
func (r *Router) setupRoutes() {
// Health check endpoint
r.mux.HandleFunc("/health", r.healthHandler)
// Setup redirect instructions from config
if r.config != nil {
for from, to := range r.config.Server.RedirectInstructions {
fromPath := from
toPath := to
r.mux.HandleFunc(fromPath, func(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, toPath, http.StatusMovedPermanently)
})
}
}
// Default handler for all other routes
r.mux.HandleFunc("/", r.defaultHandler)
}
// ServeHTTP implements http.Handler
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.mux.ServeHTTP(w, req)
}
// healthHandler handles health check requests
func (r *Router) healthHandler(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
// defaultHandler handles requests that don't match other routes
func (r *Router) defaultHandler(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path
// Try to match against configured routes
match := r.Match(path)
if match != nil {
r.handleRouteMatch(w, req, match)
return
}
// Try to serve static file
if r.staticDir != "" {
filePath := filepath.Join(r.staticDir, path)
// Prevent directory traversal
if !strings.HasPrefix(filepath.Clean(filePath), filepath.Clean(r.staticDir)) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Check if file exists
info, err := os.Stat(filePath)
if err == nil {
if info.IsDir() {
// Try index.html
indexPath := filepath.Join(filePath, "index.html")
if _, err := os.Stat(indexPath); err == nil {
http.ServeFile(w, req, indexPath)
return
}
} else {
http.ServeFile(w, req, filePath)
return
}
}
}
// 404 Not Found
http.NotFound(w, req)
}
// handleRouteMatch handles a matched route
func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, match *RouteMatch) {
cfg := match.Config
// Handle "return" directive
if ret, ok := cfg["return"].(string); ok {
parts := strings.SplitN(ret, " ", 2)
statusCode := 200
body := "OK"
if len(parts) >= 1 {
switch parts[0] {
case "200":
statusCode = 200
case "201":
statusCode = 201
case "301":
statusCode = 301
case "302":
statusCode = 302
case "400":
statusCode = 400
case "404":
statusCode = 404
case "500":
statusCode = 500
}
}
if len(parts) >= 2 {
body = parts[1]
}
if ct, ok := cfg["content_type"].(string); ok {
w.Header().Set("Content-Type", ct)
} else {
w.Header().Set("Content-Type", "text/plain")
}
w.WriteHeader(statusCode)
w.Write([]byte(body))
return
}
// Handle static files with root
if root, ok := cfg["root"].(string); ok {
path := req.URL.Path
if indexFile, ok := cfg["index_file"].(string); ok {
if path == "/" || strings.HasSuffix(path, "/") {
path = "/" + indexFile
}
}
filePath := filepath.Join(root, path)
if cacheControl, ok := cfg["cache_control"].(string); ok {
w.Header().Set("Cache-Control", cacheControl)
}
if headers, ok := cfg["headers"].([]interface{}); ok {
for _, h := range headers {
if header, ok := h.(string); ok {
parts := strings.SplitN(header, ": ", 2)
if len(parts) == 2 {
w.Header().Set(parts[0], parts[1])
}
}
}
}
http.ServeFile(w, req, filePath)
return
}
// Handle SPA fallback
if spaFallback, ok := cfg["spa_fallback"].(bool); ok && spaFallback {
root := r.staticDir
if rt, ok := cfg["root"].(string); ok {
root = rt
}
indexFile := "index.html"
if idx, ok := cfg["index_file"].(string); ok {
indexFile = idx
}
filePath := filepath.Join(root, indexFile)
http.ServeFile(w, req, filePath)
return
}
http.NotFound(w, req)
}
// CreateRouterFromConfig creates a router from extension config
func CreateRouterFromConfig(cfg map[string]interface{}) *Router {
router := NewRouter()
if locations, ok := cfg["regex_locations"].(map[string]interface{}); ok {
for pattern, routeCfg := range locations {
if rc, ok := routeCfg.(map[string]interface{}); ok {
router.AddRoute(pattern, rc)
}
}
}
return router
}

View File

@ -0,0 +1,375 @@
package routing
import (
"path/filepath"
"testing"
)
// ============== Router Initialization Tests ==============
func TestRouter_Initialization(t *testing.T) {
router := NewRouter()
if router.StaticDir() != "./static" {
t.Errorf("Expected static dir ./static, got %s", router.StaticDir())
}
if len(router.Routes()) != 0 {
t.Error("Expected empty routes")
}
if len(router.ExactRoutes()) != 0 {
t.Error("Expected empty exact routes")
}
if router.DefaultRoute() != nil {
t.Error("Expected nil default route")
}
}
func TestRouter_CustomStaticDir(t *testing.T) {
router := NewRouter(WithStaticDir("/custom/path"))
if router.StaticDir() != "/custom/path" {
t.Errorf("Expected static dir /custom/path, got %s", router.StaticDir())
}
}
// ============== Route Adding Tests ==============
func TestRouter_AddExactRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"return": "200 OK"}
router.AddRoute("=/health", config)
exactRoutes := router.ExactRoutes()
if _, ok := exactRoutes["/health"]; !ok {
t.Error("Expected /health in exact routes")
}
}
func TestRouter_AddDefaultRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"spa_fallback": true, "root": "./static"}
router.AddRoute("__default__", config)
if router.DefaultRoute() == nil {
t.Error("Expected default route to be set")
}
}
func TestRouter_AddRegexRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"root": "./static"}
router.AddRoute("~^/api/", config)
if len(router.Routes()) != 1 {
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
}
}
func TestRouter_AddCaseInsensitiveRegexRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"}
router.AddRoute("~*\\.(css|js)$", config)
if len(router.Routes()) != 1 {
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
}
if router.Routes()[0].CaseSensitive {
t.Error("Expected case-insensitive route")
}
}
func TestRouter_InvalidRegexPattern(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"root": "./static"}
// Invalid regex - unmatched bracket
router.AddRoute("~^/api/[invalid", config)
// Should not add invalid pattern
if len(router.Routes()) != 0 {
t.Error("Should not add invalid regex pattern")
}
}
// ============== Route Matching Tests ==============
func TestRouter_MatchExactRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"return": "200 OK"}
router.AddRoute("=/health", config)
match := router.Match("/health")
if match == nil {
t.Fatal("Expected match for /health")
}
if match.Config["return"] != "200 OK" {
t.Error("Expected return config")
}
if len(match.Params) != 0 {
t.Error("Expected empty params for exact match")
}
}
func TestRouter_MatchExactRouteNoMatch(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"return": "200 OK"}
router.AddRoute("=/health", config)
match := router.Match("/healthcheck")
if match != nil {
t.Error("Exact route should not match /healthcheck")
}
}
func TestRouter_MatchRegexRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"proxy_pass": "http://localhost:9001"}
router.AddRoute("~^/api/v\\d+/", config)
match := router.Match("/api/v1/users")
if match == nil {
t.Fatal("Expected match for /api/v1/users")
}
if match.Config["proxy_pass"] != "http://localhost:9001" {
t.Error("Expected proxy_pass config")
}
}
func TestRouter_MatchRegexRouteWithGroups(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"proxy_pass": "http://localhost:9001"}
router.AddRoute("~^/api/v(?P<version>\\d+)/", config)
match := router.Match("/api/v2/data")
if match == nil {
t.Fatal("Expected match for /api/v2/data")
}
if match.Params["version"] != "2" {
t.Errorf("Expected version=2, got %s", match.Params["version"])
}
}
func TestRouter_MatchCaseInsensitiveRegex(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"}
router.AddRoute("~*\\.(CSS|JS)$", config)
// Should match lowercase
match1 := router.Match("/styles/main.css")
if match1 == nil {
t.Error("Should match lowercase .css")
}
// Should match uppercase
match2 := router.Match("/scripts/app.JS")
if match2 == nil {
t.Error("Should match uppercase .JS")
}
}
func TestRouter_MatchCaseSensitiveRegex(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"root": "./static"}
router.AddRoute("~\\.(css)$", config)
// Should match lowercase
match1 := router.Match("/styles/main.css")
if match1 == nil {
t.Error("Should match lowercase .css")
}
// Should NOT match uppercase
match2 := router.Match("/styles/main.CSS")
if match2 != nil {
t.Error("Should not match uppercase .CSS for case-sensitive regex")
}
}
func TestRouter_MatchDefaultRoute(t *testing.T) {
router := NewRouter()
router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"})
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
match := router.Match("/unknown/path")
if match == nil {
t.Fatal("Expected default route match")
}
if match.Config["spa_fallback"] != true {
t.Error("Expected spa_fallback config from default route")
}
}
// ============== Priority Tests ==============
func TestRouter_PriorityExactOverRegex(t *testing.T) {
router := NewRouter()
router.AddRoute("=/api/status", map[string]interface{}{"return": "200 Exact"})
router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
match := router.Match("/api/status")
if match == nil {
t.Fatal("Expected match")
}
if match.Config["return"] != "200 Exact" {
t.Error("Exact match should have priority over regex")
}
}
func TestRouter_PriorityRegexOverDefault(t *testing.T) {
router := NewRouter()
router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
match := router.Match("/api/v1/users")
if match == nil {
t.Fatal("Expected match")
}
if match.Config["proxy_pass"] != "http://localhost:9001" {
t.Error("Regex match should have priority over default")
}
}
// ============== CreateRouterFromConfig Tests ==============
func TestCreateRouterFromConfig(t *testing.T) {
config := map[string]interface{}{
"regex_locations": map[string]interface{}{
"=/health": map[string]interface{}{
"return": "200 OK",
"content_type": "text/plain",
},
"~^/api/": map[string]interface{}{
"proxy_pass": "http://localhost:9001",
},
"__default__": map[string]interface{}{
"spa_fallback": true,
"root": "./static",
},
},
}
router := CreateRouterFromConfig(config)
// Check exact route
if _, ok := router.ExactRoutes()["/health"]; !ok {
t.Error("Expected /health exact route")
}
// Check regex route
if len(router.Routes()) != 1 {
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
}
// Check default route
if router.DefaultRoute() == nil {
t.Error("Expected default route")
}
}
// ============== Static Dir Path Tests ==============
func TestRouter_StaticDirPath(t *testing.T) {
router := NewRouter(WithStaticDir("/var/www/html"))
expected, _ := filepath.Abs("/var/www/html")
actual, _ := filepath.Abs(router.StaticDir())
if actual != expected {
t.Errorf("Expected static dir %s, got %s", expected, actual)
}
}
// ============== Concurrent Access Tests ==============
func TestRouter_ConcurrentAccess(t *testing.T) {
router := NewRouter()
// Add routes concurrently
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(n int) {
router.AddRoute("~^/api/v"+string(rune('0'+n))+"/", map[string]interface{}{
"proxy_pass": "http://localhost:900" + string(rune('0'+n)),
})
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
// Match routes concurrently
for i := 0; i < 10; i++ {
go func(n int) {
router.Match("/api/v" + string(rune('0'+n)) + "/users")
done <- true
}(i)
}
for i := 0; i < 10; i++ {
<-done
}
}
// ============== Benchmarks ==============
func BenchmarkRouter_MatchExact(b *testing.B) {
router := NewRouter()
router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"})
b.ResetTimer()
for i := 0; i < b.N; i++ {
router.Match("/health")
}
}
func BenchmarkRouter_MatchRegex(b *testing.B) {
router := NewRouter()
router.AddRoute("~^/api/v(?P<version>\\d+)/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
b.ResetTimer()
for i := 0; i < b.N; i++ {
router.Match("/api/v1/users/123")
}
}
func BenchmarkRouter_MatchWithManyRoutes(b *testing.B) {
router := NewRouter()
// Add many routes
for i := 0; i < 50; i++ {
router.AddRoute("~^/api/v"+string(rune('0'+i%10))+"/service"+string(rune('0'+i/10))+"/",
map[string]interface{}{"proxy_pass": "http://localhost:9001"})
}
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
b.ResetTimer()
for i := 0; i < b.N; i++ {
router.Match("/api/v5/service3/users/123")
}
}

View File

@ -0,0 +1,130 @@
// Package server provides the HTTP server implementation
package server
import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/konduktor/konduktor/internal/config"
"github.com/konduktor/konduktor/internal/logging"
"github.com/konduktor/konduktor/internal/middleware"
"github.com/konduktor/konduktor/internal/routing"
)
const Version = "0.1.0"
// Server represents the Konduktor HTTP server
type Server struct {
config *config.Config
httpServer *http.Server
router *routing.Router
logger *logging.Logger
}
// New creates a new server instance
func New(cfg *config.Config) (*Server, error) {
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
logger, err := logging.NewFromConfig(cfg.Logging)
if err != nil {
return nil, fmt.Errorf("failed to create logger: %w", err)
}
router := routing.New(cfg, logger)
srv := &Server{
config: cfg,
router: router,
logger: logger,
}
return srv, nil
}
// Run starts the server and blocks until shutdown
func (s *Server) Run() error {
// Build handler chain with middleware
handler := s.buildHandler()
// Create HTTP server
addr := fmt.Sprintf("%s:%d", s.config.Server.Host, s.config.Server.Port)
s.httpServer = &http.Server{
Addr: addr,
Handler: handler,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
// Start server in goroutine
errChan := make(chan error, 1)
go func() {
s.logger.Info("Server starting", "addr", addr, "version", Version)
var err error
if s.config.SSL.Enabled {
err = s.httpServer.ListenAndServeTLS(s.config.SSL.CertFile, s.config.SSL.KeyFile)
} else {
err = s.httpServer.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
errChan <- err
}
}()
// Wait for shutdown signal
return s.waitForShutdown(errChan)
}
// buildHandler builds the HTTP handler chain
func (s *Server) buildHandler() http.Handler {
var handler http.Handler = s.router
// Add middleware
handler = middleware.AccessLog(handler, s.logger)
handler = middleware.ServerHeader(handler, Version)
handler = middleware.Recovery(handler, s.logger)
return handler
}
// waitForShutdown waits for shutdown signal and gracefully stops the server
func (s *Server) waitForShutdown(errChan <-chan error) error {
// Listen for shutdown signals
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
select {
case err := <-errChan:
return err
case sig := <-sigChan:
s.logger.Info("Shutdown signal received", "signal", sig.String())
}
// Graceful shutdown with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
s.logger.Info("Shutting down server...")
if err := s.httpServer.Shutdown(ctx); err != nil {
s.logger.Error("Error during shutdown", "error", err)
return err
}
s.logger.Info("Server stopped gracefully")
return nil
}
// Shutdown gracefully shuts down the server
func (s *Server) Shutdown(ctx context.Context) error {
return s.httpServer.Shutdown(ctx)
}