from __future__ import annotations

import asyncio
import os
from dataclasses import dataclass
from typing import Any, Literal, cast

import httpx
import openai
from openai.types.chat import (
    ChatCompletionChunk,
    ChatCompletionMessageParam,
    ChatCompletionPredictionContentParam,
    ChatCompletionToolChoiceOptionParam,
    ChatCompletionToolParam,
    completion_create_params,
)
from openai.types.chat.chat_completion_chunk import Choice
from openai.types.shared.reasoning_effort import ReasoningEffort
from openai.types.shared_params import Metadata
from typing_extensions import TypedDict

from .. import llm
from .._exceptions import APIConnectionError, APIStatusError, APITimeoutError
from ..llm import ToolChoice, utils as llm_utils
from ..llm.chat_context import ChatContext
from ..llm.tool_context import Tool
from ..log import logger
from ..types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, APIConnectOptions, NotGivenOr
from ..utils import is_given
from ._utils import (
    HEADER_INFERENCE_PRIORITY,
    HEADER_INFERENCE_PROVIDER,
    create_access_token,
    get_default_inference_url,
    get_inference_headers,
)

lk_oai_debug = int(os.getenv("LK_OPENAI_DEBUG", 0))

# Reasoning models don't support sampling parameters.
# See: https://platform.openai.com/docs/guides/reasoning
_REASONING_UNSUPPORTED_PARAMS: set[str] = {
    "temperature",
    "top_p",
    "presence_penalty",
    "frequency_penalty",
    "logit_bias",
    "logprobs",
    "top_logprobs",
    "n",
}

# xAI reasoning models only restrict presence_penalty, frequency_penalty, stop.
# They still support temperature and top_p.
_XAI_REASONING_UNSUPPORTED_PARAMS: set[str] = {
    "presence_penalty",
    "frequency_penalty",
    "stop",
}

# Model prefix -> set of param names that should be dropped
_UNSUPPORTED_PARAMS: dict[str, set[str]] = {
    "o1": _REASONING_UNSUPPORTED_PARAMS,
    "o3": _REASONING_UNSUPPORTED_PARAMS,
    "o4": _REASONING_UNSUPPORTED_PARAMS,
    "gpt-5": _REASONING_UNSUPPORTED_PARAMS,
    "grok-4-1-fast-reasoning": _XAI_REASONING_UNSUPPORTED_PARAMS,
    "grok-4.20-0309-reasoning": _XAI_REASONING_UNSUPPORTED_PARAMS,
    "grok-4.20-multi-agent": _XAI_REASONING_UNSUPPORTED_PARAMS,
}

# models that don't support reasoning_effort when function tools are present
_REASONING_EFFORT_TOOL_INCOMPATIBLE_PREFIXES: set[str] = {"gpt-5.2", "gpt-5.4"}


def drop_unsupported_params(
    model: str, params: dict[str, Any], tools: list[Any] | None = None
) -> dict[str, Any]:
    """Remove parameters that are not supported by the given model.

    Strips any provider prefix (e.g. ``openai/o3-pro`` -> ``o3-pro``) before
    matching against known model prefixes.
    """
    model_name = model.split("/")[-1] if "/" in model else model
    for prefix, unsupported in _UNSUPPORTED_PARAMS.items():
        if model_name.startswith(prefix):
            params = {k: v for k, v in params.items() if k not in unsupported}
            break
    if tools and any(
        model_name.startswith(p) for p in _REASONING_EFFORT_TOOL_INCOMPATIBLE_PREFIXES
    ):
        params = {k: v for k, v in params.items() if k != "reasoning_effort"}
    return params


OpenAIModels = Literal[
    "openai/gpt-4o",
    "openai/gpt-4o-mini",
    "openai/gpt-4.1",
    "openai/gpt-4.1-mini",
    "openai/gpt-4.1-nano",
    "openai/gpt-5",
    "openai/gpt-5-mini",
    "openai/gpt-5-nano",
    "openai/gpt-5.1",
    "openai/gpt-5.1-chat-latest",
    "openai/gpt-5.2",
    "openai/gpt-5.2-chat-latest",
    "openai/gpt-5.3-chat-latest",
    "openai/gpt-5.4",
    "openai/gpt-5.4-mini",
    "openai/gpt-oss-120b",
]

GoogleModels = Literal[
    "google/gemini-3-pro",
    "google/gemini-3-flash",
    "google/gemini-2.5-pro",
    "google/gemini-2.5-flash",
    "google/gemini-2.5-flash-lite",
]

KimiModels = Literal["moonshotai/kimi-k2-instruct"]

