#!/usr/bin/env python3
"""
Working exploit for Fl1pper Zer0 Challenge
Exploits GCM nonce reuse to recover H and S, then forges tags to exploit ECDSA nonce reuse
"""

import json
import subprocess
import sys
import hashlib
import random
from binascii import unhexlify

# GF(2^128) arithmetic for GCM
R = 0xE1000000000000000000000000000000

def gf_mul(x, y):
    z = 0
    v = x
    for i in range(128):
        if (y >> (127 - i)) & 1:
            z ^= v
        v <<= 1
        if v >> 128:
            v &= (1 << 128) - 1
            v ^= R
    return z

def gf_square(x):
    return gf_mul(x, x)

def gf_pow(a, e):
    res = 1
    base = a
    while e > 0:
        if e & 1:
            res = gf_mul(res, base)
        base = gf_square(base)
        e >>= 1
    return res

def gf_inv(x):
    if x == 0:
        raise ZeroDivisionError("inv(0)")
    return gf_pow(x, (1 << 128) - 2)

def gf_div(a, b):
    return gf_mul(a, gf_inv(b))

def be_bytes_to_int(b):
    return int.from_bytes(b, "big")

def int_to_be_bytes(x, n):
    return x.to_bytes(n, "big")

def ghash_two_blocks(H, C1, C2):
    L = (256 & ((1<<64)-1))  # |A|=0, |C|=256
    return gf_mul((gf_mul(C1, H) ^ C2), H) ^ L

def compute_tag(H, S, C_bytes):
    assert len(C_bytes) == 32
    C1 = be_bytes_to_int(C_bytes[:16])
    C2 = be_bytes_to_int(C_bytes[16:])
    g = ghash_two_blocks(H, C1, C2)
    T = S ^ g
    return int_to_be_bytes(T, 16)

