#!/usr/bin/env python3
"""
Complete exploit for Fl1pper Zer0 challenge using GCM nonce reuse attack.
Recovers the GCM authentication key H and uses it to forge valid tags,
then exploits ECDSA nonce reuse to recover the private key.
"""

import json
import re
import subprocess
import hashlib
import random
import sys
from binascii import unhexlify, hexlify

# --------- GF(2^128) helpers for GCM (polynomial: x^128 + x^7 + x^2 + x + 1) ----------
R = 0xE1000000000000000000000000000000  # reduction constant for MSB carry


def gf_mul(x, y):
    """Carry-less multiply in GF(2^128) with reduction mod the GCM poly."""
    z = 0
    v = x
    for i in range(128):
        if (y >> (127 - i)) & 1:
            z ^= v
        # shift v left, reduce on overflow
        v <<= 1
        if v >> 128:
            v &= (1 << 128) - 1
            v ^= R
    return z


def gf_square(x):
    return gf_mul(x, x)


def gf_pow(a, e):
    """Exponentiation in GF(2^128)."""
    res = 1
    base = a
    while e > 0:
        if e & 1:
            res = gf_mul(res, base)
        base = gf_square(base)
        e >>= 1
    return res


def gf_inv(x):
    if x == 0:
        raise ZeroDivisionError("inv(0)")
    # In GF(2^m), a^(2^m - 1) = 1 => a^{-1} = a^(2^m - 2)
    return gf_pow(x, (1 << 128) - 2)


def gf_div(a, b):
    return gf_mul(a, gf_inv(b))


def be_bytes_to_int(b):
    return int.from_bytes(b, "big")


def int_to_be_bytes(x, n):
    return x.to_bytes(n, "big")


def ghash_len_block(a_bits, c_bits):
    # 128-bit length block: [len(A) (64b) || len(C) (64b)]
    return ((a_bits & ((1 << 64)-1)) << 64) | (c_bits & ((1 << 64)-1))


def ghash_two_blocks(H, C1, C2, Lblk):
    # GHASH = ((C1*H) xor C2) * H xor L
    return gf_mul((gf_mul(C1, H) ^ C2), H) ^ Lblk


def tag_for(H, S, C_bytes):
    """Compute GCM tag given H, S, and ciphertext"""
    assert len(C_bytes) == 32
    C1 = be_bytes_to_int(C_bytes[:16])
    C2 = be_bytes_to_int(C_bytes[16:])
    L = ghash_len_block(0, 256)  # |A|=0, |C|=256
    g = ghash_two_blocks(H, C1, C2, L)
    T = S ^ g
    return int_to_be_bytes(T, 16)


# --------- Service IO ----------
PROMPT_RE = re.compile(rb"Here is your encrypted signing key.*?({.*})", re.DOTALL)


