from __future__ import annotations

import asyncio
import base64
import json
import os
import weakref
from dataclasses import dataclass, replace
from typing import Any, Literal, TypedDict, overload

import aiohttp
from typing_extensions import Required

from livekit import rtc

from .. import stt, utils, vad
from .._exceptions import (
    APIConnectionError,
    APIStatusError,
    APITimeoutError,
    create_api_error_from_http,
)
from ..language import LanguageCode
from ..log import logger
from ..types import (
    DEFAULT_API_CONNECT_OPTIONS,
    NOT_GIVEN,
    APIConnectOptions,
    NotGivenOr,
    TimedString,
)
from ..utils import is_given
from ._utils import create_access_token, get_default_inference_url, get_inference_headers

DeepgramModels = Literal[
    "deepgram/nova-3",
    "deepgram/nova-3-medical",
    "deepgram/nova-2",
    "deepgram/nova-2-medical",
    "deepgram/nova-2-conversationalai",
    "deepgram/nova-2-phonecall",
]
DeepgramFluxModels = Literal[
    "deepgram/flux-general",
    "deepgram/flux-general-en",
]
CartesiaModels = Literal["cartesia/ink-whisper",]
AssemblyAIModels = Literal[
    "assemblyai/universal-streaming",
    "assemblyai/universal-streaming-multilingual",
    "assemblyai/u3-rt-pro",
]
ElevenlabsModels = Literal["elevenlabs/scribe_v2_realtime",]
XaiModels = Literal["xai/stt-1",]
SpeechmaticsModels = Literal[
    "speechmatics/enhanced",
    "speechmatics/standard",
]


class CartesiaOptions(TypedDict, total=False):
    min_volume: float  # default: not specified
    max_silence_duration_secs: float  # default: not specified


class DeepgramOptions(TypedDict, total=False):
    filler_words: bool  # default: True
    interim_results: bool  # default: True
    endpointing: int  # default: 25 (ms)
    punctuate: bool  # default: True
    smart_format: bool
    keywords: list[tuple[str, float]]
    keyterm: str | list[str]
    profanity_filter: bool
    numerals: bool
    mip_opt_out: bool  # default: False
    vad_events: bool  # default: False
    diarize: bool  # when True, enables speaker diarization (default off)
    dictation: bool
    detect_language: bool
    no_delay: bool  # default: True
    utterance_end: bool
    redact: str | list[str]
    replace: str | list[str]
    search: str | list[str]
    tag: str | list[str]
    channels: int
    version: str
    callback: str
    callback_method: str
    extra: str


class DeepgramFluxOptions(TypedDict, total=False):
    eager_eot_threshold: float  # range 0.3-0.9, default: 0.5
    eot_threshold: float  # range 0.5-0.9
    eot_timeout_ms: int
    keyterm: str | list[str]
    mip_opt_out: bool  # default: False
    tag: str | list[str]
    detect_language: bool


class AssemblyaiOptions(TypedDict, total=False):
    format_turns: bool  # default: False
    end_of_turn_confidence_threshold: float  # default: 0.01
    min_end_of_turn_silence_when_confident: int  # default: 0
    max_turn_silence: int  # default: not specified
    keyterms_prompt: list[str]  # default: not specified
    language_detection: bool
    inactivity_timeout: float  # seconds
    prompt: str  # default: not specified (u3-rt-pro only, mutually exclusive with keyterms_prompt)
    speaker_labels: bool  # when True, enables speaker diarization (default off)


class ElevenlabsOptions(TypedDict, total=False):
    commit_strategy: Literal["manual", "vad"]
    include_timestamps: bool
    vad_silence_threshold_secs: float
    vad_threshold: float
    min_speech_duration_ms: int
    min_silence_duration_ms: int
    language_code: str


