#!/usr/bin/env python3
"""
FINAL WORKING EXPLOIT for Fl1pper Zer0
This version has proper timeouts and debugging
"""

import json
import subprocess
import sys
import hashlib
import random
from binascii import unhexlify

# GF(2^128) arithmetic
R = 0xE1000000000000000000000000000000

def gf_mul(x, y):
    z, v = 0, x
    for i in range(128):
        if (y >> (127 - i)) & 1:
            z ^= v
        v <<= 1
        if v >> 128:
            v = (v & ((1 << 128) - 1)) ^ R
    return z

def gf_pow(a, e):
    res, base = 1, a
    while e:
        if e & 1:
            res = gf_mul(res, base)
        base = gf_mul(base, base)
        e >>= 1
    return res

def gf_inv(x):
    return gf_pow(x, (1 << 128) - 2)

def gf_div(a, b):
    return gf_mul(a, gf_inv(b))

def ghash_two_blocks(H, C1, C2):
    L = 256  # |A|=0, |C|=256 bits
    return gf_mul(gf_mul(C1, H) ^ C2, H) ^ L

def compute_tag(H, S, C_bytes):
    C1 = int.from_bytes(C_bytes[:16], "big")
    C2 = int.from_bytes(C_bytes[16:], "big")
    g = ghash_two_blocks(H, C1, C2)
    return (S ^ g).to_bytes(16, "big")

print("="*70)
print(" "*20 + "FL1PPER ZER0 SOLVER")
print("="*70)
print()

# Start service
print("[*] Starting service...")
proc = subprocess.Popen(
    [sys.executable, "chall_ecdsa.py"],
    stdin=subprocess.PIPE,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    text=True,
    bufsize=1
)

def read_json():
    buf = ""
    for _ in range(20):
        line = proc.stdout.readline()
        if not line:
            return None
        buf += line
        if '{' in buf and '}' in buf:
            try:
                return json.loads(buf[buf.find('{'):buf.rfind('}')+1])
            except:
                continue
    return None

def send(cmd):
    proc.stdin.write(json.dumps(cmd) + '\n')
    proc.stdin.flush()
    return read_json()

def skip_menu():
    for _ in range(7):
        proc.stdout.readline()

try:
    # Collect 3 samples
    print("[*] Collecting samples...")
    samples = []
    
    data = read_json()
    if data:
        sk = unhexlify(data['signkey'])
        samples.append((sk[:16], sk[16:]))
        print(f"[+] Sample 1/3")
    
    skip_menu()
    
    for i in range(2):
        send({"option": "generate_key"})
        data = read_json()
        if data and 'signkey' in data:
            sk = unhexlify(data['signkey'])
            samples.append((sk[:16], sk[16:]))
            print(f"[+] Sample {len(samples)}/3")
        skip_menu()
    
    if len(samples) < 3:
        print("[!] Failed to collect enough samples")
        sys.exit(1)
    
    # Recover H and S
    print("\n[*] Recovering GCM auth key H and mask S...")
    (T0, C0), (T1, C1), (T2, C2) = samples
    
    def b2i(b):
        return int.from_bytes(b, "big")
    
    C10, C20 = b2i(C0[:16]), b2i(C0[16:])
    C11, C21 = b2i(C1[:16]), b2i(C1[16:])
    C12, C22 = b2i(C2[:16]), b2i(C2[16:])
    
    dT1 = b2i(T0) ^ b2i(T1)
    dC11, dC21 = C10 ^ C11, C20 ^ C21
    dT2 = b2i(T0) ^ b2i(T2)
    dC12, dC22 = C10 ^ C12, C20 ^ C22
    
    alpha1, beta1 = gf_div(dC21, dC11), gf_div(dT1, dC11)
    alpha2, beta2 = gf_div(dC22, dC12), gf_div(dT2, dC12)
    
    H = gf_div(beta1 ^ beta2, alpha1 ^ alpha2)
    S = b2i(T0) ^ ghash_two_blocks(H, C10, C20)
    
    print(f"[+] H = 0x{H:032x}")
    print(f"[+] S = 0x{S:032x}")
    
    # Farm signatures for ECDSA nonce reuse
    print("\n[*] Farming ECDSA signatures (this may take a minute)...")
    print("[*] Looking for repeated 'r' values (k reuse)...")
    
    n = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551
    delta = (1 << 255)
    delta_bytes = delta.to_bytes(32, "big")
    
    tag0, ct0 = samples[0]
    bucket = {}
    rnd = random.Random(42)
    
    MAX_ATTEMPTS = 50000  # Limit attempts
    
    for attempt in range(1, MAX_ATTEMPTS + 1):
        if attempt % 100 == 0:
            print(f"    Attempt {attempt}/{MAX_ATTEMPTS} ({len(bucket)} unique r values)...", end='\r')
        
        # Generate random message
        msg = rnd.randbytes(16)
        z = int(hashlib.sha256(msg).hexdigest(), 16)
        
        # Forge signkey with delta XOR
        ct_forge = bytes(a ^ b for a, b in zip(ct0, delta_bytes))
        tag_forge = compute_tag(H, S, ct_forge)
        
        # Get signature
        resp = send({
            "option": "sign",
            "msg": msg.hex(),
            "signkey": (tag_forge + ct_forge).hex()
        })
        
        if not resp or 'r' not in resp:
            skip_menu()
            continue
        
        r = int(resp['r'], 16)
        s = int(resp['s'], 16)
        skip_menu()
        
        # Check for collision
        if r in bucket:
            for z2, s2 in bucket[r]:
                if z2 != z:
                    print(f"\n\n[+] FOUND NONCE REUSE at attempt {attempt}!")
                    print(f"[+] r = 0x{r:064x}")
                    
                    # Recover private key
                    kinv = pow((s - s2) % n, -1, n)
                    k = ((z - z2) * kinv) % n
                    rinv = pow(r, -1, n)
                    sk_xor_delta = ((s * k - z) * rinv) % n
                    sk = sk_xor_delta ^ delta
                    
                    print(f"[+] Recovered private key: 0x{sk:064x}")
                    
                    # Get and decrypt flag
                    print("\n[*] Decrypting flag...")
                    flag_resp = send({"option": "get_flag"})
                    
                    if flag_resp and 'flag' in flag_resp:
                        enc = unhexlify(flag_resp['flag'])
                        key = hashlib.sha256(sk.to_bytes(32, "big")).digest()[:16]
                        
                        from Cryptodome.Cipher import AES
                        from Cryptodome.Util.Padding import unpad
                        
                        flag = unpad(AES.new(key, AES.MODE_ECB).decrypt(enc), 16).decode()
                        
                        print("\n" + "="*70)
                        print(f"🚩 FLAG: {flag}")
                        print("="*70)
                        
                        proc.terminate()
                        sys.exit(0)
        
        bucket.setdefault(r, []).append((z, s))
        
        # Check if we're making progress
        if attempt % 5000 == 0:
            print(f"\n[*] Progress: {len(bucket)} unique r values found so far")
    
    print(f"\n\n[!] No nonce reuse found in {MAX_ATTEMPTS} attempts")
    print("[!] This might indicate:")
    print("    1. Need more attempts (increase MAX_ATTEMPTS)")
    print("    2. RNG is actually secure")
    print("    3. Issue with the exploit")

except KeyboardInterrupt:
    print("\n\n[!] Interrupted by user")
except Exception as e:
    print(f"\n\n[!] Error: {e}")
    import traceback
    traceback.print_exc()
finally:
    try:
        proc.terminate()
        proc.wait(timeout=1)
    except:
        proc.kill()
