Source code for utils.token_jwt_handling

import base64
from datetime import datetime, timedelta, timezone
from functools import wraps
import logging
import os
from flask import request
import jwt
from typing import Dict, Any, Optional, Tuple
from firebase_admin import firestore


[docs] class AuthError(Exception): """ Custom exception class for authentication and token-related errors. This class provides structured error handling for authentication operations, including status codes for proper HTTP response handling. Attributes: error (str): Descriptive error message status_code (int): HTTP status code for the error Usage: raise AuthError("Token expired", 401) """
[docs] def __init__(self, error: str, status_code: int): super().__init__() self.error = error self.status_code = status_code
""" Enhanced JWT Token Management System Version: 2.0 Author: Gabriel Cellammare Last Modified: 10/01/2025 This module implements a comprehensive JWT (JSON Web Token) handling system with a strong focus on security, encryption, and secure token lifecycle management. It provides robust token generation, validation, and storage mechanisms with defense-in-depth measures. Core Security Features: 1. Token Management - Secure token generation with user binding - Protected token storage with encryption - Automatic token expiration - Token rotation policies - Rate limiting implementation 2. Cryptographic Security - AES encryption for sensitive data - Salt-based key derivation - Secure token signing - Protected storage operations - Memory-safe cleanup 3. Access Control - User session binding - Request validation - IP tracking - Device fingerprinting - Audit logging 4. State Management - Secure token persistence - Protected state transitions - Safe cleanup operations - Error isolation - Recovery procedures Security Considerations: - All tokens are encrypted before storage - User sessions are strictly validated - Token operations are isolated - Rate limiting prevents abuse - Audit trails are maintained - Cleanup is automatic - Error states are safe Dependencies: - jwt: For token operations - firebase_admin: Database operations - cryptography: Encryption operations - datetime: Timestamp management - base64: Secure encoding """
[docs] class TokenJWTHandling:
[docs] def __init__(self, db, cipher): """ Initialize secure token handling with database and encryption support. This constructor sets up the token management system with proper security configuration and initializes the required dependencies. Args: db: Firestore database instance for token storage cipher: AESCipher instance for data encryption/decryption Security Features: - Protected configuration loading - Secure key management - Rate limit configuration - Token lifecycle settings Environment Variables: JWT_SECRET_KEY: Secret key for token signing """ self.__db = db self.__MAX_DAILY_TOKENS = 2 # Configuration self.__JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY') self.__JWT_TOKEN_EXPIRATION = timedelta(days=7) self.__JWT_REQUEST_COOLDOWN = timedelta(hours=12) self.__cipher = cipher
[docs] def get_user_token_history(self, user_id: str, since: datetime) -> list: """ Retrieve a user's token generation history with secure filtering. This method securely queries the database for a user's token history, applying proper time-based filtering and status validation. Args: user_id (str): Unique identifier of the user since (datetime): Starting timestamp for history retrieval Returns: List[Dict]: Filtered list of token history records Security: - Query parameter validation - Status filtering - Time-based constraints """ token_docs = (self.__db.collection('user_tokens') .where('user_id', '==', user_id) .where('created_at', '>=', since) # Only consider active tokens .where('status', '==', 'active') .order_by('created_at', direction='DESCENDING') .stream()) return [doc.to_dict() for doc in token_docs]
[docs] def check_token_request_eligibility(self, user_id: str) -> Tuple[bool, Optional[datetime], Optional[str]]: """ Validate token request eligibility based on security policies. This method implements rate limiting and cooldown periods for token generation to prevent abuse and ensure secure token management. Args: user_id (str): Unique identifier of the requesting user Returns: Tuple[bool, Optional[datetime], Optional[str]]: - Boolean indicating eligibility - Next eligible timestamp if applicable - Error message if request is denied Security: - Daily limit enforcement - Cooldown period validation - Request tracking - Time-based restrictions """ current_time = datetime.now(timezone.utc) day_start = current_time.replace( hour=0, minute=0, second=0, microsecond=0) # Get active tokens from database token_history = self.get_user_token_history(user_id, day_start) if not token_history: return True, None, None # Check daily token limit if len(token_history) >= self.__MAX_DAILY_TOKENS: next_day = day_start + timedelta(days=1) return False, next_day, f'Daily limit reached. Next eligible: {next_day.isoformat()}' # Check cooldown period latest_token = token_history[0] last_request_time = latest_token['created_at'] next_eligible_time = last_request_time + self.__JWT_REQUEST_COOLDOWN if current_time < next_eligible_time: return False, next_eligible_time, f'Please wait until {next_eligible_time.isoformat()}' return True, None, None
def _encrypt_token_data(self, token_data: dict, user_id: str, salt: bytes) -> dict: """ Encrypt sensitive token data for secure storage in Firebase. This method implements secure encryption of token fields before database storage, protecting sensitive information through field-level encryption with user-specific keys. Args: token_data (dict): Token information to be encrypted user_id (str): User identifier for encryption key derivation salt (bytes): Cryptographic salt for key generation Returns: dict: Token data with sensitive fields encrypted Raises: AuthError: When encryption operations fail with status code 500 Security Features: - Field-level encryption - User-specific key derivation - Salt-based encryption - Secure error handling - Protected memory operations Protected Fields: - access_token - refresh_token - expires_at """ encrypted_data = token_data.copy() # Added expires_at to sensitive fields for encryption sensitive_fields = ['access_token', 'refresh_token', 'expires_at'] try: for field in sensitive_fields: if field in token_data: encrypted_value = self.__cipher.encrypt( token_data[field], user_id, salt ) encrypted_data[field] = encrypted_value.decode('utf-8') return encrypted_data except Exception as e: logging.error(f"Token encryption error: {str(e)}") raise AuthError("Token encryption failed", 500) def _decrypt_token_data(self, encrypted_data: dict, user_id: str, salt: bytes) -> dict: """ Decrypt encrypted token data retrieved from Firebase storage. This method securely decrypts token information using user-specific keys and handles proper conversion of decrypted data types. Args: encrypted_data (dict): Encrypted token information from storage user_id (str): User identifier for decryption key derivation salt (bytes): Cryptographic salt used in encryption Returns: dict: Decrypted token data with original field values Raises: AuthError: When decryption fails with status code 500 Security Features: - Secure key derivation - Protected memory handling - Type-safe conversions - Error isolation - Audit logging Protected Operations: - Decryption key generation - Memory cleanup - String encoding handling """ decrypted_data = encrypted_data.copy() sensitive_fields = ['access_token', 'refresh_token', 'expires_at'] try: for field in sensitive_fields: if field in encrypted_data: decrypted_value = self.__cipher.decrypt( encrypted_data[field], user_id, salt ) if isinstance(decrypted_value, bytes): decrypted_value = decrypted_value.decode('utf-8') decrypted_data[field] = decrypted_value return decrypted_data except Exception as e: logging.error(f"Token decryption error: {str(e)}") raise AuthError("Token decryption failed", 500) def _store_encrypted_token(self, user_id: str, token_data: dict) -> None: """ Securely store encrypted token data in Firebase with proper metadata. This method handles the complete process of token storage, including salt retrieval, encryption, and secure database operations with proper error handling. Args: user_id (str): User identifier for token association token_data (dict): Token information to be stored Raises: AuthError: When storage operations fail with status code 500 Security Features: - Salt retrieval protection - Secure encryption - Atomic database operations - Timestamp recording - Status tracking Database Operations: - Salt retrieval from security collection - Token data encryption - Secure document creation - Status initialization """ try: # Get user's salt from security collection security_ref = self.__db.collection( 'user_security').document(user_id) security_data = security_ref.get() encoded_salt = security_data.to_dict()['salt'] salt_bytes = base64.b64decode(encoded_salt) # Encrypt token data encrypted_data = self._encrypt_token_data( token_data, user_id, salt_bytes) # Store in Firebase self.__db.collection('user_tokens').add({ 'user_id': user_id, **encrypted_data, 'created_at': firestore.SERVER_TIMESTAMP, 'status': 'active' }) except Exception as e: logging.error(f"Token storage error: {str(e)}") raise AuthError("Failed to store token", 500)
[docs] def expire_previous_tokens(self, user_id: str) -> None: """ Invalidate all active tokens for a user with secure state transitions. This method implements secure token expiration with proper audit logging and atomic database operations to maintain token lifecycle integrity. Args: user_id (str): User identifier whose tokens should be expired Security Features: - Batch operations for atomicity - Secure decryption for validation - Status transition logging - Time-based validation - Audit trail creation Operations Flow: 1. Retrieve user's cryptographic salt 2. Query active tokens 3. Validate and decrypt tokens 4. Batch update expired status 5. Create audit logs Protected State Transitions: - Active to Expired status - Timestamp recording - Reason documentation """ try: current_time = datetime.now(timezone.utc) # Get user's salt from security collection security_ref = self.__db.collection( 'user_security').document(user_id) security_data = security_ref.get() if not security_data.exists: return None encoded_salt = security_data.to_dict()['salt'] salt_bytes = base64.b64decode(encoded_salt) # Query for active tokens - we'll check expiration after decryption active_tokens = (self.__db.collection('user_tokens') .where('user_id', '==', user_id) .where('status', '==', 'active') .stream()) tokens_to_expire = [] # Decrypt and check expiration for each token for token_doc in active_tokens: encrypted_token = token_doc.to_dict() decrypted_token = self._decrypt_token_data( encrypted_token, user_id, salt_bytes) # Check if token is expired based on decrypted expiration date if current_time > datetime.fromisoformat(decrypted_token['expires_at']): tokens_to_expire.append((token_doc.id, decrypted_token)) # Create a batch operation for efficiency batch = self.__db.batch() expired_count = 0 # Use the pre-filtered tokens_to_expire list for token_id, decrypted_token in tokens_to_expire: doc_ref = self.__db.collection( 'user_tokens').document(token_id) # Update token status to expired batch.update(doc_ref, { 'status': 'expired', 'expired_at': current_time, 'expired_reason': 'new_token_generated' }) expired_count += 1 # Only commit if there are tokens to expire if expired_count > 0: batch.commit() # Create an audit log entry for the mass expiration self.__db.collection('audit_logs').add({ 'user_id': user_id, 'action': 'expire_tokens', 'tokens_expired': expired_count, 'reason': 'new_token_generated', 'timestamp': current_time, 'ip_address': request.remote_addr, 'user_agent': request.user_agent.string }) except Exception as e: logging.error(f"Error expiring tokens: {str(e)}") raise AuthError("Failed to expire tokens", 500)
[docs] def get_active_token(self, user_id: str) -> Optional[Dict[str, Any]]: """ Retrieve the current active token for a user with validation. This method securely fetches and validates the most recent active token, performing necessary decryption and expiration checks. Args: user_id (str): User identifier to check for active tokens Returns: Optional[Dict[str, Any]]: Active token information if valid, None otherwise Security Features: - Salt-based decryption - Time-based validation - Status verification - Secure ordering - Query limiting Validation Checks: - Token existence - Expiration status - Decryption integrity - Time validity """ try: current_time = datetime.now(timezone.utc) # Get user's salt security_ref = self.__db.collection( 'user_security').document(user_id) security_data = security_ref.get() if not security_data.exists: return None encoded_salt = security_data.to_dict()['salt'] salt_bytes = base64.b64decode(encoded_salt) # Query for active tokens tokens = (self.__db.collection('user_tokens') .where('user_id', '==', user_id) .where('status', '==', 'active') .order_by('expires_at', direction='DESCENDING') .limit(1) .stream()) token_list = list(tokens) if not token_list: return None # Decrypt token data before returning encrypted_token = token_list[0].to_dict() decrypted_token = self._decrypt_token_data( encrypted_token, user_id, salt_bytes) # Check if token is expired based on decrypted expiration date if current_time > datetime.fromisoformat(decrypted_token['expires_at']): return None return decrypted_token except Exception as e: logging.error(f"Error retrieving active token: {str(e)}") return None
[docs] def generate_tokens(self, user_id: str) -> Dict[str, Any]: """ Generate new JWT tokens with secure storage and rate limiting. This method creates new JWT tokens with proper security measures, including rate limiting, encryption, and secure storage operations. Args: user_id (str): User identifier for token generation Returns: Dict[str, Any]: Generated token information including expiration and next request time Raises: AuthError: For rate limit violations (429) or generation failures (500) Security Features: - Rate limit enforcement - Previous token expiration - Secure JWT creation - Protected storage - Device tracking Token Properties: - Expiration time - Creation timestamp - User binding - Device information - Request cooldown """ try: # Check if user is eligible for a new token is_eligible, next_eligible_time, error_message = self.check_token_request_eligibility( user_id) if not is_eligible: raise AuthError(error_message, 429) # Expire all previous active tokens before generating a new one self.expire_previous_tokens(user_id) current_time = datetime.now(timezone.utc) token_exp = current_time + self.__JWT_TOKEN_EXPIRATION # Generate new JWT with creation timestamp access_token = jwt.encode({ 'exp': token_exp, 'iat': current_time, 'created_at': current_time.isoformat(), 'user_id': user_id, 'type': 'access' }, self.__JWT_SECRET_KEY, algorithm='HS256') # Enhanced token document token_data = { 'user_id': user_id, 'access_token': access_token, 'created_at': current_time, 'expires_at': token_exp.isoformat(), 'status': 'active', 'device_info': { 'ip_address': request.remote_addr, 'user_agent': request.user_agent.string } } # Store encrypted token self._store_encrypted_token(user_id, token_data) # Calculate next token request time next_token_time = current_time + self.__JWT_REQUEST_COOLDOWN return { 'access_token': access_token, 'token_created_at': current_time.isoformat(), 'expires_in': int(self.__JWT_TOKEN_EXPIRATION.total_seconds()), 'expires_at': token_exp.isoformat(), 'next_token_request': next_token_time.isoformat() } except Exception as e: logging.error(f"Token generation error: {str(e)}") raise AuthError("Failed to generate tokens", 500)
[docs] def get_token_creation_time(self, token: str) -> Optional[datetime]: """ Extract token creation timestamp from JWT claims securely. This method safely decodes JWT claims to retrieve the token creation time with proper error handling and validation. Args: token (str): JWT token to analyze Returns: Optional[datetime]: Token creation timestamp if valid, None otherwise Security Features: - Secure JWT decoding - Signature validation - Time parsing protection - Error isolation """ try: payload = jwt.decode( token, self.__JWT_SECRET_KEY, algorithms=['HS256']) created_at = payload.get('created_at') return datetime.fromisoformat(created_at) if created_at else None except (jwt.InvalidTokenError, ValueError): return None
[docs] def jwt_required(self, f): """ Protect routes with comprehensive JWT validation and security checks. This decorator implements complete token validation including signature verification, database validation, and proper user session binding. Args: f (Callable): The route function to protect Returns: Callable: Protected route function Security Features: - Token presence validation - JWT signature verification - Database token validation - Expiration checking - User session binding - Error logging Validation Flow: 1. Token format verification 2. JWT signature validation 3. Token type checking 4. Database record validation 5. Expiration verification 6. User session binding Raises: AuthError: Various 401 status codes for different validation failures - Missing token - Invalid signature - Token not found - Token expired - Verification failed """ @wraps(f) def decorated(*args, **kwargs): auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): raise AuthError('Missing or invalid token format', 401) # This is the decrypted token from the user token = auth_header.split(' ')[1] try: # First verify JWT signature and expiration payload = jwt.decode( token, self.__JWT_SECRET_KEY, algorithms=['HS256']) # Verify token type if payload.get('type') != 'access': raise AuthError('Invalid token type', 401) user_id = payload['user_id'] # Get user's salt for encryption comparison security_ref = self.__db.collection( 'user_security').document(user_id) security_data = security_ref.get() if not security_data.exists: raise AuthError('Security data not found', 401) encoded_salt = security_data.to_dict()['salt'] salt_bytes = base64.b64decode(encoded_salt) # Find active tokens for this user token_docs = (self.__db.collection('user_tokens') .where('user_id', '==', user_id) .where('status', '==', 'active') .limit(1) .get()) if not token_docs: raise AuthError('No active tokens found', 401) token_doc = list(token_docs)[0].to_dict() # Decrypt stored token to compare with provided token decrypted_token = self._decrypt_token_data( token_doc, user_id, salt_bytes) # Compare the decrypted stored token with the provided token if decrypted_token.get('access_token') != token: raise AuthError('Token not found in database', 401) # Check expiration using decrypted expiry expiry = datetime.fromisoformat( decrypted_token.get('expires_at')) if not expiry or datetime.now(timezone.utc) > expiry: # Update token status to expired doc_ref = list(token_docs)[0].reference doc_ref.update({'status': 'expired'}) raise AuthError('Token expired', 401) # Add user_id to request request.user_id = user_id return f(*args, **kwargs) except jwt.ExpiredSignatureError: raise AuthError('Token expired', 401) except jwt.InvalidTokenError: raise AuthError('Invalid token', 401) except Exception as e: # Log error securely self.__db.collection('error_logs').add({ 'error_type': 'token_verification_error', 'error_message': str(e), 'timestamp': datetime.now(timezone.utc), 'request_path': request.path, 'request_method': request.method }) raise AuthError('Token verification failed', 401) return decorated