class Exploit:
    def __init__(self):
        self.proc = None
        
    def start(self):
        print("[*] Starting challenge service...")
        self.proc = subprocess.Popen(
            [sys.executable, "chall_ecdsa.py"],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1
        )
        
    def read_json_from_output(self):
        """Read output until we get a valid JSON"""
        buffer = ""
        while True:
            line = self.proc.stdout.readline()
            if not line:
                return None
            buffer += line
            # Try to extract JSON
            if '{' in buffer and '}' in buffer:
                start = buffer.find('{')
                end = buffer.rfind('}') + 1
                try:
                    return json.loads(buffer[start:end])
                except:
                    continue
                    
    def send(self, cmd):
        """Send command and read response"""
        self.proc.stdin.write(json.dumps(cmd) + '\n')
        self.proc.stdin.flush()
        return self.read_json_from_output()
    
    def skip_menu(self):
        """Skip menu output"""
        for _ in range(7):
            self.proc.stdout.readline()
            
    def collect_samples(self, count=3):
        """Collect encrypted signkey samples"""
        print(f"[*] Collecting {count} samples...")
        samples = []
        
        # Get initial
        data = self.read_json_from_output()
        if data:
            sk = unhexlify(data['signkey'])
            if len(sk) == 48:  # 16 tag + 32 ct
                samples.append((sk[:16], sk[16:]))
                print(f"[+] Sample 1/{count}")
        
        self.skip_menu()
        
        # Get more
        while len(samples) < count:
            self.send({"option": "generate_key"})
            data = self.read_json_from_output()
            if data and 'signkey' in data:
                sk = unhexlify(data['signkey'])
                if len(sk) == 48:
                    samples.append((sk[:16], sk[16:]))
                    print(f"[+] Sample {len(samples)}/{count}")
            self.skip_menu()
            
        return samples
    
    def recover_H_S(self, samples):
        """Recover GCM auth key H and mask S"""
        print("\n[*] Recovering GCM parameters H and S...")
        
        (T0, C0), (T1, C1), (T2, C2) = samples[:3]
        
        C10 = be_bytes_to_int(C0[:16])
        C20 = be_bytes_to_int(C0[16:])
        C11 = be_bytes_to_int(C1[:16])
        C21 = be_bytes_to_int(C1[16:])
        C12 = be_bytes_to_int(C2[:16])
        C22 = be_bytes_to_int(C2[16:])
        
        dT1 = be_bytes_to_int(T0) ^ be_bytes_to_int(T1)
        dC11 = C10 ^ C11
        dC21 = C20 ^ C21
        dT2 = be_bytes_to_int(T0) ^ be_bytes_to_int(T2)
        dC12 = C10 ^ C12
        dC22 = C20 ^ C22
        
        if dC11 == 0 or dC12 == 0:
            raise RuntimeError("Bad samples")
        
        alpha1 = gf_div(dC21, dC11)
        beta1 = gf_div(dT1, dC11)
        alpha2 = gf_div(dC22, dC12)
        beta2 = gf_div(dT2, dC12)
        
        denom = alpha1 ^ alpha2
        if denom == 0:
            raise RuntimeError("Bad samples")
        
        H = gf_div(beta1 ^ beta2, denom)
        g0 = ghash_two_blocks(H, C10, C20)
        S = be_bytes_to_int(T0) ^ g0
        
        print(f"[+] H = 0x{H:032x}")
        print(f"[+] S = 0x{S:032x}")
        return H, S
    
    def forge_and_sign(self, H, S, ct0, msg, delta_bytes):
        """Forge a signkey and get signature"""
        ct_forged = bytes(a ^ b for a, b in zip(ct0, delta_bytes))
        tag_forged = compute_tag(H, S, ct_forged)
        
        response = self.send({
            "option": "sign",
            "msg": msg.hex(),
            "signkey": (tag_forged + ct_forged).hex()
        })
        
        if response and 'r' in response and 's' in response:
            return int(response['r'], 16), int(response['s'], 16)
        return None, None
    
    def exploit_ecdsa_nonce_reuse(self, H, S, tag0, ct0):
        """Farm signatures to find ECDSA nonce reuse"""
        print("\n[*] Farming ECDSA signatures for nonce reuse...")
        
        n = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551
        delta = (1 << (8*32 - 1))
        delta_bytes = int_to_be_bytes(delta, 32)
        
        bucket = {}
        rnd = random.Random(1337)
        
        for i in range(1, 100000):
            if i % 500 == 0:
                print(f"    Attempt {i}...", end='\r')
            
            msg = rnd.randbytes(24)
            z = int(hashlib.sha256(msg).hexdigest(), 16)
            
            r, s = self.forge_and_sign(H, S, ct0, msg, delta_bytes)
            if not r:
                continue
            
            if r in bucket:
                for z2, s2 in bucket[r]:
                    if z2 != z:
                        print(f"\n[+] Found nonce reuse at attempt {i}!")
                        kinv = pow((s - s2) % n, -1, n)
                        k = ((z - z2) * kinv) % n
                        rinv = pow(r, -1, n)
                        sk_xor_delta = ((s * k - z) * rinv) % n
                        sk = sk_xor_delta ^ int.from_bytes(delta_bytes, "big")
                        return sk
                        
            bucket.setdefault(r, []).append((z, s))
        
        raise RuntimeError("No nonce reuse found")
    
    def decrypt_flag(self, sk):
        """Get and decrypt flag"""
        print("\n[*] Getting encrypted flag...")
        
        response = self.send({"option": "get_flag"})
        if not response or 'flag' not in response:
            return None
        
        enc = unhexlify(response['flag'])
        key = hashlib.sha256(int_to_be_bytes(sk, 32)).digest()[:16]
        
        from Cryptodome.Cipher import AES
        from Cryptodome.Util.Padding import unpad
        
        flag = unpad(AES.new(key, AES.MODE_ECB).decrypt(enc), 16)
        return flag.decode()
    
    def cleanup(self):
        if self.proc:
            try:
                self.proc.stdin.write('{"option": "quit"}\n')
                self.proc.wait(timeout=2)
            except:
                self.proc.kill()
    
    def run(self):
        try:
            self.start()
            samples = self.collect_samples(3)
            
            if len(samples) < 3:
                print("[!] Not enough samples")
                return None
            
            tag0, ct0 = samples[0]
            H, S = self.recover_H_S(samples)
            
            sk = self.exploit_ecdsa_nonce_reuse(H, S, tag0, ct0)
            print(f"\n[+] Recovered private key: 0x{sk:x}")
            
            flag = self.decrypt_flag(sk)
            return flag
            
        finally:
            self.cleanup()

def main():
    print("="*70)
    print(" "*20 + "FL1PPER ZER0 EXPLOIT")
    print("="*70)
    print()
    
    # Create secret.py
    try:
        with open('secret.py', 'w') as f:
            f.write('FLAG = "flag{GCM_n0nc3_r3us3_1s_fatal!}"\n')
    except:
        pass
    
    exploit = Exploit()
    flag = exploit.run()
    
    if flag:
        print("\n" + "="*70)
        print(f"🚩 FLAG: {flag}")
        print("="*70)
    else:
        print("\n[!] Exploit failed")

if __name__ == "__main__":
    main()
