from __future__ import annotations

import itertools
import json
from dataclasses import dataclass
from typing import Any

from livekit.agents import llm

from .utils import convert_mid_conversation_instructions, group_tool_calls


@dataclass
class BedrockFormatData:
    system_messages: list[str] | None


def to_chat_ctx(
    chat_ctx: llm.ChatContext, *, inject_dummy_user_message: bool = True
) -> tuple[list[dict], BedrockFormatData]:
    chat_ctx = convert_mid_conversation_instructions(chat_ctx)

    messages: list[dict] = []
    system_messages: list[str] = []
    current_role: str | None = None
    current_content: list[dict] = []

    for msg in itertools.chain(*(group.flatten() for group in group_tool_calls(chat_ctx))):
        if msg.type == "message" and msg.role == "system" and (text := msg.text_content):
            system_messages.append(text)
            continue

        if msg.type == "message":
            role = "assistant" if msg.role == "assistant" else "user"
        elif msg.type == "function_call":
            role = "assistant"
        elif msg.type == "function_call_output":
            role = "user"

        # if the effective role changed, finalize the previous turn.
        if role != current_role:
            if current_content and current_role is not None:
                messages.append({"role": current_role, "content": current_content})
            current_content = []
            current_role = role

        if msg.type == "message":
            for content in msg.content:
                if content and isinstance(content, str):
                    current_content.append({"text": content})
                elif isinstance(content, llm.ImageContent):
                    current_content.append(_build_image(content))
        elif msg.type == "function_call":
            current_content.append(
                {
                    "toolUse": {
                        "toolUseId": msg.call_id,
                        "name": msg.name,
                        "input": json.loads(msg.arguments or "{}"),
                    }
                }
            )
        elif msg.type == "function_call_output":
            current_content.append(
                {
                    "toolResult": {
                        "toolUseId": msg.call_id,
                        "content": [
                            {"json": msg.output}
                            if isinstance(msg.output, dict)
                            else {"text": msg.output}
                        ],
                        "status": "success",
                    }
                }
            )

    # Finalize the last message if there’s any content left
    if current_role is not None and current_content:
        messages.append({"role": current_role, "content": current_content})

    # Ensure the message list starts with a "user" message
    if inject_dummy_user_message and (not messages or messages[0]["role"] != "user"):
        messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})

    return messages, BedrockFormatData(system_messages=system_messages)


def _build_image(image: llm.ImageContent) -> dict:
    cache_key = "serialized_image"
    if cache_key not in image._cache:
        image._cache[cache_key] = llm.utils.serialize_image(image)
    img: llm.utils.SerializedImage = image._cache[cache_key]

    if img.external_url:
        raise ValueError("external_url is not supported by AWS Bedrock.")

    return {
        "image": {
            "format": "jpeg",
            "source": {"bytes": img.data_bytes},
        }
    }


def to_fnc_ctx(tool_ctx: llm.ToolContext) -> list[dict[str, Any]]:
    return [_build_tool_spec(tool) for tool in tool_ctx.function_tools.values()]


def _build_tool_spec(tool: llm.FunctionTool | llm.RawFunctionTool) -> dict:
    if isinstance(tool, llm.FunctionTool):
        fnc = llm.utils.build_legacy_openai_schema(tool, internally_tagged=True)
        return {
            "toolSpec": _strip_nones(
                {
                    "name": fnc["name"],
                    "description": fnc["description"] if fnc["description"] else None,
                    "inputSchema": {"json": fnc["parameters"] if fnc["parameters"] else {}},
                }
            )
        }
    elif isinstance(tool, llm.RawFunctionTool):
        info = tool.info
        return {
            "toolSpec": _strip_nones(
                {
                    "name": info.name,
                    "description": info.raw_schema.get("description", ""),
                    "inputSchema": {"json": info.raw_schema.get("parameters", {})},
                }
            )
        }
    else:
        raise ValueError("Invalid function tool")


def _strip_nones(d: dict) -> dict:
    return {k: v for k, v in d.items() if v is not None}
