#!/usr/bin/env python3
"""
Try to decode with year=2024 as a constraint
"""
from PIL import Image
import numpy as np
from itertools import permutations

# Load the dots image
img = Image.open('yd_Zard_uncompressed_dots.png')
arr = np.array(img)

# Focus on the region with dots
x_min, x_max = 3400, 4020
y_min, y_max = 1000, 1640
region = arr[y_min:y_max, x_min:x_max]

# Find all dots
threshold = 200
dots_mask = region > threshold
y_coords, x_coords = np.where(dots_mask)

print(f"Working with {len(x_coords)} detected dots")
print(f"Constraint: Year = 2024 (encoded as 24)")
print(f"Binary for 24 = 0011000 (7 bits, rows 1-7)")
print()

# 24 in binary (7 bits, MSB to LSB for rows 1-7)
year_24_binary = [0, 0, 1, 1, 0, 0, 0]

# Cluster the coordinates to find actual dot positions
def cluster_positions(coords, tolerance=5):
    """Cluster nearby coordinates into single positions"""
    sorted_coords = np.sort(coords)
    clusters = []
    current = [sorted_coords[0]]
    
    for i in range(1, len(sorted_coords)):
        if sorted_coords[i] - sorted_coords[i-1] < tolerance:
            current.append(sorted_coords[i])
        else:
            clusters.append(int(np.mean(current)))
            current = [sorted_coords[i]]
    
    if current:
        clusters.append(int(np.mean(current)))
    
    return clusters

x_unique = cluster_positions(x_coords)
y_unique = cluster_positions(y_coords)

print(f"Found {len(y_unique)} unique Y positions (rows)")
print(f"Found {len(x_unique)} unique X positions (columns)")
print()

# Create dot position set
dot_set = set()
for x, y in zip(x_coords, y_coords):
    x_cluster = min(x_unique, key=lambda xc: abs(xc - x))
    y_cluster = min(y_unique, key=lambda yc: abs(yc - y))
    dot_set.add((y_cluster, x_cluster))

# Try to interpret this as part of a 15x8 grid
# We need to find which of our columns could be column 7 (year)
print("Trying to match detected columns to standard 15-column format...")
print()

# For each possible mapping of our columns to positions 1-15
# Check if any column has a pattern close to year_24_binary

for col_idx, x_pos in enumerate(x_unique):
    # Get dots in this column
    col_dots = [1 if (y, x_pos) in dot_set else 0 for y in y_unique]
    
    print(f"Detected column {col_idx} (X={x_pos}): {col_dots}")
    
    # If we have at least 7 rows (need row 0 + rows 1-7 for decoding)
    if len(col_dots) >= 7:
        # Try this as the year column
        # Skip first row (row 0), use next 7 for binary
        year_bits = col_dots[1:8] if len(col_dots) >= 8 else col_dots[1:] + [0] * (7 - len(col_dots) + 1)
        year_value = sum(bit * (2 ** (6 - idx)) for idx, bit in enumerate(year_bits))
        
        print(f"  If this is year column: bits={year_bits}, value={year_value}")
        
        if year_value == 24:
            print(f"  *** MATCH! This could be column 7 (year)!")

print("\n" + "="*70)
print("Let's try a different approach - use more aggressive dot detection")
print("="*70)

# Try with lower threshold to find more dots
for threshold_pct in [99, 98, 95, 90]:
    threshold = np.percentile(region[region > 0], threshold_pct)
    dots_mask = region > threshold
    y_coords_new, x_coords_new = np.where(dots_mask)
    
    if len(x_coords_new) > len(x_coords):
        print(f"\nTrying {threshold_pct}th percentile (threshold={threshold:.1f}): {len(x_coords_new)} dots")
        
        # Re-cluster
        x_unique_new = cluster_positions(x_coords_new, tolerance=10)
        y_unique_new = cluster_positions(y_coords_new, tolerance=10)
        
        print(f"  Rows: {len(y_unique_new)}, Columns: {len(x_unique_new)}")
        
        if len(y_unique_new) >= 8 and len(x_unique_new) >= 10:
            print(f"  This looks more promising! Creating grid...")
            
            # Create grid
            dot_set_new = set()
            for x, y in zip(x_coords_new, y_coords_new):
                x_cluster = min(x_unique_new, key=lambda xc: abs(xc - x))
                y_cluster = min(y_unique_new, key=lambda yc: abs(yc - y))
                dot_set_new.add((y_cluster, x_cluster))
            
            # Print grid
            print(f"\n  Grid ({len(y_unique_new)} rows x {len(x_unique_new)} cols):")
            print("       " + "".join(f"{i%10}" for i in range(len(x_unique_new))))
            
            for row_idx, y_pos in enumerate(y_unique_new):
                row_str = f"  R{row_idx:2d}: "
                for x_pos in x_unique_new:
                    if (y_pos, x_pos) in dot_set_new:
                        row_str += "█"
                    else:
                        row_str += "·"
                print(row_str)
            
            # Try decoding
            if len(y_unique_new) >= 8:
                print("\n  Decoding (assuming first 8 rows, looking for year=24):")
                
                for col_idx, x_pos in enumerate(x_unique_new[:15]):
                    col_bits = [1 if (y_unique_new[i], x_pos) in dot_set_new else 0 
                               for i in range(min(8, len(y_unique_new)))]
                    
                    if len(col_bits) >= 8:
                        # Decode column (skip row 0)
                        value_bits = col_bits[1:8]
                        value = sum(bit * (2 ** (6 - idx)) for idx, bit in enumerate(value_bits))
                        
                        marker = " <-- Year?" if value == 24 else ""
                        print(f"  Col {col_idx:2d}: {''.join(str(b) for b in value_bits)} = {value:3d}{marker}")
