Fix circular import in __init__.py (#35)

This commit is contained in:
2026-03-14 09:24:31 +00:00
parent 1b116108a6
commit 0f62267806
25 changed files with 517 additions and 137 deletions
+95 -6
View File
@@ -7,18 +7,46 @@ import os
from typing import Any, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, StreamingResponse
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, StreamingResponse, Depends, Security
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field
from dotenv import load_dotenv
load_dotenv()
from opus_orchestrator.config import get_config
from opus_orchestrator import run_opus, OpusOrchestrator
from opus_orchestrator.frameworks import FRAMEWORKS
# =============================================================================
# AUTHENTICATION
# =============================================================================
API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
async def get_api_key(api_key: str = Security(API_KEY_HEADER)) -> str:
"""Validate API key from header or environment.
If no API key is configured (for development), allow all requests.
Set OPUS_API_KEY environment variable to protect production endpoints.
"""
configured_key = os.environ.get("OPUS_API_KEY")
# No key configured - allow all (development mode)
if not configured_key:
return "dev"
# Key configured - validate
if api_key is None:
raise HTTPException(status_code=401, detail="API key required. Set X-API-Key header.")
if api_key != configured_key:
raise HTTPException(status_code=403, detail="Invalid API key")
return api_key
# =============================================================================
# REQUEST/RESPONSE MODELS
# =============================================================================
@@ -194,7 +222,7 @@ async def list_frameworks():
@app.post("/generate", response_model=GenerateResponse, tags=["generate"])
async def generate(request: GenerateRequest, background_tasks: BackgroundTasks):
async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, api_key: str = Depends(get_api_key)):
"""Generate a manuscript from concept or GitHub repo."""
import traceback
try:
@@ -302,7 +330,7 @@ async def generate_stream(request: GenerateRequest):
return StreamingResponse(event_generator(), media_type="text/event-stream")
@app.post("/ingest", response_model=IngestResponse, tags=["ingest"])
async def ingest(request: IngestRequest):
async def ingest(request: IngestRequest, api_key: str = Depends(get_api_key)):
"""Ingest content from a GitHub repository."""
try:
orch = OpusOrchestrator(book_type="fiction")
@@ -351,7 +379,7 @@ class S3UploadResponse(BaseModel):
@app.post("/upload", response_model=UploadResponse, tags=["upload"])
async def upload_file(file: UploadFile = File(...)):
async def upload_file(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
"""Upload a file for processing."""
try:
content = await file.read()
@@ -368,7 +396,7 @@ async def upload_file(file: UploadFile = File(...)):
@app.post("/upload/s3", response_model=S3UploadResponse, tags=["upload"])
async def upload_to_s3(request: S3UploadRequest):
async def upload_to_s3(request: S3UploadRequest, api_key: str = Depends(get_api_key)):
"""Upload content to S3-compatible storage."""
try:
from opus_orchestrator import S3Ingestor
@@ -441,3 +469,64 @@ if __name__ == "__main__":
port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
uvicorn.run(app, host="0.0.0.0", port=port)
# =============================================================================
# RATE LIMITING
# =============================================================================
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
import time
from collections import defaultdict
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Simple in-memory rate limiter."""
def __init__(self, app, requests_per_minute: int = 30):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.requests = defaultdict(list)
async def dispatch(self, request: Request, call_next):
# Skip rate limiting for health check
if request.url.path == "/health":
return await call_next(request)
client_ip = request.client.host if request.client else "unknown"
current_time = time.time()
# Clean old requests (older than 1 minute)
self.requests[client_ip] = [
t for t in self.requests[client_ip]
if current_time - t < 60
]
# Check rate limit
if len(self.requests[client_ip]) >= self.requests_per_minute:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."}
)
# Record this request
self.requests[client_ip].append(current_time)
return await call_next(request)
# Get rate limit from environment, default to 30/minute
_rate_limit = int(os.environ.get("RATE_LIMIT_PER_MINUTE", "30"))
app.add_middleware(RateLimitMiddleware, requests_per_minute=_rate_limit)
# CORS middleware - secure configuration
from fastapi.middleware.cors import CORSMiddleware
# Get allowed origins from environment, default to restricted set
_cors_origins = os.environ.get("CORS_ORIGINS", "").split(",") if os.environ.get("CORS_ORIGINS") else []
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins if _cors_origins else ["http://localhost:3000", "http://localhost:8000"],
allow_credentials=True if _cors_origins else False,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["Content-Type", "Authorization"],
)