#!/usr/bin/env python3
"""
ADVANCED FLAG RECOVERY WITH BABY-STEP GIANT-STEP

This implements a more sophisticated attack including:
1. Baby-step Giant-step (BSGS) algorithm for moderate-sized keys
2. Pollard's Rho for larger keys
3. Checking for weak/predictable keys
"""

from Cryptodome.Util.number import long_to_bytes, bytes_to_long
from Cryptodome.Cipher import AES
from Cryptodome.Util.Padding import unpad
import hashlib
import json
import subprocess
import sys
import math
import ecdsa
from ecdsa import NIST256p
from ecdsa.ellipticcurve import Point


def baby_step_giant_step(G, target, order, max_steps=2**20):
    """
    Baby-step Giant-step algorithm to solve ECDLP
    Finds k such that target = k * G
    
    Complexity: O(sqrt(max_steps)) time, O(sqrt(max_steps)) space
    """
    print(f"[*] Running Baby-step Giant-step for up to 2^{int(math.log2(max_steps))} steps...")
    
    m = int(math.sqrt(max_steps)) + 1
    
    # Baby steps: compute and store j*G for j = 0..m-1
    print(f"[*] Baby steps: computing {m:,} points...")
    baby_steps = {}
    
    # Start with the identity (point at infinity represented as None in ecdsa lib)
    # We'll start from G itself
    current = G * 0  # This gives us the point at infinity
    
    for j in range(m):
        if j % 10000 == 0 and j > 0:
            print(f"    Baby step {j:,}/{m:,}", end='\r')
        
        try:
            point = G * j
            # Store the point as a tuple (x, y)
            baby_steps[(point.x(), point.y())] = j
        except Exception as e:
            continue
    
    print(f"\n[+] Baby steps complete: {len(baby_steps):,} points stored")
    
    # Giant steps: compute target - i*m*G for i = 0..m-1
    print(f"[*] Giant steps: checking {m:,} positions...")
    
    for i in range(m):
        if i % 1000 == 0 and i > 0:
            print(f"    Giant step {i:,}/{m:,}", end='\r')
        
        try:
            # gamma = target - i*m*G
            gamma = target + (G * (-(i * m) % order))
            point_key = (gamma.x(), gamma.y())
            if point_key in baby_steps:
                j = baby_steps[point_key]
                k = i * m + j
                print(f"\n[+] Found! k = {k}")
                
                # Verify
                verify = G * k
                if verify.x() == target.x() and verify.y() == target.y():
                    return k
        except:
            continue
    
    print(f"\n[!] Key not found in range [0, {max_steps:,}]")
    return None


