import struct

def parse_pcap(filename):
    """Parse PCAP and extract detailed packet information"""
    with open(filename, 'rb') as f:
        f.read(24)  # Skip global header
        
        packets = []
        packet_num = 1
        
        while True:
            packet_header = f.read(16)
            if len(packet_header) < 16:
                break
            
            ts_sec, ts_usec, incl_len, orig_len = struct.unpack('IIII', packet_header)
            packet_data = f.read(incl_len)
            if len(packet_data) < incl_len:
                break
            
            # Parse Ethernet header (14 bytes)
            if len(packet_data) < 14:
                continue
                
            eth_type = struct.unpack('!H', packet_data[12:14])[0]
            
            # Parse IP header if it's IPv4 (0x0800)
            packet_info = {
                'num': packet_num,
                'data': packet_data,
                'length': incl_len,
                'eth_type': eth_type
            }
            
            if eth_type == 0x0800 and len(packet_data) >= 34:  # IPv4
                # IP header starts at byte 14
                ip_header = packet_data[14:34]
                
                version_ihl = ip_header[0]
                ihl = (version_ihl & 0x0F) * 4  # Header length in bytes
                ttl = ip_header[8]
                protocol = ip_header[9]
                
                # Extract IP ID (bytes 4-5 of IP header)
                ip_id = struct.unpack('!H', ip_header[4:6])[0]
                
                # Extract source and dest IP
                src_ip = '.'.join(map(str, ip_header[12:16]))
                dst_ip = '.'.join(map(str, ip_header[16:20]))
                
                packet_info.update({
                    'ip_ttl': ttl,
                    'ip_protocol': protocol,
                    'ip_id': ip_id,
                    'src_ip': src_ip,
                    'dst_ip': dst_ip,
                    'ip_header_len': ihl
                })
                
                # Check for ICMP (protocol 1)
                if protocol == 1 and len(packet_data) >= 14 + ihl + 8:
                    icmp_offset = 14 + ihl
                    icmp_type = packet_data[icmp_offset]
                    icmp_code = packet_data[icmp_offset + 1]
                    icmp_checksum = struct.unpack('!H', packet_data[icmp_offset + 2:icmp_offset + 4])[0]
                    
                    packet_info.update({
                        'icmp_type': icmp_type,
                        'icmp_code': icmp_code,
                        'icmp_checksum': icmp_checksum
                    })
                    
                    # Extract ICMP data payload (after 8-byte ICMP header)
                    if len(packet_data) > icmp_offset + 8:
                        icmp_data = packet_data[icmp_offset + 8:]
                        packet_info['icmp_data'] = icmp_data
                
                # Check for UDP (protocol 17) - DNS uses UDP
                elif protocol == 17 and len(packet_data) >= 14 + ihl + 8:
                    udp_offset = 14 + ihl
                    src_port = struct.unpack('!H', packet_data[udp_offset:udp_offset + 2])[0]
                    dst_port = struct.unpack('!H', packet_data[udp_offset + 2:udp_offset + 4])[0]
                    
                    packet_info.update({
                        'udp_src_port': src_port,
                        'udp_dst_port': dst_port
                    })
                    
                    # Check if DNS (port 53)
                    if src_port == 53 or dst_port == 53:
                        packet_info['is_dns'] = True
                        if len(packet_data) > udp_offset + 8:
                            dns_data = packet_data[udp_offset + 8:]
                            packet_info['dns_data'] = dns_data
                
                # Check for TCP (protocol 6)
                elif protocol == 6 and len(packet_data) >= 14 + ihl + 20:
                    tcp_offset = 14 + ihl
                    src_port = struct.unpack('!H', packet_data[tcp_offset:tcp_offset + 2])[0]
                    dst_port = struct.unpack('!H', packet_data[tcp_offset + 2:tcp_offset + 4])[0]
                    seq_num = struct.unpack('!I', packet_data[tcp_offset + 4:tcp_offset + 8])[0]
                    
                    packet_info.update({
                        'tcp_src_port': src_port,
                        'tcp_dst_port': dst_port,
                        'tcp_seq': seq_num
                    })
            
            packets.append(packet_info)
            packet_num += 1
        
        return packets

def analyze_ttl(packets):
    """Extract potential message from TTL values"""
    print("=" * 80)
    print("ANALYZING IP TTL VALUES")
    print("=" * 80)
    
    ttl_chars = []
    for pkt in packets:
        if 'ip_ttl' in pkt:
            ttl = pkt['ip_ttl']
            # Check if TTL is printable ASCII
            if 32 <= ttl <= 126:
                ttl_chars.append((pkt['num'], chr(ttl)))
                print(f"Packet {pkt['num']:3d}: TTL={ttl:3d} (0x{ttl:02x}) = '{chr(ttl)}'")
    
    message = ''.join([c for _, c in ttl_chars])
    print(f"\nExtracted message: {message}")
    
    if 'HTB{' in message:
        start = message.index('HTB{')
        end = message.index('}', start) + 1
        print(f"\n🚩 FLAG FOUND IN TTL: {message[start:end]}")
        return message[start:end]
    return None