class SpeechmaticsOptions(TypedDict, total=False):
    domain: str  # e.g. "finance"
    output_locale: str  # BCP-47 locale for output formatting
    max_delay: float  # 0.7-4.0 seconds, default 1.0
    max_delay_mode: str  # "flexible" | "fixed"
    diarization: str  # "none" | "speaker" | "channel" | "channel_and_speaker_change" | "speaker_change"; non-"none" enables diarization
    speaker_sensitivity: float  # 0.0-1.0
    max_speakers: int
    prefer_current_speaker: bool
    enable_partials: bool  # default True (overridden by gateway)
    enable_entities: bool
    punctuation_overrides: dict[str, Any]
    additional_vocab: list[dict[str, Any]]
    end_of_utterance_silence_trigger: float  # seconds of silence before final
    audio_filtering_config: dict[str, Any]
    transcript_filtering_config: dict[str, Any]


class XaiOptions(TypedDict, total=False):
    diarize: bool  # when True, enables speaker diarization (default off)
    endpointing: int  # silence duration in ms before utterance-final (0-5000)
    format: bool  # enables Inverse Text Normalization (e.g. "one hundred dollars" -> "$100"); requires language
    interim_results: bool  # default True; set False to opt out of interim transcripts


# Diarization is requested via different extra_kwargs keys across
# providers. Keep this list in one place so adding a new provider is a
# single-line change and there's no divergence between __init__ and
# update_options capability inference.
_DIARIZATION_EXTRA_KEYS: tuple[str, ...] = (
    "diarize",  # Deepgram, xAI
    "speaker_labels",  # AssemblyAI
    "diarization",  # Speechmatics
)


def _diarization_enabled(extra_kwargs: dict[str, Any] | None) -> bool:
    """Return True if any known provider diarization flag is truthy."""
    if not extra_kwargs:
        return False
    for key in _DIARIZATION_EXTRA_KEYS:
        value = extra_kwargs.get(key)
        if not value:
            continue
        # Speechmatics' "diarization" accepts the string "none" to mean off.
        if isinstance(value, str) and value.lower() == "none":
            continue
        return True
    return False


STTLanguages = Literal["multi", "en", "de", "es", "fr", "ja", "pt", "zh", "hi"]


class FallbackModel(TypedDict, total=False):
    """Inference Fallback Adapter: configuration for a fallback STT model that runs server-side in LiveKit Inference, providing automatic fallback between providers.

    Extra fields are passed through to the provider.

    Example:
        >>> FallbackModel(model="deepgram/nova-3", extra_kwargs={"keyterm": ["livekit"]})
    """

    model: Required[str]
    """Model name (e.g. "deepgram/nova-3", "assemblyai/universal-streaming", "cartesia/ink-whisper")."""

    extra_kwargs: dict[str, Any]
    """Extra configuration for the model."""


FallbackModelType = FallbackModel | str


def _parse_model_string(model: str) -> tuple[str, NotGivenOr[LanguageCode]]:
    language: NotGivenOr[LanguageCode] = NOT_GIVEN
    if (idx := model.rfind(":")) != -1:
        language = LanguageCode(model[idx + 1 :])
        model = model[:idx]
    return model, language


def _resolve_vad_for_model(
    model: NotGivenOr[STTModels | str],
    vad_instance: vad.VAD | None,
) -> vad.VAD | None:
    is_speechmatics = (
        is_given(model) and isinstance(model, str) and model.startswith("speechmatics/")
    )
    if vad_instance is not None and not is_speechmatics:
        logger.warning(
            "`vad` will be ignored: model %r handles endpointing server-side.",
            model,
        )
        return None
    if is_speechmatics and vad_instance is None:
        try:
            from livekit.plugins.silero import VAD as SileroVAD
        except ImportError as e:
            raise ImportError(
                "livekit-plugins-silero is required: model "
                f"{model!r} does not handle endpointing server-side."
            ) from e
        vad_instance = SileroVAD.load()
    return vad_instance


