- Add optional conn parameter to get_effective_tier() - Add optional conn parameter to check_feature_access() - Pass existing connection in features.py loop - Prevents opening 20+ connections simultaneously - Fixes "connection pool exhausted" error Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
375 lines
11 KiB
Python
375 lines
11 KiB
Python
"""
|
|
Authentication and Authorization for Mitai Jinkendo
|
|
|
|
Provides password hashing, session management, and auth dependencies
|
|
for FastAPI endpoints.
|
|
"""
|
|
import hashlib
|
|
import secrets
|
|
from typing import Optional
|
|
from datetime import datetime, timedelta
|
|
from fastapi import Header, Query, HTTPException
|
|
import bcrypt
|
|
|
|
from db import get_db, get_cursor
|
|
|
|
|
|
def hash_pin(pin: str) -> str:
|
|
"""Hash password with bcrypt. Falls back gracefully from legacy SHA256."""
|
|
return bcrypt.hashpw(pin.encode(), bcrypt.gensalt()).decode()
|
|
|
|
|
|
def verify_pin(pin: str, stored_hash: str) -> bool:
|
|
"""Verify password - supports both bcrypt and legacy SHA256."""
|
|
if not stored_hash:
|
|
return False
|
|
# Detect bcrypt hash (starts with $2b$ or $2a$)
|
|
if stored_hash.startswith('$2'):
|
|
try:
|
|
return bcrypt.checkpw(pin.encode(), stored_hash.encode())
|
|
except Exception:
|
|
return False
|
|
# Legacy SHA256 support (auto-upgrade to bcrypt on next login)
|
|
return stored_hash == hashlib.sha256(pin.encode()).hexdigest()
|
|
|
|
|
|
def make_token() -> str:
|
|
"""Generate a secure random token for sessions."""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
|
|
def get_session(token: str):
|
|
"""
|
|
Get session data for a given token.
|
|
|
|
Returns session dict with profile info, or None if invalid/expired.
|
|
"""
|
|
if not token:
|
|
return None
|
|
with get_db() as conn:
|
|
cur = get_cursor(conn)
|
|
cur.execute(
|
|
"SELECT s.*, p.role, p.name, p.ai_enabled, p.ai_limit_day, p.export_enabled "
|
|
"FROM sessions s JOIN profiles p ON s.profile_id=p.id "
|
|
"WHERE s.token=%s AND s.expires_at > CURRENT_TIMESTAMP",
|
|
(token,)
|
|
)
|
|
return cur.fetchone()
|
|
|
|
|
|
def require_auth(x_auth_token: Optional[str] = Header(default=None)):
|
|
"""
|
|
FastAPI dependency - requires valid authentication.
|
|
|
|
Usage:
|
|
@app.get("/api/endpoint")
|
|
def endpoint(session: dict = Depends(require_auth)):
|
|
profile_id = session['profile_id']
|
|
...
|
|
|
|
Raises:
|
|
HTTPException 401 if not authenticated
|
|
"""
|
|
session = get_session(x_auth_token)
|
|
if not session:
|
|
raise HTTPException(401, "Nicht eingeloggt")
|
|
return session
|
|
|
|
|
|
def require_auth_flexible(x_auth_token: Optional[str] = Header(default=None), token: Optional[str] = Query(default=None)):
|
|
"""
|
|
FastAPI dependency - auth via header OR query parameter.
|
|
|
|
Used for endpoints accessed by <img> tags that can't send headers.
|
|
|
|
Usage:
|
|
@app.get("/api/photos/{id}")
|
|
def get_photo(id: str, session: dict = Depends(require_auth_flexible)):
|
|
...
|
|
|
|
Raises:
|
|
HTTPException 401 if not authenticated
|
|
"""
|
|
session = get_session(x_auth_token or token)
|
|
if not session:
|
|
raise HTTPException(401, "Nicht eingeloggt")
|
|
return session
|
|
|
|
|
|
def require_admin(x_auth_token: Optional[str] = Header(default=None)):
|
|
"""
|
|
FastAPI dependency - requires admin authentication.
|
|
|
|
Usage:
|
|
@app.put("/api/admin/endpoint")
|
|
def admin_endpoint(session: dict = Depends(require_admin)):
|
|
...
|
|
|
|
Raises:
|
|
HTTPException 401 if not authenticated
|
|
HTTPException 403 if not admin
|
|
"""
|
|
session = get_session(x_auth_token)
|
|
if not session:
|
|
raise HTTPException(401, "Nicht eingeloggt")
|
|
if session['role'] != 'admin':
|
|
raise HTTPException(403, "Nur für Admins")
|
|
return session
|
|
|
|
|
|
# ============================================================================
|
|
# Feature Access Control (v9c)
|
|
# ============================================================================
|
|
|
|
def get_effective_tier(profile_id: str, conn=None) -> str:
|
|
"""
|
|
Get the effective tier for a profile.
|
|
|
|
Checks for active access_grants first (from coupons, trials, etc.),
|
|
then falls back to profile.tier.
|
|
|
|
Args:
|
|
profile_id: User profile ID
|
|
conn: Optional existing DB connection (to avoid pool exhaustion)
|
|
|
|
Returns:
|
|
tier_id (str): 'free', 'basic', 'premium', or 'selfhosted'
|
|
"""
|
|
# Use existing connection if provided, otherwise open new one
|
|
if conn:
|
|
cur = get_cursor(conn)
|
|
|
|
# Check for active access grants (highest priority)
|
|
cur.execute("""
|
|
SELECT tier_id
|
|
FROM access_grants
|
|
WHERE profile_id = %s
|
|
AND is_active = true
|
|
AND valid_from <= CURRENT_TIMESTAMP
|
|
AND valid_until > CURRENT_TIMESTAMP
|
|
ORDER BY valid_until DESC
|
|
LIMIT 1
|
|
""", (profile_id,))
|
|
|
|
grant = cur.fetchone()
|
|
if grant:
|
|
return grant['tier_id']
|
|
|
|
# Fall back to profile tier
|
|
cur.execute("SELECT tier FROM profiles WHERE id = %s", (profile_id,))
|
|
profile = cur.fetchone()
|
|
return profile['tier'] if profile else 'free'
|
|
else:
|
|
# Open new connection if none provided
|
|
with get_db() as conn:
|
|
return get_effective_tier(profile_id, conn)
|
|
|
|
|
|
def check_feature_access(profile_id: str, feature_id: str, conn=None) -> dict:
|
|
"""
|
|
Check if a profile has access to a feature.
|
|
|
|
Access hierarchy:
|
|
1. User-specific restriction (user_feature_restrictions)
|
|
2. Tier limit (tier_limits)
|
|
3. Feature default (features.default_limit)
|
|
|
|
Args:
|
|
profile_id: User profile ID
|
|
feature_id: Feature ID to check
|
|
conn: Optional existing DB connection (to avoid pool exhaustion)
|
|
|
|
Returns:
|
|
dict: {
|
|
'allowed': bool,
|
|
'limit': int | None, # NULL = unlimited
|
|
'used': int,
|
|
'remaining': int | None, # NULL = unlimited
|
|
'reason': str # 'unlimited', 'within_limit', 'limit_exceeded', 'feature_disabled'
|
|
}
|
|
"""
|
|
# Use existing connection if provided
|
|
if conn:
|
|
return _check_impl(profile_id, feature_id, conn)
|
|
else:
|
|
with get_db() as conn:
|
|
return _check_impl(profile_id, feature_id, conn)
|
|
|
|
|
|
def _check_impl(profile_id: str, feature_id: str, conn) -> dict:
|
|
"""Internal implementation of check_feature_access."""
|
|
cur = get_cursor(conn)
|
|
|
|
# Get feature info
|
|
cur.execute("""
|
|
SELECT limit_type, reset_period, default_limit
|
|
FROM features
|
|
WHERE id = %s AND active = true
|
|
""", (feature_id,))
|
|
feature = cur.fetchone()
|
|
|
|
if not feature:
|
|
return {
|
|
'allowed': False,
|
|
'limit': None,
|
|
'used': 0,
|
|
'remaining': None,
|
|
'reason': 'feature_not_found'
|
|
}
|
|
|
|
# Priority 1: Check user-specific restriction
|
|
cur.execute("""
|
|
SELECT limit_value
|
|
FROM user_feature_restrictions
|
|
WHERE profile_id = %s AND feature_id = %s
|
|
""", (profile_id, feature_id))
|
|
restriction = cur.fetchone()
|
|
|
|
if restriction is not None:
|
|
limit = restriction['limit_value']
|
|
else:
|
|
# Priority 2: Check tier limit
|
|
tier_id = get_effective_tier(profile_id, conn)
|
|
cur.execute("""
|
|
SELECT limit_value
|
|
FROM tier_limits
|
|
WHERE tier_id = %s AND feature_id = %s
|
|
""", (tier_id, feature_id))
|
|
tier_limit = cur.fetchone()
|
|
|
|
if tier_limit is not None:
|
|
limit = tier_limit['limit_value']
|
|
else:
|
|
# Priority 3: Feature default
|
|
limit = feature['default_limit']
|
|
|
|
# For boolean features (limit 0 = disabled, 1 = enabled)
|
|
if feature['limit_type'] == 'boolean':
|
|
allowed = limit == 1
|
|
return {
|
|
'allowed': allowed,
|
|
'limit': limit,
|
|
'used': 0,
|
|
'remaining': None,
|
|
'reason': 'enabled' if allowed else 'feature_disabled'
|
|
}
|
|
|
|
# For count-based features
|
|
# Check current usage
|
|
cur.execute("""
|
|
SELECT usage_count, reset_at
|
|
FROM user_feature_usage
|
|
WHERE profile_id = %s AND feature_id = %s
|
|
""", (profile_id, feature_id))
|
|
usage = cur.fetchone()
|
|
|
|
used = usage['usage_count'] if usage else 0
|
|
|
|
# Check if reset is needed
|
|
if usage and usage['reset_at'] and datetime.now() > usage['reset_at']:
|
|
# Reset usage
|
|
used = 0
|
|
next_reset = _calculate_next_reset(feature['reset_period'])
|
|
cur.execute("""
|
|
UPDATE user_feature_usage
|
|
SET usage_count = 0, reset_at = %s, updated = CURRENT_TIMESTAMP
|
|
WHERE profile_id = %s AND feature_id = %s
|
|
""", (next_reset, profile_id, feature_id))
|
|
conn.commit()
|
|
|
|
# NULL limit = unlimited
|
|
if limit is None:
|
|
return {
|
|
'allowed': True,
|
|
'limit': None,
|
|
'used': used,
|
|
'remaining': None,
|
|
'reason': 'unlimited'
|
|
}
|
|
|
|
# 0 limit = disabled
|
|
if limit == 0:
|
|
return {
|
|
'allowed': False,
|
|
'limit': 0,
|
|
'used': used,
|
|
'remaining': 0,
|
|
'reason': 'feature_disabled'
|
|
}
|
|
|
|
# Check if within limit
|
|
allowed = used < limit
|
|
remaining = limit - used if limit else None
|
|
|
|
return {
|
|
'allowed': allowed,
|
|
'limit': limit,
|
|
'used': used,
|
|
'remaining': remaining,
|
|
'reason': 'within_limit' if allowed else 'limit_exceeded'
|
|
}
|
|
|
|
|
|
def increment_feature_usage(profile_id: str, feature_id: str) -> None:
|
|
"""
|
|
Increment usage counter for a feature.
|
|
|
|
Creates usage record if it doesn't exist, with reset_at based on
|
|
feature's reset_period.
|
|
"""
|
|
with get_db() as conn:
|
|
cur = get_cursor(conn)
|
|
|
|
# Get feature reset period
|
|
cur.execute("""
|
|
SELECT reset_period
|
|
FROM features
|
|
WHERE id = %s
|
|
""", (feature_id,))
|
|
feature = cur.fetchone()
|
|
|
|
if not feature:
|
|
return
|
|
|
|
reset_period = feature['reset_period']
|
|
next_reset = _calculate_next_reset(reset_period)
|
|
|
|
# Upsert usage
|
|
cur.execute("""
|
|
INSERT INTO user_feature_usage (profile_id, feature_id, usage_count, reset_at)
|
|
VALUES (%s, %s, 1, %s)
|
|
ON CONFLICT (profile_id, feature_id)
|
|
DO UPDATE SET
|
|
usage_count = user_feature_usage.usage_count + 1,
|
|
updated = CURRENT_TIMESTAMP
|
|
""", (profile_id, feature_id, next_reset))
|
|
|
|
conn.commit()
|
|
|
|
|
|
def _calculate_next_reset(reset_period: str) -> Optional[datetime]:
|
|
"""
|
|
Calculate next reset timestamp based on reset period.
|
|
|
|
Args:
|
|
reset_period: 'never', 'daily', 'monthly'
|
|
|
|
Returns:
|
|
datetime or None (for 'never')
|
|
"""
|
|
if reset_period == 'never':
|
|
return None
|
|
elif reset_period == 'daily':
|
|
# Reset at midnight
|
|
tomorrow = datetime.now().date() + timedelta(days=1)
|
|
return datetime.combine(tomorrow, datetime.min.time())
|
|
elif reset_period == 'monthly':
|
|
# Reset at start of next month
|
|
now = datetime.now()
|
|
if now.month == 12:
|
|
return datetime(now.year + 1, 1, 1)
|
|
else:
|
|
return datetime(now.year, now.month + 1, 1)
|
|
else:
|
|
return None
|