from __future__ import annotations

import asyncio
import contextlib
import time
from collections.abc import AsyncGenerator, AsyncIterable, Coroutine, Generator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar

from livekit import rtc

from .. import inference, llm, stt, tokenize, tts, utils, vad
from ..llm import ChatContext, RealtimeModel, ToolError, find_function_tools
from ..llm.chat_context import Instructions, _ReadOnlyChatContext
from ..log import logger
from ..types import NOT_GIVEN, FlushSentinel, NotGivenOr
from ..utils import is_given, misc
from .events import UserTurnExceededEvent
from .speech_handle import SpeechHandle
from .turn import TurnHandlingOptions, _migrate_turn_handling

if TYPE_CHECKING:
    from ..inference import LLMModels, STTModels, TTSModels
    from ..llm import mcp
    from .agent_activity import AgentActivity
    from .agent_session import AgentSession
    from .io import TimedString
    from .turn import TurnDetectionMode


@dataclass
class ModelSettings:
    tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN
    """The tool choice to use when calling the LLM."""


class Agent:
    def __init__(
        self,
        *,
        instructions: str | Instructions,
        id: str | None = None,
        chat_ctx: NotGivenOr[llm.ChatContext | None] = NOT_GIVEN,
        tools: list[llm.Tool | llm.Toolset] | None = None,
        stt: NotGivenOr[stt.STT | STTModels | str | None] = NOT_GIVEN,
        vad: NotGivenOr[vad.VAD | None] = NOT_GIVEN,
        turn_handling: NotGivenOr[TurnHandlingOptions] = NOT_GIVEN,
        llm: NotGivenOr[llm.LLM | llm.RealtimeModel | LLMModels | str | None] = NOT_GIVEN,
        tts: NotGivenOr[tts.TTS | TTSModels | str | None] = NOT_GIVEN,
        min_consecutive_speech_delay: NotGivenOr[float] = NOT_GIVEN,
        use_tts_aligned_transcript: NotGivenOr[bool] = NOT_GIVEN,
        # deprecated
        turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
        min_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
        max_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
        allow_interruptions: NotGivenOr[bool] = NOT_GIVEN,
        mcp_servers: NotGivenOr[list[mcp.MCPServer] | None] = NOT_GIVEN,
    ) -> None:
        tools = tools or []
        if type(self) is Agent:
            self._id = "default_agent"
        else:
            self._id = id or misc.camel_to_snake_case(type(self).__name__)

        turn_handling = (
            _migrate_turn_handling(
                min_endpointing_delay=min_endpointing_delay,
                max_endpointing_delay=max_endpointing_delay,
                turn_detection=turn_detection,
                allow_interruptions=allow_interruptions,
            )
            if not is_given(turn_handling)
            else turn_handling
        )

        self._instructions = instructions
        self._tools = [*tools, *find_function_tools(self)]
        self._chat_ctx = chat_ctx.copy(tools=self._tools) if chat_ctx else ChatContext.empty()
        self._turn_detection = turn_handling.get("turn_detection", NOT_GIVEN)

        if isinstance(stt, str):
            stt = inference.STT.from_model_string(stt)

        if isinstance(llm, str):
            llm = inference.LLM.from_model_string(llm)

        if isinstance(tts, str):
            tts = inference.TTS.from_model_string(tts)

        self._stt = stt
        self._llm = llm
        self._tts = tts
        self._vad = vad

        self._allow_interruptions: NotGivenOr[bool] = NOT_GIVEN
        self._interruption_detection: NotGivenOr[Literal["adaptive", "vad"]] = NOT_GIVEN
        if is_given(raw_interruption := turn_handling.get("interruption", NOT_GIVEN)):
            if "enabled" in raw_interruption:
                self._allow_interruptions = raw_interruption["enabled"]
            if "mode" in raw_interruption:
                self._interruption_detection = raw_interruption["mode"]
        endpointing = turn_handling.get("endpointing", {})
        self._min_consecutive_speech_delay = min_consecutive_speech_delay
        self._use_tts_aligned_transcript = use_tts_aligned_transcript
        self._min_endpointing_delay = endpointing.get("min_delay", NOT_GIVEN)
        self._max_endpointing_delay = endpointing.get("max_delay", NOT_GIVEN)
        self._turn_handling = turn_handling

        if isinstance(mcp_servers, list) and len(mcp_servers) == 0:
            mcp_servers = None  # treat empty list as None (but keep NOT_GIVEN)

        self._mcp_servers = mcp_servers
        if self._mcp_servers:
            logger.warning(
                "passing MCP servers to AgentSession or Agent is deprecated "
                "and will be removed in a future version. Use `MCPToolset` instead."
            )
        self._activity: AgentActivity | None = None

    @property
    def id(self) -> str:
        return self._id

    @property
    def label(self) -> str:
        return self.id

    @property
    def instructions(self) -> str | Instructions:
        """
        Returns:
            str: The core instructions that guide the agent's behavior.
        """
        return self._instructions

    @property
    def tools(self) -> list[llm.Tool | llm.Toolset]:
        """
        Returns:
            list[llm.Tool | llm.ToolSet]:
                A list of function tools available to the agent.
        """
        return self._tools.copy()

    @property
    def chat_ctx(self) -> llm.ChatContext:
        """
        Provides a read-only view of the agent's current chat context.

        Returns:
            llm.ChatContext: A read-only version of the agent's conversation history.

        See Also:
            update_chat_ctx: Method to update the internal chat context.
        """
        return _ReadOnlyChatContext(self._chat_ctx.items)

    @property
    def interruption_detection(self) -> NotGivenOr[Literal["adaptive", "vad"]]:
        return self._interruption_detection

    async def update_instructions(self, instructions: str) -> None:
        """
        Updates the agent's instructions.

        If the agent is running in realtime mode, this method also updates
        the instructions for the ongoing realtime session.

        Args:
            instructions (str):
                The new instructions to set for the agent.

        Raises:
            llm.RealtimeError: If updating the realtime session instructions fails.
        """
        if self._activity is None:
            self._instructions = instructions
            return

        await self._activity.update_instructions(instructions)

    async def update_tools(self, tools: list[llm.Tool | llm.Toolset]) -> None:
        """
        Updates the agent's available function tools.

        If the agent is running in realtime mode, this method also updates
        the tools for the ongoing realtime session.

        Args:
            tools (list[llm.Tool | llm.ToolSet]):
                The new list of function tools available to the agent.

        Raises:
            llm.RealtimeError: If updating the realtime session tools fails.
        """
        valid_tools: list[llm.Tool | llm.Toolset] = []
        for tool in tools:
            if isinstance(tool, (llm.Tool, llm.Toolset)):
                valid_tools.append(tool)
            elif resolved_tool := llm.tool_context._resolve_wrapped_tool(tool):
                valid_tools.append(resolved_tool)
            else:
                raise TypeError(f"Invalid tool type: {type(tool)}. Expected Tool or ToolSet.")

        tools = valid_tools
        if self._activity is None:
            self._tools = list({tool.id: tool for tool in tools}.values())
            self._chat_ctx = self._chat_ctx.copy(tools=self._tools)
            return

        await self._activity.update_tools(tools)

    async def update_chat_ctx(
        self, chat_ctx: llm.ChatContext, *, exclude_invalid_function_calls: bool = True
    ) -> None:
        """
        Updates the agent's chat context.

        If the agent is running in realtime mode, this method also updates
        the chat context for the ongoing realtime session.

        Args:
            chat_ctx (llm.ChatContext):
                The new or updated chat context for the agent.
            exclude_invalid_function_calls (bool): Whether to exclude function calls
                and outputs not from the agent's tools.

        Raises:
            llm.RealtimeError: If updating the realtime session chat context fails.
        """
        if self._activity is None:
            self._chat_ctx = chat_ctx.copy(
                tools=self._tools if exclude_invalid_function_calls else NOT_GIVEN
            )
            return

        await self._activity.update_chat_ctx(
            chat_ctx, exclude_invalid_function_calls=exclude_invalid_function_calls
        )

    # -- Pipeline nodes --
    # They can all be overriden by subclasses, by default they use the STT/LLM/TTS specified in the
    # constructor of the VoiceAgent

    async def on_enter(self) -> None:
        """Called when the task is entered"""
        pass

    async def on_exit(self) -> None:
        """Called when the task is exited"""
        pass

    async def on_user_turn_completed(
        self, turn_ctx: llm.ChatContext, new_message: llm.ChatMessage
    ) -> None:
        """Called when the user has finished speaking, and the LLM is about to respond

        This is a good opportunity to update the chat context or edit the new message before it is
        sent to the LLM.
        """
        pass

    async def on_user_turn_exceeded(self, ev: UserTurnExceededEvent) -> None:
        """Called when the user turn has exceeded the configured limit.

        The user has been speaking for too long without the agent successfully
        responding. By default, generates a reply using the current turn's
        transcript (previous turns are already in the chat context).

        Override to customize (e.g., use session.say() with a canned message,
        or skip the interruption entirely).
        """
        await self.session.generate_reply(
            user_input=ev.transcript,
            instructions=(
                "The user has been speaking too long without giving a chance to reply. "
                "Politely cut in with a short reply or notice. Keep it short since the user cannot interrupt it."
            ),
            allow_interruptions=False,
            tool_choice="none",
        )

    def stt_node(
        self, audio: AsyncIterable[rtc.AudioFrame], model_settings: ModelSettings
    ) -> (
        AsyncIterable[stt.SpeechEvent | str]
        | Coroutine[Any, Any, AsyncIterable[stt.SpeechEvent | str]]
        | Coroutine[Any, Any, None]
    ):
        """
        A node in the processing pipeline that transcribes audio frames into speech events.

        By default, this node uses a Speech-To-Text (STT) capability from the current agent.
        If the STT implementation does not support streaming natively, a VAD (Voice Activity
        Detection) mechanism is required to wrap the STT.

        You can override this node with your own implementation for more flexibility (e.g.,
        custom pre-processing of audio, additional buffering, or alternative STT strategies).

        Args:
            audio (AsyncIterable[rtc.AudioFrame]): An asynchronous stream of audio frames.
            model_settings (ModelSettings): Configuration and parameters for model execution.

        Yields:
            stt.SpeechEvent: An event containing transcribed text or other STT-related data.
        """
        return Agent.default.stt_node(self, audio, model_settings)

    def llm_node(
        self,
        chat_ctx: llm.ChatContext,
        tools: list[llm.Tool],
        model_settings: ModelSettings,
    ) -> (
        AsyncIterable[llm.ChatChunk | str | FlushSentinel]
        | Coroutine[Any, Any, AsyncIterable[llm.ChatChunk | str | FlushSentinel]]
        | Coroutine[Any, Any, str]
        | Coroutine[Any, Any, llm.ChatChunk]
        | Coroutine[Any, Any, None]
    ):
        """
        A node in the processing pipeline that processes text generation with an LLM.

        By default, this node uses the agent's LLM to process the provided context. It may yield
        plain text (as `str`) for straightforward text generation, or `llm.ChatChunk` objects that
        can include text and optional tool calls. `ChatChunk` is helpful for capturing more complex
        outputs such as function calls, usage statistics, or other metadata.

        You can override this node to customize how the LLM is used or how tool invocations
        and responses are handled.

        Args:
            chat_ctx (llm.ChatContext): The context for the LLM (the conversation history).
            tools (list[FunctionTool]): A list of callable tools that the LLM may invoke.
            model_settings (ModelSettings): Configuration and parameters for model execution.

        Yields/Returns:
            str: Plain text output from the LLM.
            llm.ChatChunk: An object that can contain both text and optional tool calls.
        """
        return Agent.default.llm_node(self, chat_ctx, tools, model_settings)

    def transcription_node(
        self, text: AsyncIterable[str | TimedString], model_settings: ModelSettings
    ) -> (
        AsyncIterable[str | TimedString]
        | Coroutine[Any, Any, AsyncIterable[str | TimedString]]
        | Coroutine[Any, Any, None]
    ):
        """
        A node in the processing pipeline that finalizes transcriptions from text segments.

        This node can be used to adjust or post-process text coming from an LLM (or any other
        source) into a final transcribed form. For instance, you might clean up formatting, fix
        punctuation, or perform any other text transformations here.

        You can override this node to customize post-processing logic according to your needs.

        Args:
            text (AsyncIterable[str | TimedString]): An asynchronous stream of text segments.
            model_settings (ModelSettings): Configuration and parameters for model execution.

        Yields:
            str: Finalized or post-processed text segments.
        """
        return Agent.default.transcription_node(self, text, model_settings)

    def tts_node(
        self, text: AsyncIterable[str], model_settings: ModelSettings
    ) -> (
        AsyncIterable[rtc.AudioFrame]
        | Coroutine[Any, Any, AsyncIterable[rtc.AudioFrame]]
        | Coroutine[Any, Any, None]
    ):
        """
        A node in the processing pipeline that synthesizes audio from text segments.

        By default, this node converts incoming text into audio frames using the Text-To-Speech
        from the agent.
        If the TTS implementation does not support streaming natively, it uses a sentence tokenizer
        to split text for incremental synthesis.

        You can override this node to provide different text chunking behavior, a custom TTS engine,
        or any other specialized processing.

        Args:
            text (AsyncIterable[str]): An asynchronous stream of text segments to be synthesized.
            model_settings (ModelSettings): Configuration and parameters for model execution.

        Yields:
            rtc.AudioFrame: Audio frames synthesized from the provided text.
        """
        return Agent.default.tts_node(self, text, model_settings)

    def realtime_audio_output_node(
        self, audio: AsyncIterable[rtc.AudioFrame], model_settings: ModelSettings
    ) -> (
        AsyncIterable[rtc.AudioFrame]
        | Coroutine[Any, Any, AsyncIterable[rtc.AudioFrame]]
        | Coroutine[Any, Any, None]
    ):
        """A node processing the audio from the realtime LLM session before it is played out."""
        return Agent.default.realtime_audio_output_node(self, audio, model_settings)

    def _get_activity_or_raise(self) -> AgentActivity:
        """Get the current activity context for this task (internal)"""
        if self._activity is None:
            raise RuntimeError("no activity context found, the agent is not running")

        return self._activity

    class default:
        @staticmethod
        async def stt_node(
            agent: Agent, audio: AsyncIterable[rtc.AudioFrame], model_settings: ModelSettings
        ) -> AsyncGenerator[stt.SpeechEvent, None]:
            """Default implementation for `Agent.stt_node`"""
            activity = agent._get_activity_or_raise()
            assert activity.stt is not None, "stt_node called but no STT node is available"

            wrapped_stt = activity.stt

            if not activity.stt.capabilities.streaming:
                if not activity.vad:
                    raise RuntimeError(
                        f"The STT ({activity.stt.label}) does not support streaming, add a VAD to the AgentTask/VoiceAgent to enable streaming"  # noqa: E501
                        "Or manually wrap your STT in a stt.StreamAdapter"
                    )

                wrapped_stt = stt.StreamAdapter(stt=wrapped_stt, vad=activity.vad)

            conn_options = activity.session.conn_options.stt_conn_options
            async with wrapped_stt.stream(conn_options=conn_options) as stream:
                _audio_input_started_at: float = (
                    activity._audio_recognition._input_started_at
                    if activity._audio_recognition is not None
                    and activity._audio_recognition._input_started_at is not None
                    else (
                        activity.session._recorder_io.recording_started_at
                        if activity.session._recorder_io
                        and activity.session._recorder_io.recording_started_at
                        else activity.session._started_at
                        if activity.session._started_at
                        else time.time()
                    )
                )
                stream.start_time_offset = time.time() - _audio_input_started_at

                @utils.log_exceptions(logger=logger)
                async def _forward_input() -> None:
                    async for frame in audio:
                        stream.push_frame(frame)

                forward_task = asyncio.create_task(_forward_input())
                try:
                    async for event in stream:
                        yield event
                finally:
                    await utils.aio.cancel_and_wait(forward_task)

        @staticmethod
        async def llm_node(
            agent: Agent,
            chat_ctx: llm.ChatContext,
            tools: list[llm.Tool],
            model_settings: ModelSettings,
        ) -> AsyncGenerator[llm.ChatChunk | str | FlushSentinel, None]:
            """Default implementation for `Agent.llm_node`"""
            activity = agent._get_activity_or_raise()
            assert activity.llm is not None, "llm_node called but no LLM node is available"
            assert isinstance(activity.llm, llm.LLM), (
                "llm_node should only be used with LLM (non-multimodal/realtime APIs) nodes"
            )

            tool_choice = model_settings.tool_choice if model_settings else NOT_GIVEN
            activity_llm = activity.llm

            conn_options = activity.session.conn_options.llm_conn_options
            async with activity_llm.chat(
                chat_ctx=chat_ctx, tools=tools, tool_choice=tool_choice, conn_options=conn_options
            ) as stream:
                async for chunk in stream:
                    yield chunk

        @staticmethod
        async def tts_node(
            agent: Agent, text: AsyncIterable[str], model_settings: ModelSettings
        ) -> AsyncGenerator[rtc.AudioFrame, None]:
            """Default implementation for `Agent.tts_node`"""
            activity = agent._get_activity_or_raise()
            if activity.tts is None:
                raise RuntimeError(
                    "`tts_node` called but no TTS node is available. If audio output is not needed, disable it using "
                    "`session.output.set_audio_enabled(False)`."
                )

            wrapped_tts = activity.tts

            if not activity.tts.capabilities.streaming:
                wrapped_tts = tts.StreamAdapter(
                    tts=wrapped_tts,
                    sentence_tokenizer=tokenize.blingfire.SentenceTokenizer(retain_format=True),
                )

            conn_options = activity.session.conn_options.tts_conn_options
            async with wrapped_tts.stream(conn_options=conn_options) as stream:

                async def _forward_input() -> None:
                    async for chunk in text:
                        stream.push_text(chunk)

                    stream.end_input()

                forward_task = asyncio.create_task(_forward_input())
                try:
                    async for ev in stream:
                        yield ev.frame
                finally:
                    await utils.aio.cancel_and_wait(forward_task)

        @staticmethod
        async def transcription_node(
            agent: Agent, text: AsyncIterable[str | TimedString], model_settings: ModelSettings
        ) -> AsyncGenerator[str | TimedString, None]:
            """Default implementation for `Agent.transcription_node`"""
            async for delta in text:
                yield delta

        @staticmethod
        async def realtime_audio_output_node(
            agent: Agent, audio: AsyncIterable[rtc.AudioFrame], model_settings: ModelSettings
        ) -> AsyncGenerator[rtc.AudioFrame, None]:
            """Default implementation for `Agent.realtime_audio_output_node`"""
            activity = agent._get_activity_or_raise()
            assert activity.realtime_llm_session is not None, (
                "realtime_audio_output_node called but no realtime LLM session is available"
            )

            async for frame in audio:
                yield frame

    @property
    def realtime_llm_session(self) -> llm.RealtimeSession:
        """
        Retrieve the realtime LLM session associated with the current agent.

        Raises:
            RuntimeError: If the agent is not running or the realtime LLM session is not available
        """
        if (rt_session := self._get_activity_or_raise().realtime_llm_session) is None:
            raise RuntimeError("no realtime LLM session")

        return rt_session

    @property
    def turn_detection(self) -> NotGivenOr[TurnDetectionMode | None]:
        """
        Retrieves the turn detection mode for identifying conversational turns.

        If this property was not set at Agent creation, but an ``AgentSession`` provides a turn detection,
        the session's turn detection mode will be used at runtime instead.

        Returns:
            NotGivenOr[TurnDetectionMode | None]: An optional turn detection mode for managing conversation flow.
        """  # noqa: E501
        return self._turn_detection

    @turn_detection.setter
    def turn_detection(self, value: TurnDetectionMode | None) -> None:
        self._turn_detection = value

        if self._activity is not None:
            self._activity.update_options(turn_detection=value)

    @property
    def stt(self) -> NotGivenOr[stt.STT | None]:
        """
        Retrieves the Speech-To-Text component for the agent.

        If this property was not set at Agent creation, but an ``AgentSession`` provides an STT component,
        the session's STT will be used at runtime instead.

        Returns:
            NotGivenOr[stt.STT | None]: An optional STT component.
        """  # noqa: E501
        return self._stt

    @property
    def llm(self) -> NotGivenOr[llm.LLM | llm.RealtimeModel | None]:
        """
        Retrieves the Language Model or RealtimeModel used for text generation.

        If this property was not set at Agent creation, but an ``AgentSession`` provides an LLM or RealtimeModel,
        the session's model will be used at runtime instead.

        Returns:
            NotGivenOr[llm.LLM | llm.RealtimeModel | None]: The language model for text generation.
        """  # noqa: E501
        return self._llm

    @property
    def tts(self) -> NotGivenOr[tts.TTS | None]:
        """
        Retrieves the Text-To-Speech component for the agent.

        If this property was not set at Agent creation, but an ``AgentSession`` provides a TTS component,
        the session's TTS will be used at runtime instead.

        Returns:
            NotGivenOr[tts.TTS | None]: An optional TTS component for generating audio output.
        """  # noqa: E501
        return self._tts

    @property
    def mcp_servers(self) -> NotGivenOr[list[mcp.MCPServer] | None]:
        """
        Retrieves the list of Model Context Protocol (MCP) servers providing external tools.

        If this property was not set at Agent creation, but an ``AgentSession`` provides MCP servers,
        the session's MCP servers will be used at runtime instead.

        Returns:
            NotGivenOr[list[mcp.MCPServer]]: An optional list of MCP servers.
        """  # noqa: E501
        return self._mcp_servers

    @property
    def vad(self) -> NotGivenOr[vad.VAD | None]:
        """
        Retrieves the Voice Activity Detection component for the agent.

        If this property was not set at Agent creation, but an ``AgentSession`` provides a VAD component,
        the session's VAD will be used at runtime instead.

        Returns:
            NotGivenOr[vad.VAD | None]: An optional VAD component for detecting voice activity.
        """  # noqa: E501
        return self._vad

    @property
    def allow_interruptions(self) -> NotGivenOr[bool]:
        """
        Indicates whether interruptions (e.g., stopping TTS playback) are allowed.

        If this property was not set at Agent creation, but an ``AgentSession`` provides a value for
        allowing interruptions, the session's value will be used at runtime instead.

        Returns:
            NotGivenOr[bool]: Whether interruptions are permitted.
        """
        return self._allow_interruptions

    @property
    def min_endpointing_delay(self) -> NotGivenOr[float]:
        """
        Minimum time-in-seconds since the last detected speech before the agent
        declares the user’s turn complete. In VAD mode this effectively behaves
        like max(VAD silence, min_endpointing_delay); in STT mode it is applied
        after the STT end-of-speech signal, so it can be additive with the STT
        provider’s endpointing delay.

        If this property was set at Agent creation, it will be used at runtime instead of the session's value.
        """
        return self._min_endpointing_delay

    @property
    def max_endpointing_delay(self) -> NotGivenOr[float]:
        """
        Maximum time-in-seconds the agent will wait before terminating the turn.

        If this property was set at Agent creation, it will be used at runtime instead of the session's value.
        """
        return self._max_endpointing_delay

    @property
    def min_consecutive_speech_delay(self) -> NotGivenOr[float]:
        """
        Retrieves the minimum consecutive speech delay for the agent.

        If this property was not set at Agent creation, but an ``AgentSession`` provides a value for
        the minimum consecutive speech delay, the session's value will be used at runtime instead.

        Returns:
            NotGivenOr[float]: The minimum consecutive speech delay.
        """
        return self._min_consecutive_speech_delay

    @property
    def use_tts_aligned_transcript(self) -> NotGivenOr[bool]:
        """
        Indicates whether to use TTS-aligned transcript as the input of
        the ``transcription_node``.

        If this property was not set at Agent creation, but an ``AgentSession`` provides a value for
        the use of TTS-aligned transcript, the session's value will be used at runtime instead.

        Returns:
            NotGivenOr[bool]: Whether to use TTS-aligned transcript.
        """
        return self._use_tts_aligned_transcript

    @property
    def session(self) -> AgentSession:
        """
        Retrieve the VoiceAgent associated with the current agent.

        Raises:
            RuntimeError: If the agent is not running
        """
        return self._get_activity_or_raise().session


