from __future__ import annotations

import asyncio
import time
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Awaitable
from dataclasses import dataclass
from types import TracebackType
from typing import Generic, Literal, TypeVar

from pydantic import BaseModel, ConfigDict, Field

from livekit import rtc

from ..log import logger
from ..types import NOT_GIVEN, NotGivenOr
from ..utils import is_given
from .chat_context import ChatContext, ChatItem, FunctionCall
from .tool_context import Tool, ToolChoice, ToolContext


@dataclass
class InputSpeechStartedEvent:
    pass


@dataclass
class InputSpeechStoppedEvent:
    user_transcription_enabled: bool


@dataclass
class MessageGeneration:
    message_id: str
    text_stream: AsyncIterable[str]  # could be io.TimedString
    audio_stream: AsyncIterable[rtc.AudioFrame]
    modalities: Awaitable[list[Literal["text", "audio"]]]


@dataclass
class GenerationCreatedEvent:
    message_stream: AsyncIterable[MessageGeneration]
    function_stream: AsyncIterable[FunctionCall]
    user_initiated: bool
    """True if the message was generated by the user using generate_reply()"""
    response_id: str | None = None
    """The response ID associated with this generation, used for metrics attribution"""


class RealtimeModelError(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    type: Literal["realtime_model_error"] = "realtime_model_error"
    timestamp: float
    label: str
    error: Exception = Field(..., exclude=True)
    recoverable: bool


@dataclass
class RealtimeCapabilities:
    message_truncation: bool
    """Whether generated assistant messages can be truncated after interruption"""
    turn_detection: bool
    """Whether the model emits server-side speech start and stop events for turn taking"""
    user_transcription: bool
    """Whether the model emits user audio transcription events"""
    auto_tool_reply_generation: bool
    """Whether the model automatically generates a reply after receiving tool results"""
    audio_output: bool
    """Whether the model can produce audio output directly"""
    manual_function_calls: bool
    """Whether function call items already in the chat context can be resumed"""
    mutable_chat_context: bool = False
    """Whether the chat context can be updated mid-session"""
    mutable_instructions: bool = False
    """Whether the instructions can be updated mid-session"""
    mutable_tools: bool = False
    """Whether the tools can be updated mid-session"""
    per_response_tool_choice: bool = False
    """Whether the tool and tool choice can be specified per response"""
    supports_say: bool = False
    """Whether session.say() can use the realtime session directly, without TTS.

    When used through a RealtimeModel, add_to_chat_ctx=False is ignored and the
    message is still added to the chat context.
    """


class RealtimeError(Exception):
    def __init__(self, message: str) -> None:
        super().__init__(message)


class RealtimeModel:
    def __init__(self, *, capabilities: RealtimeCapabilities) -> None:
        self._capabilities = capabilities
        self._label = f"{type(self).__module__}.{type(self).__name__}"

    @property
    def model(self) -> str:
        return "unknown"

    @property
    def provider(self) -> str:
        return "unknown"

    @property
    def capabilities(self) -> RealtimeCapabilities:
        return self._capabilities

    @property
    def label(self) -> str:
        return self._label

    @abstractmethod
    def session(self) -> RealtimeSession: ...

    @abstractmethod
    async def aclose(self) -> None: ...

    async def __aenter__(self) -> RealtimeModel:
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        await self.aclose()


EventTypes = Literal[
    "input_speech_started",  # serverside VAD (also used for interruptions)
    "input_speech_stopped",  # serverside VAD
    "input_audio_transcription_completed",
    "generation_created",
    "session_reconnected",
    "metrics_collected",
    "remote_item_added",
    "error",
]

TEvent = TypeVar("TEvent")


@dataclass
class InputTranscriptionCompleted:
    item_id: str
    """id of the item"""
    transcript: str
    """transcript of the input audio"""
    is_final: bool
    confidence: float | None = None
    """confidence score of the transcript (0.0 to 1.0), derived from model logprobs"""


@dataclass
class RealtimeSessionReconnectedEvent:
    pass


@dataclass
class RemoteItemAddedEvent:
    previous_item_id: str | None
    item: ChatItem


class RealtimeSession(ABC, rtc.EventEmitter[EventTypes | TEvent], Generic[TEvent]):
    def __init__(self, realtime_model: RealtimeModel) -> None:
        super().__init__()
        self._realtime_model = realtime_model

    def _report_connection_acquired(self, acquire_time: float) -> None:
        """Report connection timing as a RealtimeModelMetrics event with zero usage."""
        from ..metrics.base import Metadata, RealtimeModelMetrics

        self.emit(
            "metrics_collected",
            RealtimeModelMetrics(
                request_id="",
                timestamp=time.time(),
                acquire_time=acquire_time,
                connection_reused=False,
                input_token_details=RealtimeModelMetrics.InputTokenDetails(),
                output_token_details=RealtimeModelMetrics.OutputTokenDetails(),
                metadata=Metadata(
                    model_name=self._realtime_model.model,
                    model_provider=self._realtime_model.provider,
                ),
            ),
        )

    @property
    def realtime_model(self) -> RealtimeModel:
        return self._realtime_model

    @property
    @abstractmethod
    def chat_ctx(self) -> ChatContext: ...

    @property
    @abstractmethod
    def tools(self) -> ToolContext: ...

    @abstractmethod
    async def update_instructions(self, instructions: str) -> None: ...

    @abstractmethod
    async def update_chat_ctx(
        self, chat_ctx: ChatContext
    ) -> None: ...  # can raise RealtimeError on Timeout

    @abstractmethod
    async def update_tools(self, tools: list[Tool]) -> None: ...

    @abstractmethod
    def update_options(self, *, tool_choice: NotGivenOr[ToolChoice | None] = NOT_GIVEN) -> None: ...

    @abstractmethod
    def push_audio(self, frame: rtc.AudioFrame) -> None: ...

    @abstractmethod
    def push_video(self, frame: rtc.VideoFrame) -> None: ...

    @abstractmethod
    def generate_reply(
        self,
        *,
        instructions: NotGivenOr[str] = NOT_GIVEN,
        tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
        tools: NotGivenOr[list[Tool]] = NOT_GIVEN,
    ) -> asyncio.Future[GenerationCreatedEvent]: ...  # can raise RealtimeError on Timeout

    # commit the input audio buffer to the server
    @abstractmethod
    def commit_audio(self) -> None: ...

    # clear the input audio buffer to the server
    @abstractmethod
    def clear_audio(self) -> None: ...

    # cancel the current generation (do nothing if no generation is in progress)
    @abstractmethod
    def interrupt(self) -> None: ...

    # message_id is the ID of the message to truncate (inside the ChatCtx)
    @abstractmethod
    def truncate(
        self,
        *,
        message_id: str,
        modalities: list[Literal["text", "audio"]],
        audio_end_ms: int,
        audio_transcript: NotGivenOr[str] = NOT_GIVEN,
    ) -> None: ...

    @abstractmethod
    async def aclose(self) -> None: ...

    async def _update_session(
        self,
        *,
        instructions: NotGivenOr[str] = NOT_GIVEN,
        chat_ctx: NotGivenOr[ChatContext] = NOT_GIVEN,
        tools: NotGivenOr[list[Tool]] = NOT_GIVEN,
    ) -> None:
        if is_given(instructions):
            try:
                await self.update_instructions(instructions)
            except RealtimeError:
                logger.exception("failed to update the instructions")

        if is_given(chat_ctx):
            try:
                await self.update_chat_ctx(chat_ctx)
            except RealtimeError:
                logger.exception("failed to update the chat_ctx")

        if is_given(tools):
            try:
                await self.update_tools(tools)
            except RealtimeError:
                logger.exception("failed to update the tools")

    def start_user_activity(self) -> None:
        """notifies the model that user activity has started"""
        pass

    def say(
        self,
        text: str | AsyncIterable[str],
    ) -> asyncio.Future[GenerationCreatedEvent]:
        raise NotImplementedError(
            f"{type(self).__name__} does not implement say(). use a TTS model instead"
        )