DeepSeekModels = Literal[
    "deepseek-ai/deepseek-v3",
    "deepseek-ai/deepseek-v3.2",
]

XAIModels = Literal[
    "xai/grok-4-1-fast-non-reasoning",
    "xai/grok-4-1-fast-reasoning",
    "xai/grok-4.20-0309-non-reasoning",
    "xai/grok-4.20-0309-reasoning",
    "xai/grok-4.20-multi-agent-0309",
]

LLMModels = OpenAIModels | GoogleModels | KimiModels | DeepSeekModels | XAIModels

InferenceClass = Literal["priority", "standard"]


class ChatCompletionOptions(TypedDict, total=False):
    frequency_penalty: float | None
    logit_bias: dict[str, int] | None
    logprobs: bool | None
    max_completion_tokens: int | None
    max_tokens: int | None
    metadata: Metadata | None
    modalities: list[Literal["text", "audio"]] | None
    n: int | None
    parallel_tool_calls: bool
    prediction: ChatCompletionPredictionContentParam | None
    presence_penalty: float | None
    prompt_cache_key: str
    prompt_cache_retention: Literal["in_memory", "24h"] | None
    reasoning_effort: ReasoningEffort | None
    safety_identifier: str
    seed: int | None
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None
    stop: str | None | list[str] | None
    store: bool | None
    temperature: float | None
    top_logprobs: int | None
    top_p: float | None
    user: str
    verbosity: Literal["low", "medium", "high"] | None
    web_search_options: completion_create_params.WebSearchOptions

    # livekit-typed arguments
    tool_choice: ToolChoice
    # TODO(theomonnomn): support repsonse format
    # response_format: completion_create_params.ResponseFormat


@dataclass
class _LLMOptions:
    model: LLMModels | str
    provider: str | None
    base_url: str
    api_key: str
    api_secret: str
    inference_class: InferenceClass | None
    extra_kwargs: ChatCompletionOptions | dict[str, Any]


class LLM(llm.LLM):
    def __init__(
        self,
        model: LLMModels | str,
        *,
        provider: str | None = None,
        base_url: str | None = None,
        api_key: str | None = None,
        api_secret: str | None = None,
        inference_class: InferenceClass | None = None,
        extra_kwargs: ChatCompletionOptions | dict[str, Any] | None = None,
    ) -> None:
        super().__init__()

        lk_base_url = base_url if base_url else get_default_inference_url()

        lk_api_key = (
            api_key
            if api_key
            else os.getenv("LIVEKIT_INFERENCE_API_KEY", os.getenv("LIVEKIT_API_KEY", ""))
        )
        if not lk_api_key:
            raise ValueError(
                "api_key is required, either as argument or set LIVEKIT_API_KEY environmental variable"
            )

        lk_api_secret = (
            api_secret
            if api_secret
            else os.getenv("LIVEKIT_INFERENCE_API_SECRET", os.getenv("LIVEKIT_API_SECRET", ""))
        )
        if not lk_api_secret:
            raise ValueError(
                "api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable"
            )

        self._opts = _LLMOptions(
            model=model,
            provider=provider,
            base_url=lk_base_url,
            api_key=lk_api_key,
            api_secret=lk_api_secret,
            inference_class=inference_class,
            extra_kwargs=extra_kwargs or {},
        )
        self._client = openai.AsyncClient(
            api_key=create_access_token(self._opts.api_key, self._opts.api_secret),
            base_url=self._opts.base_url,
            http_client=httpx.AsyncClient(
                timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
                follow_redirects=True,
                limits=httpx.Limits(
                    max_connections=50, max_keepalive_connections=50, keepalive_expiry=120
                ),
            ),
        )

    async def aclose(self) -> None:
        await self._client.close()

    @classmethod
    def from_model_string(cls, model: str) -> LLM:
        """Create a LLM instance from a model string"""
        return cls(model)

    def update_options(
        self,
        *,
        model: NotGivenOr[LLMModels | str] = NOT_GIVEN,
        extra_kwargs: NotGivenOr[ChatCompletionOptions | dict[str, Any]] = NOT_GIVEN,
    ) -> None:
        """Update LLM configuration options.

        Each option is read on the next ``chat()`` call, so a swap
        takes effect on the agent's next turn without recreating the
        LLM. ``extra_kwargs`` *replaces* the persistent kwargs dict
        rather than merging — pass ``{}`` to clear it.
        """
        if is_given(model):
            self._opts.model = model
        if is_given(extra_kwargs):
            self._opts.extra_kwargs = dict(extra_kwargs)

    @property
    def model(self) -> str:
        """Get the model name for this LLM instance."""
        return self._opts.model

    @property
    def provider(self) -> str:
        return "livekit"

    def chat(
        self,
        *,
        chat_ctx: ChatContext,
        tools: list[Tool] | None = None,
        conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
        parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
        tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
        response_format: NotGivenOr[
            completion_create_params.ResponseFormat | type[llm_utils.ResponseFormatT]
        ] = NOT_GIVEN,
        inference_class: NotGivenOr[InferenceClass] = NOT_GIVEN,
        extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
    ) -> LLMStream:
        extra = {}
        if is_given(extra_kwargs):
            extra.update(extra_kwargs)

        parallel_tool_calls = (
            parallel_tool_calls
            if is_given(parallel_tool_calls)
            else self._opts.extra_kwargs.get("parallel_tool_calls", NOT_GIVEN)
        )
        if is_given(parallel_tool_calls):
            extra["parallel_tool_calls"] = parallel_tool_calls

        extra_tool_choice = self._opts.extra_kwargs.get("tool_choice", NOT_GIVEN)
        tool_choice = tool_choice if is_given(tool_choice) else extra_tool_choice
        if is_given(tool_choice):
            oai_tool_choice: ChatCompletionToolChoiceOptionParam
            if isinstance(tool_choice, dict):
                oai_tool_choice = {
                    "type": "function",
                    "function": {"name": tool_choice["function"]["name"]},
                }
                extra["tool_choice"] = oai_tool_choice
            elif tool_choice in ("auto", "required", "none"):
                oai_tool_choice = tool_choice
                extra["tool_choice"] = oai_tool_choice

        if is_given(response_format):
            extra["response_format"] = llm_utils.to_openai_response_format(response_format)  # type: ignore

        extra.update(self._opts.extra_kwargs)

        effective_inference_class = (
            inference_class if is_given(inference_class) else self._opts.inference_class
        )

        self._client.api_key = create_access_token(self._opts.api_key, self._opts.api_secret)
        return LLMStream(
            self,
            model=self._opts.model,
            provider=self._opts.provider,
            inference_class=effective_inference_class,
            strict_tool_schema=True,
            client=self._client,
            chat_ctx=chat_ctx,
            tools=tools or [],
            conn_options=conn_options,
            extra_kwargs=extra,
        )