TaskResult_T = TypeVar("TaskResult_T")


class AgentTask(Agent, Generic[TaskResult_T]):
    def __init__(
        self,
        *,
        instructions: str | Instructions,
        chat_ctx: NotGivenOr[llm.ChatContext] = NOT_GIVEN,
        tools: list[llm.Tool | llm.Toolset] | None = None,
        stt: NotGivenOr[stt.STT | None] = NOT_GIVEN,
        vad: NotGivenOr[vad.VAD | None] = NOT_GIVEN,
        turn_handling: NotGivenOr[TurnHandlingOptions] = NOT_GIVEN,
        llm: NotGivenOr[llm.LLM | llm.RealtimeModel | None] = NOT_GIVEN,
        tts: NotGivenOr[tts.TTS | None] = NOT_GIVEN,
        preserve_function_call_history: bool = False,
        # deprecated
        turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
        allow_interruptions: NotGivenOr[bool] = NOT_GIVEN,
        min_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
        max_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
        mcp_servers: NotGivenOr[list[mcp.MCPServer] | None] = NOT_GIVEN,
    ) -> None:
        tools = tools or []
        turn_handling = (
            _migrate_turn_handling(
                turn_detection=turn_detection,
                allow_interruptions=allow_interruptions,
                min_endpointing_delay=min_endpointing_delay,
                max_endpointing_delay=max_endpointing_delay,
            )
            if not is_given(turn_handling)
            else turn_handling
        )
        super().__init__(
            instructions=instructions,
            chat_ctx=chat_ctx,
            tools=tools,
            stt=stt,
            vad=vad,
            llm=llm,
            tts=tts,
            mcp_servers=mcp_servers,
            turn_handling=turn_handling,
        )

        self.__started = False
        self.__fut = asyncio.Future[TaskResult_T]()
        self.__inactive_ev = asyncio.Event()
        self.__inactive_ev.set()  # set when the agent is not awaited or activity is closed
        self._preserve_function_call_history = preserve_function_call_history

        self._old_agent: Agent | None = None

    def done(self) -> bool:
        return self.__fut.done()

    def cancel(self) -> None:
        if self._activity:
            self._activity.interrupt(force=True)
        if self.__fut.done():
            return
        self.complete(ToolError(f"AgentTask {self.id} is cancelled"))

    def complete(self, result: TaskResult_T | Exception) -> None:
        if self.__fut.done():
            raise RuntimeError(f"{self.__class__.__name__} is already done")

        if isinstance(result, Exception):
            self.__fut.set_exception(result)
        else:
            self.__fut.set_result(result)

        self.__fut.exception()  # silence exc not retrieved warnings

        from .agent_activity import _SpeechHandleContextVar

        speech_handle = _SpeechHandleContextVar.get(None)

        if speech_handle:
            speech_handle._maybe_run_final_output = result

        # if not self.__inline_mode:
        #    session._close_soon(reason=CloseReason.TASK_COMPLETED, drain=True)

    async def __await_impl(self) -> TaskResult_T:
        if self.__started:
            raise RuntimeError(f"{self.__class__.__name__} is not re-entrant, await only once")

        self.__started = True

        current_task = asyncio.current_task()
        if current_task is None:
            raise RuntimeError(
                f"{self.__class__.__name__} must be executed inside an async context"
            )

        task_info = _get_activity_task_info(current_task)
        if not task_info or not task_info.inline_task:
            raise RuntimeError(
                f"{self.__class__.__name__} should only be awaited inside tool_functions or the on_enter/on_exit methods of an Agent"  # noqa: E501
            )

        def _handle_task_done(_: asyncio.Task[Any]) -> None:
            if self.__fut.done():
                return

            # if the asyncio.Task running the InlineTask completes before the InlineTask itself, log
            # an error and attempt to recover by terminating the InlineTask.
            logger.error(
                f"The asyncio.Task finished before {self.__class__.__name__} was completed."
            )

            self.complete(
                RuntimeError(
                    f"The asyncio.Task finished before {self.__class__.__name__} was completed."
                )
            )

        current_task.add_done_callback(_handle_task_done)

        from .agent_activity import _AgentActivityContextVar, _SpeechHandleContextVar

        # TODO(theomonnom): add a global lock for inline tasks
        # This may currently break in the case we use parallel tool calls.

        speech_handle = _SpeechHandleContextVar.get(None)
        old_activity = _AgentActivityContextVar.get()
        old_agent = old_activity.agent
        session = old_activity.session
        self._old_agent = old_agent

        old_allow_interruptions = True
        if speech_handle:
            if speech_handle.interrupted:
                raise RuntimeError(
                    f"{self.__class__.__name__} cannot be awaited inside a function tool that is already interrupted"
                )

            # lock the speech handle to prevent interruptions until the task is complete
            # there should be no await before this line to avoid race conditions
            old_allow_interruptions = speech_handle.allow_interruptions
            speech_handle.allow_interruptions = False

        blocked_tasks = [current_task]
        if (
            old_activity._on_enter_task
            and not old_activity._on_enter_task.done()
            and current_task is not old_activity._on_enter_task
        ):
            blocked_tasks.append(old_activity._on_enter_task)

        if (
            task_info.function_call
            and isinstance(old_activity.llm, RealtimeModel)
            and not old_activity.llm.capabilities.manual_function_calls
        ):
            logger.error(
                f"Realtime model '{old_activity.llm.label}' does not support resuming function calls from chat context, "
                "using AgentTask inside a function tool may have unexpected behavior."
            )

        # TODO(theomonnom): could the RunResult watcher & the blocked_tasks share the same logic?
        self.__inactive_ev.clear()
        suspended_handles: list[SpeechHandle | asyncio.Task[Any]] = []
        pending_on_enter_task: asyncio.Task[None] | None = None
        try:
            # use wait_on_enter=False to avoid deadlock: on_enter may spawn nested
            # AgentTasks that require user input, but session.run() can't return until
            # all watched handles complete — creating a circular wait.
            await session._update_activity(
                self, previous_activity="pause", blocked_tasks=blocked_tasks, wait_on_enter=False
            )

            if not self._activity and not self.done():
                self.complete(
                    ToolError(
                        f"activity doesn't start for {self.id}, likely due to session closing"
                    )
                )

            run_state = session._global_run_state

            if self._activity and (on_enter_task := self._activity._on_enter_task):
                if run_state and not run_state.done():
                    # watch the on_enter task as a guard so RunResult won't complete
                    # before on_enter has registered its own speech handles
                    run_state._watch_handle(on_enter_task)
                    pending_on_enter_task = on_enter_task
                else:
                    # no active run to guard — just wait for on_enter directly
                    await asyncio.shield(on_enter_task)

            # now unwatch the parent speech handle and blocked tasks that belong to the
            # old activity — they can't complete while this AgentTask is running, and
            # keeping them watched would block RunResult from completing.
            if run_state and not run_state.done():
                if speech_handle and run_state._unwatch_handle(speech_handle):
                    suspended_handles.append(speech_handle)
                for task in blocked_tasks:
                    if run_state._unwatch_handle(task):
                        suspended_handles.append(task)
                if suspended_handles:
                    run_state._mark_done_if_needed(None)
        except Exception:
            self.__inactive_ev.set()
            raise

        try:
            return await asyncio.shield(self.__fut)

        finally:
            if speech_handle:
                with contextlib.suppress(RuntimeError):
                    speech_handle.allow_interruptions = old_allow_interruptions

            # run_state could have changed after self.__fut
            run_state = session._global_run_state

            # re-watch the suspended handles so the resumed parent activity
            # is tracked by the current RunResult again
            if run_state and not run_state.done():
                for handle in suspended_handles:
                    run_state._watch_handle(handle)

            if pending_on_enter_task:
                try:
                    await asyncio.shield(pending_on_enter_task)
                except BaseException:
                    logger.exception("error in on_enter task of agent %s", self.id)

            if session.current_agent != self:
                logger.warning(
                    f"{self.__class__.__name__} completed, but the agent has changed in the meantime. "
                    "Ignoring handoff to the previous agent, likely due to `AgentSession.update_agent` being invoked."
                )
                await old_activity.aclose()
            else:
                merged_chat_ctx = old_agent.chat_ctx.merge(
                    self.chat_ctx,
                    exclude_function_call=not self._preserve_function_call_history,
                    exclude_instructions=True,
                )
                # set the chat_ctx directly, `session._update_activity` will sync it to the rt_session if needed
                old_agent._chat_ctx.items[:] = merged_chat_ctx.items

                await session._update_activity(
                    old_agent, new_activity="resume", wait_on_enter=False
                )
            self.__inactive_ev.set()

    def __await__(self) -> Generator[None, None, TaskResult_T]:
        return self.__await_impl().__await__()

    async def _wait_for_inactive(self) -> None:
        await self.__inactive_ev.wait()


@dataclass
class _ActivityTaskInfo:
    function_call: llm.FunctionCall | None = None
    speech_handle: SpeechHandle | None = None
    inline_task: bool = False


def _set_activity_task_info(
    task: asyncio.Task[Any],
    *,
    function_call: NotGivenOr[llm.FunctionCall | None] = NOT_GIVEN,
    speech_handle: NotGivenOr[SpeechHandle | None] = NOT_GIVEN,
    inline_task: NotGivenOr[bool] = NOT_GIVEN,
) -> None:
    info = _get_activity_task_info(task) or _ActivityTaskInfo()

    if is_given(function_call):
        info.function_call = function_call

    if is_given(speech_handle):
        info.speech_handle = speech_handle

    if is_given(inline_task):
        info.inline_task = inline_task

    setattr(task, "__livekit_agents_activity_task", info)


def _get_activity_task_info(task: asyncio.Task[Any]) -> _ActivityTaskInfo | None:
    return getattr(task, "__livekit_agents_activity_task", None)


def _pass_through_activity_task_info(task: asyncio.Task[Any]) -> None:
    current_task = asyncio.current_task()
    if current_task and (info := _get_activity_task_info(current_task)):
        setattr(task, "__livekit_agents_activity_task", info)
