"""
OSC Aggregator

Tracks button states and votes across all phones for aggregated OSC messages.
Maintains state for:
- Individual phone button selections
- Vote counts per button
- Cumulative total presses per button per page instance
- Button metadata (names)
"""

from typing import Dict, List, Tuple, Optional


class OSCAggregator:
    """
    Aggregates interaction data across phones for OSC messages.

    Tracks three types of aggregated data:
    1. Phone Button States - which button each phone last pressed
    2. Vote Counts - how many phones voted for each button (current selection)
    3. Total Presses - cumulative presses per button across all phones (all-time for page instance)
    """

    def __init__(self):
        """Initialize aggregator with empty state."""
        self.current_page: Optional[str] = None
        self.page_instance_id: Optional[str] = None
        self.current_page_instance_name: Optional[str] = None  # Track current pageInstanceName

        # Phone state tracking
        self.phone_button_states: Dict[str, int] = {}  # phoneId -> buttonIndex
        self.phone_button_names: Dict[str, str] = {}  # phoneId -> buttonText
        self.phone_id_order: List[str] = []  # Stable ordering for array messages
        self.phone_id_to_index: Dict[str, int] = {}  # phoneId -> index

        # Vote tracking (one vote per phone, last selection)
        self.vote_counts: Dict[int, int] = {}  # buttonIndex -> count

        # Total presses (cumulative, all-time for page instance)
        self.total_presses: Dict[str, Dict[int, int]] = {}  # pageInstanceName -> {buttonIndex: count}
        self.totals_button_metadata: Dict[str, Dict[int, str]] = {}  # pageInstanceName -> {buttonIndex: buttonName}

        # Button metadata
        self.button_metadata: Dict[str, List[str]] = {}  # pageId -> [button names]

    def on_interaction(self, data: dict):
        """
        Process an interaction event and update state.

        Args:
            data: Interaction event data from Socket.IO
                  {phoneId, buttonIndex, buttonText, page, pageInstanceName, ...}
        """
        phone_id = data.get('phoneId')
        button_index = data.get('buttonIndex')
        button_text = data.get('buttonText', '')
        page_instance_name = data.get('pageInstanceName') or data.get('page', '')

        if phone_id is None or button_index is None:
            return

        # Track current pageInstanceName
        if page_instance_name:
            self.current_page_instance_name = page_instance_name

        # Add phone to tracking if not present
        if phone_id not in self.phone_id_order:
            self.phone_id_order.append(phone_id)

        # Update phone's current selection (for vote counting)
        old_selection = self.phone_button_states.get(phone_id)
        self.phone_button_states[phone_id] = button_index
        self.phone_button_names[phone_id] = button_text

        # Update vote counts (one vote per phone)
        # Decrement old selection if it exists
        if old_selection is not None and old_selection in self.vote_counts:
            self.vote_counts[old_selection] -= 1
            if self.vote_counts[old_selection] <= 0:
                del self.vote_counts[old_selection]

        # Increment new selection
        self.vote_counts[button_index] = self.vote_counts.get(button_index, 0) + 1

        # Update total presses (cumulative, keyed by pageInstanceName)
        if page_instance_name:
            if page_instance_name not in self.total_presses:
                self.total_presses[page_instance_name] = {}
                self.totals_button_metadata[page_instance_name] = {}

            totals = self.total_presses[page_instance_name]
            totals[button_index] = totals.get(button_index, 0) + 1

            # Learn button names from interactions
            if button_index not in self.totals_button_metadata[page_instance_name]:
                self.totals_button_metadata[page_instance_name][button_index] = button_text

    def on_page_change(self, page: str, page_instance_id: str, config: dict):
        """
        Handle page navigation - reset votes and extract button metadata.

        Args:
            page: Page path/identifier
            page_instance_id: Unique ID for this page instance
            config: Page configuration dict (may contain 'buttons' array)
        """
        self.current_page = page
        self.page_instance_id = page_instance_id

        # Reset vote counts (new page, new votes)
        self.vote_counts = {}
        self.phone_button_states = {}

        # Extract button names from config
        buttons = config.get('buttons', [])
        if buttons:
            self.button_metadata[page] = [btn.get('text', '') for btn in buttons]
        else:
            # Default metadata if not provided
            self.button_metadata[page] = []

    def on_phone_joined(self, phone_id: str, phone_index: int = None):
        """
        Add phone to tracking when it joins.

        Args:
            phone_id: Phone identifier
            phone_index: Phone index (optional, for index mode)
        """
        if phone_id not in self.phone_id_order:
            self.phone_id_order.append(phone_id)

        if phone_index is not None:
            self.phone_id_to_index[phone_id] = phone_index

    def on_phone_left(self, phone_id: str):
        """
        Remove phone from tracking when it leaves.

        Updates vote counts to reflect removed phone.

        Args:
            phone_id: Phone identifier
        """
        if phone_id in self.phone_id_order:
            self.phone_id_order.remove(phone_id)

        # Remove from index tracking
        if phone_id in self.phone_id_to_index:
            del self.phone_id_to_index[phone_id]

        # Remove from vote count
        if phone_id in self.phone_button_states:
            old_selection = self.phone_button_states[phone_id]

            if old_selection in self.vote_counts:
                self.vote_counts[old_selection] -= 1
                if self.vote_counts[old_selection] <= 0:
                    del self.vote_counts[old_selection]

            del self.phone_button_states[phone_id]

        # Remove button name
        if phone_id in self.phone_button_names:
            del self.phone_button_names[phone_id]

    def get_phone_states_data(self, use_index: bool = False) -> Tuple[Optional[str], int, List, List[int]]:
        """
        Get current phone button states for OSC message.

        Args:
            use_index: If True, return phone indices instead of IDs

        Returns:
            Tuple of (pageInstanceName, phoneCount, [identifiers], [button_states])
            identifiers are either phone IDs (str) or indices (int) based on use_index
            button_states[i] is the button index that identifiers[i] last pressed (-1 if none)
        """
        if use_index:
            # Return indices
            identifiers = [self.phone_id_to_index.get(pid, -1) for pid in self.phone_id_order]
        else:
            # Return phone IDs
            identifiers = self.phone_id_order.copy()

        states = [self.phone_button_states.get(pid, -1) for pid in self.phone_id_order]
        return (self.current_page_instance_name, len(identifiers), identifiers, states)

    def get_phone_names_data(self, use_index: bool = False) -> Tuple[Optional[str], int, List, List[str]]:
        """
        Get current phone button names for OSC message.

        Args:
            use_index: If True, return phone indices instead of IDs (for consistency with phone_states)

        Returns:
            Tuple of (pageInstanceName, phoneCount, [identifiers], [button_names])
            identifiers are either phone IDs (str) or indices (int) based on use_index
            button_names[i] is the button text that identifiers[i] last pressed ("" if none)
        """
        if use_index:
            # Return indices
            identifiers = [self.phone_id_to_index.get(pid, -1) for pid in self.phone_id_order]
        else:
            # Return phone IDs
            identifiers = self.phone_id_order.copy()

        # Get button names for each phone ("" if no interaction yet)
        names = [self.phone_button_names.get(pid, "") for pid in self.phone_id_order]

        return (self.current_page_instance_name, len(identifiers), identifiers, names)

    def get_votes_data(self) -> Tuple[Optional[str], int, List[str], List[int]]:
        """
        Get current vote counts for OSC message.

        Returns:
            Tuple of (pageId, buttonCount, [button_names], [vote_counts])
            vote_counts[i] is the number of phones that voted for button i
        """
        button_names = self.button_metadata.get(self.current_page, [])
        button_count = len(button_names)

        # Build vote count array aligned with button order
        counts = [self.vote_counts.get(i, 0) for i in range(button_count)]

        return (self.current_page, button_count, button_names, counts)

    def get_totals_data(self) -> Tuple[Optional[str], int, List[str], List[int]]:
        """
        Get cumulative total presses for OSC message.

        Returns:
            Tuple of (pageInstanceId, buttonCount, [button_names], [total_presses])
            total_presses[i] is the total number of times button i was pressed (all-time for this page instance)
        """
        button_names = self.button_metadata.get(self.current_page, [])
        button_count = len(button_names)

        # Get totals for current page instance
        totals_dict = self.total_presses.get(self.page_instance_id, {})
        totals = [totals_dict.get(i, 0) for i in range(button_count)]

        return (self.page_instance_id, button_count, button_names, totals)

    def get_totals_data_by_instance_name(self, page_instance_name: str) -> Tuple[Optional[str], int, List[str], List[int]]:
        """
        Get cumulative total presses for a specific pageInstanceName.

        This method learns button names from interactions dynamically, without
        requiring page change events.

        Args:
            page_instance_name: The pageInstanceName to get totals for

        Returns:
            Tuple of (pageInstanceName, buttonCount, [button_names], [total_presses])
            total_presses[i] is the total number of times button i was pressed
        """
        if page_instance_name not in self.total_presses:
            return (None, 0, [], [])

        # Get button metadata for this pageInstanceName
        button_dict = self.totals_button_metadata.get(page_instance_name, {})
        totals_dict = self.total_presses[page_instance_name]

        if not button_dict:
            return (None, 0, [], [])

        # Sort by button index to ensure consistent ordering
        sorted_buttons = sorted(button_dict.items())
        button_names = [name for idx, name in sorted_buttons]
        totals = [totals_dict.get(idx, 0) for idx, name in sorted_buttons]

        return (page_instance_name, len(button_names), button_names, totals)

    def reset(self):
        """Reset all aggregation state (for /crowd/system/reset command)."""
        self.phone_button_states = {}
        self.phone_button_names = {}
        self.vote_counts = {}
        self.total_presses = {}
        # Note: Don't reset phone_id_order or button_metadata, as phones are still connected

    def get_state_summary(self) -> dict:
        """
        Get summary of current state for debugging/display.

        Returns:
            Dictionary with current state information
        """
        return {
            'current_page': self.current_page,
            'page_instance_id': self.page_instance_id,
            'phone_count': len(self.phone_id_order),
            'vote_counts': self.vote_counts.copy(),
            'total_presses': self.total_presses.copy()
        }