def _normalize_fallback(
    fallback: list[FallbackModelType] | FallbackModelType,
) -> list[FallbackModel]:
    def _make_fallback(model: FallbackModelType) -> FallbackModel:
        if isinstance(model, str):
            name, _ = _parse_model_string(model)
            return FallbackModel(model=name)
        return model

    if isinstance(fallback, list):
        return [_make_fallback(m) for m in fallback]

    return [_make_fallback(fallback)]


STTModels = (
    DeepgramModels
    | DeepgramFluxModels
    | CartesiaModels
    | AssemblyAIModels
    | ElevenlabsModels
    | XaiModels
    | SpeechmaticsModels
    | Literal["auto"]  # automatically select a provider based on the language
)
STTEncoding = Literal["pcm_s16le"]


DEFAULT_ENCODING: STTEncoding = "pcm_s16le"
DEFAULT_SAMPLE_RATE: int = 16000


@dataclass
class STTOptions:
    model: NotGivenOr[STTModels | str]
    language: NotGivenOr[LanguageCode]
    encoding: STTEncoding
    sample_rate: int
    base_url: str
    api_key: str
    api_secret: str
    extra_kwargs: dict[str, Any]
    fallback: NotGivenOr[list[FallbackModel]]
    conn_options: NotGivenOr[APIConnectOptions]


