from __future__ import annotations

from collections import OrderedDict
from dataclasses import dataclass, field

from livekit.agents import llm
from livekit.agents.log import logger

_DEFAULT_INLINE_INSTRUCTIONS_TEMPLATE = "<instructions>\n{content}\n</instructions>"


def convert_mid_conversation_instructions(
    chat_ctx: llm.ChatContext,
    *,
    role: llm.ChatRole = "user",
    template: str = _DEFAULT_INLINE_INSTRUCTIONS_TEMPLATE,
) -> llm.ChatContext:
    """Convert mid-conversation system messages to the given role to preserve their position.

    Preamble system messages (before any user/assistant turn) are kept as-is.
    Later ones are converted using the given role and template.
    """
    seen_non_system = False
    items: list[llm.ChatItem] = []

    for item in chat_ctx.items:
        if (
            item.type == "message"
            and item.role in ("system", "developer")
            and seen_non_system
            and (text := item.text_content)
        ):
            items.append(
                llm.ChatMessage(
                    id=item.id,
                    role=role,
                    content=[template.format(content=text)],
                    created_at=item.created_at,
                )
            )
        else:
            if not seen_non_system and (
                item.type in ("function_call", "function_call_output")
                or (item.type == "message" and item.role in ("assistant", "user"))
            ):
                seen_non_system = True

            items.append(item)

    return llm.ChatContext(items)


def group_tool_calls(chat_ctx: llm.ChatContext) -> list[_ChatItemGroup]:
    """Group chat items (messages, function calls, and function outputs)
    into coherent groups based on their item IDs and call IDs.

    Each group will contain:
    - Zero or one assistant message
    - Zero or more function/tool calls
    - The corresponding function/tool outputs matched by call_id

    User and system messages are placed in their own individual groups.

    Args:
        chat_ctx: The chat context containing all conversation items

    Returns:
        A list of _ChatItemGroup objects representing the grouped conversation
    """
    item_groups: dict[str, _ChatItemGroup] = OrderedDict()  # item_id to group of items
    tool_outputs: list[llm.FunctionCallOutput] = []
    for item in chat_ctx.items:
        if (item.type == "message" and item.role == "assistant") or item.type == "function_call":
            # only assistant messages and function calls can be grouped
            # For function calls, use group_id if available (for parallel function calls),
            # otherwise fall back to id-based grouping for backwards compatibility
            if item.type == "function_call" and item.group_id:
                group_id = item.group_id
            else:
                group_id = item.id.split("/")[0]
            if group_id not in item_groups:
                item_groups[group_id] = _ChatItemGroup().add(item)
            else:
                item_groups[group_id].add(item)
        elif item.type == "function_call_output":
            tool_outputs.append(item)
        else:
            item_groups[item.id] = _ChatItemGroup().add(item)

    # add tool outputs to their corresponding groups
    call_id_to_group: dict[str, _ChatItemGroup] = {
        tool_call.call_id: group for group in item_groups.values() for tool_call in group.tool_calls
    }
    for tool_output in tool_outputs:
        if tool_output.call_id not in call_id_to_group:
            logger.warning(
                "function output missing the corresponding function call, ignoring",
                extra={"call_id": tool_output.call_id, "tool_name": tool_output.name},
            )
            continue

        call_id_to_group[tool_output.call_id].add(tool_output)

    # validate that each group and remove invalid tool calls and tool outputs
    for group in item_groups.values():
        group.remove_invalid_tool_calls()

    return list(item_groups.values())


@dataclass
class _ChatItemGroup:
    message: llm.ChatMessage | None = None
    tool_calls: list[llm.FunctionCall] = field(default_factory=list)
    tool_outputs: list[llm.FunctionCallOutput] = field(default_factory=list)

    def add(self, item: llm.ChatItem) -> _ChatItemGroup:
        if item.type == "message":
            assert self.message is None, "only one message is allowed in a group"
            self.message = item
        elif item.type == "function_call":
            self.tool_calls.append(item)
        elif item.type == "function_call_output":
            self.tool_outputs.append(item)
        return self

    def remove_invalid_tool_calls(self) -> None:
        if len(self.tool_calls) == len(self.tool_outputs):
            return

        valid_call_ids = {call.call_id for call in self.tool_calls} & {
            output.call_id for output in self.tool_outputs
        }

        valid_tool_calls = []
        valid_tool_outputs = []

        for tool_call in self.tool_calls:
            if tool_call.call_id not in valid_call_ids:
                logger.warning(
                    "function call missing the corresponding function output, ignoring",
                    extra={"call_id": tool_call.call_id, "tool_name": tool_call.name},
                )
                continue
            valid_tool_calls.append(tool_call)

        for tool_output in self.tool_outputs:
            if tool_output.call_id not in valid_call_ids:
                logger.warning(
                    "function output missing the corresponding function call, ignoring",
                    extra={"call_id": tool_output.call_id, "tool_name": tool_output.name},
                )
                continue
            valid_tool_outputs.append(tool_output)

        self.tool_calls = valid_tool_calls
        self.tool_outputs = valid_tool_outputs

    def flatten(self) -> list[llm.ChatItem]:
        items: list[llm.ChatItem] = []
        if self.message:
            items.append(self.message)
        items.extend(self.tool_calls)
        items.extend(self.tool_outputs)
        return items
