"""
Hit Detector Processor

Detects "hit" events from accelerometer Z-axis peaks.

Algorithm:
1. Buffer recent Z-axis readings per device
2. Calculate peak and baseline (median)
3. If (peak - baseline) > threshold and outside refractory period:
   - Emit hit with strength value
   - Clear buffer to detect next hit

Use cases:
- Drum hit detection
- Shake/tap gestures
- Impact detection
"""

import time
from collections import deque
from typing import Dict, Any, Optional

from .base import Processor


class HitDetectorProcessor(Processor):
    """
    Detect per-device "hits" from accelerometer Z-axis values.

    Inputs:
      Sensor data dict with 'accelerometer': {'x': ..., 'y': ..., 'z': ...}

    Outputs on hit:
      {'device': phone_id, 'hit_strength': float, 'timestamp': ms}

    The hit_strength is the raw delta (peak - baseline).
    Downstream scaling to MIDI range (0-127) should be done separately.
    """

    def __init__(self, name: str = "Hit Detector", enabled: bool = False):
        super().__init__(name=name, enabled=enabled)

        # Default parameters
        self.params = {
            'buffer_length': 50,      # Number of samples to buffer
            'stride': 1,              # Process every Nth sample (1 = every sample)
            'min_threshold': 10.0,    # Minimum delta to trigger hit
            'refractory': 0.05,       # Minimum seconds between hits (50ms)
            'use_numpy': True,        # Use numpy for faster computation (if available)
            'debug': False,           # Print debug messages
        }

        # Per-device state
        self._buffers: Dict[str, deque] = {}      # Z-axis sample buffers
        self._counters: Dict[str, int] = {}       # Sample counters (for stride)
        self._last_hit: Dict[str, float] = {}     # Last hit timestamp per device

        # Statistics
        self._total_hits = 0
        self._hits_by_device: Dict[str, int] = {}

        # Try to import numpy
        try:
            import numpy as np
            self._np = np
            self._has_numpy = True
        except ImportError:
            self._np = None
            self._has_numpy = False

    def get_ui_config(self) -> Dict[str, Any]:
        """Return UI configuration for OSC tab."""
        return {
            'name': self.name,
            'description': 'Detects hits/taps from accelerometer Z-axis peaks',
            'output_address': '/crowd/hit',
            'params': [
                {
                    'key': 'buffer_length',
                    'label': 'Buffer Length',
                    'type': 'int',
                    'default': 50,
                    'min': 10,
                    'max': 200,
                    'hint': 'Number of samples to analyze (higher = smoother but slower response)'
                },
                {
                    'key': 'stride',
                    'label': 'Stride',
                    'type': 'int',
                    'default': 1,
                    'min': 1,
                    'max': 10,
                    'hint': 'Process every Nth sample (higher = less CPU, lower sensitivity)'
                },
                {
                    'key': 'min_threshold',
                    'label': 'Min Threshold',
                    'type': 'float',
                    'default': 10.0,
                    'min': 1.0,
                    'max': 50.0,
                    'step': 0.5,
                    'hint': 'Minimum peak-baseline delta to trigger hit (m/s²)'
                },
                {
                    'key': 'refractory',
                    'label': 'Refractory Period (s)',
                    'type': 'float',
                    'default': 0.05,
                    'min': 0.01,
                    'max': 1.0,
                    'step': 0.01,
                    'hint': 'Minimum time between hits (prevents double-triggers)'
                },
                {
                    'key': 'use_numpy',
                    'label': 'Use NumPy',
                    'type': 'bool',
                    'default': True,
                    'hint': f'Use NumPy for faster computation (available: {self._has_numpy})'
                },
                {
                    'key': 'debug',
                    'label': 'Debug Output',
                    'type': 'bool',
                    'default': False,
                    'hint': 'Print debug messages to console'
                }
            ]
        }

    def process(self, data: Dict[str, Any], phone_id: str = None,
                phone_index: int = None, timestamp: int = None) -> Optional[Dict[str, Any]]:
        """
        Process accelerometer data and detect hits.

        Args:
            data: Sensor data dict with 'accelerometer': {'x', 'y', 'z'}
            phone_id: Phone identifier
            phone_index: Phone index
            timestamp: Unix timestamp in milliseconds

        Returns:
            {'device': phone_id, 'hit_strength': float, 'timestamp': ms} or None
        """
        # Early exit if disabled
        if not self.enabled:
            return None

        # Extract accelerometer data
        accel = data.get('accelerometer')
        if not accel or not isinstance(accel, dict):
            return None

        # Get Z-axis value
        z = self._read_z_from_accel(accel)
        if z is None:
            return None

        # Use phone_id as device identifier (fallback to phone_index)
        device = phone_id if phone_id else str(phone_index) if phone_index is not None else 'unknown'

        # Get or create buffer for this device
        buf = self._get_buffer(device)

        # Add sample to buffer
        buf.append(float(z))

        # Increment counter
        self._counters[device] = self._counters.get(device, 0) + 1

        # Stride check - only process every Nth sample
        stride = int(self.params.get('stride', 1))
        if (self._counters[device] % stride) != 0:
            return None

        # Need enough samples
        if len(buf) == 0:
            return None

        # Calculate peak and baseline
        peak, baseline = self._calculate_peak_baseline(buf)

        # Calculate delta
        delta = peak - baseline

        # Debug output
        if self.params.get('debug', False):
            print(f"[{self.name}] device={device} peak={peak:.3f} baseline={baseline:.3f} "
                  f"delta={delta:.3f} threshold={self.params['min_threshold']}")

        # Check if hit detected
        min_thresh = float(self.params.get('min_threshold', 10.0))
        refractory = float(self.params.get('refractory', 0.05))
        now = time.time()
        last_hit_t = self._last_hit.get(device, 0.0)

        if delta >= min_thresh and (now - last_hit_t >= refractory):
            # HIT DETECTED!
            self._last_hit[device] = now
            self._total_hits += 1
            self._hits_by_device[device] = self._hits_by_device.get(device, 0) + 1

            if self.params.get('debug', False):
                print(f"[{self.name}] HIT → device={device} strength={delta:.3f} "
                      f"time={now:.3f} (total hits: {self._total_hits})")

            # Clear buffer to detect next hit
            buf.clear()
            self._counters[device] = 0

            # Return hit event
            return {
                'device': device,
                'device_id': phone_id,
                'device_index': phone_index,
                'hit_strength': float(delta),
                'timestamp': timestamp or int(now * 1000)
            }

        return None

    def _read_z_from_accel(self, accel: dict) -> Optional[float]:
        """
        Extract Z-axis value from accelerometer dict.

        Tries multiple key names: 'z', 'accelerometer_z', 'accel_z', 'acc_z'

        Args:
            accel: Accelerometer data dict

        Returns:
            Z-axis value as float, or None if not found
        """
        for key in ('z', 'accelerometer_z', 'accel_z', 'acc_z'):
            if key in accel:
                try:
                    return float(accel[key])
                except (ValueError, TypeError):
                    pass

        # If only one value in dict, assume it's Z
        if len(accel) == 1:
            try:
                return float(list(accel.values())[0])
            except (ValueError, TypeError):
                pass

        return None

    def _get_buffer(self, device: str) -> deque:
        """
        Get or create buffer for a device.

        Args:
            device: Device identifier

        Returns:
            Deque buffer with maxlen set to buffer_length
        """
        if device not in self._buffers:
            buffer_length = int(self.params.get('buffer_length', 50))
            self._buffers[device] = deque(maxlen=buffer_length)
            self._counters[device] = 0
            self._last_hit[device] = 0.0

        return self._buffers[device]

    def _calculate_peak_baseline(self, buf: deque) -> tuple:
        """
        Calculate peak and baseline from buffer.

        Uses numpy if available and enabled, otherwise pure Python.

        Args:
            buf: Buffer of Z-axis samples

        Returns:
            (peak, baseline) tuple
        """
        use_numpy = self.params.get('use_numpy', True) and self._has_numpy

        if use_numpy and self._np is not None:
            # NumPy implementation (faster)
            arr = self._np.array(buf, dtype=float)
            peak = float(self._np.max(arr))
            baseline = float(self._np.median(arr))
        else:
            # Pure Python implementation (slower but no dependencies)
            samples = list(buf)
            peak = float(max(samples))

            # Calculate median
            sorted_samples = sorted(samples)
            n = len(sorted_samples)
            if n % 2 == 1:
                baseline = float(sorted_samples[n // 2])
            else:
                baseline = float((sorted_samples[n // 2 - 1] + sorted_samples[n // 2]) / 2.0)

        return peak, baseline

    def on_param_change(self, key: str, value: Any):
        """Handle parameter changes."""
        # Type conversion and validation
        if key == 'buffer_length':
            try:
                new_length = int(value)
                if new_length < 10:
                    new_length = 10
                elif new_length > 200:
                    new_length = 200
                self.params[key] = new_length

                # Update existing buffers
                for device in self._buffers:
                    old_data = list(self._buffers[device])
                    self._buffers[device] = deque(old_data, maxlen=new_length)

            except (ValueError, TypeError):
                pass

        elif key == 'stride':
            try:
                self.params[key] = max(1, min(10, int(value)))
            except (ValueError, TypeError):
                pass

        elif key in ('min_threshold', 'refractory'):
            try:
                self.params[key] = float(value)
            except (ValueError, TypeError):
                pass

        elif key in ('use_numpy', 'debug'):
            self.params[key] = bool(value)

        print(f"[{self.name}] param '{key}' updated → {self.params.get(key)}")

    def reset(self):
        """Reset all buffers and counters."""
        self._buffers.clear()
        self._counters.clear()
        self._last_hit.clear()
        self._total_hits = 0
        self._hits_by_device.clear()
        print(f"[{self.name}] Reset complete")

    def get_stats(self) -> Dict[str, Any]:
        """Return processor statistics."""
        return {
            'enabled': self.enabled,
            'name': self.name,
            'total_hits': self._total_hits,
            'active_devices': len(self._buffers),
            'hits_by_device': dict(self._hits_by_device),
            'has_numpy': self._has_numpy,
            'using_numpy': self.params.get('use_numpy', True) and self._has_numpy,
        }
