mitai-jinkendo/backend/auth.py
Lars 329daaef1c
All checks were successful
Deploy Development / deploy (push) Successful in 35s
Build Test / lint-backend (push) Successful in 0s
Build Test / build-frontend (push) Successful in 12s
fix: prevent connection pool exhaustion in features/usage
- Add optional conn parameter to get_effective_tier()
- Add optional conn parameter to check_feature_access()
- Pass existing connection in features.py loop
- Prevents opening 20+ connections simultaneously
- Fixes "connection pool exhausted" error

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-21 07:02:42 +01:00

375 lines
11 KiB
Python

"""
Authentication and Authorization for Mitai Jinkendo
Provides password hashing, session management, and auth dependencies
for FastAPI endpoints.
"""
import hashlib
import secrets
from typing import Optional
from datetime import datetime, timedelta
from fastapi import Header, Query, HTTPException
import bcrypt
from db import get_db, get_cursor
def hash_pin(pin: str) -> str:
"""Hash password with bcrypt. Falls back gracefully from legacy SHA256."""
return bcrypt.hashpw(pin.encode(), bcrypt.gensalt()).decode()
def verify_pin(pin: str, stored_hash: str) -> bool:
"""Verify password - supports both bcrypt and legacy SHA256."""
if not stored_hash:
return False
# Detect bcrypt hash (starts with $2b$ or $2a$)
if stored_hash.startswith('$2'):
try:
return bcrypt.checkpw(pin.encode(), stored_hash.encode())
except Exception:
return False
# Legacy SHA256 support (auto-upgrade to bcrypt on next login)
return stored_hash == hashlib.sha256(pin.encode()).hexdigest()
def make_token() -> str:
"""Generate a secure random token for sessions."""
return secrets.token_urlsafe(32)
def get_session(token: str):
"""
Get session data for a given token.
Returns session dict with profile info, or None if invalid/expired.
"""
if not token:
return None
with get_db() as conn:
cur = get_cursor(conn)
cur.execute(
"SELECT s.*, p.role, p.name, p.ai_enabled, p.ai_limit_day, p.export_enabled "
"FROM sessions s JOIN profiles p ON s.profile_id=p.id "
"WHERE s.token=%s AND s.expires_at > CURRENT_TIMESTAMP",
(token,)
)
return cur.fetchone()
def require_auth(x_auth_token: Optional[str] = Header(default=None)):
"""
FastAPI dependency - requires valid authentication.
Usage:
@app.get("/api/endpoint")
def endpoint(session: dict = Depends(require_auth)):
profile_id = session['profile_id']
...
Raises:
HTTPException 401 if not authenticated
"""
session = get_session(x_auth_token)
if not session:
raise HTTPException(401, "Nicht eingeloggt")
return session
def require_auth_flexible(x_auth_token: Optional[str] = Header(default=None), token: Optional[str] = Query(default=None)):
"""
FastAPI dependency - auth via header OR query parameter.
Used for endpoints accessed by <img> tags that can't send headers.
Usage:
@app.get("/api/photos/{id}")
def get_photo(id: str, session: dict = Depends(require_auth_flexible)):
...
Raises:
HTTPException 401 if not authenticated
"""
session = get_session(x_auth_token or token)
if not session:
raise HTTPException(401, "Nicht eingeloggt")
return session
def require_admin(x_auth_token: Optional[str] = Header(default=None)):
"""
FastAPI dependency - requires admin authentication.
Usage:
@app.put("/api/admin/endpoint")
def admin_endpoint(session: dict = Depends(require_admin)):
...
Raises:
HTTPException 401 if not authenticated
HTTPException 403 if not admin
"""
session = get_session(x_auth_token)
if not session:
raise HTTPException(401, "Nicht eingeloggt")
if session['role'] != 'admin':
raise HTTPException(403, "Nur für Admins")
return session
# ============================================================================
# Feature Access Control (v9c)
# ============================================================================
def get_effective_tier(profile_id: str, conn=None) -> str:
"""
Get the effective tier for a profile.
Checks for active access_grants first (from coupons, trials, etc.),
then falls back to profile.tier.
Args:
profile_id: User profile ID
conn: Optional existing DB connection (to avoid pool exhaustion)
Returns:
tier_id (str): 'free', 'basic', 'premium', or 'selfhosted'
"""
# Use existing connection if provided, otherwise open new one
if conn:
cur = get_cursor(conn)
# Check for active access grants (highest priority)
cur.execute("""
SELECT tier_id
FROM access_grants
WHERE profile_id = %s
AND is_active = true
AND valid_from <= CURRENT_TIMESTAMP
AND valid_until > CURRENT_TIMESTAMP
ORDER BY valid_until DESC
LIMIT 1
""", (profile_id,))
grant = cur.fetchone()
if grant:
return grant['tier_id']
# Fall back to profile tier
cur.execute("SELECT tier FROM profiles WHERE id = %s", (profile_id,))
profile = cur.fetchone()
return profile['tier'] if profile else 'free'
else:
# Open new connection if none provided
with get_db() as conn:
return get_effective_tier(profile_id, conn)
def check_feature_access(profile_id: str, feature_id: str, conn=None) -> dict:
"""
Check if a profile has access to a feature.
Access hierarchy:
1. User-specific restriction (user_feature_restrictions)
2. Tier limit (tier_limits)
3. Feature default (features.default_limit)
Args:
profile_id: User profile ID
feature_id: Feature ID to check
conn: Optional existing DB connection (to avoid pool exhaustion)
Returns:
dict: {
'allowed': bool,
'limit': int | None, # NULL = unlimited
'used': int,
'remaining': int | None, # NULL = unlimited
'reason': str # 'unlimited', 'within_limit', 'limit_exceeded', 'feature_disabled'
}
"""
# Use existing connection if provided
if conn:
return _check_impl(profile_id, feature_id, conn)
else:
with get_db() as conn:
return _check_impl(profile_id, feature_id, conn)
def _check_impl(profile_id: str, feature_id: str, conn) -> dict:
"""Internal implementation of check_feature_access."""
cur = get_cursor(conn)
# Get feature info
cur.execute("""
SELECT limit_type, reset_period, default_limit
FROM features
WHERE id = %s AND active = true
""", (feature_id,))
feature = cur.fetchone()
if not feature:
return {
'allowed': False,
'limit': None,
'used': 0,
'remaining': None,
'reason': 'feature_not_found'
}
# Priority 1: Check user-specific restriction
cur.execute("""
SELECT limit_value
FROM user_feature_restrictions
WHERE profile_id = %s AND feature_id = %s
""", (profile_id, feature_id))
restriction = cur.fetchone()
if restriction is not None:
limit = restriction['limit_value']
else:
# Priority 2: Check tier limit
tier_id = get_effective_tier(profile_id, conn)
cur.execute("""
SELECT limit_value
FROM tier_limits
WHERE tier_id = %s AND feature_id = %s
""", (tier_id, feature_id))
tier_limit = cur.fetchone()
if tier_limit is not None:
limit = tier_limit['limit_value']
else:
# Priority 3: Feature default
limit = feature['default_limit']
# For boolean features (limit 0 = disabled, 1 = enabled)
if feature['limit_type'] == 'boolean':
allowed = limit == 1
return {
'allowed': allowed,
'limit': limit,
'used': 0,
'remaining': None,
'reason': 'enabled' if allowed else 'feature_disabled'
}
# For count-based features
# Check current usage
cur.execute("""
SELECT usage_count, reset_at
FROM user_feature_usage
WHERE profile_id = %s AND feature_id = %s
""", (profile_id, feature_id))
usage = cur.fetchone()
used = usage['usage_count'] if usage else 0
# Check if reset is needed
if usage and usage['reset_at'] and datetime.now() > usage['reset_at']:
# Reset usage
used = 0
next_reset = _calculate_next_reset(feature['reset_period'])
cur.execute("""
UPDATE user_feature_usage
SET usage_count = 0, reset_at = %s, updated = CURRENT_TIMESTAMP
WHERE profile_id = %s AND feature_id = %s
""", (next_reset, profile_id, feature_id))
conn.commit()
# NULL limit = unlimited
if limit is None:
return {
'allowed': True,
'limit': None,
'used': used,
'remaining': None,
'reason': 'unlimited'
}
# 0 limit = disabled
if limit == 0:
return {
'allowed': False,
'limit': 0,
'used': used,
'remaining': 0,
'reason': 'feature_disabled'
}
# Check if within limit
allowed = used < limit
remaining = limit - used if limit else None
return {
'allowed': allowed,
'limit': limit,
'used': used,
'remaining': remaining,
'reason': 'within_limit' if allowed else 'limit_exceeded'
}
def increment_feature_usage(profile_id: str, feature_id: str) -> None:
"""
Increment usage counter for a feature.
Creates usage record if it doesn't exist, with reset_at based on
feature's reset_period.
"""
with get_db() as conn:
cur = get_cursor(conn)
# Get feature reset period
cur.execute("""
SELECT reset_period
FROM features
WHERE id = %s
""", (feature_id,))
feature = cur.fetchone()
if not feature:
return
reset_period = feature['reset_period']
next_reset = _calculate_next_reset(reset_period)
# Upsert usage
cur.execute("""
INSERT INTO user_feature_usage (profile_id, feature_id, usage_count, reset_at)
VALUES (%s, %s, 1, %s)
ON CONFLICT (profile_id, feature_id)
DO UPDATE SET
usage_count = user_feature_usage.usage_count + 1,
updated = CURRENT_TIMESTAMP
""", (profile_id, feature_id, next_reset))
conn.commit()
def _calculate_next_reset(reset_period: str) -> Optional[datetime]:
"""
Calculate next reset timestamp based on reset period.
Args:
reset_period: 'never', 'daily', 'monthly'
Returns:
datetime or None (for 'never')
"""
if reset_period == 'never':
return None
elif reset_period == 'daily':
# Reset at midnight
tomorrow = datetime.now().date() + timedelta(days=1)
return datetime.combine(tomorrow, datetime.min.time())
elif reset_period == 'monthly':
# Reset at start of next month
now = datetime.now()
if now.month == 12:
return datetime(now.year + 1, 1, 1)
else:
return datetime(now.year, now.month + 1, 1)
else:
return None