Fix circular import in __init__.py (#35)
This commit is contained in:
@@ -7,18 +7,46 @@ import os
|
||||
from typing import Any, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, StreamingResponse
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, StreamingResponse, Depends, Security
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel, Field
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from opus_orchestrator.config import get_config
|
||||
from opus_orchestrator import run_opus, OpusOrchestrator
|
||||
from opus_orchestrator.frameworks import FRAMEWORKS
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AUTHENTICATION
|
||||
# =============================================================================
|
||||
|
||||
API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
async def get_api_key(api_key: str = Security(API_KEY_HEADER)) -> str:
|
||||
"""Validate API key from header or environment.
|
||||
|
||||
If no API key is configured (for development), allow all requests.
|
||||
Set OPUS_API_KEY environment variable to protect production endpoints.
|
||||
"""
|
||||
configured_key = os.environ.get("OPUS_API_KEY")
|
||||
|
||||
# No key configured - allow all (development mode)
|
||||
if not configured_key:
|
||||
return "dev"
|
||||
|
||||
# Key configured - validate
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=401, detail="API key required. Set X-API-Key header.")
|
||||
|
||||
if api_key != configured_key:
|
||||
raise HTTPException(status_code=403, detail="Invalid API key")
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# REQUEST/RESPONSE MODELS
|
||||
# =============================================================================
|
||||
@@ -194,7 +222,7 @@ async def list_frameworks():
|
||||
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse, tags=["generate"])
|
||||
async def generate(request: GenerateRequest, background_tasks: BackgroundTasks):
|
||||
async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, api_key: str = Depends(get_api_key)):
|
||||
"""Generate a manuscript from concept or GitHub repo."""
|
||||
import traceback
|
||||
try:
|
||||
@@ -302,7 +330,7 @@ async def generate_stream(request: GenerateRequest):
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
@app.post("/ingest", response_model=IngestResponse, tags=["ingest"])
|
||||
async def ingest(request: IngestRequest):
|
||||
async def ingest(request: IngestRequest, api_key: str = Depends(get_api_key)):
|
||||
"""Ingest content from a GitHub repository."""
|
||||
try:
|
||||
orch = OpusOrchestrator(book_type="fiction")
|
||||
@@ -351,7 +379,7 @@ class S3UploadResponse(BaseModel):
|
||||
|
||||
|
||||
@app.post("/upload", response_model=UploadResponse, tags=["upload"])
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
async def upload_file(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
|
||||
"""Upload a file for processing."""
|
||||
try:
|
||||
content = await file.read()
|
||||
@@ -368,7 +396,7 @@ async def upload_file(file: UploadFile = File(...)):
|
||||
|
||||
|
||||
@app.post("/upload/s3", response_model=S3UploadResponse, tags=["upload"])
|
||||
async def upload_to_s3(request: S3UploadRequest):
|
||||
async def upload_to_s3(request: S3UploadRequest, api_key: str = Depends(get_api_key)):
|
||||
"""Upload content to S3-compatible storage."""
|
||||
try:
|
||||
from opus_orchestrator import S3Ingestor
|
||||
@@ -441,3 +469,64 @@ if __name__ == "__main__":
|
||||
|
||||
port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
||||
# =============================================================================
|
||||
# RATE LIMITING
|
||||
# =============================================================================
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Simple in-memory rate limiter."""
|
||||
|
||||
def __init__(self, app, requests_per_minute: int = 30):
|
||||
super().__init__(app)
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.requests = defaultdict(list)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip rate limiting for health check
|
||||
if request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
current_time = time.time()
|
||||
|
||||
# Clean old requests (older than 1 minute)
|
||||
self.requests[client_ip] = [
|
||||
t for t in self.requests[client_ip]
|
||||
if current_time - t < 60
|
||||
]
|
||||
|
||||
# Check rate limit
|
||||
if len(self.requests[client_ip]) >= self.requests_per_minute:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded. Please try again later."}
|
||||
)
|
||||
|
||||
# Record this request
|
||||
self.requests[client_ip].append(current_time)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# Get rate limit from environment, default to 30/minute
|
||||
_rate_limit = int(os.environ.get("RATE_LIMIT_PER_MINUTE", "30"))
|
||||
app.add_middleware(RateLimitMiddleware, requests_per_minute=_rate_limit)
|
||||
|
||||
# CORS middleware - secure configuration
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# Get allowed origins from environment, default to restricted set
|
||||
_cors_origins = os.environ.get("CORS_ORIGINS", "").split(",") if os.environ.get("CORS_ORIGINS") else []
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_cors_origins if _cors_origins else ["http://localhost:3000", "http://localhost:8000"],
|
||||
allow_credentials=True if _cors_origins else False,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE"],
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user