# solve.py
# Exploit for the provided "Fl1pper Zer0 – Signing Service!" challenge.
# Python 3.10+. Requires: pycryptodomex (only for unpad convenience; could be in stdlib)

import json
import os
import re
import subprocess
import hashlib
import random
from binascii import unhexlify, hexlify

# --------- GF(2^128) helpers for GCM (polynomial: x^128 + x^7 + x^2 + x + 1) ----------
R = 0xE1000000000000000000000000000000  # reduction constant for MSB carry


def gf_mul(x, y):
    """Carry-less multiply in GF(2^128) with reduction mod the GCM poly."""
    z = 0
    v = x
    for i in range(128):
        if (y >> (127 - i)) & 1:
            z ^= v
        # shift v left, reduce on overflow
        v <<= 1
        if v >> 128:
            v &= (1 << 128) - 1
            v ^= R
    return z


def gf_square(x):
    return gf_mul(x, x)


def gf_pow(a, e):
    """Exponentiation in GF(2^128). Used only a handful of times (OK)."""
    res = 1
    base = a
    while e > 0:
        if e & 1:
            res = gf_mul(res, base)
        base = gf_square(base)
        e >>= 1
    return res


def gf_inv(x):
    if x == 0:
        raise ZeroDivisionError("inv(0)")
    # In GF(2^m), a^(2^m - 1) = 1 => a^{-1} = a^(2^m - 2)
    return gf_pow(x, (1 << 128) - 2)


def gf_div(a, b):
    return gf_mul(a, gf_inv(b))


def be_bytes_to_int(b):
    return int.from_bytes(b, "big")


def int_to_be_bytes(x, n):
    return x.to_bytes(n, "big")


def ghash_len_block(a_bits, c_bits):
    # 128-bit length block: [len(A) (64b) || len(C) (64b)]
    return ((a_bits & ((1 << 64)-1)) << 64) | (c_bits & ((1 << 64)-1))


def ghash_two_blocks(H, C1, C2, Lblk):
    # GHASH = ((C1*H) xor C2) * H xor L
    return gf_mul((gf_mul(C1, H) ^ C2), H) ^ Lblk

# --------- GCM tag helper given recovered H,S (no AAD, 2 blocks) ----------


def tag_for(H, S, C_bytes):
    assert len(C_bytes) == 32
    C1 = be_bytes_to_int(C_bytes[:16])
    C2 = be_bytes_to_int(C_bytes[16:])
    L = ghash_len_block(0, 256)  # |A|=0, |C|=256
    g = ghash_two_blocks(H, C1, C2, L)
    T = S ^ g
    return int_to_be_bytes(T, 16)


# --------- Service IO ----------
PROMPT_RE = re.compile(
    rb"Here is your encrypted signing key.*?({.*})", re.DOTALL)