class STT(stt.STT):
    @overload
    def __init__(
        self,
        model: CartesiaModels,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        base_url: NotGivenOr[str] = NOT_GIVEN,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        sample_rate: NotGivenOr[int] = NOT_GIVEN,
        api_key: NotGivenOr[str] = NOT_GIVEN,
        api_secret: NotGivenOr[str] = NOT_GIVEN,
        http_session: aiohttp.ClientSession | None = None,
        extra_kwargs: NotGivenOr[CartesiaOptions] = NOT_GIVEN,
        fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
        conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
    ) -> None: ...

    @overload
    def __init__(
        self,
        model: DeepgramModels,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        base_url: NotGivenOr[str] = NOT_GIVEN,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        sample_rate: NotGivenOr[int] = NOT_GIVEN,
        api_key: NotGivenOr[str] = NOT_GIVEN,
        api_secret: NotGivenOr[str] = NOT_GIVEN,
        http_session: aiohttp.ClientSession | None = None,
        extra_kwargs: NotGivenOr[DeepgramOptions] = NOT_GIVEN,
        fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
        conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
    ) -> None: ...

    @overload
    def __init__(
        self,
        model: DeepgramFluxModels,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        base_url: NotGivenOr[str] = NOT_GIVEN,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        sample_rate: NotGivenOr[int] = NOT_GIVEN,
        api_key: NotGivenOr[str] = NOT_GIVEN,
        api_secret: NotGivenOr[str] = NOT_GIVEN,
        http_session: aiohttp.ClientSession | None = None,
        extra_kwargs: NotGivenOr[DeepgramFluxOptions] = NOT_GIVEN,
        fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
        conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
    ) -> None: ...

    @overload
    def __init__(
        self,
        model: AssemblyAIModels,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        base_url: NotGivenOr[str] = NOT_GIVEN,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        sample_rate: NotGivenOr[int] = NOT_GIVEN,
        api_key: NotGivenOr[str] = NOT_GIVEN,
        api_secret: NotGivenOr[str] = NOT_GIVEN,
        http_session: aiohttp.ClientSession | None = None,
        extra_kwargs: NotGivenOr[AssemblyaiOptions] = NOT_GIVEN,
        fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
        conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
    ) -> None: ...

    @overload
    def __init__(
        self,
        model: ElevenlabsModels,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        base_url: NotGivenOr[str] = NOT_GIVEN,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        sample_rate: NotGivenOr[int] = NOT_GIVEN,
        api_key: NotGivenOr[str] = NOT_GIVEN,
        api_secret: NotGivenOr[str] = NOT_GIVEN,
        http_session: aiohttp.ClientSession | None = None,
        extra_kwargs: NotGivenOr[ElevenlabsOptions] = NOT_GIVEN,
        fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
        conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
    ) -> None: ...

    @overload
    def __init__(
        self,
        model: XaiModels,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        base_url: NotGivenOr[str] = NOT_GIVEN,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        sample_rate: NotGivenOr[int] = NOT_GIVEN,
        api_key: NotGivenOr[str] = NOT_GIVEN,
        api_secret: NotGivenOr[str] = NOT_GIVEN,
        http_session: aiohttp.ClientSession | None = None,
        extra_kwargs: NotGivenOr[XaiOptions] = NOT_GIVEN,
        fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
        conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
    ) -> None: ...

    @overload
    def __init__(
        self,
        model: SpeechmaticsModels,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        base_url: NotGivenOr[str] = NOT_GIVEN,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        sample_rate: NotGivenOr[int] = NOT_GIVEN,
        api_key: NotGivenOr[str] = NOT_GIVEN,
        api_secret: NotGivenOr[str] = NOT_GIVEN,
        http_session: aiohttp.ClientSession | None = None,
        extra_kwargs: NotGivenOr[SpeechmaticsOptions] = NOT_GIVEN,
        fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
        conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
        vad: NotGivenOr[vad.VAD | None] = NOT_GIVEN,
    ) -> None: ...

    @overload
    def __init__(
        self,
        model: str,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        base_url: NotGivenOr[str] = NOT_GIVEN,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        sample_rate: NotGivenOr[int] = NOT_GIVEN,
        api_key: NotGivenOr[str] = NOT_GIVEN,
        api_secret: NotGivenOr[str] = NOT_GIVEN,
        http_session: aiohttp.ClientSession | None = None,
        extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
        fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
        conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
    ) -> None: ...

    def __init__(
        self,
        model: NotGivenOr[STTModels | str] = NOT_GIVEN,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        base_url: NotGivenOr[str] = NOT_GIVEN,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        sample_rate: NotGivenOr[int] = NOT_GIVEN,
        api_key: NotGivenOr[str] = NOT_GIVEN,
        api_secret: NotGivenOr[str] = NOT_GIVEN,
        http_session: aiohttp.ClientSession | None = None,
        extra_kwargs: NotGivenOr[
            dict[str, Any]
            | CartesiaOptions
            | DeepgramOptions
            | DeepgramFluxOptions
            | AssemblyaiOptions
            | ElevenlabsOptions
            | XaiOptions
            | SpeechmaticsOptions
        ] = NOT_GIVEN,
        fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
        conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
        vad: NotGivenOr[vad.VAD | None] = NOT_GIVEN,
    ) -> None:
        """Livekit Cloud Inference STT

        Args:
            model (STTModels | str, optional): STT model to use, in "provider/model[:language]" format.
            language (str, optional): Language of the STT model.
            encoding (STTEncoding, optional): Encoding of the STT model.
            sample_rate (int, optional): Sample rate of the STT model.
            base_url (str, optional): LIVEKIT_URL, if not provided, read from environment variable.
            api_key (str, optional): LIVEKIT_API_KEY, if not provided, read from environment variable.
            api_secret (str, optional): LIVEKIT_API_SECRET, if not provided, read from environment variable.
            http_session (aiohttp.ClientSession, optional): HTTP session to use.
            extra_kwargs (dict, optional): Extra kwargs to pass to the STT model.
            fallback (FallbackModelType, optional): Fallback models - either a list of model names,
                a list of FallbackModel instances.
            conn_options (APIConnectOptions, optional): Connection options for request attempts.
            vad (VAD, optional): External Voice Activity Detector. When provided, each audio
                frame is forwarded to the VAD and `session.finalize` is sent to the inference
                gateway on end of speech. Only applicable to Speechmatics models.
        """
        # Infer diarization capability from provider-specific extra_kwargs
        # keys (see _DIARIZATION_EXTRA_KEYS). xAI uses "diarize" (same as
        # Deepgram); AssemblyAI uses "speaker_labels".
        diarization_enabled = _diarization_enabled(
            dict(extra_kwargs) if is_given(extra_kwargs) else None
        )

        # Parse language from model string if provided: "provider/model:language"
        if is_given(model) and isinstance(model, str):
            parsed_model, parsed_language = _parse_model_string(model)
            model = parsed_model
            if is_given(parsed_language) and not is_given(language):
                language = parsed_language

        vad = _resolve_vad_for_model(model, vad if is_given(vad) else None)

        super().__init__(
            capabilities=stt.STTCapabilities(
                streaming=True,
                interim_results=True,
                diarization=diarization_enabled,
                aligned_transcript="word",
                offline_recognize=False,
            ),
        )

        lk_base_url = base_url if is_given(base_url) else get_default_inference_url()

        lk_api_key = (
            api_key
            if is_given(api_key)
            else os.getenv("LIVEKIT_INFERENCE_API_KEY", os.getenv("LIVEKIT_API_KEY", ""))
        )
        if not lk_api_key:
            raise ValueError(
                "api_key is required, either as argument or set LIVEKIT_API_KEY environmental variable"
            )

        lk_api_secret = (
            api_secret
            if is_given(api_secret)
            else os.getenv("LIVEKIT_INFERENCE_API_SECRET", os.getenv("LIVEKIT_API_SECRET", ""))
        )
        if not lk_api_secret:
            raise ValueError(
                "api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable"
            )
        fallback_models: NotGivenOr[list[FallbackModel]] = NOT_GIVEN
        if is_given(fallback):
            fallback_models = _normalize_fallback(fallback)

        self._opts = STTOptions(
            model=model,
            language=LanguageCode(language) if isinstance(language, str) else language,
            encoding=encoding if is_given(encoding) else DEFAULT_ENCODING,
            sample_rate=sample_rate if is_given(sample_rate) else DEFAULT_SAMPLE_RATE,
            base_url=lk_base_url,
            api_key=lk_api_key,
            api_secret=lk_api_secret,
            extra_kwargs=dict(extra_kwargs) if is_given(extra_kwargs) else {},
            fallback=fallback_models,
            conn_options=conn_options if is_given(conn_options) else DEFAULT_API_CONNECT_OPTIONS,
        )

        self._session = http_session
        self._vad = vad
        self._streams = weakref.WeakSet[SpeechStream]()

    @classmethod
    def from_model_string(cls, model: str) -> STT:
        """Create a STT instance from a model string

        Args:
            model (str): STT model to use, in "provider/model[:language]" format

        Returns:
            STT: STT instance
        """
        model_name, language = _parse_model_string(model)
        return cls(model=model_name, language=language)

    @property
    def model(self) -> str:
        return self._opts.model if is_given(self._opts.model) else "unknown"

    @property
    def provider(self) -> str:
        return "livekit"

    def _ensure_session(self) -> aiohttp.ClientSession:
        if not self._session:
            self._session = utils.http_context.http_session()
        return self._session

    async def _recognize_impl(
        self,
        buffer: utils.AudioBuffer,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        conn_options: APIConnectOptions,
    ) -> stt.SpeechEvent:
        raise NotImplementedError(
            "LiveKit Inference STT does not support batch recognition, use stream() instead"
        )

    def stream(
        self,
        *,
        language: NotGivenOr[STTLanguages | str] = NOT_GIVEN,
        conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
    ) -> SpeechStream:
        """Create a streaming transcription session."""
        options = self._sanitize_options(language=language)
        stream = SpeechStream(
            stt=self,
            opts=options,
            conn_options=conn_options,
            vad_instance=self._vad,
        )
        self._streams.add(stream)
        return stream

    def update_options(
        self,
        *,
        model: NotGivenOr[STTModels | str] = NOT_GIVEN,
        language: NotGivenOr[STTLanguages | str] = NOT_GIVEN,
        extra: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
    ) -> None:
        """Update STT configuration options."""
        if is_given(model):
            # Mirror __init__: strip ":language" suffix and apply if not overridden.
            if isinstance(model, str):
                parsed_model, parsed_language = _parse_model_string(model)
                model = parsed_model
                if is_given(parsed_language) and not is_given(language):
                    language = parsed_language

            self._opts.model = model
            self._vad = _resolve_vad_for_model(model, self._vad)
        if is_given(language):
            self._opts.language = LanguageCode(language)
        if is_given(extra):
            self._opts.extra_kwargs.update(extra)
            self._capabilities = replace(
                self._capabilities,
                diarization=_diarization_enabled(self._opts.extra_kwargs),
            )

        for stream in self._streams:
            stream.update_options(model=model, language=language, extra=extra)

    def _sanitize_options(
        self, *, language: NotGivenOr[STTLanguages | str] = NOT_GIVEN
    ) -> STTOptions:
        """Create a sanitized copy of options with language override if provided."""
        options = replace(self._opts)
        options.extra_kwargs = dict(options.extra_kwargs)

        if is_given(language):
            options.language = LanguageCode(language)

        return options