class LLMStream(llm.LLMStream):
    def __init__(
        self,
        llm_v: LLM | llm.LLM,
        *,
        model: LLMModels | str,
        provider: str | None = None,
        inference_class: InferenceClass | None = None,
        strict_tool_schema: bool,
        client: openai.AsyncClient,
        chat_ctx: llm.ChatContext,
        tools: list[Tool],
        conn_options: APIConnectOptions,
        extra_kwargs: dict[str, Any],
        provider_fmt: str = "openai",  # used internally for chat_ctx format
    ) -> None:
        super().__init__(llm_v, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
        self._model = model
        self._provider = provider
        self._inference_class = inference_class
        self._provider_fmt = provider_fmt
        self._strict_tool_schema = strict_tool_schema
        self._client = client
        self._llm = llm_v
        self._extra_kwargs = drop_unsupported_params(model, extra_kwargs, tools=tools)
        self._tool_ctx = llm.ToolContext(tools)

    async def _run(self) -> None:
        # current function call that we're waiting for full completion (args are streamed)
        # (defined inside the _run method to make sure the state is reset for each run/attempt)
        self._oai_stream: openai.AsyncStream[ChatCompletionChunk] | None = None
        self._tool_call_id: str | None = None
        self._fnc_name: str | None = None
        self._fnc_raw_arguments: str | None = None
        self._tool_extra: dict[str, Any] | None = None
        self._tool_index: int | None = None
        retryable = True

        try:
            chat_ctx, _ = self._chat_ctx.to_provider_format(format=self._provider_fmt)
            tool_schemas = cast(
                list[ChatCompletionToolParam],
                self._tool_ctx.parse_function_tools("openai", strict=self._strict_tool_schema),
            )
            if lk_oai_debug:
                tool_choice = self._extra_kwargs.get("tool_choice", NOT_GIVEN)
                logger.debug(
                    "chat.completions.create",
                    extra={
                        "fnc_ctx": tool_schemas,
                        "tool_choice": tool_choice,
                        "chat_ctx": chat_ctx,
                    },
                )
            if not self._tools:
                # remove tool_choice from extra_kwargs if no tools are provided
                self._extra_kwargs.pop("tool_choice", None)

            extra_headers = self._extra_kwargs.setdefault("extra_headers", {})
            extra_headers.update(get_inference_headers())
            if self._provider:
                extra_headers[HEADER_INFERENCE_PROVIDER] = self._provider
            if self._inference_class:
                extra_headers[HEADER_INFERENCE_PRIORITY] = self._inference_class

            self._oai_stream = stream = await self._client.chat.completions.create(
                messages=cast(list[ChatCompletionMessageParam], chat_ctx),
                tools=tool_schemas or openai.omit,
                model=self._model,
                stream_options={"include_usage": True},
                stream=True,
                timeout=httpx.Timeout(self._conn_options.timeout),
                **self._extra_kwargs,
            )

            thinking = asyncio.Event()
            async with stream:
                async for chunk in stream:
                    for choice in chunk.choices:
                        chat_chunk = self._parse_choice(chunk.id, choice, thinking)
                        if chat_chunk is not None:
                            retryable = False
                            self._event_ch.send_nowait(chat_chunk)

                    if chunk.usage is not None:
                        retryable = False
                        tokens_details = chunk.usage.prompt_tokens_details
                        cached_tokens = tokens_details.cached_tokens if tokens_details else 0
                        usage_chunk = llm.ChatChunk(
                            id=chunk.id,
                            usage=llm.CompletionUsage(
                                completion_tokens=chunk.usage.completion_tokens,
                                prompt_tokens=chunk.usage.prompt_tokens,
                                prompt_cached_tokens=cached_tokens or 0,
                                total_tokens=chunk.usage.total_tokens,
                                service_tier=getattr(chunk, "service_tier", None),
                            ),
                        )
                        self._event_ch.send_nowait(usage_chunk)

        except openai.APITimeoutError:
            raise APITimeoutError(retryable=retryable) from None
        except openai.APIStatusError as e:
            raise APIStatusError(
                e.message,
                status_code=e.status_code,
                request_id=e.request_id,
                body=e.body,
                retryable=retryable,
            ) from None
        except Exception as e:
            raise APIConnectionError(retryable=retryable) from e

    def _parse_choice(
        self, id: str, choice: Choice, thinking: asyncio.Event
    ) -> llm.ChatChunk | None:
        delta = choice.delta

        # https://github.com/livekit/agents/issues/688
        # the delta can be None when using Azure OpenAI (content filtering)
        if delta is None:
            return None

        if delta.tool_calls:
            for tool in delta.tool_calls:
                if not tool.function:
                    continue

                call_chunk = None
                if self._tool_call_id and tool.id and tool.index != self._tool_index:
                    call_chunk = llm.ChatChunk(
                        id=id,
                        delta=llm.ChoiceDelta(
                            role="assistant",
                            content=delta.content,
                            tool_calls=[
                                llm.FunctionToolCall(
                                    arguments=self._fnc_raw_arguments or "",
                                    name=self._fnc_name or "",
                                    call_id=self._tool_call_id or "",
                                    extra=self._tool_extra,
                                )
                            ],
                        ),
                    )
                    self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None
                    self._tool_extra = None

                if tool.function.name:
                    self._tool_index = tool.index
                    self._tool_call_id = tool.id
                    self._fnc_name = tool.function.name
                    self._fnc_raw_arguments = tool.function.arguments or ""
                    # Extract extra from tool call (e.g., Google thought signatures)
                    self._tool_extra = getattr(tool, "extra_content", None)
                elif tool.function.arguments:
                    self._fnc_raw_arguments += tool.function.arguments  # type: ignore

                if call_chunk is not None:
                    return call_chunk

        if choice.finish_reason in ("tool_calls", "stop") and self._tool_call_id:
            finish_extra = getattr(delta, "extra_content", None)
            call_chunk = llm.ChatChunk(
                id=id,
                delta=llm.ChoiceDelta(
                    role="assistant",
                    content=delta.content,
                    extra=finish_extra,
                    tool_calls=[
                        llm.FunctionToolCall(
                            arguments=self._fnc_raw_arguments or "",
                            name=self._fnc_name or "",
                            call_id=self._tool_call_id or "",
                            extra=self._tool_extra,
                        )
                    ],
                ),
            )
            self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None
            self._tool_extra = None
            return call_chunk

        delta.content = llm_utils.strip_thinking_tokens(delta.content, thinking)

        # Extract extra from delta (e.g., Google thought signatures on text parts)
        delta_extra = getattr(delta, "extra_content", None)

        if not delta.content and not delta_extra:
            return None

        return llm.ChatChunk(
            id=id,
            delta=llm.ChoiceDelta(
                content=delta.content,
                role="assistant",
                extra=delta_extra,
            ),
        )