def analyze_ip_id(packets):
    """Extract potential message from IP ID field"""
    print("\n" + "=" * 80)
    print("ANALYZING IP ID VALUES")
    print("=" * 80)
    
    ip_id_chars = []
    for pkt in packets:
        if 'ip_id' in pkt:
            ip_id = pkt['ip_id']
            # Try low byte
            low_byte = ip_id & 0xFF
            high_byte = (ip_id >> 8) & 0xFF
            
            if 32 <= low_byte <= 126:
                ip_id_chars.append((pkt['num'], chr(low_byte), 'low'))
            if 32 <= high_byte <= 126:
                ip_id_chars.append((pkt['num'], chr(high_byte), 'high'))
    
    if ip_id_chars:
        print(f"Found {len(ip_id_chars)} printable bytes in IP ID fields")
        print("First 20:")
        for num, char, byte_type in ip_id_chars[:20]:
            print(f"Packet {num:3d}: {byte_type} byte = '{char}'")
        
        message = ''.join([c for _, c, _ in ip_id_chars])
        if 'HTB{' in message:
            start = message.index('HTB{')
            end = message.index('}', start) + 1
            print(f"\n🚩 FLAG FOUND IN IP ID: {message[start:end]}")
            return message[start:end]
    return None

def analyze_icmp_data(packets):
    """Extract potential message from ICMP data field"""
    print("\n" + "=" * 80)
    print("ANALYZING ICMP DATA FIELD")
    print("=" * 80)
    
    icmp_packets = [pkt for pkt in packets if 'icmp_data' in pkt]
    print(f"Found {len(icmp_packets)} ICMP packets with data")
    
    if not icmp_packets:
        return None
    
    # Try first byte of each ICMP data
    chars = []
    for pkt in icmp_packets:
        if len(pkt['icmp_data']) > 0:
            first_byte = pkt['icmp_data'][0]
            if 32 <= first_byte <= 126:
                chars.append((pkt['num'], chr(first_byte)))
    
    message = ''.join([c for _, c in chars])
    print(f"First byte of each ICMP data: {message[:50]}...")
    
    if 'HTB{' in message:
        start = message.index('HTB{')
        end = message.index('}', start) + 1
        print(f"\n🚩 FLAG FOUND IN ICMP DATA: {message[start:end]}")
        return message[start:end]
    
    return None

def analyze_dns_queries(packets):
    """Extract potential message from DNS queries"""
    print("\n" + "=" * 80)
    print("ANALYZING DNS QUERIES")
    print("=" * 80)
    
    dns_packets = [pkt for pkt in packets if pkt.get('is_dns') and 'dns_data' in pkt]
    print(f"Found {len(dns_packets)} DNS packets")
    
    if not dns_packets:
        return None
    
    # Simple DNS query name extraction (this is simplified)
    for pkt in dns_packets[:10]:
        print(f"Packet {pkt['num']}: DNS data length = {len(pkt['dns_data'])} bytes")
        # Show first few bytes
        data_preview = pkt['dns_data'][:50].hex()
        print(f"  Data: {data_preview}...")
    
    return None

def analyze_tcp_seq(packets):
    """Extract potential message from TCP sequence numbers"""
    print("\n" + "=" * 80)
    print("ANALYZING TCP SEQUENCE NUMBERS")
    print("=" * 80)
    
    tcp_packets = [pkt for pkt in packets if 'tcp_seq' in pkt]
    print(f"Found {len(tcp_packets)} TCP packets")
    
    if not tcp_packets:
        return None
    
    # Try low byte of sequence number
    chars = []
    for pkt in tcp_packets[:20]:
        seq = pkt['tcp_seq']
        low_byte = seq & 0xFF
        if 32 <= low_byte <= 126:
            chars.append((pkt['num'], chr(low_byte)))
            print(f"Packet {pkt['num']:3d}: Seq={seq} low byte={low_byte} = '{chr(low_byte)}'")
    
    return None

def packet_summary(packets):
    """Print summary of packet types"""
    print("\n" + "=" * 80)
    print("PACKET SUMMARY")
    print("=" * 80)
    
    protocols = {}
    for pkt in packets:
        if 'ip_protocol' in pkt:
            proto = pkt['ip_protocol']
            proto_name = {1: 'ICMP', 6: 'TCP', 17: 'UDP'}.get(proto, f'Unknown({proto})')
            protocols[proto_name] = protocols.get(proto_name, 0) + 1
    
    for proto, count in sorted(protocols.items()):
        print(f"{proto}: {count} packets")

if __name__ == "__main__":
    pcap_file = "sniffed.pcap"
    
    print(f"Comprehensive analysis of {pcap_file}\n")
    packets = parse_pcap(pcap_file)
    print(f"Total packets: {len(packets)}\n")
    
    packet_summary(packets)
    
    # Test all common steganography locations
    flag = None
    
    flag = flag or analyze_ttl(packets)
    flag = flag or analyze_ip_id(packets)
    flag = flag or analyze_icmp_data(packets)
    flag = flag or analyze_dns_queries(packets)
    flag = flag or analyze_tcp_seq(packets)
    
    if not flag:
        print("\n" + "=" * 80)
        print("❌ NO FLAG FOUND IN STANDARD LOCATIONS")
        print("=" * 80)
        print("The writeup's method (last byte of raw packet) may be the only way.")
