package extension import ( "context" "fmt" "net/http" "sort" "sync" "github.com/konduktor/konduktor/internal/logging" ) // Manager manages all loaded extensions type Manager struct { extensions []Extension registry map[string]ExtensionFactory logger *logging.Logger mu sync.RWMutex } // NewManager creates a new extension manager func NewManager(logger *logging.Logger) *Manager { m := &Manager{ extensions: make([]Extension, 0), registry: make(map[string]ExtensionFactory), logger: logger, } // Register built-in extensions m.RegisterFactory("routing", NewRoutingExtension) m.RegisterFactory("security", NewSecurityExtension) m.RegisterFactory("caching", NewCachingExtension) return m } // RegisterFactory registers an extension factory func (m *Manager) RegisterFactory(name string, factory ExtensionFactory) { m.mu.Lock() defer m.mu.Unlock() m.registry[name] = factory } // LoadExtension loads an extension by type and config func (m *Manager) LoadExtension(extType string, config map[string]interface{}) error { m.mu.Lock() defer m.mu.Unlock() factory, ok := m.registry[extType] if !ok { return fmt.Errorf("unknown extension type: %s", extType) } ext, err := factory(config, m.logger) if err != nil { return fmt.Errorf("failed to create extension %s: %w", extType, err) } if err := ext.Initialize(); err != nil { return fmt.Errorf("failed to initialize extension %s: %w", extType, err) } m.extensions = append(m.extensions, ext) // Sort by priority (lower first) sort.Slice(m.extensions, func(i, j int) bool { return m.extensions[i].Priority() < m.extensions[j].Priority() }) m.logger.Info("Loaded extension", "type", extType, "name", ext.Name(), "priority", ext.Priority()) return nil } // AddExtension adds a pre-created extension func (m *Manager) AddExtension(ext Extension) error { m.mu.Lock() defer m.mu.Unlock() if err := ext.Initialize(); err != nil { return fmt.Errorf("failed to initialize extension %s: %w", ext.Name(), err) } m.extensions = append(m.extensions, ext) // Sort by priority sort.Slice(m.extensions, func(i, j int) bool { return m.extensions[i].Priority() < m.extensions[j].Priority() }) m.logger.Info("Added extension", "name", ext.Name(), "priority", ext.Priority()) return nil } // ProcessRequest runs all extensions' ProcessRequest in order // Returns true if any extension handled the request func (m *Manager) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) { m.mu.RLock() extensions := m.extensions m.mu.RUnlock() for _, ext := range extensions { if !ext.Enabled() { continue } handled, err := ext.ProcessRequest(ctx, w, r) if err != nil { m.logger.Error("Extension error", "extension", ext.Name(), "error", err) // Continue to next extension on error continue } if handled { return true, nil } } return false, nil } // ProcessResponse runs all extensions' ProcessResponse in reverse order func (m *Manager) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) { m.mu.RLock() extensions := m.extensions m.mu.RUnlock() // Process in reverse order for response for i := len(extensions) - 1; i >= 0; i-- { ext := extensions[i] if !ext.Enabled() { continue } ext.ProcessResponse(ctx, w, r) } } // Cleanup cleans up all extensions func (m *Manager) Cleanup() { m.mu.Lock() defer m.mu.Unlock() for _, ext := range m.extensions { if err := ext.Cleanup(); err != nil { m.logger.Error("Extension cleanup error", "extension", ext.Name(), "error", err) } } m.extensions = nil } // GetExtension returns an extension by name func (m *Manager) GetExtension(name string) Extension { m.mu.RLock() defer m.mu.RUnlock() for _, ext := range m.extensions { if ext.Name() == name { return ext } } return nil } // Extensions returns all loaded extensions func (m *Manager) Extensions() []Extension { m.mu.RLock() defer m.mu.RUnlock() result := make([]Extension, len(m.extensions)) copy(result, m.extensions) return result } // Handler returns an http.Handler that processes requests through all extensions func (m *Manager) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Wrap response writer through all extensions that support it // Process in reverse priority order so highest priority wrapper is outermost wrappedWriter := w var finalizers []ResponseFinalizer m.mu.RLock() extensions := m.extensions m.mu.RUnlock() // Wrap response writer (lowest priority first, so they wrap in correct order) for _, ext := range extensions { if !ext.Enabled() { continue } if wrapper, ok := ext.(ResponseWriterWrapper); ok { wrappedWriter = wrapper.WrapResponseWriter(wrappedWriter, r) // Check if the wrapped writer implements Finalizer if finalizer, ok := wrappedWriter.(ResponseFinalizer); ok { finalizers = append(finalizers, finalizer) } } } // Create response wrapper to capture status code responseWrapper := newResponseWrapper(wrappedWriter) // Process request through extensions handled, err := m.ProcessRequest(ctx, responseWrapper, r) if err != nil { m.logger.Error("Error processing request", "error", err) } if handled { // Extension handled the request, process response m.ProcessResponse(ctx, responseWrapper, r) // Finalize all response writers for i := len(finalizers) - 1; i >= 0; i-- { finalizers[i].Finalize() } return } // No extension handled, pass to next handler next.ServeHTTP(responseWrapper, r) // Process response m.ProcessResponse(ctx, responseWrapper, r) // Finalize all response writers for i := len(finalizers) - 1; i >= 0; i-- { finalizers[i].Finalize() } }) } // responseWrapper wraps http.ResponseWriter to allow response modification type responseWrapper struct { http.ResponseWriter statusCode int written bool } func newResponseWrapper(w http.ResponseWriter) *responseWrapper { return &responseWrapper{ ResponseWriter: w, statusCode: http.StatusOK, } } func (rw *responseWrapper) WriteHeader(code int) { if !rw.written { rw.statusCode = code rw.ResponseWriter.WriteHeader(code) rw.written = true } } func (rw *responseWrapper) Write(b []byte) (int, error) { if !rw.written { rw.WriteHeader(http.StatusOK) } return rw.ResponseWriter.Write(b) } func (rw *responseWrapper) StatusCode() int { return rw.statusCode } // Unwrap returns the underlying ResponseWriter (for type assertions) func (rw *responseWrapper) Unwrap() http.ResponseWriter { return rw.ResponseWriter }