from __future__ import annotations

import asyncio
import json
import math
import time
from collections import deque
from collections.abc import AsyncIterable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol

from opentelemetry import trace
from opentelemetry.sdk.trace import ReadableSpan

from livekit import rtc

from .. import inference, llm, stt, tokenize, utils, vad
from .._exceptions import APIError
from ..inference.interruption import (
    _AgentSpeechEndedSentinel,
    _AgentSpeechStartedSentinel,
    _OverlapSpeechEndedSentinel,
    _OverlapSpeechStartedSentinel,
)
from ..language import LanguageCode
from ..log import logger
from ..stt import SpeechEvent
from ..telemetry import trace_types, tracer
from ..types import NOT_GIVEN, NotGivenOr
from ..utils import aio, is_given
from . import io
from ._utils import _set_participant_attributes
from .endpointing import BaseEndpointing
from .events import UserTurnExceededEvent
from .turn import TurnDetectionMode as TurnDetectionMode

if TYPE_CHECKING:
    from .agent_session import AgentSession

MIN_LANGUAGE_DETECTION_LENGTH = 5
# Mirrors turn_detector.base.MAX_HISTORY_TURNS for tracing
_EOU_MAX_HISTORY_TURNS = 6


@dataclass
class _EndOfTurnInfo:
    skip_reply: bool
    """If True, a reply was already triggered and should be skipped after end of turn detection."""
    new_transcript: str
    transcript_confidence: float

    # metrics report
    started_speaking_at: float | None
    stopped_speaking_at: float | None
    transcription_delay: float | None
    end_of_turn_delay: float | None


@dataclass
class _PreemptiveGenerationInfo:
    new_transcript: str
    transcript_confidence: float
    started_speaking_at: float | None


@dataclass
class _UserTurnTracker:
    words: int = 0
    transcript: str = ""
    started_at: float | None = None


class RecognitionHooks(Protocol):
    def on_interruption(self, ev: inference.OverlappingSpeechEvent) -> None: ...
    def on_start_of_speech(self, ev: vad.VADEvent | None, speech_start_time: float) -> None: ...
    def on_vad_inference_done(self, ev: vad.VADEvent) -> None: ...
    def on_end_of_speech(self, ev: vad.VADEvent | None) -> None: ...
    def on_interim_transcript(self, ev: stt.SpeechEvent, *, speaking: bool | None) -> None: ...
    def on_final_transcript(self, ev: stt.SpeechEvent, *, speaking: bool | None = None) -> None: ...
    def on_end_of_turn(self, info: _EndOfTurnInfo) -> bool: ...
    def on_preemptive_generation(self, info: _PreemptiveGenerationInfo) -> None: ...
    def on_user_turn_exceeded(self, ev: UserTurnExceededEvent) -> None: ...
    def retrieve_chat_ctx(self) -> llm.ChatContext: ...


class _STTPipeline:
    """Transferable STT pipeline that survives agent handoff.

    The pump task iterates the STT generator and forwards events into event_ch.
    It is never cancelled during handoff — only the consumer is swapped.
    """

    def __init__(self, stt_node: io.STTNode) -> None:
        self._stt_node = stt_node
        self._audio_ch = aio.Chan[rtc.AudioFrame]()
        self._event_ch = aio.Chan[stt.SpeechEvent]()
        self._pump_task = asyncio.create_task(self._stt_pump())
        self._pump_task.add_done_callback(lambda _: self._event_ch.close())

    @property
    def audio_ch(self) -> aio.Chan[rtc.AudioFrame]:
        return self._audio_ch

    @property
    def event_ch(self) -> aio.Chan[stt.SpeechEvent]:
        return self._event_ch

    @utils.log_exceptions(logger=logger)
    async def _stt_pump(self) -> None:
        """Iterate the STT generator and forward events into *event_ch*.

        This task owns the generator lifecycle and is never cancelled during
        handoff — only the consumer is swapped.
        """
        from .agent import ModelSettings

        node = self._stt_node(self._audio_ch, ModelSettings())
        if asyncio.iscoroutine(node):
            node = await node

        if node is None:
            return

        if isinstance(node, AsyncIterable):
            async for ev in node:
                assert isinstance(ev, stt.SpeechEvent), (
                    f"STT node must yield SpeechEvent, got: {type(ev)}"
                )
                self._event_ch.send_nowait(ev)

    async def aclose(self) -> None:
        await aio.cancel_and_wait(self._pump_task)


