#!/usr/bin/env python3
"""
Ballistic Calculator Reticle View Comparison Script
Detects and compares reticle shapes: DOT, CIRCLE, CROSSHAIR
"""

import cv2
import numpy as np
import argparse
import sys
import os
from pathlib import Path
from typing import Tuple, Dict, Optional, List

# ROI configurations - larger to capture full crosshair
ROI_CONFIGS = {
    'normal': {'x': 605, 'y': 477, 'width': 70, 'height': 70},  # Center reticle area
    'small': {'x': 620, 'y': 492, 'width': 40, 'height': 40}    # Fallback for weak signals
}

# Detection thresholds
THERMAL_NOISE_PERCENTILE = 10  # Bottom 10% of brightness = thermal noise
MIN_SIGNAL_STRENGTH = 12       # Minimum brightness after denoising

# Shape-specific thresholds
DOT_MAX_SIZE = 15              # Dot blob should be ≤15x15 pixels
CIRCLE_MIN_SIZE = 6            # Circle should be ≥6 pixels diameter
CIRCLE_MAX_SIZE = 13           # Circle should be ≤13 pixels diameter
CIRCLE_MIN_CIRCULARITY = 0.50  # Minimum circularity for EDGE method only
CROSSHAIR_SEGMENT_MIN = 6      # Each crosshair segment ≥6 pixels
CROSSHAIR_SEGMENT_MAX = 18     # Each crosshair segment ≤18 pixels
CROSSHAIR_ASPECT_MIN = 1.4     # Line aspect ratio (relaxed for thermal fuzziness)
CROSSHAIR_LINE_THICKNESS = 8   # Max line thickness
CROSSHAIR_CENTER_GAP_MIN = 2   # Center gap should be ≥2 pixels
CROSSHAIR_CENTER_GAP_MAX = 15  # Center gap should be ≤15 pixels


def load_and_crop_roi(image_path: str, roi_name: str = 'normal') -> Tuple[np.ndarray, np.ndarray, Dict]:
    """Load image and extract reticle ROI."""
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")
    
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Failed to load image: {image_path}")
    
    height, width = img.shape[:2]
    roi_config = ROI_CONFIGS[roi_name]
    x, y, w, h = roi_config['x'], roi_config['y'], roi_config['width'], roi_config['height']
    
    if x + w > width or y + h > height:
        raise ValueError(f"ROI out of bounds for image {image_path}")
    
    roi = img[y:y+h, x:x+w]
    
    if roi.size == 0:
        raise ValueError(f"Empty ROI extracted from {image_path}")
    
    return img, roi, roi_config


def denoise_thermal_image(roi: np.ndarray, debug: bool = False) -> Tuple[np.ndarray, float]:
    """
    Remove thermal noise from ROI using percentile-based approach.
    
    Returns:
        (denoised_gray, noise_floor)
    """
    # Convert to grayscale (average of RGB)
    gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY).astype(np.float32)
    
    # Calculate thermal noise floor (bottom percentile)
    noise_floor = np.percentile(gray, THERMAL_NOISE_PERCENTILE)
    
    # Subtract noise floor and clip
    denoised = np.maximum(gray - noise_floor, 0)
    
    if debug:
        signal_strength = np.max(denoised)
        print(f"    Thermal noise floor: {noise_floor:.1f}")
        print(f"    Signal strength: {signal_strength:.1f}")
    
    return denoised.astype(np.uint8), noise_floor


