"""Latency diagnostic widget for measuring system performance"""

from textual.widgets import DataTable
from textual.coordinate import Coordinate
from rich.text import Text
import time


class LatencyDiagnosticTab(DataTable):
    """Live table display for latency diagnostics - one row per device"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.device_rows = {}  # phoneId -> row_key mapping
        self.device_stats = {}  # phoneId -> timing stats
        self.cursor_type = "none"
        self.zebra_stripes = True
        self.clock_offset = 0.0  # NTP-style clock offset (server_time - client_time)

    def on_mount(self) -> None:
        """Set up table columns"""
        self.add_columns(
            "Phone",
            "Ph→Srv",
            "Srv⏱",
            "Srv→Cl",
            "Handler⏱",
            "Thread⏱",
            "UI⏱",
            "Total",
            "Count"
        )

    def add_timing(self, phone_id: str, timing_data: dict, clock_offset: float = 0):
        """
        Add or update timing data for a device

        timing_data should contain:
        - phone_timestamp: when phone sent data (ms)
        - server_receive_time: when socket server received from phone (ms)
        - server_send_time: when socket server sent to Python client (ms)
        - socket_receive_time: when Python client socket.io received (ms)
        - handler_call_time: when on_socket_data called (ms)
        - thread_exec_time: when UI thread executed (ms)
        - render_complete_time: when render finished (ms)
        - count: message count

        clock_offset: NTP-style clock offset (server_time - client_time) in ms
        """
        phone_short = phone_id[:8]

        # Calculate latencies
        phone_ts = timing_data.get('phone_timestamp', 0)
        server_rx = timing_data.get('server_receive_time', 0)
        server_send = timing_data.get('server_send_time', 0)
        socket_rx = timing_data.get('socket_receive_time', 0)
        handler_call = timing_data.get('handler_call_time', 0)
        thread_exec = timing_data.get('thread_exec_time', 0)
        render_complete = timing_data.get('render_complete_time', 0)
        count = timing_data.get('count', 0)

        # Phone → Server latency: phone send → server receive
        phone_to_server = (server_rx - phone_ts) if (server_rx and phone_ts) else 0

        # Server processing latency: server receive → server send
        server_processing = (server_send - server_rx) if (server_send and server_rx) else 0

        # Server → Client latency: server send → Python client socket receive
        # ADJUSTED for clock skew between machines using NTP-style offset
        socket_rx_adjusted = socket_rx + clock_offset  # Convert client time to server time
        server_to_client = (socket_rx_adjusted - server_send) if (socket_rx_adjusted and server_send) else 0

        # Handler latency: socket receive → handler call (both on client, no adjustment needed)
        handler_latency = (handler_call - socket_rx) if (handler_call and socket_rx) else 0

        # Thread queue latency: handler call → thread execution
        thread_latency = (thread_exec - handler_call) if (thread_exec and handler_call) else 0

        # UI render latency: thread exec → render complete (both on client, no adjustment needed)
        ui_latency = (render_complete - thread_exec) if (render_complete and thread_exec) else 0

        # Total latency: phone → render complete
        # ADJUSTED for clock skew (phone uses server time, render_complete uses client time)
        render_complete_adjusted = render_complete + clock_offset  # Convert client time to server time
        total_latency = (render_complete_adjusted - phone_ts) if (render_complete_adjusted and phone_ts) else 0

        # Update running stats for this device
        if phone_id not in self.device_stats:
            self.device_stats[phone_id] = {
                'phone_to_server_sum': 0,
                'server_processing_sum': 0,
                'server_to_client_sum': 0,
                'handler_sum': 0,
                'thread_sum': 0,
                'ui_sum': 0,
                'total_sum': 0,
                'count': 0
            }

        stats = self.device_stats[phone_id]
        stats['phone_to_server_sum'] += phone_to_server
        stats['server_processing_sum'] += server_processing
        stats['server_to_client_sum'] += server_to_client
        stats['handler_sum'] += handler_latency
        stats['thread_sum'] += thread_latency
        stats['ui_sum'] += ui_latency
        stats['total_sum'] += total_latency
        stats['count'] += 1

        # Calculate averages
        n = stats['count']
        avg_phone_to_server = stats['phone_to_server_sum'] / n
        avg_server_processing = stats['server_processing_sum'] / n
        avg_server_to_client = stats['server_to_client_sum'] / n
        avg_handler = stats['handler_sum'] / n
        avg_thread = stats['thread_sum'] / n
        avg_ui = stats['ui_sum'] / n
        avg_total = stats['total_sum'] / n

        # Color code based on latency severity
        def get_latency_style(ms, warn_threshold=100, critical_threshold=500):
            if ms > critical_threshold:
                return "red bold"
            elif ms > warn_threshold:
                return "yellow"
            else:
                return "green"

        # Build cell list with current and average values
        cells = [
            Text(phone_short, style="cyan"),
            Text(f"{phone_to_server:.0f}ms\n({avg_phone_to_server:.0f})",
                 style=get_latency_style(phone_to_server, 50, 200)),
            Text(f"{server_processing:.0f}ms\n({avg_server_processing:.0f})",
                 style=get_latency_style(server_processing, 5, 20)),
            Text(f"{server_to_client:.0f}ms\n({avg_server_to_client:.0f})",
                 style=get_latency_style(server_to_client, 100, 1000)),
            Text(f"{handler_latency:.0f}ms\n({avg_handler:.0f})",
                 style=get_latency_style(handler_latency, 10, 50)),
            Text(f"{thread_latency:.0f}ms\n({avg_thread:.0f})",
                 style=get_latency_style(thread_latency, 50, 200)),
            Text(f"{ui_latency:.0f}ms\n({avg_ui:.0f})",
                 style=get_latency_style(ui_latency, 10, 50)),
            Text(f"{total_latency:.0f}ms\n({avg_total:.0f})",
                 style=get_latency_style(total_latency, 100, 500)),
            Text(str(count), style="dim")
        ]

        # Update existing row or add new one
        if phone_id in self.device_rows:
            row_key = self.device_rows[phone_id]
            row_index = self.get_row_index(row_key)

            # Update each cell
            try:
                for col_index, cell in enumerate(cells):
                    self.update_cell_at(
                        Coordinate(row_index, col_index),
                        cell,
                        update_width=False
                    )
            except Exception as e:
                pass  # Ignore update errors
        else:
            # Add new row
            row_key = self.add_row(*cells)
            self.device_rows[phone_id] = row_key

    def set_clock_offset(self, offset: float) -> None:
        """
        Set the NTP-style clock offset and update display.

        Args:
            offset: Clock offset in milliseconds (server_time - client_time)
        """
        self.clock_offset = offset

        # Update table border title to show the offset
        if abs(offset) < 1:
            offset_text = f"Clock Offset: {offset:.1f}ms ✓"
        elif abs(offset) < 10:
            offset_text = f"Clock Offset: {offset:.1f}ms"
        elif abs(offset) < 50:
            offset_text = f"Clock Offset: {offset:.1f}ms ⚠"
        else:
            offset_text = f"Clock Offset: {offset:.0f}ms ⚠⚠"

        # Set border title (appears at top of table)
        self.border_title = offset_text
