Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 75 additions & 74 deletions backend-data-elaborator/api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
import asyncio
import time
import hashlib
import os
from datetime import datetime
from typing import List
from contextlib import asynccontextmanager

from fastapi import FastAPI, Depends, HTTPException, status, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, Depends, HTTPException, status, WebSocket, WebSocketDisconnect, Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.orm import Session
from sqlalchemy import func, desc, text
from sqlalchemy.exc import OperationalError
Expand All @@ -34,40 +37,13 @@

# --- CONFIGURATION ---
MAX_TIMESTAMP_SKEW = 60
REDIS_URL = "redis://redis:6379/0"
REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379/0")

# ⚠️ SECURITY: Shared Secret for Device Provisioning (Must match Firmware)
ENROLLMENT_TOKEN = os.getenv("ENROLLMENT_TOKEN", "S3cret_Qu4k3_K3y")

# ==========================================
# WEBSOCKET CONNECTION MANAGER
# ==========================================

class ConnectionManager:
"""Manages active WebSocket connections for broadcasting alerts."""

def __init__(self):
self.active_connections: List[WebSocket] = []

async def connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.active_connections.append(websocket)

def disconnect(self, websocket: WebSocket) -> None:
if websocket in self.active_connections:
self.active_connections.remove(websocket)

async def broadcast(self, message: str) -> None:
"""Pushes a message to all connected clients."""
for connection in self.active_connections:
try:
await connection.send_text(message)
except Exception as e:
# FIX: Log the error instead of 'pass'
print(f"⚠️ Warning: Failed to broadcast to a client. Error: {e}")
self.disconnect(connection)

ws_manager = ConnectionManager()
# ⚠️ SECURITY: API Key for IoT Endpoints
IOT_API_KEY = os.getenv("IOT_API_KEY", "SuperSecretIoTKey2024")

# ==========================================
# INFRASTRUCTURE INITIALIZATION
Expand All @@ -92,7 +68,8 @@ def wait_for_db(retries: int = 10, delay: int = 3) -> None:
models.Base.metadata.create_all(bind=engine)

# 2. Initialize Redis Client (Async)
redis_client = aioredis.from_url("redis://redis:6379/0", decode_responses=True)
redis_client = aioredis.from_url(REDIS_URL, decode_responses=True)


# ==========================================
# REAL-TIME NOTIFICATION SYSTEM (PUBSUB)
Expand All @@ -107,21 +84,28 @@ def __init__(self):
async def connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.active_connections.append(websocket)
print(f"📱 Client Connected. Active: {len(self.active_connections)}")

def disconnect(self, websocket: WebSocket) -> None:
if websocket in self.active_connections:
self.active_connections.remove(websocket)
print(f"📱 Client Disconnected. Active: {len(self.active_connections)}")

async def broadcast(self, message: str) -> None:
"""Pushes a message to all connected clients."""
dead_connections = []
for connection in self.active_connections:
try:
await connection.send_text(message)
except Exception as e:
# FIX: We log the error here instead of using 'pass'
print(f"⚠️ Client disconnected abruptly. Error: {e}")
self.disconnect(connection)
# FIXED [B110:try_except_pass]: We now log the exception and schedule cleanup
print(f"⚠️ Failed to broadcast to a client: {e}")
dead_connections.append(connection)

# Clean up any connections that threw errors
for dead in dead_connections:
self.disconnect(dead)

manager = ConnectionManager()

async def redis_alert_listener() -> None:
Expand All @@ -145,19 +129,63 @@ async def lifespan(app: FastAPI):
# 3. Initialize FastAPI
app = FastAPI(title="QuakeGuard Backend", version="2.2.0", lifespan=lifespan)

# ==========================================
# SECURITY MIDDLEWARE & DEPENDENCIES
# ==========================================

class IoTAuthenticationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Exclude public endpoints, websockets, and health checks from Auth
if request.url.path.startswith(("/docs", "/openapi", "/health", "/ws")):
return await call_next(request)