class AudioRecognition:
    def __init__(
        self,
        session: AgentSession,
        *,
        hooks: RecognitionHooks,
        endpointing: BaseEndpointing,
        stt: io.STTNode | None,
        vad: vad.VAD | None,
        interruption_detection: inference.AdaptiveInterruptionDetector | None,
        turn_detection: TurnDetectionMode | None,
        stt_model: str | None = None,
        stt_provider: str | None = None,
    ) -> None:
        self._session = session
        self._hooks = hooks
        self._audio_input_atask: asyncio.Task[None] | None = None
        self._commit_user_turn_atask: asyncio.Task[None] | None = None
        self._stt_consumer_atask: asyncio.Task[None] | None = None
        self._vad_atask: asyncio.Task[None] | None = None
        self._end_of_turn_task: asyncio.Task[None] | None = None
        self._endpointing: BaseEndpointing = endpointing
        self._turn_detector = turn_detection if not isinstance(turn_detection, str) else None
        self._stt = stt
        self._vad = vad
        self._stt_model = stt_model
        self._stt_provider = stt_provider
        self._turn_detection_mode = turn_detection if isinstance(turn_detection, str) else None
        self._vad_base_turn_detection = self._turn_detection_mode in ("vad", None)
        self._user_turn_committed = False  # true if user turn ended but EOU task not done

        self._sample_rate: int | None = None
        self._speaking = False

        self._last_final_transcript_time: float | None = None
        self._last_speaking_time: float | None = None
        self._speech_start_time: float | None = None

        # used for manual commit_user_turn
        self._final_transcript_received = asyncio.Event()
        self._final_transcript_confidence: list[float] = []
        self._audio_transcript = ""
        self._audio_interim_transcript = ""
        # used for STTs that support preflight mode, so it could start preemptive generation earlier
        self._audio_preflight_transcript = ""
        self._last_language: LanguageCode | None = None

        self._stt_pipeline: _STTPipeline | None = None
        self._vad_ch: aio.Chan[rtc.AudioFrame] | None = None

        self._tasks: set[asyncio.Task[Any]] = set()

        # region: adaptive interruption detection
        self._interruption_atask: asyncio.Task[None] | None = None
        self._interruption_detection = interruption_detection
        self._interruption_ch: aio.Chan[inference.InterruptionDataFrameType] | None = None
        self._input_started_at: float | None = None
        self._ignore_user_transcript_until: NotGivenOr[float] = NOT_GIVEN
        self._transcript_buffer: deque[SpeechEvent] = deque()
        self._interruption_enabled: bool = interruption_detection is not None and vad is not None
        self._agent_speaking: bool = False

        _backchannel_boundary: float | tuple[float, float] | None = (
            session.options.interruption.get("backchannel_boundary")
        )
        self._backchannel_boundary: tuple[float, float] | None = (
            (_backchannel_boundary, _backchannel_boundary)
            if isinstance(_backchannel_boundary, int | float)
            else _backchannel_boundary
        )
        if self._backchannel_boundary and (
            len(self._backchannel_boundary) != 2 or any(x < 0.0 for x in self._backchannel_boundary)
        ):
            raise ValueError("backchannel_boundary must be a tuple of two non-negative floats")
        self._backchannel_boundary_timer: asyncio.TimerHandle | None = None
        self.backchannel_boundary_callback: Callable[[], None] | None = None
        # endregion

        self._user_turn_span: trace.Span | None = None
        self._user_turn_start: float | None = None
        self._stt_request_ids: list[str] = []
        self._closing = asyncio.Event()

        self._vad_speech_started: bool = False

        # user turn limit tracking — accumulates across turns until agent speaks
        self._turn_tracker = _UserTurnTracker()
        self._word_tokenizer = tokenize.basic.WordTokenizer()

    def update_options(
        self,
        *,
        endpointing: NotGivenOr[BaseEndpointing] = NOT_GIVEN,
        turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
        # deprecated
        min_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
        max_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
    ) -> None:
        if is_given(endpointing):
            self._endpointing = endpointing

        if is_given(turn_detection):
            self._turn_detector = turn_detection if not isinstance(turn_detection, str) else None

            mode = turn_detection if isinstance(turn_detection, str) else None
            if self._turn_detection_mode != mode:
                previous_mode = self._turn_detection_mode
                self._turn_detection_mode = mode
                self._vad_base_turn_detection = self._turn_detection_mode in ("vad", None)

                if self._turn_detection_mode == "manual" or previous_mode == "manual":
                    if self._end_of_turn_task:
                        if not self._end_of_turn_task.done():
                            self._end_of_turn_task.cancel()
                    self._end_of_turn_task = None
                    self._user_turn_committed = False

    def start(self, *, stt_pipeline: _STTPipeline | None = None) -> None:
        self.update_stt(self._stt, pipeline=stt_pipeline)
        self.update_vad(self._vad)
        self.update_interruption_detection(self._interruption_detection)

    def stop(self) -> None:
        self.update_stt(None)
        self.update_vad(None)
        self.update_interruption_detection(None)

    @property
    def adaptive_interruption_active(self) -> bool:
        return (
            self._interruption_enabled
            and self._interruption_ch is not None
            and not self._interruption_ch.closed
        )

    # region: boundary for adaptive interruption detection

    @property
    def backchannel_boundary_active(self) -> bool:
        return self._backchannel_boundary_timer is not None

    def _on_backchannel_boundary_done(self) -> None:
        self._backchannel_boundary_timer = None
        cb, self.backchannel_boundary_callback = (
            self.backchannel_boundary_callback,
            None,
        )
        if cb is not None:
            cb()

    def _cancel_backchannel_boundary(self) -> None:
        if self._backchannel_boundary_timer is not None:
            self._backchannel_boundary_timer.cancel()
            self._backchannel_boundary_timer = None
        self.backchannel_boundary_callback = None

    # endregion

    def on_start_of_agent_speech(self, started_at: float) -> None:
        self._agent_speaking = True
        self._endpointing.on_start_of_agent_speech(started_at=started_at)

        # reset user turn tracker when agent starts speaking
        self._turn_tracker = _UserTurnTracker()

        if self._backchannel_boundary and (start_cooldown := self._backchannel_boundary[0]) > 0:
            self._cancel_backchannel_boundary()
            self._backchannel_boundary_timer = asyncio.get_running_loop().call_later(
                start_cooldown, self._on_backchannel_boundary_done
            )

        if self.adaptive_interruption_active:
            self._interruption_ch.send_nowait(_AgentSpeechStartedSentinel())  # type: ignore[union-attr]

    def on_end_of_agent_speech(self, *, ignore_user_transcript_until: float) -> None:
        self._cancel_backchannel_boundary()

        if self._agent_speaking:
            self._endpointing.on_end_of_agent_speech(ended_at=time.time())

        if not self.adaptive_interruption_active:
            self._agent_speaking = False
            return

        self._interruption_ch.send_nowait(_AgentSpeechEndedSentinel())  # type: ignore[union-attr]

        if self._agent_speaking:
            # no interruption is detected, end the inference (idempotent)
            if not is_given(self._ignore_user_transcript_until):
                self.on_end_of_overlap_speech(ended_at=time.time())

            end_cooldown: float = (
                self._backchannel_boundary[1] if self._backchannel_boundary else 0.0
            )

            ignore_until = (
                ignore_user_transcript_until
                if not is_given(self._ignore_user_transcript_until)
                else min(ignore_user_transcript_until, self._ignore_user_transcript_until)
            )
            logger.trace(
                "flushing held transcripts",
                extra={
                    "ignore_until": ignore_until,
                    "end_cooldown": end_cooldown,
                },
            )
            self._ignore_user_transcript_until = ignore_until - end_cooldown

            # flush held transcripts if possible
            task = asyncio.create_task(self._flush_held_transcripts(cooldown=end_cooldown))
            task.add_done_callback(lambda _: self._tasks.discard(task))
            self._tasks.add(task)

        self._agent_speaking = False

    def on_start_of_speech(
        self,
        started_at: float,
        speech_duration: float = 0.0,
        user_speaking_span: trace.Span | None = None,
    ) -> None:
        self._endpointing.on_start_of_speech(
            started_at=started_at, overlapping=self._agent_speaking
        )
        if not self.adaptive_interruption_active or not self._agent_speaking:
            return
        self._interruption_ch.send_nowait(  # type: ignore[union-attr]
            _OverlapSpeechStartedSentinel(
                speech_duration=speech_duration,
                user_speaking_span=user_speaking_span,
                started_at=started_at,
            )
        )

    def on_end_of_speech(
        self,
        ended_at: float,
        user_speaking_span: trace.Span | None = None,
        interruption: NotGivenOr[bool] = NOT_GIVEN,
    ) -> None:
        if self._speaking:
            self._endpointing.on_end_of_speech(
                ended_at=ended_at,
                should_ignore=(
                    is_given(interruption) and not interruption and self._agent_speaking
                ),
            )

        self.on_end_of_overlap_speech(ended_at=ended_at, user_speaking_span=user_speaking_span)

    def on_end_of_overlap_speech(
        self,
        ended_at: float,
        user_speaking_span: trace.Span | None = None,
    ) -> None:
        """End interruption inference when agent is speaking and overlap speech ends."""
        if not self.adaptive_interruption_active or not self._agent_speaking:
            return

        # Only set is_interruption=false if not already set (avoid overwriting true from interruption detection)
        if user_speaking_span and user_speaking_span.is_recording():
            if isinstance(user_speaking_span, ReadableSpan):
                if (
                    user_speaking_span.attributes
                    and user_speaking_span.attributes.get(trace_types.ATTR_IS_INTERRUPTION) is None
                ):
                    user_speaking_span.set_attribute(trace_types.ATTR_IS_INTERRUPTION, "false")
            else:
                user_speaking_span.set_attribute(trace_types.ATTR_IS_INTERRUPTION, "false")

        self._interruption_ch.send_nowait(  # type: ignore[union-attr]
            _OverlapSpeechEndedSentinel(ended_at=ended_at or time.time())
        )

    @utils.log_exceptions(logger=logger)
    async def _flush_held_transcripts(self, cooldown: float, force: bool = False) -> None:
        """Flush held transcripts.

        When ``force`` is True, all buffered events are emitted unconditionally; this
        is used during interruption-detector teardown when the ignore-window gating
        can no longer be trusted.

        Otherwise, drop transcripts whose *end time* falls before
        ``ignore_user_transcript_until - cooldown`` and re-emit the rest. Events
        without timestamps are treated as the next valid event.
        """
        if not self._transcript_buffer:
            self._reset_interruption_detection()
            return

        if force:
            events_to_emit = list(self._transcript_buffer)
            # reset before emitting to avoid recursive calls
            self._reset_interruption_detection()
            for ev in events_to_emit:
                await self._on_stt_event(ev)
            return

        if (
            not self._interruption_enabled
            or not is_given(self._ignore_user_transcript_until)
            or self._input_started_at is None
        ):
            self._reset_interruption_detection()
            return

        emit_from_index: int | None = None
        should_flush = False
        for i, ev in enumerate(self._transcript_buffer):
            # always try to emit from a sentinel event
            if not ev.alternatives:
                emit_from_index = min(emit_from_index, i) if emit_from_index is not None else i
                continue
            if ev.alternatives[0].start_time == ev.alternatives[0].end_time == 0:
                self._reset_interruption_detection()
                return

            if (
                ev.alternatives[0].end_time > 0
                and ev.alternatives[0].end_time + self._input_started_at
                < self._ignore_user_transcript_until
            ):
                # reset the index to emit from the next valid event
                emit_from_index = None
            else:
                # break since we found a valid event to emit from
                emit_from_index = min(emit_from_index, i) if emit_from_index is not None else i
                should_flush = True
                break

        events_to_emit = (
            list(self._transcript_buffer)[int(emit_from_index) :]
            if emit_from_index is not None and should_flush
            else []
        )
        _ignore_user_transcript_until = self._ignore_user_transcript_until
        # reset before emitting to avoid recursive calls
        self._reset_interruption_detection()

        for ev in events_to_emit:
            added_delay = 0.0
            if ev.alternatives and ev.alternatives[0].end_time > 0:
                added_delay = max(
                    0,
                    (
                        ev.alternatives[0].end_time
                        + self._input_started_at
                        - _ignore_user_transcript_until
                    )
                    + (cooldown or 0.0),
                )
            logger.trace(
                "re-emitting held user transcript",
                extra={
                    "event": ev.type,
                    "cooldown": cooldown,
                    "added_delay": added_delay,
                },
            )
            await self._on_stt_event(ev)

    def _reset_interruption_detection(self) -> None:
        """Reset relevant states for adaptive interruption detection."""
        self._transcript_buffer.clear()
        self._ignore_user_transcript_until = NOT_GIVEN

    def _should_hold_stt_event(self, ev: stt.SpeechEvent) -> bool:
        """Test if the event should be held until the ignore_user_transcript_until timestamp."""
        if not self._interruption_enabled:
            return False

        if self._agent_speaking:
            return True

        # reset when the user starts speaking after the agent speech
        # this could let a transcript pass through if the user starts
        # speaking right before the agent speech ends, not ideal but
        # better than swallowing the transcript.
        if ev.type == stt.SpeechEventType.START_OF_SPEECH:
            self._ignore_user_transcript_until = NOT_GIVEN
            return False

        if not is_given(self._ignore_user_transcript_until):
            return False
        # sentinel events are always held until
        # we have something concrete to release them
        if not ev.alternatives:
            return True
        if (
            # most vendors don't set timestamps properly, in which case we just assume
            # it is a valid event after the ignore_user_transcript_until timestamp
            is_given(self._input_started_at)
            # check if the event should be held if
            # 1. the stt input stream has started
            # 2. the current event has a valid start and end time, relative to the input stream start time
            # 3. the event is for audio sent before the ignore_user_transcript_until timestamp
            and self._input_started_at is not None
            and not (ev.alternatives[0].start_time == ev.alternatives[0].end_time == 0)
            and ev.alternatives[0].start_time > 0
            and ev.alternatives[0].start_time + self._input_started_at
            < self._ignore_user_transcript_until
        ):
            return True

        return False

    def push_audio(self, frame: rtc.AudioFrame, *, skip_stt: bool = False) -> None:
        if self._input_started_at is None:
            self._input_started_at = time.time() - frame.duration

        self._sample_rate = frame.sample_rate
        if not skip_stt and self._stt_pipeline is not None:
            self._stt_pipeline.audio_ch.send_nowait(frame)

        if self._vad_ch is not None:
            self._vad_ch.send_nowait(frame)

        if self._session.amd is not None:
            self._session.amd.push_audio(frame)

        if self._interruption_ch is not None:
            self._interruption_ch.send_nowait(frame)

    async def aclose(self) -> None:
        self._closing.set()
        if self._commit_user_turn_atask is not None:
            try:
                await self._commit_user_turn_atask
            except asyncio.CancelledError:
                pass

        if self._stt_pipeline is not None:
            await self._stt_pipeline.aclose()
            self._stt_pipeline = None

        await aio.cancel_and_wait(*self._tasks)

        if self._stt_consumer_atask is not None:
            await aio.cancel_and_wait(self._stt_consumer_atask)

        if self._vad_atask is not None:
            await aio.cancel_and_wait(self._vad_atask)

        if self._interruption_atask is not None:
            await aio.cancel_and_wait(self._interruption_atask)

        if self._end_of_turn_task is not None:
            try:
                await self._end_of_turn_task
            except asyncio.CancelledError:
                pass

        if self._backchannel_boundary_timer is not None:
            self._backchannel_boundary_timer.cancel()
            self._backchannel_boundary_timer = None
            self.backchannel_boundary_callback = None

    def update_stt(self, stt: io.STTNode | None, *, pipeline: _STTPipeline | None = None) -> None:
        self._stt = stt
        if pipeline is None and stt is not None:
            pipeline = _STTPipeline(stt)

        if pipeline is not None:
            self._stt_consumer_atask = asyncio.create_task(
                self._stt_consumer(
                    event_ch=pipeline.event_ch,
                    old_pipeline=self._stt_pipeline,
                    old_consumer=self._stt_consumer_atask,
                )
            )
            self._stt_pipeline = pipeline
            # reset interruption handling related state
            self._transcript_buffer.clear()
            self._ignore_user_transcript_until = NOT_GIVEN
            self._input_started_at = None
        else:
            if self._stt_consumer_atask is not None:
                task = asyncio.create_task(aio.cancel_and_wait(self._stt_consumer_atask))
                task.add_done_callback(lambda _: self._tasks.discard(task))
                self._tasks.add(task)
                self._stt_consumer_atask = None

            if self._stt_pipeline is not None:
                task = asyncio.create_task(self._stt_pipeline.aclose())
                task.add_done_callback(lambda _: self._tasks.discard(task))
                self._tasks.add(task)
                self._stt_pipeline = None

    def update_vad(self, vad: vad.VAD | None) -> None:
        self._vad = vad
        if vad:
            self._vad_ch = aio.Chan[rtc.AudioFrame]()
            self._vad_atask = asyncio.create_task(
                self._vad_task(vad, self._vad_ch, self._vad_atask)
            )
        elif self._vad_atask is not None:
            task = asyncio.create_task(aio.cancel_and_wait(self._vad_atask))
            task.add_done_callback(lambda _: self._tasks.discard(task))
            self._tasks.add(task)
            self._vad_atask = None
            self._vad_ch = None

        self._interruption_enabled = (
            self._interruption_detection is not None and self._vad is not None
        )

    async def detach_stt(self) -> _STTPipeline | None:
        """Detach the STT pipeline for handoff to another AudioRecognition.

        Returns the pipeline (pump task + channels) without stopping it.
        The caller is responsible for passing it to the new AudioRecognition
        via start(..., stt_pipeline=pipeline).
        """
        pipeline = self._stt_pipeline
        self._stt_pipeline = None

        # stop the consumer — the new AudioRecognition will start its own
        if self._stt_consumer_atask is not None:
            await aio.cancel_and_wait(self._stt_consumer_atask)
            self._stt_consumer_atask = None

        return pipeline

    def update_interruption_detection(
        self, interruption_detection: inference.AdaptiveInterruptionDetector | None
    ) -> None:
        self._interruption_detection = interruption_detection
        if interruption_detection is not None:
            self._interruption_ch = aio.Chan[inference.InterruptionDataFrameType]()
            self._interruption_atask = asyncio.create_task(
                self._interruption_task(
                    interruption_detection, self._interruption_ch, self._interruption_atask
                )
            )
            self._transcript_buffer.clear()
            self._ignore_user_transcript_until = NOT_GIVEN
            self._input_started_at = None
        elif self._interruption_atask is not None:
            task = asyncio.create_task(aio.cancel_and_wait(self._interruption_atask))
            task.add_done_callback(lambda _: self._tasks.discard(task))
            self._tasks.add(task)
            self._interruption_atask = None
            self._interruption_ch = None
            self._cancel_backchannel_boundary()
            flush_task = asyncio.create_task(self._flush_held_transcripts(cooldown=0.0, force=True))
            flush_task.add_done_callback(lambda _: self._tasks.discard(flush_task))
            self._tasks.add(flush_task)

        self._interruption_enabled = (
            self._interruption_detection is not None and self._vad is not None
        )

    def clear_user_turn(self) -> None:
        self._audio_transcript = ""
        self._audio_interim_transcript = ""
        self._audio_preflight_transcript = ""
        self._final_transcript_confidence = []
        self._last_final_transcript_time = None
        self._speech_start_time = None
        self._last_speaking_time = None
        self._vad_speech_started = False
        self._user_turn_committed = False

        # end any in-progress user_turn span so the next speech starts a fresh one
        if self._user_turn_span is not None and self._user_turn_span.is_recording():
            self._user_turn_span.end()
        self._user_turn_span = None
        self._stt_request_ids = []

        # reset stt to clear the buffer from previous user turn
        stt = self._stt
        self.update_stt(None)
        self.update_stt(stt)

    def commit_user_turn(
        self,
        *,
        audio_detached: bool,
        transcript_timeout: float,
        stt_flush_duration: float = 2.0,
        skip_reply: bool = False,
    ) -> asyncio.Future[str]:
        loop = asyncio.get_running_loop()
        fut: asyncio.Future[str] = loop.create_future()

        if not self._stt or self._closing.is_set():
            fut.set_result("")
            return fut

        async def _commit_user_turn() -> None:
            if self._last_final_transcript_time is None or (
                time.time() - self._last_final_transcript_time > 0.5
            ):
                # if the last final transcript is received more than 0.5s ago
                # append a silence frame to the stt to flush the buffer

                self._final_transcript_received.clear()

                # flush the stt by pushing silence
                if audio_detached and self._sample_rate:
                    num_samples = int(self._sample_rate * 0.2)
                    silence_frame = rtc.AudioFrame(
                        b"\x00\x00" * num_samples,
                        sample_rate=self._sample_rate,
                        num_channels=1,
                        samples_per_channel=num_samples,
                    )
                    num_frames = max(0, int(math.ceil(stt_flush_duration / silence_frame.duration)))
                    for _ in range(num_frames):
                        self.push_audio(silence_frame)

                # wait for the final transcript to be available
                try:
                    await asyncio.wait_for(
                        self._final_transcript_received.wait(),
                        timeout=transcript_timeout,
                    )
                except asyncio.TimeoutError:
                    if self._audio_interim_transcript:
                        logger.warning(
                            "final transcript not received after timeout",
                            extra={
                                "transcript_timeout": transcript_timeout,
                                "interim_transcript": self._audio_interim_transcript,
                            },
                        )

            if self._audio_interim_transcript:
                # emit interim transcript as final for frontend display
                self._hooks.on_final_transcript(
                    stt.SpeechEvent(
                        type=stt.SpeechEventType.FINAL_TRANSCRIPT,
                        alternatives=[
                            stt.SpeechData(
                                language=LanguageCode(""), text=self._audio_interim_transcript
                            )
                        ],
                    )
                )

                # append interim transcript in case the final transcript is not ready
                self._audio_transcript = (
                    f"{self._audio_transcript} {self._audio_interim_transcript}".strip()
                )

            transcript = self._audio_transcript
            self._audio_interim_transcript = ""
            chat_ctx = self._hooks.retrieve_chat_ctx().copy()
            self._run_eou_detection(chat_ctx, skip_reply=skip_reply)
            self._user_turn_committed = True
            if not fut.done():
                fut.set_result(transcript)

        def _on_task_done(task: asyncio.Task[None]) -> None:
            if fut.done():
                return
            if task.cancelled():
                fut.cancel()
            elif exc := task.exception():
                fut.set_exception(exc)

        if self._commit_user_turn_atask is not None:
            self._commit_user_turn_atask.cancel()

        self._commit_user_turn_atask = asyncio.create_task(_commit_user_turn())
        self._commit_user_turn_atask.add_done_callback(_on_task_done)
        return fut

    @property
    def current_transcript(self) -> str:
        """
        Transcript for this turn, including interim transcript if available.
        """
        if self._audio_interim_transcript:
            return self._audio_transcript + " " + self._audio_interim_transcript
        return self._audio_transcript

    async def _on_stt_event(self, ev: stt.SpeechEvent) -> None:
        # Collect provider-known STT ids for this user turn. The actual attribute
        # is written once when the user_turn span ends (see _on_end_of_turn), to
        # avoid ordering issues with span creation.
        if ev.request_id and ev.request_id not in self._stt_request_ids:
            self._stt_request_ids.append(ev.request_id)

        if (
            self._turn_detection_mode == "manual"
            and self._user_turn_committed
            and (
                self._end_of_turn_task is None
                or self._end_of_turn_task.done()
                or ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT
            )
        ):
            # ignore transcript for manual turn detection when user turn already committed
            # and EOU task is done or this is an interim transcript
            return

        # handle interruption detection
        # - hold the event until the ignore_user_transcript_until expires
        # - release only relevant events
        # - allow RECOGNITION_USAGE to pass through immediately
        if ev.type != stt.SpeechEventType.RECOGNITION_USAGE and self._interruption_enabled:
            if self._should_hold_stt_event(ev):
                logger.trace(
                    "holding STT event until ignore_user_transcript_until expires",
                    extra={
                        "event": ev.type,
                        "ignore_user_transcript_until": self._ignore_user_transcript_until
                        if is_given(self._ignore_user_transcript_until)
                        else None,
                    },
                )
                self._transcript_buffer.append(ev)
                return

            if self._transcript_buffer:
                end_cooldown: float = (
                    self._backchannel_boundary[1] if self._backchannel_boundary else 0.0
                )
                await self._flush_held_transcripts(cooldown=end_cooldown)
                # no return here to allow the new event to be processed normally

        if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT:
            transcript = ev.alternatives[0].text
            language = ev.alternatives[0].language
            confidence = ev.alternatives[0].confidence

            if not self._last_language or (
                language and len(transcript) > MIN_LANGUAGE_DETECTION_LENGTH
            ):
                self._last_language = language

            self._final_transcript_received.set()
            if not transcript:
                return

            self._hooks.on_final_transcript(
                ev,
                speaking=self._speaking
                if self._vad or self._turn_detection_mode == "stt"
                else None,
            )
            if self._session.amd is not None:
                self._session.amd._on_transcript(transcript)

            extra: dict[str, Any] = {"user_transcript": transcript, "language": self._last_language}
            if self._last_speaking_time:
                extra["transcript_delay"] = time.time() - self._last_speaking_time
            logger.debug("received user transcript", extra=extra)

            self._last_final_transcript_time = time.time()
            self._audio_transcript += f" {transcript}"
            self._audio_transcript = self._audio_transcript.lstrip()
            self._final_transcript_confidence.append(confidence)
            transcript_changed = self._audio_transcript != self._audio_preflight_transcript
            self._audio_interim_transcript = ""
            self._audio_preflight_transcript = ""

            if not self._vad or self._last_speaking_time is None:
                # vad disabled, use stt timestamp
                # TODO: this would screw up transcription latency metrics
                # but we'll live with it for now.
                # the correct way is to ensure STT fires SpeechEventType.END_OF_SPEECH
                # and using that timestamp for _last_speaking_time
                self._last_speaking_time = time.time()

            # check user turn limit after accumulating transcript
            self._check_user_turn_limit(transcript)

            if self._vad_base_turn_detection or self._user_turn_committed:
                if transcript_changed:
                    self._hooks.on_preemptive_generation(
                        _PreemptiveGenerationInfo(
                            new_transcript=self._audio_transcript,
                            transcript_confidence=(
                                sum(self._final_transcript_confidence)
                                / len(self._final_transcript_confidence)
                                if self._final_transcript_confidence
                                else 0
                            ),
                            started_speaking_at=self._speech_start_time,
                        )
                    )

                if not self._speaking:
                    chat_ctx = self._hooks.retrieve_chat_ctx().copy()
                    self._run_eou_detection(chat_ctx)

        elif ev.type == stt.SpeechEventType.PREFLIGHT_TRANSCRIPT:
            self._hooks.on_interim_transcript(
                ev,
                speaking=self._speaking
                if self._vad or self._turn_detection_mode == "stt"
                else None,
            )
            transcript = ev.alternatives[0].text
            language = ev.alternatives[0].language
            confidence = ev.alternatives[0].confidence

            if not self._last_language or (
                language and len(transcript) > MIN_LANGUAGE_DETECTION_LENGTH
            ):
                self._last_language = language

            if not transcript:
                return

            logger.debug(
                "received user preflight transcript",
                extra={"user_transcript": transcript, "language": self._last_language},
            )

            # still need to increment it as it's used for turn detection,
            self._last_final_transcript_time = time.time()
            # preflight transcript includes all pre-committed transcripts (including final transcript from the previous STT run)
            self._audio_preflight_transcript = (self._audio_transcript + " " + transcript).lstrip()
            self._audio_interim_transcript = transcript

            if not self._vad or self._last_speaking_time is None:
                # vad disabled, use stt timestamp
                self._last_speaking_time = time.time()

            if self._turn_detection_mode != "manual" or self._user_turn_committed:
                confidence_vals = list(self._final_transcript_confidence) + [confidence]
                self._hooks.on_preemptive_generation(
                    _PreemptiveGenerationInfo(
                        new_transcript=self._audio_preflight_transcript,
                        transcript_confidence=sum(confidence_vals) / len(confidence_vals),
                        started_speaking_at=self._speech_start_time,
                    )
                )

        elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT:
            self._hooks.on_interim_transcript(
                ev,
                speaking=self._speaking
                if self._vad or self._turn_detection_mode == "stt"
                else None,
            )
            self._audio_interim_transcript = ev.alternatives[0].text

        elif ev.type == stt.SpeechEventType.END_OF_SPEECH and self._turn_detection_mode == "stt":
            with trace.use_span(self._ensure_user_turn_span()):
                self._hooks.on_end_of_speech(None)

            # STT EOT changes user state from speaking to listening without updating VAD internal states
            # VAD EOS will also skip updating user state from listening (STT enforced) to listening (VAD detected)
            # and user state won't be updated until a new VAD SOS is received
            # reset VAD so that incorrect end of turn from STT can be corrected by VAD interruption
            # if user is still speaking (an immediate VAD SOS will interrupt the agent)
            if self._vad:
                if self._speaking:
                    logger.warning(
                        "stt end of speech received while user is speaking, resetting vad"
                    )
                self.update_vad(self._vad)

            self._speaking = False
            self._user_turn_committed = True
            if not self._vad or self._last_speaking_time is None:
                self._last_speaking_time = time.time()

            chat_ctx = self._hooks.retrieve_chat_ctx().copy()
            self._run_eou_detection(chat_ctx)

        elif ev.type == stt.SpeechEventType.START_OF_SPEECH and self._turn_detection_mode == "stt":
            # If the plugin provided a server onset timestamp, use it;
            # otherwise fall back to message arrival time.
            if self._speech_start_time is None:
                self._speech_start_time = ev.speech_start_time or time.time()

            with trace.use_span(self._ensure_user_turn_span(start_time=self._speech_start_time)):
                self._hooks.on_start_of_speech(None, speech_start_time=self._speech_start_time)

            self._speaking = True
            self._last_speaking_time = time.time()

            if self._end_of_turn_task is not None:
                self._end_of_turn_task.cancel()

    @utils.log_exceptions(logger=logger)
    async def _on_vad_event(self, ev: vad.VADEvent) -> None:
        if ev.type == vad.VADEventType.START_OF_SPEECH:
            speech_start_time = time.time() - ev.speech_duration - ev.inference_duration
            if not self._vad_speech_started:
                self._speech_start_time = speech_start_time
                self._vad_speech_started = True

            with trace.use_span(self._ensure_user_turn_span(start_time=speech_start_time)):
                self._hooks.on_start_of_speech(ev, speech_start_time=speech_start_time)

            self._speaking = True

            if self._end_of_turn_task is not None:
                self._end_of_turn_task.cancel()

            if self._session.amd is not None:
                self._session.amd._on_user_speech_started()

        elif ev.type == vad.VADEventType.INFERENCE_DONE:
            self._hooks.on_vad_inference_done(ev)

            # for metrics, get the "earliest" signal of speech as possible
            if ev.raw_accumulated_speech > 0.0:
                self._last_speaking_time = time.time()

                if self._speech_start_time is None:
                    self._speech_start_time = time.time() - ev.raw_accumulated_speech

        elif ev.type == vad.VADEventType.END_OF_SPEECH:
            with trace.use_span(self._ensure_user_turn_span()):
                self._hooks.on_end_of_speech(ev)

            self._vad_speech_started = False
            self._speaking = False

            if self._vad_base_turn_detection or (
                self._turn_detection_mode == "stt" and self._user_turn_committed
            ):
                chat_ctx = self._hooks.retrieve_chat_ctx().copy()
                self._run_eou_detection(chat_ctx)

            if self._session.amd is not None:
                self._session.amd._on_user_speech_ended(ev.silence_duration)

    async def _on_overlap_speech_event(self, ev: inference.OverlappingSpeechEvent) -> None:
        if self.backchannel_boundary_active and not ev.is_interruption:
            logger.trace(
                "ignoring backchannel event during backchannel boundary cooldown, falling back to vad"
            )
            return

        if ev.is_interruption:
            self._hooks.on_interruption(ev)

    def _run_eou_detection(self, chat_ctx: llm.ChatContext, skip_reply: bool = False) -> None:
        if self._stt and not self._audio_transcript and self._turn_detection_mode != "manual":
            # stt enabled but no transcript yet
            return

        chat_ctx = chat_ctx.copy()
        chat_ctx.add_message(role="user", content=self._audio_transcript)
        turn_detector = (
            self._turn_detector
            if self._audio_transcript and self._turn_detection_mode != "manual"
            else None  # disable EOU model if manual turn detection enabled
        )

        @utils.log_exceptions(logger=logger)
        async def _bounce_eou_task(
            last_speaking_time: float | None = None,
            last_final_transcript_time: float | None = None,
            speech_start_time: float | None = None,
        ) -> None:
            endpointing_delay = self._endpointing.min_delay
            user_turn_span = self._ensure_user_turn_span()
            if turn_detector is not None:
                if not await turn_detector.supports_language(self._last_language):
                    logger.info("Turn detector does not support language %s", self._last_language)
                else:
                    with (
                        trace.use_span(user_turn_span),
                        tracer.start_as_current_span("eou_detection") as eou_detection_span,
                    ):
                        # if there are failures, we should not hold the pipeline up
                        end_of_turn_probability = 0.0
                        unlikely_threshold: float | None = None
                        try:
                            end_of_turn_probability = await turn_detector.predict_end_of_turn(
                                chat_ctx
                            )
                            unlikely_threshold = await turn_detector.unlikely_threshold(
                                self._last_language
                            )
                            if (
                                unlikely_threshold is not None
                                and end_of_turn_probability < unlikely_threshold
                            ):
                                endpointing_delay = self._endpointing.max_delay
                        except Exception:
                            logger.exception("Error predicting end of turn")

                        eou_detection_span.set_attributes(
                            {
                                trace_types.ATTR_CHAT_CTX: json.dumps(
                                    llm.ChatContext(chat_ctx.items[-_EOU_MAX_HISTORY_TURNS:])
                                    .copy(
                                        exclude_function_call=True,
                                        exclude_instructions=True,
                                        exclude_empty_message=True,
                                        exclude_handoff=True,
                                        exclude_config_update=True,
                                    )
                                    .to_dict(
                                        exclude_audio=True,
                                        exclude_image=True,
                                        exclude_timestamp=True,
                                        exclude_metrics=True,
                                    )
                                ),
                                trace_types.ATTR_EOU_PROBABILITY: end_of_turn_probability,
                                trace_types.ATTR_EOU_UNLIKELY_THRESHOLD: unlikely_threshold or 0,
                                trace_types.ATTR_EOU_DELAY: endpointing_delay,
                                trace_types.ATTR_EOU_LANGUAGE: self._last_language or "",
                            }
                        )

            extra_sleep = endpointing_delay
            if last_speaking_time:
                extra_sleep += last_speaking_time - time.time()

            if extra_sleep > 0:
                try:
                    await asyncio.wait_for(self._closing.wait(), timeout=extra_sleep)
                except asyncio.TimeoutError:
                    pass

            confidence_avg = (
                sum(self._final_transcript_confidence) / len(self._final_transcript_confidence)
                if self._final_transcript_confidence
                else 0
            )

            started_speaking_at = None
            stopped_speaking_at = None
            transcription_delay = None
            end_of_turn_delay = None

            # sometimes, we can't calculate the metrics because VAD was unreliable.
            # in this case, we just ignore the calculation, it's better than providing likely wrong values
            if (
                last_final_transcript_time is not None
                and last_speaking_time is not None
                and speech_start_time is not None
            ):
                started_speaking_at = speech_start_time
                stopped_speaking_at = last_speaking_time
                transcription_delay = max(last_final_transcript_time - last_speaking_time, 0)
                end_of_turn_delay = time.time() - last_speaking_time

            committed = self._hooks.on_end_of_turn(
                _EndOfTurnInfo(
                    skip_reply=skip_reply,
                    new_transcript=self._audio_transcript,
                    transcript_confidence=confidence_avg,
                    transcription_delay=transcription_delay or 0,
                    end_of_turn_delay=end_of_turn_delay,
                    started_speaking_at=started_speaking_at,
                    stopped_speaking_at=stopped_speaking_at,
                )
            )
            if committed:
                user_turn_span.set_attributes(
                    {
                        trace_types.ATTR_USER_TRANSCRIPT: self._audio_transcript,
                        trace_types.ATTR_TRANSCRIPT_CONFIDENCE: confidence_avg,
                        trace_types.ATTR_TRANSCRIPTION_DELAY: transcription_delay or 0,
                        trace_types.ATTR_END_OF_TURN_DELAY: end_of_turn_delay or 0,
                    }
                )
                if self._stt_request_ids:
                    user_turn_span.set_attribute(
                        trace_types.ATTR_PROVIDER_REQUEST_IDS, self._stt_request_ids
                    )
                user_turn_span.end()
                self._user_turn_span = None
                self._user_turn_start = None
                self._stt_request_ids = []

                # clear the transcript if the user turn was committed
                self._audio_transcript = ""
                self._final_transcript_confidence = []
                self._last_final_transcript_time = None
                # concurrent user speech might have changed it
                # only reset if there is no new speech
                if self._last_speaking_time == last_speaking_time:
                    self._speech_start_time = None
                    self._vad_speech_started = False
                    self._last_speaking_time = None

            self._user_turn_committed = False

        if self._end_of_turn_task is not None:
            # TODO(theomonnom): disallow cancel if the extra sleep is done
            self._end_of_turn_task.cancel()

        # copy the last_speaking_time before awaiting (the value can change)
        self._end_of_turn_task = asyncio.create_task(
            _bounce_eou_task(
                self._last_speaking_time,
                self._last_final_transcript_time,
                self._user_turn_start,
            )
        )

    def _check_user_turn_limit(self, transcript: str) -> None:
        """Check if the user turn exceeds configured limits.
        Called when a final transcript event is received."""
        opts = self._session.options.turn_handling["user_turn_limit"]
        max_words = opts.get("max_words")
        max_duration = opts.get("max_duration")

        if max_words is None and max_duration is None:
            return

        now = time.time()
        if self._turn_tracker.started_at is None:
            self._turn_tracker.started_at = self._speech_start_time or now

        words = self._word_tokenizer.tokenize(transcript)
        self._turn_tracker.words += len(words)
        self._turn_tracker.transcript = f"{self._turn_tracker.transcript} {transcript}".strip()

        duration = now - self._turn_tracker.started_at
        time_exceeded = max_duration is not None and duration >= max_duration
        words_exceeded = max_words is not None and self._turn_tracker.words >= max_words

        if not time_exceeded and not words_exceeded:
            return

        ev = UserTurnExceededEvent(
            transcript=self.current_transcript,
            accumulated_transcript=self._turn_tracker.transcript,
            accumulated_word_count=self._turn_tracker.words,
            duration=duration,
        )
        self._hooks.on_user_turn_exceeded(ev)

    @utils.log_exceptions(logger=logger)
    async def _stt_consumer(
        self,
        event_ch: aio.Chan[stt.SpeechEvent],
        old_pipeline: _STTPipeline | None,
        old_consumer: asyncio.Task[None] | None,
    ) -> None:
        """Consume STT events from the pump. Swapped on handoff."""

        if old_pipeline is not None:
            await old_pipeline.aclose()

        if old_consumer is not None:
            await aio.cancel_and_wait(old_consumer)

        async for ev in event_ch:
            await self._on_stt_event(ev)

    @utils.log_exceptions(logger=logger)
    async def _vad_task(
        self,
        vad: vad.VAD,
        audio_input: AsyncIterable[rtc.AudioFrame],
        task: asyncio.Task[None] | None,
    ) -> None:
        if task is not None:
            await aio.cancel_and_wait(task)

        stream = vad.stream()

        @utils.log_exceptions(logger=logger)
        async def _forward() -> None:
            async for frame in audio_input:
                stream.push_frame(frame)

        forward_task = asyncio.create_task(_forward())

        try:
            async for ev in stream:
                await self._on_vad_event(ev)
        finally:
            await aio.cancel_and_wait(forward_task)
            await stream.aclose()

            # reset the speaking state to prevent stuck user speaking state during handoff
            if self._speaking:
                with trace.use_span(self._ensure_user_turn_span()):
                    self._hooks.on_end_of_speech(None)
                self._speaking = False
                self._vad_speech_started = False

    @utils.log_exceptions(logger=logger)
    async def _interruption_task(
        self,
        interruption_detection: inference.AdaptiveInterruptionDetector,
        audio_input: AsyncIterable[inference.InterruptionDataFrameType],
        task: asyncio.Task[None] | None,
    ) -> None:
        if task is not None:
            await aio.cancel_and_wait(task)

        stream = interruption_detection.stream()

        @utils.log_exceptions(logger=logger)
        async def _forward() -> None:
            async for frame in audio_input:
                stream.push_frame(frame)

        forward_task = asyncio.create_task(_forward())

        try:
            async for ev in stream:
                await self._on_overlap_speech_event(ev)
        except APIError:
            # avoid already emitted error from the stream
            return
        finally:
            await aio.cancel_and_wait(forward_task)
            await stream.aclose()

    def _ensure_user_turn_span(self, start_time: float | None = None) -> trace.Span:
        if self._user_turn_span and self._user_turn_span.is_recording():
            return self._user_turn_span

        if start_time is None:
            start_time = time.time()
        start_time_ns = int(start_time * 1_000_000_000)
        self._user_turn_span = tracer.start_span("user_turn", start_time=start_time_ns)

        if self._user_turn_start is None:
            self._user_turn_start = start_time

        if (room_io := self._session._room_io) and room_io.linked_participant:
            _set_participant_attributes(self._user_turn_span, room_io.linked_participant)

        # add STT model/provider attributes
        if self._stt_model:
            self._user_turn_span.set_attribute(
                trace_types.ATTR_GEN_AI_REQUEST_MODEL, self._stt_model
            )
        if self._stt_provider:
            self._user_turn_span.set_attribute(
                trace_types.ATTR_GEN_AI_PROVIDER_NAME, self._stt_provider
            )

        return self._user_turn_span