def launch_service():
    return subprocess.Popen(
        ["python3", "service.py"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        bufsize=0
    )


def read_until_json_line(p):
    """
    Read lines until we see a JSON object printed by the service for the signkey banner.
    Returns dict.
    """
    buf = b""
    while True:
        line = p.stdout.readline()
        if not line:
            raise RuntimeError("Service terminated unexpectedly.")
        buf += line
        m = PROMPT_RE.search(buf)
        if m:
            return json.loads(m.group(1).decode())


def send_cmd(p, obj):
    p.stdin.write((json.dumps(obj) + "\n").encode())
    p.stdin.flush()
    line = p.stdout.readline()
    if not line:
        raise RuntimeError("Service terminated.")
    # Service sometimes prints two lines (option menu + json). Keep reading until JSON.
    try:
        j = json.loads(line.decode())
        return j
    except Exception:
        # read next
        line2 = p.stdout.readline()
        j = json.loads(line2.decode())
        return j

# --------- Phase 1: collect samples (T, C) for 32-byte signkeys ----------


def collect_signkeys_2blocks(p, needed=3):
    samples = []
    # initial banner already printed a signkey; parse it
    banner = read_until_json_line(p)
    s0 = banner["signkey"]
    tag0 = unhexlify(s0)[:16]
    ct0 = unhexlify(s0)[16:]
    if len(ct0) == 32:
        samples.append((tag0, ct0))
    # keep generating until we have >= needed with 32-byte ct
    while len(samples) < needed:
        send_cmd(p, {"option": "generate_key"})
        # The service prints one line "Here is your *NEW* ..." and then the JSON line
        # so read the JSON line next:
        j = json.loads(p.stdout.readline().decode())
        sk = unhexlify(j["signkey"])
        t, c = sk[:16], sk[16:]
        if len(c) == 32:
            samples.append((t, c))
    return samples

# --------- Phase 2: recover H and S using the two-equation trick ----------


def recover_H_S(samples):
    # Use three samples: (T0,C0), (T1,C1), (T2,C2) each with 2 blocks
    (T0, C0), (T1, C1), (T2, C2) = samples[:3]
    C10 = be_bytes_to_int(C0[:16])
    C20 = be_bytes_to_int(C0[16:])
    C11 = be_bytes_to_int(C1[:16])
    C21 = be_bytes_to_int(C1[16:])
    C12 = be_bytes_to_int(C2[:16])
    C22 = be_bytes_to_int(C2[16:])
    L = ghash_len_block(0, 256)
    # Δ values relative to 0:
    dT1 = be_bytes_to_int(T0) ^ be_bytes_to_int(T1)
    dC11 = C10 ^ C11
    dC21 = C20 ^ C21
    dT2 = be_bytes_to_int(T0) ^ be_bytes_to_int(T2)
    dC12 = C10 ^ C12
    dC22 = C20 ^ C22
    # Our quadratic becomes: (dC1)*H^2 ^ (dC2)*H ^ dT = 0
    # Divide by dC1 to get: H^2 ^ alpha*H ^ beta = 0, where alpha=dC2/dC1, beta=dT/dC1
    if dC11 == 0 or dC12 == 0:
        raise RuntimeError(
            "Degenerate ΔC1 encountered. Rerun to get different samples.")
    alpha1 = gf_div(dC21, dC11)
    beta1 = gf_div(dT1,  dC11)
    alpha2 = gf_div(dC22, dC12)
    beta2 = gf_div(dT2,  dC12)
    # Eliminate H^2: (alpha1 ^ alpha2)*H ^ (beta1 ^ beta2) = 0  => H = (beta1 ^ beta2)/(alpha1 ^ alpha2)
    denom = alpha1 ^ alpha2
    if denom == 0:
        raise RuntimeError("Denominator zero; unlucky sample set. Rerun.")
    H = gf_div(beta1 ^ beta2, denom)
    # Now S = T ^ GHASH_H(C)
    g0 = ghash_two_blocks(H, C10, C20, L)
    S = (be_bytes_to_int(T0) ^ g0)
    return H, S

# --------- Phase 3: chosen-CT signing oracle with forged tags ----------


def forge_sign(p, H, S, C0, T0, msg_bytes, delta_mask_bytes):
    # C' = C0 xor Δ
    Cprime = bytes(x ^ y for x, y in zip(C0, delta_mask_bytes))
    Tprime = tag_for(H, S, Cprime)
    j = send_cmd(p, {
        "option": "sign",
        "msg": msg_bytes.hex(),
        "signkey": (Tprime + Cprime).hex(),
    })
    # returns {"r": "0x...", "s": "0x..."}
    r = int(j["r"], 16)
    s = int(j["s"], 16)
    return r, s

# --------- Phase 4: farm signatures until a repeated r (nonce reuse) under the SAME Δ ----------


def recover_sk_via_reused_k(p, H, S, tag0, ct0):
    n = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551  # NIST P-256 order
    # Fix a single Δ (32 bytes): flip one high bit
    delta = (1 << (8*32 - 1))  # flip MSB
    delta_bytes = int_to_be_bytes(delta, 32)

    bucket = {}  # r -> list of (z, s)
    rnd = random.Random(1337)
    for attempt in range(1, 200000):  # usually hits quickly; adjust if needed
        msg = rnd.randbytes(24)  # random message to vary z
        z = int(hashlib.sha256(msg).hexdigest(), 16)
        r, s = forge_sign(p, H, S, ct0, tag0, msg, delta_bytes)
        L = bucket.setdefault(r, [])
        # Check for a prior entry with a different z
        for (z2, s2) in L:
            if z2 != z:
                # same r => same k. Recover k, then sk⊕Δ; finally sk = (sk⊕Δ) xor Δ.
                kinv = pow((s - s2) % n, -1, n)
                k = ((z - z2) * kinv) % n
                rinv = pow(r, -1, n)
                sk_xor_delta = ((s * k - z) * rinv) % n
                sk = sk_xor_delta ^ int.from_bytes(delta_bytes, "big")
                return sk
        L.append((z, s))
    raise RuntimeError(
        "Did not observe a repeated r. Re-run (MT will collide eventually).")

# --------- Phase 5: decrypt flag ----------


def get_flag_and_decrypt(p, sk):
    j = send_cmd(p, {"option": "get_flag"})
    enc = unhexlify(j["flag"])
    key = hashlib.sha256(int_to_be_bytes(sk, 32)).digest()[:16]
    # AES-ECB decrypt (no need for lib: implement minimal)
    from Crypto.Cipher import AES
    from Crypto.Util.Padding import unpad
    flag = unpad(AES.new(key, AES.MODE_ECB).decrypt(enc), 16)
    return flag.decode(errors="ignore")


def main():
    p = launch_service()
    # Drain the initial banner + first signkey JSON
    samples = collect_signkeys_2blocks(p, needed=3)
    (t0, c0) = samples[0]
    H, S = recover_H_S(samples)
    # Recover sk using forged-tag chosen-CT oracle + repeated-r trick
    sk = recover_sk_via_reused_k(p, H, S, t0, c0)
    print(f"[+] Recovered sk = 0x{sk:x}")
    flag = get_flag_and_decrypt(p, sk)
    print(f"[+] FLAG: {flag}")


if __name__ == "__main__":
    main()