def detect_dot_reticle(denoised: np.ndarray, debug: bool = False) -> Tuple[bool, float, Dict]:
    """
    Detect DOT reticle - small bright blob or tight pixel cluster.
    
    Uses two strategies:
    1. Blob detection (after morphological closing)
    2. Radial clustering (for fragmented dots)
    
    Returns:
        (is_dot, confidence, analysis_info)
    """
    analysis_info = {'shape': 'dot', 'detections': []}
    
    roi_center_y = denoised.shape[0] // 2
    roi_center_x = denoised.shape[1] // 2
    
    # Threshold to find bright regions
    _, binary = cv2.threshold(denoised, MIN_SIGNAL_STRENGTH, 255, cv2.THRESH_BINARY)
    
    # Strategy 1: Blob detection with morphological closing
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    binary_closed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
    
    contours, _ = cv2.findContours(binary_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if debug:
        print(f"    DOT check (blob): found {len(contours)} bright regions (after morphological closing)")
    
    # Look for small compact blob
    blob_detections = []
    for cnt in contours:
        x, y, w, h = cv2.boundingRect(cnt)
        area = cv2.contourArea(cnt)
        center_x = x + w // 2
        center_y = y + h // 2
        
        # Relaxed size check to handle thermal fuzziness
        is_small = (w <= DOT_MAX_SIZE and h <= DOT_MAX_SIZE)
        is_compact = (area >= 6)
        
        # Must be near ROI center
        near_center_x = abs(center_x - roi_center_x) <= roi_center_x * 0.6
        near_center_y = abs(center_y - roi_center_y) <= roi_center_y * 0.6
        
        # REJECT elongated blobs that look like crosshair segments
        # Crosshair segments have high aspect ratio (long and thin)
        aspect_ratio = max(w, h) / min(w, h) if min(w, h) > 0 else 0
        is_elongated = aspect_ratio >= 2.0  # Line segment, not a dot
        
        if is_small and is_compact and near_center_x and near_center_y and not is_elongated:
            bbox_area = w * h
            compactness = area / bbox_area if bbox_area > 0 else 0
            
            detection = {
                'method': 'blob',
                'x': int(x), 'y': int(y), 'w': int(w), 'h': int(h),
                'area': float(area),
                'compactness': float(compactness),
                'aspect_ratio': float(aspect_ratio)
            }
            blob_detections.append(detection)
            
            if debug:
                print(f"      Blob candidate: {w}x{h} px at ({x},{y}), area={area:.0f}, compactness={compactness:.2f}, aspect={aspect_ratio:.1f}")
        elif is_elongated and debug:
            print(f"      Rejected elongated blob: {w}x{h} px (aspect={aspect_ratio:.1f} ≥2.0, likely crosshair segment)")
    
    if blob_detections:
        analysis_info['detections'].extend(blob_detections)
    
    # Strategy 2: Radial clustering detection
    bright_points = np.column_stack(np.where(binary > 0))
    
    if len(bright_points) >= 8:
        distances = []
        for point in bright_points:
            py, px = point
            dist = ((px - roi_center_x)**2 + (py - roi_center_y)**2)**0.5
            distances.append(dist)
        
        if len(distances) > 0:
            distances = np.array(distances)
            
            median_dist = np.median(distances)
            std_dist = np.std(distances)
            max_dist = np.max(distances)
            percentile_95 = np.percentile(distances, 95)
            
            is_small_cluster = (median_dist <= 5.0 and percentile_95 <= 10.0)
            is_tight = (std_dist <= 3.0)
            
            if debug:
                print(f"    DOT check (radial): {len(bright_points)} bright pixels")
                print(f"      Median distance: {median_dist:.1f} px, 95th percentile: {percentile_95:.1f} px, max: {max_dist:.1f} px")
                print(f"      Distance std: {std_dist:.1f} (tight cluster if <3.0)")
            
            if is_small_cluster and is_tight:
                cluster_quality = 1.0 - (std_dist / 3.0)
                
                detection = {
                    'method': 'radial',
                    'pixel_count': len(bright_points),
                    'median_dist': float(median_dist),
                    'percentile_95': float(percentile_95),
                    'max_dist': float(max_dist),
                    'std_dist': float(std_dist),
                    'cluster_quality': float(cluster_quality)
                }
                analysis_info['detections'].append(detection)
                
                if debug:
                    print(f"      Radial dot: tight cluster (quality={cluster_quality:.2f})")
    
    # Verdict
    if len(analysis_info['detections']) >= 1:
        radial_dets = [d for d in analysis_info['detections'] if d.get('method') == 'radial']
        blob_dets = [d for d in analysis_info['detections'] if d.get('method') == 'blob']
        
        if radial_dets:
            det = max(radial_dets, key=lambda d: d['cluster_quality'])
            confidence = min(100.0, 60.0 + det['cluster_quality'] * 40.0)
            
            if debug:
                print(f"    → DOT detected (radial method, confidence: {confidence:.1f}%)")
        else:
            det = max(blob_dets, key=lambda d: d['compactness'])
            confidence = min(100.0, 60.0 + det['compactness'] * 40.0)
            
            if debug:
                if len(blob_dets) > 1:
                    print(f"    → DOT detected (blob method, picked best of {len(blob_dets)} candidates, confidence: {confidence:.1f}%)")
                else:
                    print(f"    → DOT detected (blob method, confidence: {confidence:.1f}%)")
        
        return True, confidence, analysis_info
    
    if debug:
        print(f"    → Not DOT (no small bright blob or tight cluster found)")
    
    return False, 0.0, analysis_info


def detect_circle_reticle(denoised: np.ndarray, debug: bool = False) -> Tuple[bool, float, Dict]:
    """
    Detect CIRCLE reticle - ring shape.
    Uses multiple detection strategies for robustness.
    
    Returns:
        (is_circle, confidence, analysis_info)
    """
    analysis_info = {'shape': 'circle', 'detections': []}
    
    roi_center_y = denoised.shape[0] // 2
    roi_center_x = denoised.shape[1] // 2
    
    # Strategy 1: Edge-based circle detection
    edges = cv2.Canny(denoised, 30, 100)
    contours_edge, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if debug:
        print(f"    CIRCLE check (edges): found {len(contours_edge)} edge contours")
    
    for cnt in contours_edge:
        area = cv2.contourArea(cnt)
        if area < 15:
            continue
        
        if len(cnt) >= 5:
            (x, y), radius = cv2.minEnclosingCircle(cnt)
            diameter = radius * 2
            is_right_size = (CIRCLE_MIN_SIZE <= diameter <= CIRCLE_MAX_SIZE)
            
            if is_right_size:
                perimeter = cv2.arcLength(cnt, True)
                if perimeter > 0:
                    circularity = 4 * np.pi * area / (perimeter * perimeter)
                else:
                    circularity = 0
                
                detection = {
                    'method': 'edge',
                    'x': float(x), 'y': float(y),
                    'radius': float(radius),
                    'diameter': float(diameter),
                    'area': float(area),
                    'circularity': float(circularity)
                }
                analysis_info['detections'].append(detection)
                
                if debug:
                    print(f"      Edge candidate: diameter={diameter:.1f} px, circularity={circularity:.2f}")
    
    # Strategy 2: Radial pattern detection
    _, binary = cv2.threshold(denoised, MIN_SIGNAL_STRENGTH, 255, cv2.THRESH_BINARY)
    
    bright_points = np.column_stack(np.where(binary > 0))
    
    if len(bright_points) >= 10:
        distances = []
        for point in bright_points:
            y, x = point
            dist = ((x - roi_center_x)**2 + (y - roi_center_y)**2)**0.5
            distances.append(dist)
        
        if len(distances) > 0:
            distances = np.array(distances)
            
            median_dist = np.median(distances)
            std_dist = np.std(distances)
            
            near_median = np.sum((distances >= median_dist - 2) & (distances <= median_dist + 2))
            pixel_fraction = near_median / len(distances)
            
            diameter = median_dist * 2
            is_right_size = (6 <= diameter <= CIRCLE_MAX_SIZE)
            
            if debug:
                print(f"    CIRCLE check (radial): {len(bright_points)} bright pixels")
                print(f"      Median distance from center: {median_dist:.1f} px (diameter={diameter:.1f})")
                print(f"      Distance std: {std_dist:.1f}, pixels near median: {pixel_fraction:.1%}")
            
            if is_right_size and pixel_fraction >= 0.5 and std_dist < 3.0:
                circularity = pixel_fraction * (1.0 - std_dist / 5.0)
                
                if pixel_fraction >= 0.9:
                    circularity = min(1.0, circularity * 1.2)
                
                detection = {
                    'method': 'radial',
                    'x': float(roi_center_x), 'y': float(roi_center_y),
                    'radius': float(median_dist),
                    'diameter': float(diameter),
                    'circularity': float(circularity),
                    'pixel_fraction': float(pixel_fraction),
                    'dist_std': float(std_dist)
                }
                analysis_info['detections'].append(detection)
                
                if debug:
                    print(f"      Radial circle: diameter={diameter:.1f}, circularity={circularity:.2f}")
    
    # Strategy 3: Blob-based circle detection
    contours_blob, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if debug:
        print(f"    CIRCLE check (blobs): found {len(contours_blob)} bright regions")
    
    if 2 <= len(contours_blob) <= 5:
        blob_centers = []
        for cnt in contours_blob:
            x, y, w, h = cv2.boundingRect(cnt)
            if w <= 4 and h <= 4:
                center_x = x + w // 2
                center_y = y + h // 2
                blob_centers.append((center_x, center_y))
        
        if len(blob_centers) >= 2:
            points = np.array(blob_centers, dtype=np.float32)
            
            distances_from_roi_center = []
            for p in points:
                dist = ((p[0] - roi_center_x)**2 + (p[1] - roi_center_y)**2)**0.5
                distances_from_roi_center.append(dist)
            
            avg_radius = np.mean(distances_from_roi_center)
            radius_std = np.std(distances_from_roi_center)
            
            diameter = avg_radius * 2
            is_right_size = (6 <= diameter <= CIRCLE_MAX_SIZE)
            
            if debug:
                print(f"      Blob pattern: {len(blob_centers)} blobs, diameter={diameter:.1f} px, radius_std={radius_std:.2f}")
            
            if is_right_size and radius_std < 2.5:
                circularity = 1.0 - (radius_std / avg_radius) if avg_radius > 0 else 0
                
                detection = {
                    'method': 'blob',
                    'x': float(roi_center_x), 'y': float(roi_center_y),
                    'radius': float(avg_radius),
                    'diameter': float(diameter),
                    'circularity': float(circularity),
                    'blob_count': len(blob_centers)
                }
                analysis_info['detections'].append(detection)
                
                if debug:
                    print(f"      Blob circle: diameter={diameter:.1f}, circularity={circularity:.2f}")
    
    # Verdict - PRIORITIZE RADIAL METHOD
    if len(analysis_info['detections']) > 0:
        method_priority = {'radial': 3, 'blob': 2, 'edge': 1}
        
        radial_detections = [d for d in analysis_info['detections'] if d['method'] == 'radial']
        if radial_detections:
            best = max(radial_detections, key=lambda d: d['circularity'])
        else:
            best = max(analysis_info['detections'], key=lambda d: d['circularity'])
        
        if best['method'] == 'radial':
            threshold = 0.35
        elif best['method'] == 'blob':
            threshold = 0.40
        else:
            threshold = CIRCLE_MIN_CIRCULARITY
        
        if best['circularity'] >= threshold:
            if best['method'] == 'radial' and best.get('pixel_fraction', 0) >= 0.9:
                confidence = min(100.0, 70.0 + best['circularity'] * 30.0)
            else:
                confidence = min(100.0, 60.0 + best['circularity'] * 40.0)
            
            if debug:
                print(f"    → CIRCLE detected ({best['method']} method, confidence: {confidence:.1f}%)")
            
            return True, confidence, analysis_info
        else:
            if debug:
                print(f"    → Not CIRCLE (best circularity {best['circularity']:.2f} < {threshold:.2f})")
    else:
        if debug:
            print(f"    → Not CIRCLE (no candidates found)")
    
    return False, 0.0, analysis_info


def detect_crosshair_reticle(denoised: np.ndarray, debug: bool = False) -> Tuple[bool, float, Dict]:
    """
    Detect CROSSHAIR reticle - 4 segments with center gap.
    
    Returns:
        (is_crosshair, confidence, analysis_info)
    """
    analysis_info = {'shape': 'crosshair', 'segments': [], 'center_gap': None}
    
    _, binary = cv2.threshold(denoised, MIN_SIGNAL_STRENGTH, 255, cv2.THRESH_BINARY)
    
    contours_raw, _ = cv2.findContours(binary.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if debug:
        print(f"    CROSSHAIR check: found {len(contours_raw)} raw bright regions")
    
    # Strategy 1: 4-blob cross pattern
    if 3 <= len(contours_raw) <= 6:
        roi_center_y = denoised.shape[0] // 2
        roi_center_x = denoised.shape[1] // 2
        
        blobs = []
        for cnt in contours_raw:
            x, y, w, h = cv2.boundingRect(cnt)
            center_x = x + w // 2
            center_y = y + h // 2
            
            is_small = (w <= 4 and h <= 4)
            near_center_x = abs(center_x - roi_center_x) <= roi_center_x * 0.6
            near_center_y = abs(center_y - roi_center_y) <= roi_center_y * 0.6
            
            if is_small and near_center_x and near_center_y:
                blob_info = {
                    'x': int(x), 'y': int(y), 'w': int(w), 'h': int(h),
                    'center_x': int(center_x), 'center_y': int(center_y)
                }
                blobs.append(blob_info)
                
                if debug:
                    print(f"      Small blob: {w}x{h} px at ({x},{y}), center=({center_x},{center_y})")
        
        if len(blobs) == 4:
            left = [b for b in blobs if b['center_x'] < roi_center_x]
            right = [b for b in blobs if b['center_x'] >= roi_center_x]
            top = [b for b in blobs if b['center_y'] < roi_center_y]
            bottom = [b for b in blobs if b['center_y'] >= roi_center_y]
            
            has_lr = len(left) == 2 and len(right) == 2
            has_tb = len(top) == 2 and len(bottom) == 2
            
            if debug:
                print(f"      Quadrant distribution: L={len(left)}, R={len(right)}, T={len(top)}, B={len(bottom)}")
            
            if has_lr and has_tb:
                centers = [(b['center_x'], b['center_y']) for b in blobs]
                
                com_x = sum(c[0] for c in centers) / 4
                com_y = sum(c[1] for c in centers) / 4
                
                distances = [((c[0] - com_x)**2 + (c[1] - com_y)**2)**0.5 for c in centers]
                avg_dist = sum(distances) / 4
                dist_variation = max(distances) - min(distances)
                
                analysis_info['blob_pattern'] = {
                    'blobs': blobs,
                    'center_of_mass': (float(com_x), float(com_y)),
                    'avg_distance': float(avg_dist),
                    'dist_variation': float(dist_variation)
                }
                
                if debug:
                    print(f"      Center of mass: ({com_x:.1f}, {com_y:.1f})")
                    print(f"      Avg distance from center: {avg_dist:.1f} ± {dist_variation:.1f}")
                
                if dist_variation <= 3:
                    confidence = min(100.0, 70.0 - dist_variation * 5.0)
                    
                    if debug:
                        print(f"    → CROSSHAIR detected (4-blob pattern, confidence: {confidence:.1f}%)")
                    
                    return True, confidence, analysis_info
    
    # Strategy 2: Line segment detection
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 1))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 3))
    binary_h = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, h_kernel)
    binary_v = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, v_kernel)
    binary_closed = cv2.bitwise_or(binary_h, binary_v)
    
    contours, _ = cv2.findContours(binary_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if debug:
        print(f"    CROSSHAIR check (line segments): found {len(contours)} regions after closing")
    
    h_segments = []
    v_segments = []
    
    roi_center_y = denoised.shape[0] // 2
    roi_center_x = denoised.shape[1] // 2
    
    for cnt in contours:
        x, y, w, h = cv2.boundingRect(cnt)
        center_x = x + w // 2
        center_y = y + h // 2
        
        if debug:
            print(f"      Region at ({x},{y}): {w}x{h} px")
        
        if w >= CROSSHAIR_SEGMENT_MIN and h <= CROSSHAIR_LINE_THICKNESS:
            aspect_ratio = w / h if h > 0 else 0
            
            if debug:
                print(f"        H-candidate: aspect={aspect_ratio:.1f}, threshold={CROSSHAIR_ASPECT_MIN}")
            
            if aspect_ratio >= CROSSHAIR_ASPECT_MIN:
                segment_info = {
                    'x': int(x), 'y': int(y), 'w': int(w), 'h': int(h),
                    'center_x': int(center_x), 'center_y': int(center_y),
                    'aspect_ratio': float(aspect_ratio),
                    'orientation': 'horizontal',
                    'side': 'left' if center_x < roi_center_x else 'right'
                }
                h_segments.append(segment_info)
                analysis_info['segments'].append(segment_info)
                
                if debug:
                    print(f"      ✓ H-segment ({segment_info['side']}): {w}x{h} px at ({x},{y}), aspect={aspect_ratio:.1f}")
        
        if h >= CROSSHAIR_SEGMENT_MIN and w <= CROSSHAIR_LINE_THICKNESS:
            aspect_ratio = h / w if w > 0 else 0
            
            if debug:
                print(f"        V-candidate: aspect={aspect_ratio:.1f}, threshold={CROSSHAIR_ASPECT_MIN}")
            
            if aspect_ratio >= CROSSHAIR_ASPECT_MIN:
                segment_info = {
                    'x': int(x), 'y': int(y), 'w': int(w), 'h': int(h),
                    'center_x': int(center_x), 'center_y': int(center_y),
                    'aspect_ratio': float(aspect_ratio),
                    'orientation': 'vertical',
                    'side': 'top' if center_y < roi_center_y else 'bottom'
                }
                v_segments.append(segment_info)
                analysis_info['segments'].append(segment_info)
                
                if debug:
                    print(f"      ✓ V-segment ({segment_info['side']}): {w}x{h} px at ({x},{y}), aspect={aspect_ratio:.1f}")
    
    left_segments = [s for s in h_segments if s['side'] == 'left']
    right_segments = [s for s in h_segments if s['side'] == 'right']
    top_segments = [s for s in v_segments if s['side'] == 'top']
    bottom_segments = [s for s in v_segments if s['side'] == 'bottom']
    
    has_left = len(left_segments) > 0
    has_right = len(right_segments) > 0
    has_top = len(top_segments) > 0
    has_bottom = len(bottom_segments) > 0
    
    segment_count = sum([has_left, has_right, has_top, has_bottom])
    
    if debug:
        print(f"      Segments found: L={len(left_segments)}, R={len(right_segments)}, T={len(top_segments)}, B={len(bottom_segments)} (total={segment_count}/4)")
    
    # Accept either 4/4 segments (perfect) or 3/4 segments (acceptable for thermal)
    if segment_count >= 3:
        # Pick best segment from each side (or None if missing)
        left = max(left_segments, key=lambda s: s['aspect_ratio']) if has_left else None
        right = max(right_segments, key=lambda s: s['aspect_ratio']) if has_right else None
        top = max(top_segments, key=lambda s: s['aspect_ratio']) if has_top else None
        bottom = max(bottom_segments, key=lambda s: s['aspect_ratio']) if has_bottom else None
        
        # Check alignment for segments that exist
        h_segments_present = [s for s in [left, right] if s is not None]
        v_segments_present = [s for s in [top, bottom] if s is not None]
        
        # Calculate alignment only if we have 2+ segments in each direction
        h_aligned = True
        v_aligned = True
        
        if len(h_segments_present) >= 2:
            h_y_diff = abs(h_segments_present[0]['center_y'] - h_segments_present[1]['center_y'])
            h_aligned = h_y_diff <= 8
            if debug:
                print(f"      H-alignment: Y-diff={h_y_diff} ({'OK' if h_aligned else 'FAIL'})")
        
        if len(v_segments_present) >= 2:
            v_x_diff = abs(v_segments_present[0]['center_x'] - v_segments_present[1]['center_x'])
            v_aligned = v_x_diff <= 8
            if debug:
                print(f"      V-alignment: X-diff={v_x_diff} ({'OK' if v_aligned else 'FAIL'})")
        
        if h_aligned and v_aligned:
            # Calculate center gap (if both segments in direction exist)
            h_gap = None
            v_gap = None
            
            if left and right:
                h_gap = right['x'] - (left['x'] + left['w'])
            if top and bottom:
                v_gap = bottom['y'] - (top['y'] + top['h'])
            
            if debug:
                gap_str = []
                if h_gap is not None:
                    gap_str.append(f"H={h_gap}px")
                if v_gap is not None:
                    gap_str.append(f"V={v_gap}px")
                print(f"      Center gap: {', '.join(gap_str) if gap_str else 'N/A (missing segments)'}")
            
            # Store gap info
            if h_gap is not None or v_gap is not None:
                analysis_info['center_gap'] = {}
                if h_gap is not None:
                    analysis_info['center_gap']['horizontal'] = float(h_gap)
                if v_gap is not None:
                    analysis_info['center_gap']['vertical'] = float(v_gap)
            
            # Verify gap is in expected range (if calculable)
            h_gap_ok = (h_gap is None) or (CROSSHAIR_CENTER_GAP_MIN <= h_gap <= CROSSHAIR_CENTER_GAP_MAX)
            v_gap_ok = (v_gap is None) or (CROSSHAIR_CENTER_GAP_MIN <= v_gap <= CROSSHAIR_CENTER_GAP_MAX)
            
            # Accept if at least one gap is valid OR we only have 3 segments
            if h_gap_ok and v_gap_ok:
                # Calculate confidence
                present_segments = [s for s in [left, right, top, bottom] if s is not None]
                avg_aspect = np.mean([s['aspect_ratio'] for s in present_segments])
                
                # Penalty for missing segment
                segment_penalty = (4 - segment_count) * 10  # 10% penalty per missing segment
                
                # Alignment bonus (only if we could calculate it)
                alignment_score = 0.0
                if len(h_segments_present) >= 2 and len(v_segments_present) >= 2:
                    alignment_score = 1.0 - (h_y_diff + v_x_diff) / 30.0
                else:
                    alignment_score = 0.5  # Assume moderate alignment if can't verify
                
                # Gap bonus
                gap_count = sum([h_gap is not None and h_gap_ok, v_gap is not None and v_gap_ok])
                gap_bonus = gap_count * 10  # 10% per valid gap
                
                confidence = min(100.0, 50.0 + alignment_score * 30.0 + gap_bonus - segment_penalty)
                
                if debug:
                    print(f"    → CROSSHAIR detected ({segment_count}/4 segments, confidence: {confidence:.1f}%)")
                
                return True, confidence, analysis_info
            else:
                if debug:
                    fail_reasons = []
                    if not h_gap_ok and h_gap is not None:
                        fail_reasons.append(f"H-gap={h_gap}")
                    if not v_gap_ok and v_gap is not None:
                        fail_reasons.append(f"V-gap={v_gap}")
                    print(f"    → Not CROSSHAIR (gaps out of range: {', '.join(fail_reasons)}, expected {CROSSHAIR_CENTER_GAP_MIN}-{CROSSHAIR_CENTER_GAP_MAX})")
        else:
            if debug:
                print(f"    → Not CROSSHAIR (segments not aligned)")
    
    if debug:
        if segment_count < 3:
            missing = []
            if not has_left: missing.append("left")
            if not has_right: missing.append("right")
            if not has_top: missing.append("top")
            if not has_bottom: missing.append("bottom")
            print(f"    → Not CROSSHAIR (only {segment_count}/4 segments, missing: {', '.join(missing)})")
        else:
            print(f"    → Not CROSSHAIR ({segment_count}/4 segments found but validation failed)")
    
    return False, 0.0, analysis_info


def detect_reticle_view(image_path: str, debug: bool = False) -> Tuple[str, float, Dict]:
    """
    Detect reticle view type: DOT, CIRCLE, or CROSSHAIR.
    
    Returns:
        (view_type, confidence, analysis_info)
    """
    try:
        full_img, roi, roi_config = load_and_crop_roi(image_path, 'normal')
        
        if debug:
            print(f"\n  Processing ROI: {roi_config['width']}x{roi_config['height']} at ({roi_config['x']}, {roi_config['y']})")
        
        denoised, noise_floor = denoise_thermal_image(roi, debug)
        
        signal_strength = np.max(denoised)
        if signal_strength < MIN_SIGNAL_STRENGTH:
            if debug:
                print(f"    ⚠ Weak signal: {signal_strength:.1f} < {MIN_SIGNAL_STRENGTH}")
            return 'weak_signal', 0.0, {'error': 'weak_signal', 'signal_strength': float(signal_strength)}
        
        results = []
        
        is_dot, dot_conf, dot_info = detect_dot_reticle(denoised, debug)
        if is_dot:
            results.append(('dot', dot_conf, dot_info))
        
        is_circle, circle_conf, circle_info = detect_circle_reticle(denoised, debug)
        if is_circle:
            results.append(('circle', circle_conf, circle_info))
        
        is_crosshair, cross_conf, cross_info = detect_crosshair_reticle(denoised, debug)
        if is_crosshair:
            results.append(('crosshair', cross_conf, cross_info))
        
        if len(results) == 0:
            if debug:
                print(f"    → UNKNOWN (no shape detected)")
            return 'unknown', 0.0, {'error': 'no_shape_detected'}
        
        results.sort(key=lambda x: x[1], reverse=True)
        view_type, confidence, analysis_info = results[0]
        
        if len(results) > 1:
            analysis_info['ambiguous'] = True
            analysis_info['other_candidates'] = [(r[0], r[1]) for r in results[1:]]
            if debug:
                print(f"    ⚠ Ambiguous: multiple shapes detected, choosing {view_type.upper()}")
        
        return view_type, confidence, analysis_info
        
    except Exception as e:
        if debug:
            print(f"    Error: {e}")
        import traceback
        traceback.print_exc()
        return 'error', 0.0, {'error': str(e)}


def save_debug_images(photo_name: str, roi: np.ndarray, denoised: np.ndarray, output_dir: str) -> None:
    """Save debug images showing ROI and processing steps."""
    os.makedirs(output_dir, exist_ok=True)
    
    cv2.imwrite(os.path.join(output_dir, f"roi_{photo_name}"), roi)
    cv2.imwrite(os.path.join(output_dir, f"denoised_{photo_name}"), denoised)
    
    _, binary = cv2.threshold(denoised, MIN_SIGNAL_STRENGTH, 255, cv2.THRESH_BINARY)
    cv2.imwrite(os.path.join(output_dir, f"binary_{photo_name}"), binary)
    
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    binary_closed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
    cv2.imwrite(os.path.join(output_dir, f"morph_closed_{photo_name}"), binary_closed)
    
    edges = cv2.Canny(denoised, 30, 100)
    cv2.imwrite(os.path.join(output_dir, f"edges_{photo_name}"), edges)


def compare_reticle_views(photo_dir: str, photo1: str, photo2: str, photo3: str,
                          expected_sequence: Optional[List[str]] = None,
                          test_persistence: bool = False,
                          debug: bool = False, save_debug: bool = False) -> Dict:
    """
    Compare reticle views across three photos.
    """
    results = {
        'success': False,
        'views_match_expected': False,
        'all_different': False,
        'all_same': False,
        'photos': {},
        'verdict': '',
        'details': []
    }
    
    try:
        photos = [photo1, photo2, photo3]
        
        print("=" * 70)
        print("RETICLE VIEW COMPARISON")
        print("=" * 70)
        print(f"\nPhoto Directory: {photo_dir}")
        
        if test_persistence:
            print(f"\nTest Mode: PERSISTENCE (Save/Cancel validation)")
            print(f"Expected: All three photos should show SAME view")
        elif expected_sequence:
            print(f"\nTest Mode: EXPECTED SEQUENCE")
            print(f"Expected: {' → '.join([s.upper() for s in expected_sequence])}")
        else:
            print(f"\nTest Mode: VIEW SWITCHING")
            print(f"Expected: All three photos should show DIFFERENT views")
        
        print(f"\nDetection Method:")
        print(f"  1. Thermal noise removal (percentile-based)")
        print(f"  2. Shape detection:")
        print(f"     - DOT: Small bright blob (≤{DOT_MAX_SIZE}x{DOT_MAX_SIZE} px) or tight cluster")
        print(f"     - CIRCLE: Ring shape ({CIRCLE_MIN_SIZE}-{CIRCLE_MAX_SIZE} px diameter)")
        print(f"     - CROSSHAIR: 3-4 segments with center gap (≥{CROSSHAIR_SEGMENT_MIN} px, aspect≥{CROSSHAIR_ASPECT_MIN})")
        print(f"  3. Confidence scoring based on shape quality")
        print()
        
        for i, photo_name in enumerate(photos, 1):
            photo_path = os.path.join(photo_dir, photo_name)
            print(f"Processing Photo {i}: {photo_name}")
            
            try:
                view_type, confidence, analysis_info = detect_reticle_view(photo_path, debug)
                
                results['photos'][f'photo{i}'] = {
                    'filename': photo_name,
                    'view': view_type,
                    'confidence': confidence,
                    'analysis_info': analysis_info
                }
                
                print(f"  → {view_type.upper()} (confidence: {confidence:.1f}%)")
                
                if save_debug:
                    _, roi, _ = load_and_crop_roi(photo_path, 'normal')
                    denoised, _ = denoise_thermal_image(roi)
                    save_debug_images(photo_name, roi, denoised, 
                                    os.path.join(photo_dir, 'debug_views'))
                
            except Exception as e:
                error_msg = f"Error processing {photo_name}: {str(e)}"
                print(f"  ERROR: {error_msg}")
                results['details'].append(error_msg)
                return results
        
        print("\n" + "-" * 70)
        print("COMPARISON ANALYSIS")
        print("-" * 70)
        
        view1 = results['photos']['photo1']['view']
        view2 = results['photos']['photo2']['view']
        view3 = results['photos']['photo3']['view']
        
        detected_views = [view1, view2, view3]
        
        all_same = (view1 == view2 == view3)
        all_different = (view1 != view2 and view2 != view3 and view1 != view3)
        
        results['all_same'] = all_same
        results['all_different'] = all_different
        
        print(f"Detected views: {view1} / {view2} / {view3}")
        print(f"All same: {'YES' if all_same else 'NO'}")
        print(f"All different: {'YES' if all_different else 'NO'}")
        
        if expected_sequence:
            matches_expected = (detected_views == expected_sequence)
            results['views_match_expected'] = matches_expected
            
            print(f"\nExpected: {' / '.join(expected_sequence)}")
            print(f"Matches expected: {'YES' if matches_expected else 'NO'}")
            
            if not matches_expected:
                for i, (detected, expected) in enumerate(zip(detected_views, expected_sequence), 1):
                    if detected != expected:
                        print(f"  Photo {i}: expected {expected.upper()}, got {detected.upper()}")
        
        print("\n" + "-" * 70)
        print("VERDICT")
        print("-" * 70)
        
        has_errors = any(v in ['error', 'weak_signal', 'unknown'] for v in detected_views)
        
        if has_errors:
            results['verdict'] = 'ERROR'
            error_views = [v for v in detected_views if v in ['error', 'weak_signal', 'unknown']]
            print(f"✗ ERROR: Failed to detect views ({', '.join(error_views)})")
            results['details'].append(f"Detection failed: {error_views}")
        
        elif expected_sequence:
            if results['views_match_expected']:
                results['verdict'] = 'PASS'
                print(f"✓ PASS: Views match expected sequence")
                results['details'].append(f"Detected: {' → '.join([v.upper() for v in detected_views])}")
            else:
                results['verdict'] = 'FAIL'
                print(f"✗ FAIL: Views do NOT match expected sequence")
                results['details'].append(f"Expected: {' → '.join([s.upper() for s in expected_sequence])}")
                results['details'].append(f"Detected: {' → '.join([v.upper() for v in detected_views])}")
        
        elif test_persistence:
            if all_same:
                results['verdict'] = 'PASS'
                print(f"✓ PASS: All three views are SAME ({view1.upper()})")
                print(f"  → Save/Cancel functionality working correctly")
                results['details'].append(f"Consistent view: {view1.upper()}")
                results['details'].append("Save operation preserved the view")
                results['details'].append("Cancel operation prevented unwanted changes")
            else:
                results['verdict'] = 'FAIL'
                print(f"✗ FAIL: Views are NOT consistent - Save or Cancel failed")
                
                if view1 != view2:
                    results['details'].append(f"Photo 1→2 changed: {view1.upper()} → {view2.upper()}")
                    results['details'].append("⚠ SAVE operation may have failed to persist the view")
                
                if view2 != view3:
                    results['details'].append(f"Photo 2→3 changed: {view2.upper()} → {view3.upper()}")
                    results['details'].append("⚠ CANCEL operation may have failed (view changed unexpectedly)")
                
                if view1 != view3:
                    results['details'].append(f"Photo 1 vs 3: {view1.upper()} → {view3.upper()}")
                    results['details'].append("⚠ Overall state inconsistent across save/cancel cycle")
        
        else:
            if all_different:
                results['verdict'] = 'PASS'
                print(f"✓ PASS: All three views are different (switching confirmed)")
                results['details'].append(f"Views: {view1.upper()} / {view2.upper()} / {view3.upper()}")
            else:
                results['verdict'] = 'FAIL'
                print(f"✗ FAIL: Views are NOT all different")
                
                if view1 == view2:
                    results['details'].append(f"Photo 1 and 2 both show: {view1.upper()}")
                if view2 == view3:
                    results['details'].append(f"Photo 2 and 3 both show: {view2.upper()}")
                if view1 == view3:
                    results['details'].append(f"Photo 1 and 3 both show: {view1.upper()}")
        
        results['success'] = True
        print("=" * 70)
        
        # Output JSON for n8n parsing (on a single line)
        import json
        json_output = {
            'verdict': results['verdict'],
            'all_same': results['all_same'],
            'all_different': results['all_different'],
            'views': detected_views,
            'details': results['details']
        }
        print(f"\nJSON_RESULT: {json.dumps(json_output)}")
        
    except Exception as e:
        error_msg = f"Critical error: {str(e)}"
        print(f"\nERROR: {error_msg}")
        results['details'].append(error_msg)
        import traceback
        traceback.print_exc()
    
    return results


def main():
    parser = argparse.ArgumentParser(
        description='Detect and compare reticle views in thermal images',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    
    parser.add_argument('photo_dir', help='Directory containing photos')
    parser.add_argument('photo1', help='First photo filename')
    parser.add_argument('photo2', help='Second photo filename')
    parser.add_argument('photo3', help='Third photo filename')
    parser.add_argument('--expect', nargs=3, metavar=('VIEW1', 'VIEW2', 'VIEW3'),
                       help='Expected view sequence (dot/circle/crosshair)')
    parser.add_argument('--test-persistence', action='store_true',
                       help='Test save/cancel (expect all same view)')
    parser.add_argument('--debug', action='store_true', help='Enable debug output')
    parser.add_argument('--save-debug', action='store_true', help='Save debug images')
    
    args = parser.parse_args()
    
    if not os.path.isdir(args.photo_dir):
        print(f"Error: Directory not found: {args.photo_dir}", file=sys.stderr)
        return 1
    
    expected_sequence = None
    if args.expect:
        valid_views = ['dot', 'circle', 'crosshair']
        expected_sequence = [v.lower() for v in args.expect]
        
        for view in expected_sequence:
            if view not in valid_views:
                print(f"Error: Invalid view '{view}'. Must be one of: {', '.join(valid_views)}", 
                      file=sys.stderr)
                return 1
    
    if args.expect and args.test_persistence:
        print("Warning: --expect takes precedence over --test-persistence", file=sys.stderr)
    
    results = compare_reticle_views(
        args.photo_dir,
        args.photo1,
        args.photo2,
        args.photo3,
        expected_sequence=expected_sequence,
        test_persistence=args.test_persistence,
        debug=args.debug,
        save_debug=args.save_debug
    )
    
    # Always exit with 0 for n8n compatibility
    # n8n can parse the verdict from stdout
    return 0


if __name__ == '__main__':
    sys.exit(main())
