#!/usr/bin/env python3
"""
Extract and decode the dot pattern from the detected region
"""
from PIL import Image, ImageDraw
import numpy as np

def decode_printer_dots():
    """Extract and decode the printer tracking dots"""
    
    # Load the dots image
    img = Image.open('yd_Zard_uncompressed_dots.png')
    arr = np.array(img)
    
    print("Loading dots image...")
    
    # The dots are in region X: 3410-4009, Y: 1013-1623
    # Let's expand this a bit to make sure we get everything
    x_min, x_max = 3400, 4020
    y_min, y_max = 1000, 1640
    
    # Crop to the region
    region = arr[y_min:y_max, x_min:x_max]
    
    print(f"Cropped region size: {region.shape}")
    
    # Find dots in this region
    threshold = 200
    dots_mask = region > threshold
    y_coords, x_coords = np.where(dots_mask)
    
    print(f"Found {len(x_coords)} dots in region")
    
    # Save visualization of the region
    region_img = Image.fromarray(region)
    region_viz = region_img.convert('RGB')
    region_viz = region_viz.resize((region.shape[1]*3, region.shape[0]*3), Image.NEAREST)
    
    draw = ImageDraw.Draw(region_viz)
    for x, y in zip(x_coords, y_coords):
        draw.ellipse([x*3-3, y*3-3, x*3+3, y*3+3], fill='red', outline='yellow', width=2)
    
    region_viz.save('dots_region_zoom.png')
    print("Saved visualization to: dots_region_zoom.png")
    
    # Try to extract a 15x8 grid
    # Sort coordinates to find spacing
    sorted_x = np.sort(x_coords)
    sorted_y = np.sort(y_coords)
    
    # Find spacing between consecutive dots
    x_diffs = []
    y_diffs = []
    
    for i in range(len(sorted_x) - 1):
        diff = sorted_x[i+1] - sorted_x[i]
        if 5 < diff < 100:  # Reasonable spacing
            x_diffs.append(diff)
    
    for i in range(len(sorted_y) - 1):
        diff = sorted_y[i+1] - sorted_y[i]
        if 5 < diff < 100:
            y_diffs.append(diff)
    
    if len(x_diffs) > 0 and len(y_diffs) > 0:
        # Use median spacing
        x_spacing = np.median(x_diffs)
        y_spacing = np.median(y_diffs)
        
        print(f"\nEstimated spacing: X={x_spacing:.1f}, Y={y_spacing:.1f}")
        
        # Find grid origin (top-left corner)
        start_x = x_coords.min()
        start_y = y_coords.min()
        
        print(f"Grid origin: ({start_x}, {start_y})")
        
        # Extract 15x8 grid
        grid = np.zeros((8, 15), dtype=int)
        
        tolerance = max(x_spacing, y_spacing) / 2.5
        
        for row in range(8):
            for col in range(15):
                expected_x = start_x + col * x_spacing
                expected_y = start_y + row * y_spacing
                
                # Check if there's a dot near this position
                distances = np.sqrt((x_coords - expected_x)**2 + (y_coords - expected_y)**2)
                if np.min(distances) < tolerance:
                    grid[row, col] = 1
        
        print("\n" + "="*70)
        print("EXTRACTED GRID (1=dot present, 0=no dot):")
        print("="*70)
        print("     " + "".join(f"{i:2}" for i in range(1, 16)))
        for i, row in enumerate(grid):
            print(f"Row {i}: " + "".join(f" {val}" for val in row))
        
        # Decode according to the spec
        print("\n" + "="*70)
        print("DECODING:")
        print("="*70)
        
        # Columns are 1-indexed in the spec, array is 0-indexed
        # Each column (except row 0) encodes a 7-bit number
        
        minutes_col = 1  # Column 2 in spec (1 in 0-indexed)
        hour_col = 3     # Column 4 in spec (3 in 0-indexed)
        day_col = 4      # Column 5 in spec (4 in 0-indexed)
        month_col = 5    # Column 6 in spec (5 in 0-indexed)
        year_col = 6     # Column 7 in spec (6 in 0-indexed)
        serial_cols = [9, 10, 11, 12]  # Columns 10-13 in spec (9-12 in 0-indexed)
        
        def decode_column(col_idx):
            """Decode a column (skip row 0, read rows 1-7 as MSB to LSB)"""
            bits = grid[1:, col_idx]  # Skip row 0
            value = sum(bit * (2 ** (6 - idx)) for idx, bit in enumerate(bits))
            return value
        
        minutes = decode_column(minutes_col)
        hour = decode_column(hour_col)
        day = decode_column(day_col)
        month = decode_column(month_col)
        year = decode_column(year_col)
        
        serial_digits = [decode_column(col) for col in serial_cols]
        
        print(f"Column 2 (Minutes):  {minutes:02d}")
        print(f"Column 5 (Day):      {day:02d}")
        print(f"Column 6 (Month):    {month:02d}")
        print(f"Column 7 (Year):     {year:02d} (20{year:02d})")
        print(f"Column 4 (Hour):     {hour:02d}")
        print(f"Columns 10-13 (Serial): {serial_digits}")
        
        # Format the result
        year_full = f"20{year:02d}"
        serial = "".join(f"{d}" for d in serial_digits)
        
        print("\n" + "="*70)
        print("FINAL RESULT:")
        print("="*70)
        print(f"Date/Time: {year_full}_{month:02d}_{day:02d}_{hour:02d}:{minutes:02d}")
        print(f"Serial Number: {serial}")
        print(f"\nFlag: uoftctf{{{year_full}_{month:02d}_{day:02d}_{hour:02d}:{minutes:02d}_{serial}}}")
        print("="*70)
        
        return {
            'year': year_full,
            'month': month,
            'day': day,
            'hour': hour,
            'minutes': minutes,
            'serial': serial
        }
    else:
        print("Could not determine grid spacing")
        return None

if __name__ == "__main__":
    result = decode_printer_dots()