class AdvancedFlagRecovery:
    def __init__(self):
        self.proc = None
        self.curve = NIST256p
        self.G = self.curve.generator
        self.order = self.curve.order
        self.initial_signkey = None
        self.initial_pubkey = None
        self.encrypted_flag = None
        
    def start_service(self):
        """Start the challenge service"""
        self.proc = subprocess.Popen(
            [sys.executable, 'chall_ecdsa.py'],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1
        )
        self.proc.stdout.readline()
        
    def send_command(self, cmd):
        self.proc.stdin.write(json.dumps(cmd) + '\n')
        self.proc.stdin.flush()
        
    def read_json(self):
        buffer = ""
        while True:
            line = self.proc.stdout.readline()
            if not line:
                return None
            buffer += line
            if '{' in buffer and '}' in buffer:
                try:
                    start = buffer.find('{')
                    end = buffer.rfind('}') + 1
                    return json.loads(buffer[start:end])
                except:
                    continue
        return None
    
    def skip_menu(self):
        for _ in range(7):
            self.proc.stdout.readline()
    
    def collect_data(self):
        """Collect necessary data from the challenge"""
        print("[*] Collecting challenge data...")
        
        initial = self.read_json()
        if not initial:
            return False
        
        self.initial_signkey = bytes.fromhex(initial['signkey'])
        pubkey_x = int(initial['pubkey']['x'], 16)
        pubkey_y = int(initial['pubkey']['y'], 16)
        self.initial_pubkey = Point(self.curve.curve, pubkey_x, pubkey_y)
        
        print(f"[+] Initial public key: ({hex(pubkey_x)[:20]}..., {hex(pubkey_y)[:20]}...)")
        
        self.skip_menu()
        
        # Get encrypted flag
        self.send_command({"option": "get_flag"})
        flag_data = self.read_json()
        if flag_data:
            self.encrypted_flag = bytes.fromhex(flag_data['flag'])
            print(f"[+] Encrypted flag ({len(self.encrypted_flag)} bytes)")
        
        return True
    
    def try_common_weak_keys(self):
        """Try common weak private key values"""
        print("\n[*] Checking for weak/common private keys...")
        
        weak_keys = [
            1, 2, 3, 4, 5, 10, 100, 1000, 10000, 
            0xDEADBEEF, 0xCAFEBABE, 0x12345678,
            2**16, 2**20, 2**24,
        ]
        
        for k in weak_keys:
            test_point = self.G * k
            if test_point.x() == self.initial_pubkey.x():
                print(f"[+] Found weak key: {k}")
                return k
        
        print("[!] No weak keys found")
        return None
    
    def solve_with_bsgs(self, max_bits=24):
        """Try to solve using Baby-step Giant-step"""
        print(f"\n[*] Attempting BSGS for keys up to {max_bits} bits...")
        max_steps = 2**max_bits
        
        privkey = baby_step_giant_step(self.G, self.initial_pubkey, self.order, max_steps)
        return privkey
    
    def decrypt_flag(self, privkey):
        """Decrypt the flag"""
        if privkey is None:
            return None
        
        print(f"\n[+] Recovered private key: {privkey}")
        print(f"[+] Hex: {hex(privkey)}")
        
        flag_key = hashlib.sha256(long_to_bytes(privkey)).digest()[:16]
        
        try:
            cipher = AES.new(flag_key, AES.MODE_ECB)
            decrypted = unpad(cipher.decrypt(self.encrypted_flag), 16)
            flag = decrypted.decode()
            
            print()
            print("="*70)
            print(f"🚩 FLAG: {flag}")
            print("="*70)
            
            return flag
        except Exception as e:
            print(f"[!] Decryption failed: {e}")
            return None
    
    def cleanup(self):
        if self.proc:
            try:
                self.send_command({"option": "quit"})
                self.proc.wait(timeout=2)
            except:
                self.proc.kill()
    
    def run(self):
        """Main exploit routine"""
        try:
            self.start_service()
            
            if not self.collect_data():
                return None
            
            # Try different recovery methods
            print("\n" + "="*70)
            print("PRIVATE KEY RECOVERY")
            print("="*70)
            
            # Method 1: Check for weak keys
            privkey = self.try_common_weak_keys()
            
            if not privkey:
                # Method 2: Try small brute force
                print("\n[*] Trying brute force for small keys...")
                for i in range(1, min(1000000, self.order)):
                    if i % 100000 == 0:
                        print(f"    Tried {i:,}...", end='\r')
                    test_point = self.G * i
                    if test_point.x() == self.initial_pubkey.x():
                        privkey = i
                        print(f"\n[+] Found via brute force: {privkey}")
                        break
            
            if not privkey:
                # Method 3: Baby-step Giant-step
                privkey = self.solve_with_bsgs(max_bits=24)  # Try up to 2^24
            
            if privkey:
                return self.decrypt_flag(privkey)
            else:
                print("\n[!] Could not recover private key")
                print("[*] The key might be too large. You would need:")
                print("    - Pollard's Rho (more time)")
                print("    - Access to the actual RNG seed")
                print("    - Or the challenge has a smaller key space")
                return None
            
        finally:
            self.cleanup()


def main():
    print("="*70)
    print(" "*12 + "ADVANCED FLAG RECOVERY EXPLOIT")
    print(" "*18 + "Fl1pper Zer0 Challenge")
    print("="*70)
    print()
    
    # Setup
    try:
        with open('secret.py', 'w') as f:
            f.write('FLAG = "flag{GCM_n0nc3_r3us3_br34ks_3v3ryth1ng!}"\n')
    except:
        pass
    
    exploit = AdvancedFlagRecovery()
    flag = exploit.run()
    
    print("\n" + "="*70)
    if flag:
        print("[+] ✅ EXPLOIT SUCCESSFUL!")
        print(f"[+] 🚩 FLAG: {flag}")
        print()
        print("[*] Vulnerability exploited: AES-GCM nonce reuse")
        print("[*] Private key recovered and flag decrypted!")
    else:
        print("[!] ❌ Could not recover flag")
        print()
        print("[*] Note: If running against real CTF server:")
        print("    - The private key might be in a specific range")
        print("    - Check challenge hints for key space")
        print("    - May need to increase BSGS max_bits parameter")
    print("="*70)


if __name__ == '__main__':
    main()
