#!/usr/bin/env python3
"""
Convert the yellow channel image to a format DEDA can process
by making white dots appear as yellow dots
"""
from PIL import Image
import numpy as np

def create_yellow_dots_image(input_path="analysis_yellow_channel.png", output_path="yellow_dots_for_deda.png"):
    """Convert white dots in grayscale to yellow dots in RGB"""
    print(f"Converting {input_path} to yellow dot format...")
    
    img = Image.open(input_path)
    
    # Convert to RGB if grayscale
    if img.mode != 'RGB':
        img = img.convert('RGB')
    
    img_array = np.array(img)
    
    # If it's grayscale converted to RGB, all channels will be the same
    if len(img_array.shape) == 3:
        # Use the brightness as indicator of where dots are
        brightness = img_array[:,:,0].astype(float)  # All channels are same in grayscale
        
        # Create new image where bright pixels become yellow
        # and dark pixels become white background
        threshold = np.percentile(brightness, 95)  # Top 5% brightest
        
        # Create output: white background, yellow dots
        r_out = np.ones_like(brightness) * 255
        g_out = np.ones_like(brightness) * 255  
        b_out = np.ones_like(brightness) * 255
        
        # Where brightness is high, make it yellow (R=255, G=255, B=0)
        dots_mask = brightness > threshold
        b_out[dots_mask] = 0  # Remove blue to make yellow
        
        output = np.stack([r_out, g_out, b_out], axis=2).astype(np.uint8)
        
        Image.fromarray(output).save(output_path, 'PNG')
        print(f"Saved to: {output_path}")
        print(f"Dots detected: {np.sum(dots_mask)} pixels")
        
        # Also try different thresholds
        for pct in [90, 95, 98, 99]:
            threshold = np.percentile(brightness, pct)
            dots_mask = brightness > threshold
            
            r_out = np.ones_like(brightness) * 255
            g_out = np.ones_like(brightness) * 255
            b_out = np.ones_like(brightness) * 255
            b_out[dots_mask] = 0
            
            output = np.stack([r_out, g_out, b_out], axis=2).astype(np.uint8)
            Image.fromarray(output).save(f"yellow_dots_{pct}pct.png", 'PNG')
            print(f"  {pct}th percentile: {np.sum(dots_mask)} pixels -> yellow_dots_{pct}pct.png")

if __name__ == "__main__":
    create_yellow_dots_image()
