from __future__ import annotations

import asyncio
import contextlib
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING

from livekit import api, rtc

from ... import llm, stt, tts, utils, vad
from ...job import DEFAULT_PARTICIPANT_KINDS, get_job_context
from ...llm.chat_context import Instructions
from ...llm.tool_context import ToolError, ToolFlag, function_tool
from ...log import logger
from ...types import NOT_GIVEN, NotGivenOr
from ...utils import is_given
from ...voice import room_io
from ...voice.agent import Agent, AgentTask
from ...voice.agent_session import AgentSession
from ...voice.background_audio import (
    AudioConfig,
    AudioSource,
    BackgroundAudioPlayer,
    BuiltinAudioClip,
    PlayHandle,
)
from .utils import InstructionParts

if TYPE_CHECKING:
    from ...voice.turn import TurnDetectionMode


@dataclass
class WarmTransferResult:
    human_agent_identity: str


class WarmTransferTask(AgentTask[WarmTransferResult]):
    def __init__(
        self,
        sip_call_to: NotGivenOr[str] = NOT_GIVEN,
        *,
        sip_trunk_id: NotGivenOr[str | None] = NOT_GIVEN,
        sip_connection: NotGivenOr[api.SIPOutboundConfig] = NOT_GIVEN,
        sip_number: NotGivenOr[str] = NOT_GIVEN,
        sip_headers: NotGivenOr[dict[str, str]] = NOT_GIVEN,
        dtmf: NotGivenOr[str | None] = NOT_GIVEN,
        ringing_timeout: NotGivenOr[float | None] = NOT_GIVEN,
        hold_audio: NotGivenOr[AudioSource | AudioConfig | list[AudioConfig] | None] = NOT_GIVEN,
        instructions: NotGivenOr[InstructionParts | Instructions | str] = NOT_GIVEN,
        chat_ctx: NotGivenOr[llm.ChatContext] = NOT_GIVEN,
        turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
        tools: NotGivenOr[list[llm.Tool | llm.Toolset]] = NOT_GIVEN,
        stt: NotGivenOr[stt.STT | None] = NOT_GIVEN,
        vad: NotGivenOr[vad.VAD | None] = NOT_GIVEN,
        llm: NotGivenOr[llm.LLM | llm.RealtimeModel | None] = NOT_GIVEN,
        tts: NotGivenOr[tts.TTS | None] = NOT_GIVEN,
        allow_interruptions: NotGivenOr[bool] = NOT_GIVEN,
        # deprecated
        extra_instructions: str = "",
        target_phone_number: NotGivenOr[str] = NOT_GIVEN,
    ) -> None:
        """Initialize a WarmTransferTask to dial a human agent via SIP.

        Args:
            sip_call_to: The phone number or SIP URI to dial for the human agent
                (e.g. ``"+15105550123"`` or ``"sip:user@example.com"``).
            sip_trunk_id: ID of a pre-configured LiveKit SIP outbound trunk used to
                originate the call. Falls back to the ``LIVEKIT_SIP_OUTBOUND_TRUNK``
                environment variable when not provided.
            sip_connection: Low-level SIP connection config (``api.SIPOutboundConfig``)
                for originating calls from a **custom SIP domain** instead of through a
                saved trunk. Use this when you need to specify a custom hostname,
                transport, or authentication credentials directly, bypassing the
                trunk-based configuration.
            dtmf: DTMF tones to send once the human agent's call is answered, e.g. to dial
                an extension or navigate an IVR menu (``"1234#"``). Insert ``w`` characters
                to pause ~0.5s each before/between digits (``"wwww1234#"`` waits ~2s, useful
                when the destination plays a greeting before accepting input).
            ringing_timeout: How long to wait, in seconds, for the human agent to answer
                before giving up on the call. When the timeout elapses the task completes
                with a ``ToolError`` and the caller conversation resumes.
            hold_audio: Audio played to the caller while they are on hold during the
                    transfer.
            extra_instructions: Extra instructions to append to the base instructions
                that are used to summarize the conversation history.
        """

        if not is_given(instructions):
            instructions = InstructionParts(persona=PERSONA, extra=extra_instructions)
        elif extra_instructions:
            logger.warning("`extra_instructions` will be ignored when `instructions` is provided")

        if isinstance(instructions, InstructionParts):
            conversation_history = self._format_conversation_history(chat_ctx)
            instructions = Instructions(INSTRUCTIONS_TEMPLATE).format(
                persona=instructions.persona if is_given(instructions.persona) else PERSONA,
                extra=instructions.extra,
                _conversation_history=conversation_history,
            )

        assert is_given(instructions)  # for type checking
        super().__init__(
            instructions=instructions,
            chat_ctx=NOT_GIVEN,  # don't pass the chat_ctx
            turn_detection=turn_detection,
            tools=tools or [],
            stt=stt,
            vad=vad,
            llm=llm,
            tts=tts,
            allow_interruptions=allow_interruptions,
        )

        self._caller_room: rtc.Room | None = None
        self._human_agent_sess: AgentSession | None = None
        self._human_agent_failed_fut: asyncio.Future[None] = asyncio.Future()
        self._human_agent_identity = "human-agent-sip"

        if target_phone_number:
            logger.warning("`target_phone_number` is deprecated, use `sip_call_to` instead")
            if not sip_call_to:
                sip_call_to = target_phone_number

        if not sip_call_to:
            raise ValueError("`sip_call_to` must be set")

        self._sip_call_to = sip_call_to
        self._sip_connection = sip_connection if is_given(sip_connection) else None
        if is_given(sip_trunk_id):
            self._sip_trunk_id = sip_trunk_id
        elif self._sip_connection is not None:
            # explicit sip_connection: don't override with the env var trunk
            self._sip_trunk_id = None
        else:
            self._sip_trunk_id = os.getenv("LIVEKIT_SIP_OUTBOUND_TRUNK", None)
        if self._sip_trunk_id is None and self._sip_connection is None:
            raise ValueError(
                "`LIVEKIT_SIP_OUTBOUND_TRUNK` environment variable, `sip_trunk_id`,"
                " or `sip_connection` must be set"
            )

        self._sip_number = (
            sip_number if is_given(sip_number) else os.getenv("LIVEKIT_SIP_NUMBER", "")
        )
        self._sip_headers = sip_headers if is_given(sip_headers) else {}
        self._dtmf = dtmf if is_given(dtmf) else None
        self._ringing_timeout = ringing_timeout if is_given(ringing_timeout) else None

        # background audio and io
        self._background_audio = BackgroundAudioPlayer()
        self._hold_audio_handle: PlayHandle | None = None
        self._hold_audio = (
            hold_audio
            if is_given(hold_audio)
            else AudioConfig(BuiltinAudioClip.HOLD_MUSIC, volume=0.8)
        )

        self._original_io_state: dict[str, bool] = {}

    @staticmethod
    def _format_conversation_history(chat_ctx: NotGivenOr[llm.ChatContext]) -> str:
        if not is_given(chat_ctx) or not chat_ctx:
            return ""
        prev_convo = ""
        for msg in chat_ctx.messages():
            if msg.role not in ("user", "assistant"):
                continue
            if not msg.text_content:
                continue
            role = "Caller" if msg.role == "user" else "Assistant"
            prev_convo += f"{role}: {msg.text_content}\n"
        return prev_convo

    async def on_enter(self) -> None:
        job_ctx = get_job_context()
        self._caller_room = job_ctx.room

        # start the background audio
        if self._hold_audio is not None:
            await self._background_audio.start(room=self._caller_room)
            self._hold_audio_handle = self._background_audio.play(self._hold_audio, loop=True)

        self._set_io_enabled(False)

        try:
            dial_human_agent_task = asyncio.create_task(self._dial_human_agent())
            done, _ = await asyncio.wait(
                (dial_human_agent_task, self._human_agent_failed_fut),
                return_when=asyncio.FIRST_COMPLETED,
            )
            if dial_human_agent_task not in done:
                raise RuntimeError()

            self._human_agent_sess = dial_human_agent_task.result()
            # let the human speak first

        except Exception:
            logger.exception("could not dial human agent")
            self._set_result(ToolError("could not dial human agent"))
            return

        finally:
            await utils.aio.cancel_and_wait(dial_human_agent_task)

    @function_tool(flags=ToolFlag.IGNORE_ON_ENTER)
    async def connect_to_caller(self) -> None:
        """Called when the human agent wants to connect to the caller."""
        logger.debug("connecting to caller")
        assert self._caller_room is not None

        await self._merge_calls()
        self._set_result(WarmTransferResult(human_agent_identity=self._human_agent_identity))

        # when the caller or human agent leaves the room, we'll delete the room
        self._caller_room.on("participant_disconnected", self._on_caller_participant_disconnected)

    @function_tool(flags=ToolFlag.IGNORE_ON_ENTER)
    async def decline_transfer(self, reason: str) -> None:
        """Handles the case when the human agent explicitly declines to connect to the caller.

        Args:
            reason: A short explanation of why the human agent declined to connect to the caller
        """
        self._set_result(ToolError(f"human agent declined to connect: {reason}"))

    @function_tool(flags=ToolFlag.IGNORE_ON_ENTER)
    async def voicemail_detected(self) -> None:
        """Called when the call reaches voicemail. Use this tool AFTER you hear the voicemail greeting"""
        self._set_result(ToolError("voicemail detected"))

    def _on_human_agent_room_close(self, reason: rtc.DisconnectReason.ValueType) -> None:
        logger.debug(
            "human agent's room closed",
            extra={"reason": rtc.DisconnectReason.Name(reason)},
        )
        with contextlib.suppress(asyncio.InvalidStateError):
            self._human_agent_failed_fut.set_result(None)

        self._set_result(ToolError(f"room closed: {rtc.DisconnectReason.Name(reason)}"))

    def _on_caller_participant_disconnected(self, participant: rtc.RemoteParticipant) -> None:
        if participant.kind not in DEFAULT_PARTICIPANT_KINDS:
            return

        logger.info(f"participant disconnected from caller room: {participant.identity}, closing")

        assert self._caller_room is not None
        self._caller_room.off("participant_disconnected", self._on_caller_participant_disconnected)
        job_ctx = get_job_context()
        job_ctx.delete_room(room_name=self._caller_room.name)

    def _set_result(self, result: WarmTransferResult | Exception) -> None:
        if self.done():
            return

        if self._human_agent_sess:
            self._human_agent_sess.shutdown()
            self._human_agent_sess = None

        if self._hold_audio_handle:
            self._hold_audio_handle.stop()
            self._hold_audio_handle = None

        self._set_io_enabled(True)
        self.complete(result)

    async def _dial_human_agent(self) -> AgentSession:
        assert self._caller_room is not None

        job_ctx = get_job_context()
        ws_url = job_ctx._info.url

        # create a new room for the human agent
        human_agent_room_name = self._caller_room.name + "-human-agent"
        room = rtc.Room()
        token = (
            api.AccessToken()
            .with_identity(self._caller_room.local_participant.identity)
            .with_grants(
                api.VideoGrants(
                    room_join=True,
                    room=human_agent_room_name,
                    can_update_own_metadata=True,
                    can_publish=True,
                    can_subscribe=True,
                )
            )
            .with_kind("agent")
        ).to_jwt()

        logger.debug(
            "connecting to human agent room",
            extra={"ws_url": ws_url, "human_agent_room_name": human_agent_room_name},
        )
        await room.connect(ws_url, token)

        # if human agent hung up for whatever reason, we'd resume the caller conversation
        room.on("disconnected", self._on_human_agent_room_close)

        human_agent_sess: AgentSession = AgentSession(
            vad=self.session.vad or NOT_GIVEN,
            llm=self.session.llm or NOT_GIVEN,
            stt=self.session.stt or NOT_GIVEN,
            tts=self.session.tts or NOT_GIVEN,
            turn_detection=self.session.turn_detection or NOT_GIVEN,
        )
        # create a copy of this AgentTask
        human_agent_agent = Agent(
            instructions=self.instructions,
            turn_detection=self.turn_detection,
            stt=self.stt,
            vad=self.vad,
            llm=self.llm,
            tts=self.tts,
            tools=self.tools,
            chat_ctx=self.chat_ctx,
            allow_interruptions=self.allow_interruptions,
        )
        await human_agent_sess.start(
            agent=human_agent_agent,
            room=room,
            room_options=room_io.RoomOptions(
                close_on_disconnect=True,
                delete_room_on_close=True,
                participant_identity=self._human_agent_identity,
            ),
            record=False,  # TODO: support recording on multiple sessions?
        )

        # dial the human agent
        sip_request = api.CreateSIPParticipantRequest(
            sip_trunk_id=self._sip_trunk_id,
            sip_call_to=self._sip_call_to,
            room_name=human_agent_room_name,
            participant_identity=self._human_agent_identity,
            wait_until_answered=True,
            sip_number=self._sip_number or None,
            headers=self._sip_headers,
            dtmf=self._dtmf or "",
        )
        if self._ringing_timeout is not None:
            sip_request.ringing_timeout.FromNanoseconds(int(self._ringing_timeout * 1e9))
        if self._sip_connection is not None:
            sip_request.trunk.CopyFrom(self._sip_connection)
        await job_ctx.api.sip.create_sip_participant(sip_request)

        return human_agent_sess

    async def _merge_calls(self) -> None:
        assert self._caller_room is not None
        assert self._human_agent_sess is not None

        job_ctx = get_job_context()
        human_agent_room = self._human_agent_sess.room_io.room
        # we no longer care about the human agent session. it's supposed to be over
        human_agent_room.off("disconnected", self._on_human_agent_room_close)

        logger.debug(f"moving {self._human_agent_identity} to caller room {self._caller_room.name}")
        await job_ctx.api.room.move_participant(
            api.MoveParticipantRequest(
                room=human_agent_room.name,
                identity=self._human_agent_identity,
                destination_room=self._caller_room.name,
            )
        )

    def _set_io_enabled(self, enabled: bool) -> None:
        input = self.session.input
        output = self.session.output

        if not self._original_io_state:
            self._original_io_state = {
                "audio_input": input.audio_enabled,
                "video_input": input.video_enabled,
                "audio_output": output.audio_enabled,
                "transcription_output": output.transcription_enabled,
                "video_output": output.video_enabled,
            }

        if input.audio:
            input.set_audio_enabled(enabled and self._original_io_state["audio_input"])
        if input.video:
            input.set_video_enabled(enabled and self._original_io_state["video_input"])
        if output.audio:
            output.set_audio_enabled(enabled and self._original_io_state["audio_output"])
        if output.transcription:
            output.set_transcription_enabled(
                enabled and self._original_io_state["transcription_output"]
            )
        if output.video:
            output.set_video_enabled(enabled and self._original_io_state["video_output"])


# instructions
PERSONA = """\
# Identity

You are an agent that is reaching out to a human agent for help. There has been a previous conversation
between you and a caller, the conversation history is included below.

# Goal

Your main goal is to give the human agent sufficient context about why the caller had called in,
so that the human agent could gain sufficient knowledge to help the caller directly."""

INSTRUCTIONS_TEMPLATE = """\
{persona}

# Context

In the conversation, user refers to the human agent, caller refers to the person who's transcript is included.
Remember, you are not speaking to the caller right now, you are speaking to the human agent.

## Conversation history with caller
{_conversation_history}
## End of conversation history with caller

Once the human agent has confirmed, you should call the tool `connect_to_caller` to connect them to the caller.

You are talking to the human agent now, start by giving them a summary of the conversation so far, and answer any questions they might have.

{extra}
"""
