from __future__ import annotations

import time
from enum import Enum, unique
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self

from ..inference.interruption import (
    AdaptiveInterruptionDetector,
    InterruptionDetectionError,
    OverlappingSpeechEvent,
)
from ..language import LanguageCode
from ..llm import (
    LLM,
    AgentHandoff,
    ChatMessage,
    FunctionCall,
    FunctionCallOutput,
    LLMError,
    RealtimeModel,
    RealtimeModelError,
)
from ..log import logger
from ..metrics import AgentMetrics, AgentSessionUsage
from ..stt import STT, STTError
from ..tts import TTS, TTSError
from .speech_handle import SpeechHandle

if TYPE_CHECKING:
    from .agent_session import AgentSession


Userdata_T = TypeVar("Userdata_T")


class RunContext(Generic[Userdata_T]):
    # private ctor
    def __init__(
        self,
        *,
        session: AgentSession[Userdata_T],
        speech_handle: SpeechHandle,
        function_call: FunctionCall,
    ) -> None:
        self._session = session
        self._speech_handle = speech_handle
        self._function_call = function_call

        self._initial_step_idx = speech_handle.num_steps - 1

    @property
    def session(self) -> AgentSession[Userdata_T]:
        return self._session

    @property
    def speech_handle(self) -> SpeechHandle:
        return self._speech_handle

    @property
    def function_call(self) -> FunctionCall:
        return self._function_call

    @property
    def userdata(self) -> Userdata_T:
        return self.session.userdata

    def disallow_interruptions(self) -> None:
        """Disable interruptions for this FunctionCall.

        Delegates to the SpeechHandle.allow_interruptions setter,
        which will raise a RuntimeError if the handle is already interrupted.

        Raises:
            RuntimeError: If the SpeechHandle is already interrupted.
        """
        self.speech_handle.allow_interruptions = False

    async def wait_for_playout(self) -> None:
        """Waits for the speech playout corresponding to this function call step.

        Unlike `SpeechHandle.wait_for_playout`, which waits for the full
        assistant turn to complete (including all function tools),
        this method only waits for the assistant's spoken response prior running
        this tool to finish playing."""
        await self.speech_handle._wait_for_generation(step_idx=self._initial_step_idx)


EventTypes = Literal[
    "user_state_changed",
    "agent_state_changed",
    "user_input_transcribed",
    "conversation_item_added",
    "agent_false_interruption",
    "overlapping_speech",
    "function_tools_executed",
    "metrics_collected",
    "session_usage_updated",
    "speech_created",
    "error",
    "close",
]

UserState = Literal["speaking", "listening", "away"]
AgentState = Literal["initializing", "idle", "listening", "thinking", "speaking"]


class UserStateChangedEvent(BaseModel):
    type: Literal["user_state_changed"] = "user_state_changed"
    old_state: UserState
    new_state: UserState
    created_at: float = Field(default_factory=time.time)


class AgentStateChangedEvent(BaseModel):
    type: Literal["agent_state_changed"] = "agent_state_changed"
    old_state: AgentState
    new_state: AgentState
    created_at: float = Field(default_factory=time.time)


class UserInputTranscribedEvent(BaseModel):
    type: Literal["user_input_transcribed"] = "user_input_transcribed"
    transcript: str
    is_final: bool
    speaker_id: str | None = None
    language: LanguageCode | None = None
    created_at: float = Field(default_factory=time.time)


class AgentFalseInterruptionEvent(BaseModel):
    type: Literal["agent_false_interruption"] = "agent_false_interruption"
    resumed: bool
    """Whether the false interruption was resumed automatically."""
    created_at: float = Field(default_factory=time.time)

    # deprecated
    message: ChatMessage | None = None
    extra_instructions: str | None = None

    def __getattribute__(self, name: str) -> Any:
        if name in ["message", "extra_instructions"]:
            logger.warning(
                f"AgentFalseInterruptionEvent.{name} is deprecated, automatic resume is now supported"
            )
        return super().__getattribute__(name)


class MetricsCollectedEvent(BaseModel):
    """Deprecated: use session_usage_updated for usage tracking.
    Per-turn latency metrics are available on ChatMessage.metrics."""

    type: Literal["metrics_collected"] = "metrics_collected"
    metrics: AgentMetrics
    created_at: float = Field(default_factory=time.time)


