from __future__ import annotations

from typing import Literal, Protocol

from typing_extensions import TypedDict

from ..language import LanguageCode
from ..llm import ChatContext
from ..types import NOT_GIVEN, NotGivenOr
from ..utils import is_given


class _TurnDetector(Protocol):
    @property
    def model(self) -> str:
        return "unknown"

    @property
    def provider(self) -> str:
        return "unknown"

    # TODO: Move those two functions to EOU ctor (capabilities dataclass)
    async def unlikely_threshold(self, language: LanguageCode | None) -> float | None: ...
    async def supports_language(self, language: LanguageCode | None) -> bool: ...

    async def predict_end_of_turn(
        self, chat_ctx: ChatContext, *, timeout: float | None = None
    ) -> float: ...


TurnDetectionMode = Literal["stt", "vad", "realtime_llm", "manual"] | _TurnDetector
"""
The mode of turn detection to use.

- "stt": use speech-to-text result to detect the end of the user's turn
- "vad": use VAD to detect the start and end of the user's turn
- "realtime_llm": use server-side turn detection provided by the realtime LLM
- "manual": manually manage the turn detection
- _TurnDetector: use the default mode with the provided turn detector

(default) If not provided, automatically choose the best mode based on
    available models (realtime_llm -> vad -> stt -> manual)
If the needed model (VAD, STT, or RealtimeModel) is not provided, fallback to the default mode.
"""


class EndpointingOptions(TypedDict, total=False):
    """Configuration for endpointing.

    All keys are optional. Missing keys inherit from the session default
    (at the ``Agent`` level) or use the documented defaults
    (at the ``AgentSession`` level).
    """

    mode: Literal["fixed", "dynamic"]
    """Endpointing mode. ``"fixed"`` for fixed delay, ``"dynamic"`` for dynamic delay. Defaults to ``"fixed"``."""
    min_delay: float
    """Minimum time (s) since last detected speech before declaring the
    user's turn complete. Defaults to ``0.5``."""
    max_delay: float
    """Maximum time (s) the agent waits before terminating the turn.
    Defaults to ``3.0``."""
    alpha: float
    """Exponential moving average coefficient for dynamic endpointing.
    The higher the value, the more weight is given to the history.
    Defaults to ``0.9``. Only applies when mode is ``dynamic``."""


_ENDPOINTING_DEFAULTS: EndpointingOptions = {
    "mode": "fixed",
    "min_delay": 0.5,
    "max_delay": 3.0,
    "alpha": 0.9,
}


class InterruptionOptions(TypedDict, total=False):
    """Configuration for interruption handling.

    All keys are optional. Missing keys inherit from the session default
    (at the ``Agent`` level) or use the documented defaults
    (at the ``AgentSession`` level).

    ``mode`` absent means the session picks the best available strategy.
    """

    enabled: bool
    """Whether interruptions are enabled. Defaults to ``True``."""
    mode: Literal["adaptive", "vad"]
    """Interruption handling strategy. ``"adaptive"`` for ML-based
    detection, ``"vad"`` for simple voice-activity detection.
    Absent means auto-detect."""
    discard_audio_if_uninterruptible: bool
    """Drop buffered audio while the agent speaks and cannot be
    interrupted. Defaults to ``True``."""
    min_duration: float
    """Minimum speech length (s) to register as an interruption.
    Defaults to ``0.5``."""
    min_words: int
    """Minimum word count to consider an interruption (STT only).
    Defaults to ``0``."""
    resume_false_interruption: bool
    """Resume the agent's speech after a false interruption.
    Defaults to ``True``."""
    false_interruption_timeout: float | None
    """Seconds of silence after an interruption before it is
    classified as false. ``None`` disables. Defaults to ``2.0``."""
    backchannel_boundary: float | tuple[float, float] | None
    """Seconds near the start/end of each agent turn during which overlapping
    speech classified as a backchannel by the adaptive detector is suppressed
    (events flagged as interruptions still pass through). Use a tuple to apply
    different values for start and end separately. ``None`` disables. Defaults
    to ``(1.0, 3.5)``. End value should be higher to account for STT transcript
    timestamp inaccuracy."""


_INTERRUPTION_DEFAULTS: InterruptionOptions = {
    "enabled": True,
    "discard_audio_if_uninterruptible": True,
    "min_duration": 0.5,
    "min_words": 0,
    "resume_false_interruption": True,
    "false_interruption_timeout": 2.0,
    "backchannel_boundary": (
        1.0,
        3.5,  # higher value for the end as STT timestamps aren't very reliable
    ),
}


class PreemptiveGenerationOptions(TypedDict, total=False):
    """Configuration for preemptive generation."""

    enabled: bool
    """Whether preemptive generation is enabled. Defaults to ``True``."""

    preemptive_tts: bool
    """Whether to also run TTS preemptively before the turn is confirmed.
    When ``False`` (default), only LLM runs preemptively; TTS starts once the
    turn is confirmed and the speech is scheduled."""

    max_speech_duration: float
    """Maximum user speech duration (s) for which preemptive generation
    is attempted. Beyond this threshold, preemptive generation is skipped
    since long utterances are more likely to change and users may expect
    slower responses. Defaults to ``10.0``."""

    max_retries: int
    """Maximum number of preemptive generation attempts per user turn.
    The counter resets when the turn completes. Defaults to ``3``."""


_PREEMPTIVE_GENERATION_DEFAULTS: PreemptiveGenerationOptions = {
    "enabled": True,
    "preemptive_tts": False,
    "max_speech_duration": 10.0,
    "max_retries": 3,
}


