#!/usr/bin/env python3
"""
Decode and analyze JWT tokens from captured traffic
"""

import json
import base64
import sys
import argparse
from datetime import datetime

def decode_jwt(token):
    """Decode a JWT token and return the payload"""
    try:
        # JWT format: header.payload.signature
        parts = token.split('.')
        
        if len(parts) != 3:
            print("[!] Invalid JWT format (expected 3 parts separated by dots)")
            return None
        
        header, payload, signature = parts
        
        # Add padding if needed
        def add_padding(s):
            padding = 4 - len(s) % 4
            if padding != 4:
                s += '=' * padding
            return s
        
        # Decode header
        header_decoded = base64.urlsafe_b64decode(add_padding(header))
        header_json = json.loads(header_decoded)
        
        # Decode payload
        payload_decoded = base64.urlsafe_b64decode(add_padding(payload))
        payload_json = json.loads(payload_decoded)
        
        return {
            "header": header_json,
            "payload": payload_json,
            "signature": signature[:20] + "..." if len(signature) > 20 else signature
        }
    
    except Exception as e:
        print(f"[!] Error decoding JWT: {e}")
        return None

def analyze_token(token):
    """Analyze a JWT token and extract useful information"""
    print("\n" + "="*80)
    print("  JWT TOKEN ANALYSIS")
    print("="*80)
    
    decoded = decode_jwt(token)
    if not decoded:
        return
    
    print("\n[📋] Header:")
    for key, value in decoded['header'].items():
        print(f"  {key}: {value}")
    
    print("\n[📊] Payload:")
    payload = decoded['payload']
    for key, value in payload.items():
        if key in ['exp', 'iat', 'nbf']:
            # Convert timestamp to readable date
            try:
                date = datetime.fromtimestamp(int(value))
                print(f"  {key}: {value} ({date})")
            except:
                print(f"  {key}: {value}")
        elif key == 'sub' or key == 'user_id' or key == 'uid':
            print(f"  {key}: {value} [⭐ USER ID]")
        elif isinstance(value, dict):
            print(f"  {key}: {json.dumps(value, indent=4)}")
        else:
            # Truncate long values
            val_str = str(value)
            if len(val_str) > 60:
                print(f"  {key}: {val_str[:60]}...")
            else:
                print(f"  {key}: {val_str}")
    
    # Check expiration
    if 'exp' in payload:
        exp_time = datetime.fromtimestamp(payload['exp'])
        now = datetime.now()
        if exp_time > now:
            remaining = exp_time - now
            print(f"\n[⏰] Token expires in: {remaining}")
        else:
            print(f"\n[⏰] Token EXPIRED at: {exp_time}")
    
    # Identify token type
    print("\n[🔍] Token Type:")
    if 'user_id' in payload or 'sub' in payload:
        print("  Type: User Session Token (can access user data)")
    elif 'app' in payload or 'client' in payload:
        print("  Type: App/Client Token (limited access)")
    elif 'firebase' in token.lower() or 'iss' in payload and 'firebase' in str(payload.get('iss', '')):
        print("  Type: Firebase Token")
    else:
        print("  Type: Unknown")
    
    # Suggest next actions
    print("\n[📍] Next Steps:")
    if 'user_id' in payload:
        user_id = payload['user_id']
        print(f"  1. Use this token to access user {user_id}'s data")
        print(f"  2. Try API: GET /api/v1/users/{user_id}")
    print("  3. Use this token in Authorization header for API calls")
    print("  4. Can enumerate other user IDs if ID is sequential")

def extract_tokens_from_file(filename):
    """Extract all JWT tokens from a text file (mitmproxy dump)"""
    print("[*] Searching for JWT tokens in file...")
    
    tokens = []
    try:
        with open(filename, 'r', errors='ignore') as f:
            content = f.read()
            
            # Look for Bearer tokens
            import re
            bearer_pattern = r'Bearer\s+([A-Za-z0-9_\-\.]+)'
            matches = re.findall(bearer_pattern, content)
            
            for match in matches:
                # Validate JWT format
                if match.count('.') == 2:
                    tokens.append(match)
            
            # Also look for standalone JWT patterns
            jwt_pattern = r'([A-Za-z0-9_\-]+\.[A-Za-z0-9_\-]+\.[A-Za-z0-9_\-]+)'
            jwt_matches = re.findall(jwt_pattern, content)
            
            for match in jwt_matches:
                if match not in tokens:
                    tokens.append(match)
    
    except Exception as e:
        print(f"[!] Error reading file: {e}")
        return []
    
    return tokens

def main():
    parser = argparse.ArgumentParser(description='Analyze JWT tokens from captured traffic')
    parser.add_argument('token', nargs='?', help='JWT token to analyze')
    parser.add_argument('--file', help='Extract tokens from file')
    parser.add_argument('--list', action='store_true', help='List all tokens found in file')
    
    args = parser.parse_args()
    
    if args.file:
        # Extract from file
        tokens = extract_tokens_from_file(args.file)
        
        if not tokens:
            print(f"[!] No JWT tokens found in {args.file}")
            return
        
        print(f"\n[✓] Found {len(tokens)} potential JWT tokens")
        
        if args.list:
            print("\nTokens found:")
            for i, token in enumerate(tokens, 1):
                print(f"{i}. {token[:50]}...")
        
        else:
            # Analyze first token
            print(f"\nAnalyzing first token...")
            analyze_token(tokens[0])
            
            if len(tokens) > 1:
                print(f"\n[*] Found {len(tokens) - 1} more tokens")
                print("[*] Use --list to see all tokens")
    
    elif args.token:
        # Analyze provided token
        analyze_token(args.token)
    
    else:
        print("Usage: python decode_jwt.py <token>")
        print("       python decode_jwt.py --file <mitmproxy_dump> [--list]")
        print("\nExample:")
        print("  python decode_jwt.py eyJhbGc...")
        print("  python decode_jwt.py --file mitmproxy_dump.txt --list")

if __name__ == "__main__":
    main()
