#!/usr/bin/env python3
"""
Parse Saleae Logic binary format based on official documentation:
https://support.saleae.com/faq/technical-faq/binary-export-format-logic-2

Binary format:
- byte[8] identifier: "<SALEAE>"
- int32 version
- int32 type (0=Digital, 1=Analog)

Digital format:
- int64 initial_state
- double begin_time
- double end_time
- int64 num_transitions
- Then pairs of: double transition_time, int64 state
"""

import struct
import sys

def parse_saleae_digital(filename):
    """Parse Saleae digital binary file"""
    with open(filename, 'rb') as f:
        # Read header
        identifier = f.read(8)
        if identifier != b'<SALEAE>':
            print(f"[!] Not a valid Saleae file: {identifier}")
            return None
        
        version = struct.unpack('<i', f.read(4))[0]
        file_type = struct.unpack('<i', f.read(4))[0]
        
        print(f"[+] Saleae file version {version}, type {file_type} ({'Digital' if file_type == 0 else 'Analog'})")
        
        if file_type != 0:
            print("[!] Not a digital file")
            return None
        
        # Read digital data
        initial_state = struct.unpack('<Q', f.read(8))[0]
        begin_time = struct.unpack('<d', f.read(8))[0]
        end_time = struct.unpack('<d', f.read(8))[0]
        num_transitions = struct.unpack('<Q', f.read(8))[0]
        
        print(f"[+] Initial state: {initial_state}")
        print(f"[+] Time range: {begin_time} to {end_time}")
        print(f"[+] Transitions: {num_transitions}")
        
        # Read transitions
        transitions = []
        for i in range(num_transitions):
            time = struct.unpack('<d', f.read(8))[0]
            state = struct.unpack('<Q', f.read(8))[0]
            transitions.append((time, state))
        
        return transitions

def extract_spi_data(channels):
    """
    Extract SPI data from 4 channels
    Typical MFRC522 SPI: CLK, MOSI, MISO, CS
    """
    print("\n[*] Analyzing for SPI patterns...")
    
    # Look for patterns in the combined data
    # For RFID, we're looking for READ commands followed by 16-byte blocks
    
    # Simple approach: combine channel states as nibbles
    min_len = min(len(ch) for ch in channels if ch)
    
    bytes_extracted = []
    for i in range(min(min_len, 10000)):  # First 10k transitions
        # Combine the 4 channel states
        byte_val = 0
        for ch_idx in range(4):
            if channels[ch_idx] and i < len(channels[ch_idx]):
                _, state = channels[ch_idx][i]
                if state & 1:
                    byte_val |= (1 << ch_idx)
        
        if byte_val > 0:  # Only non-zero values
            bytes_extracted.append(byte_val)
    
    return bytes_extracted

def find_mifare_blocks(data):
    """Look for MIFARE 16-byte block patterns"""
    print("\n[*] Searching for MIFARE data blocks...")
    
    # Look for sector data that might be readable
    # MIFARE blocks are 16 bytes
    for offset in range(len(data) - 16):
        block = data[offset:offset+16]
        
        # Check if this looks like a data block (some ASCII or structured data)
        ascii_count = sum(1 for b in block if 32 <= b < 127)
        unique_bytes = len(set(block))
        
        if ascii_count >= 4 and unique_bytes >= 3:
            hex_str = ''.join(f'{b:02x}' for b in block)
            ascii_str = ''.join(chr(b) if 32 <= b < 127 else '.' for b in block)
            print(f"\n  Offset {offset:06x}: {hex_str}")
            print(f"              {ascii_str}")

# Main
print("="*70)
print("Saleae Binary Parser for MIFARE RFID")
print("="*70)

channels = []
for i in range(4):
    filename = f"extracted/digital-{i}.bin"
    print(f"\n[*] Parsing {filename}...")
    transitions = parse_saleae_digital(filename)
    channels.append(transitions)
    
    if transitions and len(transitions) > 0:
        print(f"[+] Sample transitions:")
        for j, (time, state) in enumerate(transitions[:5]):
            print(f"    {j}: time={time:.9f}s, state={state}")

# Try to extract useful data
if all(ch for ch in channels):
    extracted = extract_spi_data(channels)
    print(f"\n[+] Extracted {len(extracted)} bytes from combined channels")
    
    # Show first bytes
    if len(extracted) > 0:
        print("\n[*] First 128 bytes:")
        for i in range(0, min(128, len(extracted)), 16):
            chunk = extracted[i:i+16]
            hex_str = ' '.join(f'{b:02x}' for b in chunk)
            print(f"  {i:04x}: {hex_str}")
    
    # Look for MIFARE blocks
    find_mifare_blocks(extracted)