# Extract the key from headers
api_key = request.headers.get("X-API-Key")

if api_key != IOT_API_KEY:
return JSONResponse(
status_code=401,
content={"detail": "Unauthorized: Invalid or missing X-API-Key header"}
)

return await call_next(request)

app.add_middleware(IoTAuthenticationMiddleware)

async def rate_limiter(request: Request):
"""
Fixed-window rate limiter using Redis.
Restricts ingestion per IP to prevent Thundering Herd / DoS attacks.
"""
client_ip = request.client.host
current_second = int(time.time())

# Create a unique Redis key for this IP for the current second
key = f"rate_limit:{client_ip}:{current_second}"

# Increment the request count
request_count = await redis_client.incr(key)

# Set expiration on the key the first time it is created
if request_count == 1:
await redis_client.expire(key, 5)

# Threshold: Allow max 50 requests per second per IP
if request_count > 50:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded. Too many requests from this IP."
)


@app.websocket("/ws/alerts")
async def websocket_endpoint(websocket: WebSocket):
"""Clients connect here to receive real-time updates."""
await manager.connect(websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
# Standard expected disconnection
manager.disconnect(websocket)
except Exception as e:
# FIX: Log unexpected disconnects
print(f"⚠️ Unexpected WebSocket error: {e}")
except (WebSocketDisconnect, Exception):
manager.disconnect(websocket)

# ==========================================
Expand Down Expand Up @@ -240,7 +268,7 @@ def create_misurator(misurator: schemas.MisuratorCreate, db: Session = Depends(g
def get_misurators(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
return db.query(models.Misurator).offset(skip).limit(limit).all()

@app.post("/misurations/", status_code=status.HTTP_202_ACCEPTED, tags=["Ingestion"])
@app.post("/misurations/", status_code=status.HTTP_202_ACCEPTED, tags=["Ingestion"], dependencies=[Depends(rate_limiter)])
async def create_misuration_async(misuration: schemas.MisurationCreate, db: Session = Depends(get_db)):
misurator = db.query(models.Misurator).filter(models.Misurator.id == misuration.misurator_id).first()
if not misurator or not misurator.active:
Expand All @@ -258,39 +286,12 @@ async def create_misuration_async(misuration: schemas.MisurationCreate, db: Sess

# 3. Check Replay Attack
if abs(time.time() - misuration.device_timestamp) > 60:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Replay Attack Detected: Timestamp invalid")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Replay Attack Detected: Timestamp invalid")

# 4. Enqueue for Worker
payload = misuration.model_dump()
payload['zone_id'] = misurator.zone_id
await redis_client.lpush("seismic_events", json.dumps(payload))
payload['zone_id'] = misurator.zone_id

return {"status": "accepted"}

@app.get("/zones/{zone_id}/alerts", response_model=List[schemas.AlertResponse], tags=["Data Retrieval"])
def get_zone_alerts(zone_id: int, limit: int = 10, db: Session = Depends(get_db)):
return db.query(models.Alert).filter(models.Alert.zone_id == zone_id).order_by(desc(models.Alert.timestamp)).limit(limit).all()

@app.get("/sensors/{misurator_id}/statistics", tags=["Analytics"])
def get_sensor_statistics(misurator_id: int, db: Session = Depends(get_db)):
sensor = db.query(models.Misurator).filter(models.Misurator.id == misurator_id).first()
if not sensor:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Sensor not found")

stats = db.query(
func.count(models.Misuration.id).label("count"),
func.avg(models.Misuration.value).label("average"),
func.max(models.Misuration.value).label("max_value"),
func.min(models.Misuration.value).label("min_value")
).filter(models.Misuration.misurator_id == misurator_id).first()

return {
"misurator_id": misurator_id, "total_readings": stats.count,
"average_value": round(stats.average, 2) if stats.average else 0.0,
"max_recorded": stats.max_value, "min_recorded": stats.min_value,
"generated_at": datetime.utcnow().isoformat()
}

@app.get("/health", tags=["System"])
async def health_check():
return {"status": "ok"}
# Offload the rest to the Redis queue for async processing
await redis_client.lpush("seismic_events", json.dumps(payload))
return {"status": "accepted"}
92 changes: 21 additions & 71 deletions backend-data-elaborator/api/src/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# ==========================================
# ZONE SCHEMAS
# ==========================================

class ZoneBase(BaseModel):
city: str
city: str

class ZoneCreate(ZoneBase):
pass
pass

class ZoneUpdate(BaseModel):
city: Optional[str] = None
Expand All @@ -26,13 +25,11 @@ class Zone(ZoneBase):
id: int
model_config = ConfigDict(from_attributes=True)


# ==========================================
# MISURATOR (SENSOR) SCHEMAS
# ==========================================

class MisuratorBase(BaseModel):
active: bool
active: bool
zone_id: int

class MisuratorCreate(MisuratorBase):
Expand All @@ -42,7 +39,6 @@ class MisuratorCreate(MisuratorBase):
"""
latitude: float = Field(..., ge=-90, le=90, description="GPS Latitude")
longitude: float = Field(..., ge=-180, le=180, description="GPS Longitude")

# The Public Key generated by the ESP32 (Hex string)
public_key_hex: str = Field(..., description="ECDSA Public Key (NIST256p) in Hex format")

Expand All @@ -55,80 +51,34 @@ class Misurator(MisuratorBase):
id: int
latitude: Optional[float] = None
longitude: Optional[float] = None
# We do not return the public key by default to keep responses clean,
# We do not return the public key by default to keep responses clean,
# but it can be added if needed.

model_config = ConfigDict(from_attributes=True)


# ==========================================
# MISURATION (DATA POINT) SCHEMAS
# ==========================================

class MisurationBase(BaseModel):
value: int
misurator_id: int

class MisurationCreate(MisurationBase):
class MisurationCreate(BaseModel):
"""
Payload for data ingestion.
MUST include signature and device timestamp for verification.
"""
# Timestamp generated by the device (Unix epoch float/int) used for replay protection
device_timestamp: float

# The digital signature of "value:device_timestamp"
signature_hex: str

class MisurationUpdate(BaseModel):
value: Optional[int] = None
misurator_id: Optional[int] = None

class Misuration(MisurationBase):
value: int = Field(
...,
ge=-8192,
le=8192,
description="Vibration value. Must be within the physical constraints of the ADXL345 sensor to prevent spoofing."
)
misurator_id: int = Field(..., gt=0, description="The ID of the registered sensor")
device_timestamp: int = Field(..., gt=1600000000, description="Unix timestamp of the event")
signature_hex: str = Field(..., min_length=64, description="ECDSA NIST256p Signature")

class Misuration(BaseModel):
id: int
created_at: datetime
model_config = ConfigDict(from_attributes=True)


# ==========================================
# ANALYTICS & ALERTS SCHEMAS
# ==========================================

class ZoneStats(BaseModel):
"""
DTO for providing statistical aggregated data about a zone.
"""
zone_id: int
city: str
active_misurators: int
total_misurators: int
avg_misuration_value: Optional[float] = None
last_misuration: Optional[datetime] = None


# --- ALERT DEFINITIONS ---

class AlertBase(BaseModel):
"""
Base properties shared between creation and retrieval of Alerts.
"""
zone_id: int
severity: float
message: Optional[str] = None
timestamp: datetime

class AlertCreate(AlertBase):
"""
Schema for internal creation of alerts (used by the Worker).
"""
pass
value: int
misurator_id: int
device_timestamp: int
server_timestamp: datetime
signature_hex: str

class AlertResponse(AlertBase):
"""
Schema for API responses returning Alert data.
Includes the database ID.
"""
id: int

# Config to allow Pydantic to read data from the SQLAlchemy object
model_config = ConfigDict(from_attributes=True)
Loading