#!/usr/bin/env python3
"""
PRAGMATIC SOLVER - Tests if exploit works, with realistic expectations
"""

import json
import subprocess
import sys
import hashlib
import random
import time
from binascii import unhexlify

# GF(2^128) math (simplified)
R = 0xE1000000000000000000000000000000

def gf_mul(x, y):
    z, v = 0, x
    for i in range(128):
        if (y >> (127 - i)) & 1:
            z ^= v
        v <<= 1
        if v >> 128:
            v = (v & ((1 << 128) - 1)) ^ R
    return z

def gf_pow(a, e):
    res, base = 1, a
    while e:
        if e & 1:
            res = gf_mul(res, base)
        base = gf_mul(base, base)
        e >>= 1
    return res

def gf_div(a, b):
    return gf_mul(a, gf_pow(b, (1 << 128) - 2))

def ghash(H, C1, C2):
    return gf_mul(gf_mul(C1, H) ^ C2, H) ^ 256

def tag(H, S, C):
    return (S ^ ghash(H, int.from_bytes(C[:16], "big"), int.from_bytes(C[16:], "big"))).to_bytes(16, "big")

print("="*70)
print("FL1PPER ZER0 - REALISTIC SOLVER")
print("="*70)
print("\nNOTE: ECDSA nonce reuse with Python's random() can take")
print("anywhere from 100 to 100,000+ signatures due to birthday paradox.")
print("Expected time: 1-10 minutes with fast CPU.")
print("="*70)
print()

# Setup
proc = subprocess.Popen(
    [sys.executable, "chall_ecdsa.py"],
    stdin=subprocess.PIPE,
    stdout=subprocess.PIPE,
    text=True,
    bufsize=1
)

def rj():
    b = ""
    for _ in range(20):
        l = proc.stdout.readline()
        if not l:
            return None
        b += l
        if '{' in b and '}' in b:
            try:
                return json.loads(b[b.find('{'):b.rfind('}')+1])
            except:
                pass
    return None

def cmd(c):
    proc.stdin.write(json.dumps(c) + '\n')
    proc.stdin.flush()
    return rj()

def sk():
    for _ in range(7):
        proc.stdout.readline()

try:
    # Phase 1: Collect samples
    print("[1/4] Collecting GCM nonce reuse samples...")
    S = []
    d = rj()
    if d:
        s = unhexlify(d['signkey'])
        S.append((s[:16], s[16:]))
        print(f"  ✓ Sample 1/3")
    sk()
    
    for i in range(2):
        cmd({"option": "generate_key"})
        d = rj()
        if d:
            s = unhexlify(d['signkey'])
            S.append((s[:16], s[16:]))
            print(f"  ✓ Sample {len(S)}/3")
        sk()
    
    # Phase 2: Recover H, S
    print("\n[2/4] Recovering GCM authentication key...")
    (T0,C0),(T1,C1),(T2,C2) = S
    b2i = lambda b: int.from_bytes(b, "big")
    
    C10,C20 = b2i(C0[:16]),b2i(C0[16:])
    C11,C21 = b2i(C1[:16]),b2i(C1[16:])
    C12,C22 = b2i(C2[:16]),b2i(C2[16:])
    
    dT1 = b2i(T0)^b2i(T1)
    dC11,dC21 = C10^C11,C20^C21
    dT2 = b2i(T0)^b2i(T2)
    dC12,dC22 = C10^C12,C20^C22
    
    H = gf_div(gf_div(dT1,dC11)^gf_div(dT2,dC12), gf_div(dC21,dC11)^gf_div(dC22,dC12))
    S_mask = b2i(T0)^ghash(H,C10,C20)
    
    print(f"  ✓ H recovered: {hex(H)[:20]}...")
    print(f"  ✓ S recovered: {hex(S_mask)[:20]}...")
    
    # Phase 3: Farm for ECDSA k reuse
    print("\n[3/4] Farming ECDSA signatures for nonce collision...")
    print("  This uses birthday paradox - expect √(2^256) ≈ 2^128 operations")
    print("  BUT Python's MT has only 2^19937 period, so collisions happen faster")
    print("  Typical: 5,000-50,000 signatures\n")
    
    n = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551
    delta = 1 << 255
    db = delta.to_bytes(32, "big")
    
    t0,c0 = S[0]
    B = {}
    r = random.Random(12345)
    
    start_time = time.time()
    CHECK_INTERVAL = 250
    
    for i in range(1, 200001):
        m = r.randbytes(16)
        z = int(hashlib.sha256(m).hexdigest(), 16)
        
        cf = bytes(a^b for a,b in zip(c0,db))
        tf = tag(H, S_mask, cf)
        
        resp = cmd({"option":"sign","msg":m.hex(),"signkey":(tf+cf).hex()})
        if not resp or 'r' not in resp:
            sk()
            continue
        
        rv = int(resp['r'], 16)
        sv = int(resp['s'], 16)
        sk()
        
        if rv in B:
            for z2,s2 in B[rv]:
                if z2 != z:
                    elapsed = time.time() - start_time
                    print(f"\n  🎯 COLLISION FOUND at signature #{i}!")
                    print(f"  ⏱️  Time: {elapsed:.1f} seconds")
                    print(f"  📊 Collision rate: {i/elapsed:.1f} sigs/sec\n")
                    
                    # Recover private key
                    kinv = pow((sv-s2)%n, -1, n)
                    k = ((z-z2)*kinv)%n
                    sk_d = ((sv*k-z)*pow(rv,-1,n))%n
                    sk = sk_d ^ delta
                    
                    print(f"[4/4] Decrypting flag...")
                    print(f"  ✓ Private key: {hex(sk)[:20]}...\n")
                    
                    fr = cmd({"option":"get_flag"})
                    if fr and 'flag' in fr:
                        enc = unhexlify(fr['flag'])
                        key = hashlib.sha256(sk.to_bytes(32,"big")).digest()[:16]
                        
                        from Cryptodome.Cipher import AES
                        from Cryptodome.Util.Padding import unpad
                        
                        flag = unpad(AES.new(key,AES.MODE_ECB).decrypt(enc),16).decode()
                        
                        print("="*70)
                        print(f"🚩 FLAG: {flag}")
                        print("="*70)
                        proc.terminate()
                        sys.exit(0)
        
        B.setdefault(rv,[]).append((z,sv))
        
        if i % CHECK_INTERVAL == 0:
            elapsed = time.time() - start_time
            rate = i / elapsed
            unique = len(B)
            print(f"  Progress: {i:,} sigs | {unique:,} unique r | {rate:.1f} sig/s | {elapsed:.0f}s", end='\r')
    
    print(f"\n\n⚠️  Reached 200,000 attempts without collision")
    print("This is unusual but possible. Try running again.")

except KeyboardInterrupt:
    print("\n\n⚠️  Interrupted by user (Ctrl+C)")
    print("\nTip: Let it run longer. Collisions are probabilistic.")
except Exception as e:
    print(f"\n\n❌ Error: {e}")
    import traceback
    traceback.print_exc()
finally:
    try:
        proc.terminate()
    except:
        proc.kill()

print("\n" + "="*70)