class SessionUsageUpdatedEvent(BaseModel):
    type: Literal["session_usage_updated"] = "session_usage_updated"
    usage: AgentSessionUsage
    created_at: float = Field(default_factory=time.time)


class _TypeDiscriminator(BaseModel):
    type: Literal["unknown"] = "unknown"  # force user to use the type discriminator


class ConversationItemAddedEvent(BaseModel):
    type: Literal["conversation_item_added"] = "conversation_item_added"
    item: ChatMessage | AgentHandoff | _TypeDiscriminator
    created_at: float = Field(default_factory=time.time)


class FunctionToolsExecutedEvent(BaseModel):
    """Emitted after a batch of function tools finishes executing.

    ``function_calls`` and ``function_call_outputs`` are parallel lists: the
    output at a given index belongs to the call at the same index. When an
    output is present, its ``call_id`` matches the paired function call's
    ``call_id``. A ``None`` output means the function call did not produce a
    value that should be sent back to the LLM, such as when a tool raises
    ``StopResponse`` or returns an invalid output.
    """

    type: Literal["function_tools_executed"] = "function_tools_executed"
    function_calls: list[FunctionCall]
    function_call_outputs: list[FunctionCallOutput | None]
    created_at: float = Field(default_factory=time.time)
    _reply_required: bool = PrivateAttr(default=False)
    _handoff_required: bool = PrivateAttr(default=False)

    def zipped(self) -> list[tuple[FunctionCall, FunctionCallOutput | None]]:
        """Return calls paired with outputs by list position."""
        return list(zip(self.function_calls, self.function_call_outputs, strict=False))

    def cancel_tool_reply(self) -> None:
        self._reply_required = False

    def cancel_agent_handoff(self) -> None:
        self._handoff_required = False

    @property
    def has_tool_reply(self) -> bool:
        return self._reply_required

    @property
    def has_agent_handoff(self) -> bool:
        return self._handoff_required

    @model_validator(mode="after")
    def verify_lists_length(self) -> Self:
        if len(self.function_calls) != len(self.function_call_outputs):
            raise ValueError("The number of function_calls and function_call_outputs must match.")

        return self


class SpeechCreatedEvent(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    type: Literal["speech_created"] = "speech_created"
    user_initiated: bool
    """True if the speech was created using public methods like `say` or `generate_reply`"""
    source: Literal["say", "generate_reply"]
    """Source indicating how the speech handle was created"""
    speech_handle: SpeechHandle = Field(..., exclude=True)
    """The speech handle that was created"""
    created_at: float = Field(default_factory=time.time)


class UserTurnExceededEvent(BaseModel):
    type: Literal["user_turn_exceeded"] = "user_turn_exceeded"
    transcript: str
    """Transcript from the current (uncommitted) user turn only.
    Previous turns in the accumulation window are already in the chat context."""
    accumulated_transcript: str
    """Full transcript since the start of user speaking."""
    accumulated_word_count: int
    """Total word count since the start of user speaking."""
    duration: float
    """Duration of the user turn in seconds."""
    created_at: float = Field(default_factory=time.time)


class ErrorEvent(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    type: Literal["error"] = "error"
    error: LLMError | STTError | TTSError | RealtimeModelError | InterruptionDetectionError | Any
    source: LLM | STT | TTS | RealtimeModel | AdaptiveInterruptionDetector | Any
    created_at: float = Field(default_factory=time.time)


@unique
class CloseReason(str, Enum):
    ERROR = "error"
    JOB_SHUTDOWN = "job_shutdown"
    PARTICIPANT_DISCONNECTED = "participant_disconnected"
    USER_INITIATED = "user_initiated"
    TASK_COMPLETED = "task_completed"


class CloseEvent(BaseModel):
    type: Literal["close"] = "close"
    error: (
        LLMError | STTError | TTSError | RealtimeModelError | InterruptionDetectionError | None
    ) = None
    reason: CloseReason
    created_at: float = Field(default_factory=time.time)


AgentEvent = Annotated[
    UserInputTranscribedEvent
    | UserStateChangedEvent
    | AgentStateChangedEvent
    | AgentFalseInterruptionEvent
    | MetricsCollectedEvent
    | SessionUsageUpdatedEvent
    | ConversationItemAddedEvent
    | FunctionToolsExecutedEvent
    | SpeechCreatedEvent
    | ErrorEvent
    | CloseEvent
    | OverlappingSpeechEvent,
    Field(discriminator="type"),
]
