diff --git a/backend/auth.py b/backend/auth.py index 21b0042..936628b 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -121,17 +121,22 @@ def require_admin(x_auth_token: Optional[str] = Header(default=None)): # Feature Access Control (v9c) # ============================================================================ -def get_effective_tier(profile_id: str) -> str: +def get_effective_tier(profile_id: str, conn=None) -> str: """ Get the effective tier for a profile. Checks for active access_grants first (from coupons, trials, etc.), then falls back to profile.tier. + Args: + profile_id: User profile ID + conn: Optional existing DB connection (to avoid pool exhaustion) + Returns: tier_id (str): 'free', 'basic', 'premium', or 'selfhosted' """ - with get_db() as conn: + # Use existing connection if provided, otherwise open new one + if conn: cur = get_cursor(conn) # Check for active access grants (highest priority) @@ -154,9 +159,13 @@ def get_effective_tier(profile_id: str) -> str: cur.execute("SELECT tier FROM profiles WHERE id = %s", (profile_id,)) profile = cur.fetchone() return profile['tier'] if profile else 'free' + else: + # Open new connection if none provided + with get_db() as conn: + return get_effective_tier(profile_id, conn) -def check_feature_access(profile_id: str, feature_id: str) -> dict: +def check_feature_access(profile_id: str, feature_id: str, conn=None) -> dict: """ Check if a profile has access to a feature. @@ -165,6 +174,11 @@ def check_feature_access(profile_id: str, feature_id: str) -> dict: 2. Tier limit (tier_limits) 3. Feature default (features.default_limit) + Args: + profile_id: User profile ID + feature_id: Feature ID to check + conn: Optional existing DB connection (to avoid pool exhaustion) + Returns: dict: { 'allowed': bool, @@ -174,8 +188,17 @@ def check_feature_access(profile_id: str, feature_id: str) -> dict: 'reason': str # 'unlimited', 'within_limit', 'limit_exceeded', 'feature_disabled' } """ - with get_db() as conn: - cur = get_cursor(conn) + # Use existing connection if provided + if conn: + return _check_impl(profile_id, feature_id, conn) + else: + with get_db() as conn: + return _check_impl(profile_id, feature_id, conn) + + +def _check_impl(profile_id: str, feature_id: str, conn) -> dict: + """Internal implementation of check_feature_access.""" + cur = get_cursor(conn) # Get feature info cur.execute(""" @@ -206,7 +229,7 @@ def check_feature_access(profile_id: str, feature_id: str) -> dict: limit = restriction['limit_value'] else: # Priority 2: Check tier limit - tier_id = get_effective_tier(profile_id) + tier_id = get_effective_tier(profile_id, conn) cur.execute(""" SELECT limit_value FROM tier_limits diff --git a/backend/routers/features.py b/backend/routers/features.py index 69766aa..228315a 100644 --- a/backend/routers/features.py +++ b/backend/routers/features.py @@ -187,7 +187,8 @@ def get_feature_usage(x_profile_id: Optional[str]=Header(default=None), session: for feature in features: # Use existing check_feature_access to get usage and limits # This respects user overrides, tier limits, and feature defaults - access = check_feature_access(pid, feature['id']) + # Pass connection to avoid pool exhaustion + access = check_feature_access(pid, feature['id'], conn) # Get reset date from user_feature_usage cur.execute("""