class UserTurnLimitOptions(TypedDict, total=False):
    """Configuration for detecting when a user has been speaking too long
    without the agent successfully responding.

    The framework tracks accumulated word count and wall-clock duration
    across consecutive user turns. Counters only reset when the agent
    transitions to ``speaking`` state (i.e., produces audio output).

    Both thresholds default to ``None`` (disabled). Set at least one to
    enable the feature.
    """

    max_words: int | None
    """Maximum accumulated word count before triggering. Uses the
    framework's WordTokenizer for counting. ``None`` disables word-based
    limiting. Defaults to ``None``."""

    max_duration: float | None
    """Maximum wall-clock duration (seconds) since the user first started
    speaking in the current accumulation window. ``None`` disables
    duration-based limiting. Defaults to ``None``."""


_USER_TURN_LIMIT_DEFAULTS: UserTurnLimitOptions = {
    "max_words": None,
    "max_duration": None,
}


class TurnHandlingOptions(TypedDict, total=False):
    """Configuration for the turn handling system.

    Can be passed as a plain dict::

        AgentSession(
            turn_handling={
                "endpointing": {"min_delay": 0.3},
                "interruption": {"enabled": False},
                "preemptive_generation": {"preemptive_tts": True},
            },
        )

    All keys are optional and default to sensible values.
    """

    turn_detection: TurnDetectionMode | None
    """Strategy for deciding when the user has finished speaking.
    Absent means the session auto-selects."""
    endpointing: EndpointingOptions
    """Endpointing configuration. Defaults to ``{"min_delay": 0.5, "max_delay": 3.0}``."""
    interruption: InterruptionOptions
    """Interruption handling configuration. Use ``{"enabled": False}`` to disable."""
    preemptive_generation: PreemptiveGenerationOptions
    """Preemptive generation configuration. Use ``{"enabled": False}`` to disable."""
    user_turn_limit: UserTurnLimitOptions
    """User turn limit configuration. Use ``{"max_words": 50}`` to enable."""


def _resolve_preemptive_generation(
    config: PreemptiveGenerationOptions | None = None,
) -> PreemptiveGenerationOptions:
    """Fill in defaults for missing keys."""
    if config is None:
        return PreemptiveGenerationOptions(**_PREEMPTIVE_GENERATION_DEFAULTS)
    return PreemptiveGenerationOptions(**{**_PREEMPTIVE_GENERATION_DEFAULTS, **config})


def _resolve_endpointing(config: EndpointingOptions | None = None) -> EndpointingOptions:
    """Fill in defaults for missing keys."""
    if config is None:
        return EndpointingOptions(**_ENDPOINTING_DEFAULTS)
    return EndpointingOptions(**{**_ENDPOINTING_DEFAULTS, **config})


def _resolve_interruption(
    config: InterruptionOptions | None = None,
) -> InterruptionOptions:
    """Fill in defaults for missing keys (``mode`` stays absent if not provided)."""
    if config is None:
        return InterruptionOptions(**_INTERRUPTION_DEFAULTS)
    return InterruptionOptions(**{**_INTERRUPTION_DEFAULTS, **config})


def _resolve_user_turn_limit(
    config: UserTurnLimitOptions | None = None,
) -> UserTurnLimitOptions:
    """Fill in defaults for missing keys."""
    if config is None:
        return UserTurnLimitOptions(**_USER_TURN_LIMIT_DEFAULTS)
    return UserTurnLimitOptions(**{**_USER_TURN_LIMIT_DEFAULTS, **config})


def _migrate_turn_handling(
    min_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
    max_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
    false_interruption_timeout: NotGivenOr[float | None] = NOT_GIVEN,
    turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
    discard_audio_if_uninterruptible: NotGivenOr[bool] = NOT_GIVEN,
    min_interruption_duration: NotGivenOr[float] = NOT_GIVEN,
    min_interruption_words: NotGivenOr[int] = NOT_GIVEN,
    allow_interruptions: NotGivenOr[bool] = NOT_GIVEN,
    resume_false_interruption: NotGivenOr[bool] = NOT_GIVEN,
    agent_false_interruption_timeout: NotGivenOr[float | None] = NOT_GIVEN,
    preemptive_generation: NotGivenOr[bool] = NOT_GIVEN,
) -> TurnHandlingOptions:
    """Build a TurnHandlingOptions from deprecated keyword arguments."""
    if is_given(agent_false_interruption_timeout):
        false_interruption_timeout = agent_false_interruption_timeout

    result: TurnHandlingOptions = {}

    # endpointing — only include keys that were explicitly provided
    endpointing_opts: EndpointingOptions = {}
    if is_given(min_endpointing_delay):
        endpointing_opts["min_delay"] = min_endpointing_delay
    if is_given(max_endpointing_delay):
        endpointing_opts["max_delay"] = max_endpointing_delay
    if endpointing_opts:
        result["endpointing"] = endpointing_opts

    # interruption — only include keys that were explicitly provided
    interruption: InterruptionOptions = {}
    if allow_interruptions is False:
        interruption["enabled"] = False
    if is_given(discard_audio_if_uninterruptible):
        interruption["discard_audio_if_uninterruptible"] = discard_audio_if_uninterruptible
    if is_given(min_interruption_duration):
        interruption["min_duration"] = min_interruption_duration
    if is_given(min_interruption_words):
        interruption["min_words"] = min_interruption_words
    if is_given(false_interruption_timeout):
        interruption["false_interruption_timeout"] = false_interruption_timeout
    if is_given(resume_false_interruption):
        interruption["resume_false_interruption"] = resume_false_interruption
    if interruption:
        result["interruption"] = interruption

    if is_given(turn_detection):
        result["turn_detection"] = turn_detection

    if is_given(preemptive_generation):
        result["preemptive_generation"] = {"enabled": preemptive_generation}

    return result