class SpeechStream(stt.SpeechStream):
    def __init__(
        self,
        *,
        stt: STT,
        opts: STTOptions,
        conn_options: APIConnectOptions,
        vad_instance: vad.VAD | None = None,
    ) -> None:
        super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
        self._stt: STT = stt
        self._opts = opts
        self._request_id = str(utils.shortuuid("stt_request_"))

        self._speaking = False
        self._speech_duration: float = 0
        self._ws: aiohttp.ClientWebSocketResponse | None = None
        self._vad: vad.VAD | None = vad_instance

    def update_options(
        self,
        *,
        model: NotGivenOr[STTModels | str] = NOT_GIVEN,
        language: NotGivenOr[STTLanguages | str] = NOT_GIVEN,
        extra: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
    ) -> None:
        """Update streaming transcription options.

        When the WebSocket is live, a mid-stream session.update is sent so providers
        that support it (e.g. AssemblyAI, Deepgram Flux) can apply changes without
        reconnecting. Unsupported providers ignore the message.
        """
        if is_given(model):
            self._opts.model = model
        if is_given(language):
            self._opts.language = LanguageCode(language)
        if is_given(extra):
            self._opts.extra_kwargs.update(extra)

        has_update = is_given(model) or is_given(language) or is_given(extra)
        if has_update and self._ws is not None and not self._ws.closed:
            settings: dict[str, Any] = {}
            if is_given(model):
                settings["model"] = model
            if is_given(language):
                settings["language"] = str(LanguageCode(language))
            if is_given(extra):
                settings["extra"] = extra
            update_msg = {
                "type": "session.update",
                "settings": settings,
            }
            asyncio.ensure_future(self._send_session_update(update_msg))

    async def _send_session_update(self, msg: dict[str, Any]) -> None:
        try:
            if self._ws is not None and not self._ws.closed:
                await self._ws.send_str(json.dumps(msg))
        except Exception:
            logger.debug("failed to send session.update, ws may be closing")

    async def _run(self) -> None:
        """Main loop for streaming transcription."""
        closing_ws = False
        http_session = self._stt._ensure_session()
        vad_stream: vad.VADStream | None = self._vad.stream() if self._vad is not None else None

        @utils.log_exceptions(logger=logger)
        async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
            nonlocal closing_ws

            audio_bstream = utils.audio.AudioByteStream(
                sample_rate=self._opts.sample_rate,
                num_channels=1,
                samples_per_channel=self._opts.sample_rate // 20,  # 50ms
            )

            async for ev in self._input_ch:
                frames: list[rtc.AudioFrame] = []
                if isinstance(ev, rtc.AudioFrame):
                    if vad_stream is not None:
                        vad_stream.push_frame(ev)
                    frames.extend(audio_bstream.push(ev.data))
                elif isinstance(ev, self._FlushSentinel):
                    frames.extend(audio_bstream.flush())

                for frame in frames:
                    self._speech_duration += frame.duration
                    audio_bytes = frame.data.tobytes()
                    base64_audio = base64.b64encode(audio_bytes).decode("utf-8")
                    audio_msg = {
                        "type": "input_audio",
                        "audio": base64_audio,
                    }
                    await ws.send_str(json.dumps(audio_msg))

            if vad_stream is not None:
                vad_stream.end_input()

            closing_ws = True
            finalize_msg = {
                "type": "session.finalize",
            }
            await ws.send_str(json.dumps(finalize_msg))

        @utils.log_exceptions(logger=logger)
        async def vad_task(ws: aiohttp.ClientWebSocketResponse, stream: vad.VADStream) -> None:
            async for ev in stream:
                if ev.type != vad.VADEventType.END_OF_SPEECH:
                    continue
                if ws.closed:
                    return
                try:
                    await ws.send_str(json.dumps({"type": "session.finalize"}))
                except Exception:
                    logger.debug("failed to send session.finalize from VAD, ws may be closing")
                    return

        @utils.log_exceptions(logger=logger)
        async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
            nonlocal closing_ws
            while True:
                msg = await ws.receive()
                if msg.type in (
                    aiohttp.WSMsgType.CLOSED,
                    aiohttp.WSMsgType.CLOSE,
                    aiohttp.WSMsgType.CLOSING,
                ):
                    if closing_ws or http_session.closed:
                        return
                    raise APIStatusError(
                        message="LiveKit Inference STT connection closed unexpectedly"
                    )

                if msg.type != aiohttp.WSMsgType.TEXT:
                    logger.warning("unexpected LiveKit Inference STT message type %s", msg.type)
                    continue

                data = json.loads(msg.data)
                msg_type = data.get("type")
                if msg_type == "session.created":
                    pass
                elif msg_type == "interim_transcript":
                    self._process_transcript(data, is_final=False)
                elif msg_type == "preflight_transcript":
                    self._process_preflight_transcript(data)
                elif msg_type == "final_transcript":
                    self._process_transcript(data, is_final=True)
                elif msg_type == "session.finalized":
                    pass
                elif msg_type == "session.closed":
                    pass
                elif msg_type == "error":
                    raise APIStatusError(
                        f"LiveKit Inference STT returned error: {data.get('message')}",
                        status_code=data.get("code", -1),
                        body=data,
                    )

        ws: aiohttp.ClientWebSocketResponse | None = None
        try:
            ws = await self._connect_ws(http_session)
            self._ws = ws
            tasks = [
                asyncio.create_task(send_task(ws)),
                asyncio.create_task(recv_task(ws)),
            ]
            if vad_stream is not None:
                tasks.append(asyncio.create_task(vad_task(ws, vad_stream)))
            try:
                await asyncio.gather(*tasks)
            finally:
                await utils.aio.gracefully_cancel(*tasks)
        finally:
            self._ws = None
            if ws is not None:
                await ws.close()
            if vad_stream is not None:
                await vad_stream.aclose()

    async def _connect_ws(
        self, http_session: aiohttp.ClientSession
    ) -> aiohttp.ClientWebSocketResponse:
        """Connect to the LiveKit Inference STT WebSocket."""
        params: dict[str, Any] = {
            "settings": {
                "sample_rate": str(self._opts.sample_rate),
                "encoding": self._opts.encoding,
                "extra": self._opts.extra_kwargs,
            },
        }

        if self._opts.model and self._opts.model != "auto":
            params["model"] = self._opts.model

        if self._opts.language:
            params["settings"]["language"] = self._opts.language

        if self._opts.fallback:
            models = [
                {"model": m.get("model"), "extra": m.get("extra_kwargs")}
                for m in self._opts.fallback
            ]
            params["fallback"] = {"models": models}

        if self._opts.conn_options:
            params["connection"] = {
                "timeout": self._opts.conn_options.timeout,
                "retries": self._opts.conn_options.max_retry,
            }

        base_url = self._opts.base_url
        if base_url.startswith(("http://", "https://")):
            base_url = base_url.replace("http", "ws", 1)
        headers = {
            **get_inference_headers(),
            "Authorization": f"Bearer {create_access_token(self._opts.api_key, self._opts.api_secret)}",
        }
        try:
            ws = await asyncio.wait_for(
                http_session.ws_connect(
                    f"{base_url}/stt?model={self._opts.model}", headers=headers
                ),
                self._conn_options.timeout,
            )
            params["type"] = "session.create"
            await ws.send_str(json.dumps(params))
        except aiohttp.ClientResponseError as e:
            raise create_api_error_from_http(e.message, status=e.status) from e
        except asyncio.TimeoutError as e:
            raise APITimeoutError("LiveKit Inference STT connection timed out.") from e
        except aiohttp.ClientConnectorError as e:
            raise APIConnectionError("failed to connect to LiveKit Inference STT") from e
        return ws

    def _build_speech_data(self, data: dict) -> stt.SpeechData:
        language = LanguageCode(data.get("language", self._opts.language or "en"))
        words = data.get("words", []) or []
        # The gateway carries provider-specific data on the `extra` field
        # of the transcript message. We surface it on SpeechData.metadata
        extra = data.get("extra")
        metadata = extra if isinstance(extra, dict) and extra else None
        return stt.SpeechData(
            language=language,
            start_time=self.start_time_offset + data.get("start", 0),
            end_time=self.start_time_offset + data.get("start", 0) + data.get("duration", 0),
            confidence=data.get("confidence", 1.0),
            text=data.get("transcript", ""),
            speaker_id=data.get("speaker_id"),
            words=[
                TimedString(
                    text=word.get("word", ""),
                    start_time=word.get("start", 0) + self.start_time_offset,
                    end_time=word.get("end", 0) + self.start_time_offset,
                    start_time_offset=self.start_time_offset,
                    confidence=word.get("confidence", 0.0),
                    speaker_id=word.get("speaker_id"),
                )
                for word in words
            ],
            metadata=metadata,
        )

    def _process_preflight_transcript(self, data: dict) -> None:
        text = data.get("transcript", "")
        if not text or not self._speaking:
            return

        speech_data = self._build_speech_data(data)
        request_id = data.get("request_id", self._request_id)
        event = stt.SpeechEvent(
            type=stt.SpeechEventType.PREFLIGHT_TRANSCRIPT,
            request_id=request_id,
            alternatives=[speech_data],
        )
        self._event_ch.send_nowait(event)

    def _process_transcript(self, data: dict, is_final: bool) -> None:
        request_id = data.get("request_id", self._request_id)
        text = data.get("transcript", "")

        if not text and not is_final:
            return
        # We'll have a more accurate way of detecting when speech started when we have VAD
        if not self._speaking:
            self._speaking = True
            start_event = stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
            self._event_ch.send_nowait(start_event)

        speech_data = self._build_speech_data(data)

        if is_final:
            if self._speech_duration > 0:
                self._event_ch.send_nowait(
                    stt.SpeechEvent(
                        type=stt.SpeechEventType.RECOGNITION_USAGE,
                        request_id=request_id,
                        recognition_usage=stt.RecognitionUsage(
                            audio_duration=self._speech_duration,
                        ),
                    )
                )
                self._speech_duration = 0

            event = stt.SpeechEvent(
                type=stt.SpeechEventType.FINAL_TRANSCRIPT,
                request_id=request_id,
                alternatives=[speech_data],
            )
            self._event_ch.send_nowait(event)

            if self._speaking:
                self._speaking = False
                end_event = stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
                self._event_ch.send_nowait(end_event)
        else:
            event = stt.SpeechEvent(
                type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
                request_id=request_id,
                alternatives=[speech_data],
            )
            self._event_ch.send_nowait(event)
