"""
OSC Service

Main OSC orchestration service providing bidirectional OSC communication.

Responsibilities:
- Send outbound OSC messages (sensor data, interactions, aggregated data, events)
- Receive inbound OSC commands (color, navigate, alert, frequency)
- Manage stream enable/disable toggles
- Coordinate aggregation service
- Coordinate signal processors (hit detection, filters, etc.)
- Thread-safe callback integration with Textual app
"""

import threading
from typing import Optional, Dict
from pythonosc.dispatcher import Dispatcher
from pythonosc.osc_server import BlockingOSCUDPServer

from utils.osc_sender import OSCSender
from utils.osc_aggregator import OSCAggregator
from utils.osc_validators import (
    validate_color_command,
    validate_frequency,
    sanitize_string
)
from utils.processors import HitDetectorProcessor


class OSCService:
    """
    Bidirectional OSC communication service with stream control.

    Integrates with CrowdSourceApp via callbacks for:
    - Outbound: Convert Socket.IO events to OSC messages
    - Inbound: Convert OSC commands to Socket.IO actions
    """

    def __init__(self, app):
        """
        Initialize OSC service.

        Args:
            app: Reference to CrowdSourceApp instance (for callbacks)
        """
        self.app = app

        # Components
        self.sender = OSCSender('127.0.0.1', 7000)
        self.aggregator = OSCAggregator()

        # OSC Server (receiving)
        self.dispatcher = Dispatcher()
        self.server = None
        self.server_thread = None

        # Stream enable/disable toggles (all disabled by default)
        self.streams = {
            'sensor_all': False,
            'sensor_accel': False,
            'sensor_gyro': False,
            'sensor_orientation': False,
            'interaction': False,
            'phone_states': False,
            'phone_names': False,
            'votes': False,
            'totals': False,
            'phone_joined': False,
            'phone_left': False,
            'phone_count': False
        }

        # Processor Registry (all disabled by default)
        # New processors can be added here
        self.processors: Dict[str, any] = {
            'hit_detector': HitDetectorProcessor(name="Hit Detector", enabled=False),
            # Add more processors here as needed:
            # 'smoother': SmootherProcessor(enabled=False),
            # 'scaler': ScalerProcessor(enabled=False),
        }

        # Statistics
        self.inbound_count = 0
        self.last_received = None

        # Register OSC handlers
        self._register_handlers()

    def _register_handlers(self):
        """Register inbound OSC address handlers."""
        # Register exact paths for base commands (all phones)
        self.dispatcher.map('/crowd/page/color', self._handle_color_all)
        self.dispatcher.map('/crowd/page/navigate', self._handle_navigate_all)
        self.dispatcher.map('/crowd/page/alert', self._handle_alert_all)
        self.dispatcher.map('/crowd/sensor/frequency', self._handle_frequency)
        self.dispatcher.map('/crowd/system/reset', self._handle_reset)

        # Set default handler for dynamic paths (with identifiers in path)
        self.dispatcher.set_default_handler(self._handle_dynamic_path)

    # ========================================
    # Outbound Methods (Socket.IO → OSC)
    # ========================================

    def send_sensor_data(self, data: dict, phone_index: int = None):
        """
        Send sensor OSC messages and run processors.

        OPTIMIZATION: Early exit if nothing is enabled (zero overhead when disabled).

        Sends individual sensor messages based on stream toggles:
        - /crowd/sensor/all (all sensors combined)
        - /crowd/sensor/accel (accelerometer only)
        - /crowd/sensor/gyro (gyroscope only)
        - /crowd/sensor/orientation (orientation only)

        Runs enabled processors:
        - hit_detector: Detects hits from accelerometer Z-axis peaks

        Args:
            data: Sensor data from Socket.IO event
            phone_index: Phone index (optional, for index mode)
        """
        # CRITICAL OPTIMIZATION: Early exit if nothing needs processing
        if not self._needs_sensor_processing():
            return

        # Extract common fields (only if we got past early exit)
        phone_id = data.get('phoneId', '')
        timestamp = data.get('timestamp', 0)
        sensor_data = data.get('data', {})

        # Use index if toggle enabled, otherwise ID
        identifier = phone_index if (self.app.osc_use_device_index and phone_index is not None) else phone_id

        # ========================================
        # RAW SENSOR STREAMS
        # ========================================

        # Only extract sensor values if any raw streams are enabled
        if self._any_raw_sensor_streams_enabled():
            accel = sensor_data.get('accelerometer', {})
            gyro = sensor_data.get('gyroscope')
            orientation = sensor_data.get('orientation')

            # Extract accelerometer values
            ax = accel.get('x', 0.0)
            ay = accel.get('y', 0.0)
            az = accel.get('z', 0.0)

            # Extract gyroscope values (optional)
            g_alpha = gyro.get('alpha', 0.0) if gyro else 0.0
            g_beta = gyro.get('beta', 0.0) if gyro else 0.0
            g_gamma = gyro.get('gamma', 0.0) if gyro else 0.0

            # Extract orientation values (optional)
            o_alpha = orientation.get('alpha', 0.0) if orientation else 0.0
            o_beta = orientation.get('beta', 0.0) if orientation else 0.0
            o_gamma = orientation.get('gamma', 0.0) if orientation else 0.0

            # Send individual messages based on stream toggles
            if self.streams['sensor_all']:
                self.sender.send('/crowd/sensor/all', identifier, timestamp,
                               ax, ay, az, g_alpha, g_beta, g_gamma,
                               o_alpha, o_beta, o_gamma)

            if self.streams['sensor_accel']:
                self.sender.send('/crowd/sensor/accel', identifier, timestamp, ax, ay, az)

            if gyro and self.streams['sensor_gyro']:
                self.sender.send('/crowd/sensor/gyro', identifier, timestamp,
                               g_alpha, g_beta, g_gamma)

            if orientation and self.streams['sensor_orientation']:
                self.sender.send('/crowd/sensor/orientation', identifier, timestamp,
                               o_alpha, o_beta, o_gamma)

        # ========================================
        # PROCESSORS
        # ========================================

        # Run enabled processors (each processor has its own early-exit)
        for proc_id, processor in self.processors.items():
            if processor.enabled:
                try:
                    result = processor.process(
                        data=sensor_data,
                        phone_id=phone_id,
                        phone_index=phone_index,
                        timestamp=timestamp
                    )

                    # If processor returned a result, send it via OSC
                    if result:
                        self._send_processor_output(proc_id, result)

                except Exception as e:
                    # Don't let processor errors crash the system
                    print(f"[OSCService] Processor '{proc_id}' error: {e}")

    def _needs_sensor_processing(self) -> bool:
        """
        Check if ANY sensor processing is needed.

        This is the critical early-exit check for performance.
        If this returns False, we skip ALL processing (zero overhead).

        Returns:
            True if any raw streams or processors are enabled
        """
        # Check raw sensor streams
        if self._any_raw_sensor_streams_enabled():
            return True

        # Check processors
        if any(p.enabled for p in self.processors.values()):
            return True

        return False

    def _any_raw_sensor_streams_enabled(self) -> bool:
        """Check if any raw sensor streams are enabled."""
        return any([
            self.streams['sensor_all'],
            self.streams['sensor_accel'],
            self.streams['sensor_gyro'],
            self.streams['sensor_orientation']
        ])

    def _send_processor_output(self, proc_id: str, result: dict):
        """
        Send processor output via OSC.

        Args:
            proc_id: Processor ID (e.g., 'hit_detector')
            result: Result dict from processor (structure varies by processor)
        """
        # Get processor to determine output format
        processor = self.processors.get(proc_id)
        if not processor:
            return

        # HIT DETECTOR: Send /crowd/hit/<identifier> <strength> <timestamp>
        if proc_id == 'hit_detector':
            device_id = result.get('device_id')
            device_index = result.get('device_index')
            hit_strength = result.get('hit_strength', 0.0)
            timestamp = result.get('timestamp', 0)

            # Use index if toggle enabled, otherwise ID
            identifier = device_index if (self.app.osc_use_device_index and device_index is not None) else device_id

            self.sender.send('/crowd/hit', identifier, timestamp, hit_strength)

        # Add more processor output handlers here as needed
        # elif proc_id == 'smoother':
        #     ...
        # elif proc_id == 'scaler':
        #     ...

    def send_interaction(self, data: dict, phone_index: int = None):
        """
        Send interaction and trigger aggregated updates.

        Sends:
        - /crowd/interaction (individual button press)
        - Updates aggregator and sends aggregated data

        Args:
            data: Interaction event from Socket.IO
            phone_index: Phone index (optional, for index mode)
        """
        phone_id = data.get('phoneId', '')
        timestamp = data.get('timestamp', 0)
        # Use pageInstanceName if provided, otherwise fall back to page type
        page_instance_name = data.get('pageInstanceName') or data.get('page', '')
        button_index = data.get('buttonIndex', -1)
        button_name = data.get('buttonText', '')

        # Use index if toggle enabled, otherwise ID
        identifier = phone_index if (self.app.osc_use_device_index and phone_index is not None) else phone_id

        # Send individual interaction
        if self.streams['interaction']:
            self.sender.send('/crowd/interaction', identifier, timestamp,
                           page_instance_name, button_index, button_name)

        # Update aggregator
        self.aggregator.on_interaction(data)

        # Send aggregated data updates (pass pageInstanceName for totals)
        self._send_aggregated_data(page_instance_name)

    def send_phone_joined(self, phone_id: str, phone_index: int, timestamp: int, total_count: int):
        """
        Send phone joined events.

        Sends:
        - /crowd/phone/joined
        - /crowd/phone/count
        - Updated metadata for phone states

        Args:
            phone_id: Phone identifier
            phone_index: Phone index
            timestamp: Unix timestamp in milliseconds
            total_count: Total number of connected phones
        """
        self.aggregator.on_phone_joined(phone_id, phone_index)

        # Use index if toggle enabled, otherwise ID
        identifier = phone_index if self.app.osc_use_device_index else phone_id

        if self.streams['phone_joined']:
            self.sender.send('/crowd/phone/joined', identifier, timestamp, total_count)

        if self.streams['phone_count']:
            self.sender.send('/crowd/phone/count', total_count)

        # Send updated metadata
        self._send_phone_states_metadata()

    def send_phone_left(self, phone_id: str, phone_index: int, timestamp: int, total_count: int):
        """
        Send phone left events.

        Sends:
        - /crowd/phone/left
        - /crowd/phone/count
        - Updated metadata and aggregated data

        Args:
            phone_id: Phone identifier
            phone_index: Phone index
            timestamp: Unix timestamp in milliseconds
            total_count: Total number of connected phones
        """
        # Use index if toggle enabled, otherwise ID (before removing from aggregator)
        identifier = phone_index if self.app.osc_use_device_index else phone_id

        self.aggregator.on_phone_left(phone_id)

        if self.streams['phone_left']:
            self.sender.send('/crowd/phone/left', identifier, timestamp, total_count)

        if self.streams['phone_count']:
            self.sender.send('/crowd/phone/count', total_count)

        # Send updated metadata and data
        self._send_phone_states_metadata()
        # Pass current_page_instance_name for totals updates
        self._send_aggregated_data(self.aggregator.current_page_instance_name)

    def send_page_change(self, page: str, page_instance_id: str, config: dict):
        """
        Send metadata on page change.

        Sends:
        - /crowd/state/phones/meta
        - /crowd/votes/meta
        - /crowd/totals/meta

        Args:
            page: Page path/identifier
            page_instance_id: Unique ID for this page instance
            config: Page configuration dict
        """
        self.aggregator.on_page_change(page, page_instance_id, config)

        # Send all metadata messages
        self._send_phone_states_metadata()
        self._send_votes_metadata()
        self._send_totals_metadata()

    def _send_aggregated_data(self, page_instance_name: str = None):
        """
        Send current aggregated data (called after interactions).

        Args:
            page_instance_name: The pageInstanceName from the interaction (for totals)
        """
        use_index = self.app.osc_use_device_index

        # Phone states data
        if self.streams['phone_states']:
            page_id, phone_count, identifiers, states = self.aggregator.get_phone_states_data(use_index)
            if page_id:
                self.sender.send('/crowd/state/phones/data', page_id, *states)

        # Phone names data
        if self.streams['phone_names']:
            page_id, phone_count, identifiers, names = self.aggregator.get_phone_names_data(use_index)
            if page_id:
                self.sender.send('/crowd/state/phones/names', page_id, *names)

        # Votes data
        if self.streams['votes']:
            page_id, button_count, button_names, counts = self.aggregator.get_votes_data()
            if page_id and button_count > 0:
                self.sender.send('/crowd/votes/data', page_id, *counts)

        # Totals data (NEW approach using pageInstanceName)
        if self.streams['totals'] and page_instance_name:
            instance_name, button_count, button_names, totals = \
                self.aggregator.get_totals_data_by_instance_name(page_instance_name)

            if instance_name and button_count > 0:
                # Send metadata first (on first interaction with a button, this will populate)
                self._send_totals_metadata_for_instance(page_instance_name)

                # Send data
                self.sender.send('/crowd/totals/data', instance_name, *totals)

    def _send_phone_states_metadata(self):
        """Send phone states metadata."""
        if self.streams['phone_states']:
            use_index = self.app.osc_use_device_index
            page_id, phone_count, identifiers, states = self.aggregator.get_phone_states_data(use_index)
            if page_id:
                self.sender.send('/crowd/state/phones/meta', page_id, phone_count, *identifiers)

    def _send_votes_metadata(self):
        """Send votes metadata."""
        if self.streams['votes']:
            page_id, button_count, button_names, counts = self.aggregator.get_votes_data()
            if page_id and button_count > 0:
                self.sender.send('/crowd/votes/meta', page_id, button_count, *button_names)

    def _send_totals_metadata(self):
        """Send totals metadata."""
        if self.streams['totals']:
            page_instance_id, button_count, button_names, totals = self.aggregator.get_totals_data()
            if page_instance_id and button_count > 0:
                self.sender.send('/crowd/totals/meta', page_instance_id, button_count, *button_names)

    def _send_totals_metadata_for_instance(self, page_instance_name: str):
        """
        Send totals metadata for a specific pageInstanceName.

        Args:
            page_instance_name: The pageInstanceName to send metadata for
        """
        if self.streams['totals']:
            instance_name, button_count, button_names, totals = \
                self.aggregator.get_totals_data_by_instance_name(page_instance_name)

            if instance_name and button_count > 0:
                self.sender.send('/crowd/totals/meta', instance_name, button_count, *button_names)

    # ========================================
    # Inbound Handlers (OSC → Socket.IO)
    # ========================================

    def _handle_color_all(self, address: str, *args):
        """Handle /crowd/page/color - set all phones to color."""
        self.inbound_count += 1
        self.last_received = f"{address} {args}"

        result = validate_color_command(list(args))
        if result and len(result) == 3:
            r, g, b = result
            self.app.call_from_thread(self.app.handle_osc_color, None, r, g, b)

    def _handle_navigate_all(self, address: str, *args):
        """Handle /crowd/page/navigate - navigate all phones."""
        self.inbound_count += 1
        self.last_received = f"{address} {args}"

        if len(args) >= 1:
            page_path = str(args[0])
            self.app.call_from_thread(self.app.handle_osc_navigate, None, page_path)

    def _handle_alert_all(self, address: str, *args):
        """Handle /crowd/page/alert - send alert to all phones."""
        self.inbound_count += 1
        self.last_received = f"{address} {args}"

        if len(args) >= 1:
            message = sanitize_string(str(args[0]))
            self.app.call_from_thread(self.app.handle_osc_alert, None, message)

    def _handle_frequency(self, address: str, *args):
        """Handle /crowd/sensor/frequency - set sensor frequency."""
        self.inbound_count += 1
        self.last_received = f"{address} {args}"

        if len(args) >= 1:
            try:
                hz = int(args[0])
                ms = validate_frequency(hz)
                if ms:
                    self.app.call_from_thread(self.app.handle_osc_frequency, ms)
            except (ValueError, TypeError):
                pass

    def _handle_reset(self, address: str, *args):
        """Handle /crowd/system/reset - reset aggregations."""
        self.inbound_count += 1
        self.last_received = f"{address} {args}"

        self.aggregator.reset()

    def _handle_dynamic_path(self, address: str, *args):
        """
        Handle dynamic OSC paths with identifiers embedded in the address.

        Supports patterns like:
        - /crowd/page/color/<identifier> <r> <g> <b>
        - /crowd/page/navigate/<identifier> <page_path>
        - /crowd/page/alert/<identifier> <message>

        Args:
            address: OSC address string (e.g., "/crowd/page/color/1")
            *args: OSC arguments
        """
        # Parse address to extract components
        parts = address.split('/')
        # parts = ['', 'crowd', 'page', 'command', 'identifier']

        if len(parts) < 5:
            # Not enough parts for a dynamic command, ignore
            return

        if parts[1] != 'crowd' or parts[2] != 'page':
            # Not a crowd/page command, ignore
            return

        command = parts[3]  # 'color', 'navigate', or 'alert'
        identifier_str = parts[4]  # '1', '2', 'phone_abc123', etc.

        # Convert identifier to proper type (int or str)
        device_identifier = self._parse_identifier(identifier_str)

        # Route to appropriate handler based on command
        if command == 'color' and len(args) >= 3:
            self._handle_color_with_id(device_identifier, args)
        elif command == 'navigate' and len(args) >= 1:
            self._handle_navigate_with_id(device_identifier, args)
        elif command == 'alert' and len(args) >= 1:
            self._handle_alert_with_id(device_identifier, args)
        # If doesn't match known patterns, silently ignore

    def _parse_identifier(self, identifier_str: str):
        """
        Parse identifier from path segment.

        Tries to convert to int if numeric, otherwise keeps as string.

        Args:
            identifier_str: Identifier string from path (e.g., "1", "phone_abc")

        Returns:
            int if numeric, str otherwise
        """
        try:
            # Try to parse as integer index
            return int(identifier_str)
        except ValueError:
            # Keep as string (phone ID)
            return identifier_str

    def _handle_color_with_id(self, device_identifier, args):
        """Handle color command with identifier from path."""
        self.inbound_count += 1
        self.last_received = f"/crowd/page/color/{device_identifier} {args}"

        result = validate_color_command(list(args))
        if result and len(result) == 3:
            r, g, b = result
            self.app.call_from_thread(self.app.handle_osc_color, device_identifier, r, g, b)

    def _handle_navigate_with_id(self, device_identifier, args):
        """Handle navigate command with identifier from path."""
        self.inbound_count += 1
        self.last_received = f"/crowd/page/navigate/{device_identifier} {args}"

        if len(args) >= 1:
            page_path = str(args[0])
            self.app.call_from_thread(self.app.handle_osc_navigate, device_identifier, page_path)

    def _handle_alert_with_id(self, device_identifier, args):
        """Handle alert command with identifier from path."""
        self.inbound_count += 1
        self.last_received = f"/crowd/page/alert/{device_identifier} {args}"

        if len(args) >= 1:
            message = sanitize_string(str(args[0]))
            self.app.call_from_thread(self.app.handle_osc_alert, device_identifier, message)

    # ========================================
    # Processor Management Methods
    # ========================================

    def set_processor_enabled(self, proc_id: str, enabled: bool):
        """
        Enable or disable a processor.

        Args:
            proc_id: Processor ID (e.g., 'hit_detector')
            enabled: True to enable, False to disable
        """
        if proc_id in self.processors:
            self.processors[proc_id].enabled = enabled
            print(f"[OSCService] Processor '{proc_id}' {'enabled' if enabled else 'disabled'}")

    def set_processor_param(self, proc_id: str, param_key: str, value):
        """
        Set a processor parameter.

        Args:
            proc_id: Processor ID (e.g., 'hit_detector')
            param_key: Parameter key (e.g., 'min_threshold')
            value: New parameter value
        """
        if proc_id in self.processors:
            processor = self.processors[proc_id]
            if processor.set_param(param_key, value):
                print(f"[OSCService] Processor '{proc_id}' param '{param_key}' set to {value}")
            else:
                print(f"[OSCService] Processor '{proc_id}' has no param '{param_key}'")

    def reset_processor(self, proc_id: str):
        """
        Reset a processor's state (buffers, counters, etc.).

        Args:
            proc_id: Processor ID (e.g., 'hit_detector')
        """
        if proc_id in self.processors:
            self.processors[proc_id].reset()
            print(f"[OSCService] Processor '{proc_id}' reset")

    def get_processor_stats(self, proc_id: str) -> dict:
        """
        Get processor statistics.

        Args:
            proc_id: Processor ID (e.g., 'hit_detector')

        Returns:
            Dict with processor stats
        """
        if proc_id in self.processors:
            return self.processors[proc_id].get_stats()
        return {}

    # ========================================
    # Lifecycle Methods
    # ========================================

    def start(self, receive_port: int = 7001):
        """
        Start OSC server.

        Args:
            receive_port: UDP port to listen on
        """
        if self.server_thread is None:
            try:
                self.server = BlockingOSCUDPServer(('0.0.0.0', receive_port), self.dispatcher)
                self.server_thread = threading.Thread(
                    target=self.server.serve_forever,
                    daemon=True
                )
                self.server_thread.start()
            except Exception as e:
                # Log error but don't crash
                print(f"Failed to start OSC server: {e}")

    def stop(self):
        """Stop OSC server."""
        if self.server:
            self.server.shutdown()
            self.server_thread = None
            self.server = None

    def update_config(self, send_host: str, send_port: int, receive_port: int):
        """
        Update OSC configuration.

        Args:
            send_host: Target host for outbound messages
            send_port: Target port for outbound messages
            receive_port: Port to listen on for inbound messages
        """
        # Update sender
        self.sender.update_target(send_host, send_port)

        # Restart server with new port
        self.stop()
        self.start(receive_port)

    def enable(self):
        """Enable OSC sending."""
        self.sender.enable()

    def disable(self):
        """Disable OSC sending."""
        self.sender.disable()

    def set_stream_enabled(self, stream_name: str, enabled: bool):
        """
        Enable/disable specific stream.

        Args:
            stream_name: Name of stream to toggle
            enabled: True to enable, False to disable
        """
        if stream_name in self.streams:
            self.streams[stream_name] = enabled

    def get_stats(self) -> dict:
        """
        Get statistics for UI display.

        Returns:
            Dictionary with enabled status, message counts, last messages, etc.
        """
        # Gather processor stats
        processor_stats = {}
        for proc_id, processor in self.processors.items():
            processor_stats[proc_id] = processor.get_stats()

        return {
            'enabled': self.sender.enabled,
            'outbound_count': self.sender.message_count,
            'inbound_count': self.inbound_count,
            'last_sent': self.sender.last_message,
            'last_received': self.last_received,
            'last_error': self.sender.last_error,
            'streams': self.streams.copy(),
            'aggregator_state': self.aggregator.get_state_summary(),
            'processors': processor_stats
        }