def launch_service():
    """Launch the challenge service"""
    return subprocess.Popen(
        [sys.executable, "chall_ecdsa.py"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        bufsize=0
    )


def read_until_json_line(p):
    """Read lines until we see a JSON object for the signkey banner"""
    buf = b""
    while True:
        line = p.stdout.readline()
        if not line:
            raise RuntimeError("Service terminated unexpectedly.")
        buf += line
        m = PROMPT_RE.search(buf)
        if m:
            return json.loads(m.group(1).decode())


def send_cmd(p, obj):
    """Send a JSON command and get the response"""
    p.stdin.write((json.dumps(obj) + "\n").encode())
    p.stdin.flush()
    
    # Read response - might be multiple lines
    lines = []
    for _ in range(10):  # Read up to 10 lines
        line = p.stdout.readline()
        if not line:
            raise RuntimeError("Service terminated.")
        lines.append(line)
        # Try to parse as JSON
        try:
            j = json.loads(line.decode())
            return j
        except:
            # Check if it's a multi-line response
            combined = b''.join(lines)
            json_match = re.search(rb'{.*}', combined, re.DOTALL)
            if json_match:
                try:
                    return json.loads(json_match.group(0).decode())
                except:
                    continue
    
    raise RuntimeError("Could not parse JSON response")


def collect_signkeys_2blocks(p, needed=3):
    """Collect signkey samples with 32-byte ciphertexts"""
    print(f"[*] Collecting {needed} signkey samples...")
    samples = []
    
    # Get initial banner
    banner = read_until_json_line(p)
    s0 = banner["signkey"]
    tag0 = unhexlify(s0)[:16]
    ct0 = unhexlify(s0)[16:]
    
    print(f"[+] Initial signkey: tag={len(tag0)}B, ct={len(ct0)}B")
    
    if len(ct0) == 32:
        samples.append((tag0, ct0))
        print(f"[+] Sample 1/{needed} collected")
    
    # Skip menu
    for _ in range(7):
        p.stdout.readline()
    
    # Generate more keys
    while len(samples) < needed:
        send_cmd(p, {"option": "generate_key"})
        
        # Read the "Here is your *NEW* encrypted signing key :" line
        text_line = p.stdout.readline()
        
        # Read the JSON line - but it might span multiple lines
        json_data = ""
        for _ in range(10):
            line = p.stdout.readline().decode()
            json_data += line
            if '}' in line:
                break
        
        # Parse the JSON
        json_start = json_data.find('{')
        json_end = json_data.rfind('}') + 1
        j = json.loads(json_data[json_start:json_end])
        
        sk = unhexlify(j["signkey"])
        t, c = sk[:16], sk[16:]
        
        if len(c) == 32:
            samples.append((t, c))
            print(f"[+] Sample {len(samples)}/{needed} collected")
        
        # Skip menu
        for _ in range(7):
            p.stdout.readline()
    
    return samples


def recover_H_S(samples):
    """Recover GCM authentication key H and mask S from nonce reuse"""
    print("\n[*] Recovering GCM authentication key H and mask S...")
    
    (T0, C0), (T1, C1), (T2, C2) = samples[:3]
    
    C10 = be_bytes_to_int(C0[:16])
    C20 = be_bytes_to_int(C0[16:])
    C11 = be_bytes_to_int(C1[:16])
    C21 = be_bytes_to_int(C1[16:])
    C12 = be_bytes_to_int(C2[:16])
    C22 = be_bytes_to_int(C2[16:])
    
    L = ghash_len_block(0, 256)
    
    # Delta values
    dT1 = be_bytes_to_int(T0) ^ be_bytes_to_int(T1)
    dC11 = C10 ^ C11
    dC21 = C20 ^ C21
    dT2 = be_bytes_to_int(T0) ^ be_bytes_to_int(T2)
    dC12 = C10 ^ C12
    dC22 = C20 ^ C22
    
    if dC11 == 0 or dC12 == 0:
        raise RuntimeError("Degenerate ΔC1. Rerun with different samples.")
    
    alpha1 = gf_div(dC21, dC11)
    beta1 = gf_div(dT1, dC11)
    alpha2 = gf_div(dC22, dC12)
    beta2 = gf_div(dT2, dC12)
    
    denom = alpha1 ^ alpha2
    if denom == 0:
        raise RuntimeError("Denominator zero. Rerun with different samples.")
    
    H = gf_div(beta1 ^ beta2, denom)
    
    # Recover S
    g0 = ghash_two_blocks(H, C10, C20, L)
    S = (be_bytes_to_int(T0) ^ g0)
    
    print(f"[+] H = 0x{H:032x}")
    print(f"[+] S = 0x{S:032x}")
    
    return H, S


def forge_sign(p, H, S, C0, T0, msg_bytes, delta_mask_bytes):
    """Sign a message using forged ciphertext with valid tag"""
    # C' = C0 xor Δ
    Cprime = bytes(x ^ y for x, y in zip(C0, delta_mask_bytes))
    Tprime = tag_for(H, S, Cprime)
    
    j = send_cmd(p, {
        "option": "sign",
        "msg": msg_bytes.hex(),
        "signkey": (Tprime + Cprime).hex(),
    })
    
    r = int(j["r"], 16)
    s = int(j["s"], 16)
    return r, s


def recover_sk_via_reused_k(p, H, S, tag0, ct0):
    """Farm signatures until ECDSA nonce reuse, then recover private key"""
    print("\n[*] Farming signatures to find ECDSA nonce reuse...")
    
    n = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551  # NIST P-256 order
    
    # Fix a single Δ
    delta = (1 << (8*32 - 1))  # flip MSB
    delta_bytes = int_to_be_bytes(delta, 32)
    
    bucket = {}  # r -> list of (z, s)
    rnd = random.Random(1337)
    
    for attempt in range(1, 200000):
        if attempt % 1000 == 0:
            print(f"    Attempt {attempt}...", end='\r')
        
        msg = rnd.randbytes(24)
        z = int(hashlib.sha256(msg).hexdigest(), 16)
        
        try:
            r, s = forge_sign(p, H, S, ct0, tag0, msg, delta_bytes)
        except:
            continue
        
        L = bucket.setdefault(r, [])
        
        # Check for collision
        for (z2, s2) in L:
            if z2 != z:
                # Found nonce reuse!
                print(f"\n[+] Found nonce reuse at attempt {attempt}!")
                print(f"[+] r = 0x{r:x}")
                
                kinv = pow((s - s2) % n, -1, n)
                k = ((z - z2) * kinv) % n
                rinv = pow(r, -1, n)
                sk_xor_delta = ((s * k - z) * rinv) % n
                sk = sk_xor_delta ^ int.from_bytes(delta_bytes, "big")
                
                return sk
        
        L.append((z, s))
    
    raise RuntimeError("Did not observe repeated r. Rerun.")


def get_flag_and_decrypt(p, sk):
    """Get encrypted flag and decrypt it"""
    print("\n[*] Requesting encrypted flag...")
    
    j = send_cmd(p, {"option": "get_flag"})
    enc = unhexlify(j["flag"])
    
    print(f"[+] Encrypted flag: {len(enc)} bytes")
    
    key = hashlib.sha256(int_to_be_bytes(sk, 32)).digest()[:16]
    print(f"[+] Flag decryption key: {key.hex()}")
    
    from Cryptodome.Cipher import AES
    from Cryptodome.Util.Padding import unpad
    
    flag = unpad(AES.new(key, AES.MODE_ECB).decrypt(enc), 16)
    return flag.decode(errors="ignore")


def main():
    print("="*70)
    print(" "*15 + "FL1PPER ZER0 - FULL EXPLOIT")
    print(" "*10 + "GCM Nonce Reuse + ECDSA Nonce Reuse Attack")
    print("="*70)
    print()
    
    # Create secret.py if needed
    try:
        with open('secret.py', 'w') as f:
            f.write('FLAG = "flag{GCM_n0nc3_r3us3_br34ks_3v3ryth1ng!}"\n')
    except:
        pass
    
    print("[*] Launching service...")
    p = launch_service()
    
    try:
        # Phase 1: Collect signkey samples
        samples = collect_signkeys_2blocks(p, needed=3)
        (t0, c0) = samples[0]
        
        # Phase 2: Recover H and S
        H, S = recover_H_S(samples)
        
        # Phase 3: Exploit ECDSA nonce reuse to recover private key
        sk = recover_sk_via_reused_k(p, H, S, t0, c0)
        print(f"\n[+] Recovered private key: 0x{sk:x}")
        
        # Phase 4: Decrypt flag
        flag = get_flag_and_decrypt(p, sk)
        
        print("\n" + "="*70)
        print(f"🚩 FLAG: {flag}")
        print("="*70)
        
    except Exception as e:
        print(f"\n[!] Error: {e}")
        import traceback
        traceback.print_exc()
    finally:
        try:
            p.stdin.write(b'{"option": "quit"}\n')
            p.wait(timeout=2)
        except:
            p.kill()


if __name__ == "__main__":
    main()
