#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/rf_detr/modular_rf_detr.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_rf_detr.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
import math
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from ... import initialization as init
from ...activations import ACT2CLS, ACT2FN
from ...backbone_utils import BackboneMixin, filter_output_hidden_states
from ...integrations import use_kernel_forward_from_hub
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, torch_compilable_check, torch_int
from ...utils.generic import ModelOutput, TransformersKwargs, can_return_tuple, merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from .configuration_rf_detr import RfDetrConfig, RfDetrDinov2Config


class RfDetrDinov2PatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self, config):
        super().__init__()
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size

        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        num_channels = pixel_values.shape[1]
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
                f" Expected {self.num_channels} but got {num_channels}."
            )
        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
        return embeddings


class RfDetrDinov2Embeddings(nn.Module):
    """
    Construct the CLS token, mask token, position and patch embeddings.
    """

    def __init__(self, config: RfDetrDinov2Config) -> None:
        super().__init__()

        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        if config.use_mask_token:
            self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
        self.patch_embeddings = RfDetrDinov2PatchEmbeddings(config)
        num_patches = self.patch_embeddings.num_patches
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.patch_size = config.patch_size
        self.use_mask_token = config.use_mask_token
        self.config = config

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        """

        num_patches = embeddings.shape[1] - 1
        num_positions = self.position_embeddings.shape[1] - 1

        # always interpolate when tracing to ensure the exported model works for dynamic input shapes
        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
            return self.position_embeddings

        class_pos_embed = self.position_embeddings[:, :1]
        patch_pos_embed = self.position_embeddings[:, 1:]

        dim = embeddings.shape[-1]

        new_height = height // self.patch_size
        new_width = width // self.patch_size

        sqrt_num_positions = torch_int(num_positions**0.5)
        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        target_dtype = patch_pos_embed.dtype
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.to(torch.float32),
            size=(new_height, new_width),
            mode="bicubic",
            # Difference from Dinov2, we use align_corners=False and antialias=True
            align_corners=False,
            antialias=True,
        ).to(dtype=target_dtype)

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

    def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor:
        batch_size, _, height, width = pixel_values.shape
        target_dtype = self.patch_embeddings.projection.weight.dtype
        embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))

        if bool_masked_pos is not None and self.use_mask_token:
            embeddings = torch.where(
                bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
            )

        # add the [CLS] token to the embedded patch tokens
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # add positional encoding to each token
        embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)

        # Difference from Dinov2, we use window partitioning
        if self.config.num_windows > 1:
            embeddings = self.window_partition(embeddings, height, width)
        embeddings = self.dropout(embeddings)

        return embeddings

    def window_partition(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        Splits each image's patch-token grid into num_windows^2 local windows,
        replicates the [CLS] token per window, and returns window-local token sequences
        """
        batch_size = embeddings.shape[0]
        num_windows = self.config.num_windows
        patch_size = self.patch_size
        num_height_patches = height // patch_size
        num_width_patches = width // patch_size
        num_width_patches_per_window = num_width_patches // num_windows
        num_height_patches_per_window = num_height_patches // num_windows

        # Split the embeddings into the [CLS] token and the pixel tokens
        cls_token_with_pos_embed = embeddings[:, :1]
        pixel_tokens_with_pos_embed = embeddings[:, 1:]
        pixel_tokens_with_pos_embed = pixel_tokens_with_pos_embed.view(
            batch_size, num_height_patches, num_width_patches, -1
        )

        # Reshape the pixel tokens into windowed pixel tokens
        windowed_pixel_tokens = pixel_tokens_with_pos_embed.view(
            batch_size, num_windows, num_width_patches_per_window, num_windows, num_height_patches_per_window, -1
        )
        windowed_pixel_tokens = windowed_pixel_tokens.transpose(2, 3)
        windowed_pixel_tokens = windowed_pixel_tokens.reshape(
            batch_size * num_windows**2, num_height_patches_per_window * num_width_patches_per_window, -1
        )

        # Repeat the [CLS] token per window
        windowed_cls_token_with_pos_embed = cls_token_with_pos_embed.repeat(num_windows**2, 1, 1)

        # Concatenate the [CLS] token with the windowed pixel tokens to get the final embeddings
        embeddings = torch.cat((windowed_cls_token_with_pos_embed, windowed_pixel_tokens), dim=1)
        return embeddings


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor | None,
    scaling: float | None = None,
    dropout: float = 0.0,
    **kwargs: Unpack[TransformersKwargs],
):
    if scaling is None:
        scaling = query.size(-1) ** -0.5

    # Take the dot product between "query" and "key" to get the raw attention scores.
    attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

    attn_output = torch.matmul(attn_weights, value)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


# Todo - Refactor as part of vision refactor. Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->RfDetrDinov2
class RfDetrDinov2SelfAttention(nn.Module):
    def __init__(self, config: RfDetrDinov2Config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        self.config = config
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.dropout_prob = config.attention_probs_dropout_prob
        self.scaling = self.attention_head_size**-0.5
        self.is_causal = False

        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)

    def forward(
        self,
        hidden_states: torch.Tensor,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size = hidden_states.shape[0]
        new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size

        key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
        value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
        query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        context_layer, attention_probs = attention_interface(
            self,
            query_layer,
            key_layer,
            value_layer,
            None,
            is_causal=self.is_causal,
            scaling=self.scaling,
            dropout=0.0 if not self.training else self.dropout_prob,
            **kwargs,
        )

        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.reshape(new_context_layer_shape)

        return context_layer, attention_probs


# Todo - Refactor as part of vision refactor. Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->RfDetrDinov2
class RfDetrDinov2SelfOutput(nn.Module):
    """
    The residual connection is defined in RfDetrDinov2Layer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """

    def __init__(self, config: RfDetrDinov2Config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states


# Todo - Refactor as part of vision refactor. Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->RfDetrDinov2
class RfDetrDinov2Attention(nn.Module):
    def __init__(self, config: RfDetrDinov2Config):
        super().__init__()
        self.attention = RfDetrDinov2SelfAttention(config)
        self.output = RfDetrDinov2SelfOutput(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        self_attn_output, _ = self.attention(hidden_states, **kwargs)
        output = self.output(self_attn_output, hidden_states)
        return output


class RfDetrDinov2LayerScale(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        return hidden_state * self.lambda1


class RfDetrDinov2MLP(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        in_features = out_features = config.hidden_size
        hidden_features = int(config.hidden_size * config.mlp_ratio)
        self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
        if isinstance(config.hidden_act, str):
            self.activation = ACT2FN[config.hidden_act]
        else:
            self.activation = config.hidden_act
        self.fc2 = nn.Linear(hidden_features, out_features, bias=True)

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        hidden_state = self.fc1(hidden_state)
        hidden_state = self.activation(hidden_state)
        hidden_state = self.fc2(hidden_state)
        return hidden_state


class RfDetrDinov2SwiGLUFFN(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        in_features = out_features = config.hidden_size
        hidden_features = int(config.hidden_size * config.mlp_ratio)
        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8

        self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
        self.weights_out = nn.Linear(hidden_features, out_features, bias=True)

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        hidden_state = self.weights_in(hidden_state)
        x1, x2 = hidden_state.chunk(2, dim=-1)
        hidden = nn.functional.silu(x1) * x2
        return self.weights_out(hidden)


class RfDetrDinov2DropPath(nn.Module):
    """Stochastic depth (DropPath) per sample, for residual blocks.

    Identity when ``drop_prob`` is 0 or outside training. See `Deep Networks with Stochastic Depth
    <https://arxiv.org/abs/1603.09382>`_.
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return hidden_states
        keep_prob = 1 - self.drop_prob
        shape = (hidden_states.shape[0],) + (1,) * (hidden_states.ndim - 1)
        random_tensor = torch.rand(shape, dtype=hidden_states.dtype, device=hidden_states.device)
        random_tensor = torch.floor(random_tensor + keep_prob)
        return hidden_states.div(keep_prob) * random_tensor

    def extra_repr(self) -> str:
        return f"p={self.drop_prob}"


class RfDetrDinov2Layer(GradientCheckpointingLayer):
    """This corresponds to the Block class in the original implementation."""

    def __init__(self, config: RfDetrDinov2Config, layer_idx: int) -> None:
        super().__init__()

        self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.attention = RfDetrDinov2Attention(config)
        self.layer_scale1 = RfDetrDinov2LayerScale(config)
        self.drop_path = RfDetrDinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()

        self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        if config.use_swiglu_ffn:
            self.mlp = RfDetrDinov2SwiGLUFFN(config)
        else:
            self.mlp = RfDetrDinov2MLP(config)
        self.layer_scale2 = RfDetrDinov2LayerScale(config)
        self.num_windows = config.num_windows
        self.global_attention = layer_idx not in config.window_block_indexes

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states

        # Difference from Dinov2, when the layer is not a window block, we need to unpartition the hidden states before the attention
        if self.global_attention:
            hidden_states = self.window_unpartition_before_attention(hidden_states)

        hidden_states_norm = self.norm1(hidden_states)
        self_attention_output = self.attention(hidden_states_norm)

        # And reverse the operation after the attention
        if self.global_attention:
            self_attention_output = self.window_partition_after_attention(hidden_states.shape, self_attention_output)

        self_attention_output = self.layer_scale1(self_attention_output)

        # first residual connection
        hidden_states = self.drop_path(self_attention_output) + residual
        residual = hidden_states

        # in Dinov2, layernorm is also applied after self-attention
        hidden_states = self.norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.layer_scale2(hidden_states)

        # second residual connection
        hidden_states = self.drop_path(hidden_states) + residual

        return hidden_states

    def window_unpartition_before_attention(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        For layers configured to use global attention, merges the window-batched sequences back
        into one sequence per image so attention can be computed across all windows jointly.
        """
        batch_size, seq_len, channels = hidden_states.shape
        num_windows_squared = self.num_windows**2
        hidden_states = hidden_states.view(batch_size // num_windows_squared, num_windows_squared * seq_len, channels)
        return hidden_states

    def window_partition_after_attention(
        self, hidden_state_shape: tuple[int, int, int], self_attention_output: torch.Tensor
    ) -> torch.Tensor:
        """
        After global attention, reshapes the output sequence back into window-batched
        form so the model can continue in the same windowed pipeline.
        """

        batch_size, seq_len, channels = hidden_state_shape
        num_windows_squared = self.num_windows**2
        self_attention_output = self_attention_output.view(
            batch_size * num_windows_squared, seq_len // num_windows_squared, channels
        )
        return self_attention_output


@auto_docstring
class RfDetrDinov2PreTrainedModel(PreTrainedModel):
    config: RfDetrDinov2Config
    base_model_prefix = "rf_detr_dinov2"
    main_input_name = "pixel_values"
    input_modalities = ("image",)
    supports_gradient_checkpointing = True
    _no_split_modules = ["RfDetrDinov2Layer"]
    _supports_sdpa = True
    _supports_flash_attn = True
    _supports_flex_attn = True
    _supports_attention_backend = True
    _can_record_outputs = {
        "hidden_states": RfDetrDinov2Layer,
        "attentions": RfDetrDinov2SelfAttention,
    }

    @torch.no_grad()
    def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            init.zeros_(module.bias)
            init.ones_(module.weight)
        elif isinstance(module, RfDetrDinov2Embeddings):
            init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
            init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
            if self.config.use_mask_token:
                init.zeros_(module.mask_token)
        elif isinstance(module, RfDetrDinov2LayerScale):
            init.constant_(module.lambda1, self.config.layerscale_value)


class RfDetrDinov2Encoder(RfDetrDinov2PreTrainedModel):
    def __init__(self, config: RfDetrDinov2Config):
        super().__init__(config)
        self.layer = nn.ModuleList([RfDetrDinov2Layer(config, i) for i in range(config.num_hidden_layers)])
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs(tie_last_hidden_states=False)
    def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BaseModelOutput:
        for layer_module in self.layer:
            hidden_states = layer_module(hidden_states)

        return BaseModelOutput(last_hidden_state=hidden_states)


@auto_docstring(
    custom_intro="""
    RfDetrDinov2 backbone, to be used with frameworks like DETR and MaskFormer.
    """
)
class RfDetrDinov2Backbone(BackboneMixin, RfDetrDinov2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
        self.embeddings = RfDetrDinov2Embeddings(config)
        self.encoder = RfDetrDinov2Encoder(config)

        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self) -> RfDetrDinov2PatchEmbeddings:
        return self.embeddings.patch_embeddings

    @can_return_tuple
    @filter_output_hidden_states
    @auto_docstring
    def forward(
        self,
        pixel_values: torch.Tensor,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BackboneOutput:
        r"""
        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
        >>> model = AutoBackbone.from_pretrained(
        ...     "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 768, 16, 16]
        ```"""
        # Like Dinov2, we need to output the hidden states to extract the layers for the stages
        kwargs["output_hidden_states"] = True

        embedding_output = self.embeddings(pixel_values)
        output: BaseModelOutput = self.encoder(embedding_output, **kwargs)
        hidden_states = output.hidden_states

        feature_maps = ()
        for stage, hidden_state in zip(self.stage_names, hidden_states):
            if stage in self.out_features:
                if self.config.apply_layernorm:
                    hidden_state = self.layernorm(hidden_state)
                if self.config.reshape_hidden_states:
                    hidden_state = hidden_state[:, 1:]
                    # this was actually a bug in the original implementation that we copied here,
                    # cause normally the order is height, width
                    batch_size, _, height, width = pixel_values.shape
                    num_h_patches = height // self.config.patch_size
                    num_w_patches = width // self.config.patch_size

                    # Difference from Dinov2, when the layer is not a window block, we need to unpartition the hidden states before reshaping
                    if self.config.num_windows > 1:
                        hidden_state = self.window_unpartition(hidden_state, height, width)

                    hidden_state = hidden_state.reshape(batch_size, num_h_patches, num_w_patches, -1)
                    hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()

                feature_maps += (hidden_state,)

        return BackboneOutput(
            feature_maps=tuple(feature_maps),
            hidden_states=hidden_states,
            attentions=output.attentions,
        )

    def window_unpartition(self, hidden_state: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        Reassembles windowed patch tokens into their original 2D patch layout (image-level grid structure)
        before converting backbone hidden states into spatial feature maps.
        """
        num_windows = self.config.num_windows
        patch_size = self.config.patch_size
        num_h_patches = height // patch_size
        num_w_patches = width // patch_size
        hidden_batch_size, seq_len, channels = hidden_state.shape
        num_windows_squared = num_windows**2
        num_h_patches_per_window = num_h_patches // num_windows
        num_w_patches_per_window = num_w_patches // num_windows

        # Reshape the hidden states into the original sequence length
        hidden_state = hidden_state.reshape(
            hidden_batch_size // num_windows_squared, num_windows_squared * seq_len, channels
        )
        hidden_state = hidden_state.view(
            hidden_batch_size // num_windows_squared,
            num_windows,
            num_windows,
            num_h_patches_per_window,
            num_w_patches_per_window,
            channels,
        )
        hidden_state = hidden_state.transpose(2, 3)
        return hidden_state


class RfDetrLayerNorm(nn.LayerNorm):
    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
    width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
        super().__init__(normalized_shape, eps=eps, **kwargs)
        if data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError(f"Unsupported data format: {data_format}")
        self.data_format = data_format

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
        """
        if self.data_format == "channels_first":
            features = features.permute(0, 2, 3, 1)
            features = super().forward(features)
            features = features.permute(0, 3, 1, 2)
        else:
            features = super().forward(features)
        return features


class RfDetrConvNormLayer(nn.Module):
    def __init__(
        self,
        config: RfDetrConfig,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        activation: str | None = None,
    ):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding=kernel_size // 2,
            bias=False,
        )
        self.norm = RfDetrLayerNorm(out_channels, data_format="channels_first", eps=config.layer_norm_eps)
        self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()

    def forward(self, hidden_state):
        hidden_state = self.conv(hidden_state)
        hidden_state = self.norm(hidden_state)
        hidden_state = self.activation(hidden_state)
        return hidden_state


class RfDetrRepVggBlock(nn.Module):
    def __init__(self, config: RfDetrConfig):
        super().__init__()
        hidden_channels = int(config.d_model * config.hidden_expansion)
        self.conv1 = RfDetrConvNormLayer(
            config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function
        )
        self.conv2 = RfDetrConvNormLayer(
            config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv1(x)
        y = self.conv2(y)
        return y


class RfDetrC2FLayer(nn.Module):
    # Inspired by RTDetrCSPRepLayer
    def __init__(self, config: RfDetrConfig, in_channels: int):
        super().__init__()
        num_blocks = config.c2f_num_blocks
        activation = config.activation_function
        out_channels = config.d_model

        self.hidden_channels = int(out_channels * config.hidden_expansion)

        conv1_out_channels = 2 * self.hidden_channels
        self.conv1 = RfDetrConvNormLayer(config, in_channels, conv1_out_channels, 1, 1, activation=activation)

        conv2_in_channels = (2 + num_blocks) * self.hidden_channels
        self.conv2 = RfDetrConvNormLayer(config, conv2_in_channels, out_channels, 1, 1, activation=activation)

        self.bottlenecks = nn.ModuleList(RfDetrRepVggBlock(config) for _ in range(num_blocks))

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.conv1(hidden_states)
        all_hidden_states = list(hidden_states.split(self.hidden_channels, 1))
        hidden_states = all_hidden_states[-1]
        hidden_states = hidden_states.contiguous()

        for bottleneck in self.bottlenecks:
            hidden_states = bottleneck(hidden_states)
            all_hidden_states.append(hidden_states)

        hidden_states = torch.cat(all_hidden_states, 1)
        hidden_states = self.conv2(hidden_states)
        return hidden_states


class RfDetrScaleProjector(nn.Module):
    def __init__(self, config: RfDetrConfig):
        super().__init__()
        projector_input_dim: int = config.backbone_config.hidden_size * len(config.backbone_config.out_indices)
        self.projector_layer = RfDetrC2FLayer(config, projector_input_dim)
        self.layer_norm = RfDetrLayerNorm(config.d_model, data_format="channels_first")

    def forward(self, hidden_states: tuple[torch.Tensor]) -> torch.Tensor:
        hidden_states = torch.cat(hidden_states, dim=1)
        hidden_states = self.projector_layer(hidden_states)
        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


class RfDetrConvEncoder(nn.Module):
    def __init__(self, config: RfDetrConfig):
        super().__init__()
        self.backbone = RfDetrDinov2Backbone(config.backbone_config)
        self.projector = RfDetrScaleProjector(config)

    def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
        # send pixel_values through the model to get list of feature maps
        features = self.backbone(pixel_values).feature_maps
        features = self.projector(features)
        mask = nn.functional.interpolate(pixel_mask[None].float(), size=features.shape[-2:]).to(torch.bool)[0]
        return features, mask


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class RfDetrAttention(nn.Module):
    """LW-DETR self-attention with group-DETR training technique."""

    def __init__(self, config: RfDetrConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.d_model // config.decoder_self_attention_heads)
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = False
        self.num_key_value_groups = 1

        self.q_proj = nn.Linear(
            config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(
            config.decoder_self_attention_heads * self.head_dim, config.d_model, bias=config.attention_bias
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size, seq_len, _ = hidden_states.shape

        hidden_states_original = hidden_states
        if position_embeddings is not None:
            hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings

        if self.training:
            # at training, we use group detr technique to add more supervision by using multiple weight-sharing decoders at once for faster convergence
            # at inference, we only use one decoder
            hidden_states_original = torch.cat(
                hidden_states_original.split(seq_len // self.config.group_detr, dim=1), dim=0
            )
            hidden_states = torch.cat(hidden_states.split(seq_len // self.config.group_detr, dim=1), dim=0)

        attention_input_shape = hidden_states.shape[:-1]
        hidden_shape = (*attention_input_shape, -1, self.head_dim)
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states_original).view(hidden_shape).transpose(1, 2)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask=None,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )
        attn_output = attn_output.reshape(*attention_input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)

        if self.training:
            attn_output = torch.cat(torch.split(attn_output, batch_size, dim=0), dim=1)

        return attn_output, attn_weights


@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
class MultiScaleDeformableAttention(nn.Module):
    def forward(
        self,
        value: Tensor,
        value_spatial_shapes: Tensor,
        value_spatial_shapes_list: list[tuple],
        level_start_index: Tensor,
        sampling_locations: Tensor,
        attention_weights: Tensor,
        im2col_step: int,
    ):
        batch_size, _, num_heads, hidden_dim = value.shape
        _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
        value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
        sampling_grids = 2 * sampling_locations - 1
        sampling_value_list = []
        for level_id, (height, width) in enumerate(value_spatial_shapes_list):
            # batch_size, height*width, num_heads, hidden_dim
            # -> batch_size, height*width, num_heads*hidden_dim
            # -> batch_size, num_heads*hidden_dim, height*width
            # -> batch_size*num_heads, hidden_dim, height, width
            value_l_ = (
                value_list[level_id]
                .flatten(2)
                .transpose(1, 2)
                .reshape(batch_size * num_heads, hidden_dim, height, width)
            )
            # batch_size, num_queries, num_heads, num_points, 2
            # -> batch_size, num_heads, num_queries, num_points, 2
            # -> batch_size*num_heads, num_queries, num_points, 2
            sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
            # batch_size*num_heads, hidden_dim, num_queries, num_points
            sampling_value_l_ = nn.functional.grid_sample(
                value_l_,
                sampling_grid_l_,
                mode="bilinear",
                padding_mode="zeros",
                align_corners=False,
            )
            sampling_value_list.append(sampling_value_l_)
        # (batch_size, num_queries, num_heads, num_levels, num_points)
        # -> (batch_size, num_heads, num_queries, num_levels, num_points)
        # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
        attention_weights = attention_weights.transpose(1, 2).reshape(
            batch_size * num_heads, 1, num_queries, num_levels * num_points
        )
        output = (
            (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
            .sum(-1)
            .view(batch_size, num_heads * hidden_dim, num_queries)
        )
        return output.transpose(1, 2).contiguous()


class RfDetrMultiscaleDeformableAttention(nn.Module):
    """
    Multiscale deformable attention as proposed in Deformable DETR.
    """

    def __init__(self, config: RfDetrConfig, num_heads: int, n_points: int):
        super().__init__()

        self.attn = MultiScaleDeformableAttention()

        if config.d_model % num_heads != 0:
            raise ValueError(
                f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
            )
        dim_per_head = config.d_model // num_heads
        # check if dim_per_head is power of 2
        if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
            warnings.warn(
                "You'd better set embed_dim (d_model) in RfDetrMultiscaleDeformableAttention to make the"
                " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
                " implementation."
            )

        self.im2col_step = 64

        self.d_model = config.d_model
        self.n_levels = config.num_feature_levels
        self.n_heads = num_heads
        self.n_points = n_points

        self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
        self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
        self.value_proj = nn.Linear(config.d_model, config.d_model)
        self.output_proj = nn.Linear(config.d_model, config.d_model)

        self.disable_custom_kernels = config.disable_custom_kernels

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        position_embeddings: torch.Tensor | None = None,
        reference_points=None,
        spatial_shapes=None,
        spatial_shapes_list=None,
        level_start_index=None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # add position embeddings to the hidden states before projecting to queries and keys
        if position_embeddings is not None:
            hidden_states = hidden_states + position_embeddings

        batch_size, num_queries, _ = hidden_states.shape
        batch_size, sequence_length, _ = encoder_hidden_states.shape
        total_elements = sum(height * width for height, width in spatial_shapes_list)
        torch_compilable_check(
            total_elements == sequence_length,
            "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
        )

        value = self.value_proj(encoder_hidden_states)
        if attention_mask is not None:
            # we invert the attention_mask
            value = value.masked_fill(~attention_mask[..., None], float(0))
        value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
        sampling_offsets = self.sampling_offsets(hidden_states).view(
            batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
        )
        attention_weights = self.attention_weights(hidden_states).view(
            batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
        )
        attention_weights = F.softmax(attention_weights, -1).view(
            batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
        )
        # batch_size, num_queries, n_heads, n_levels, n_points, 2
        num_coordinates = reference_points.shape[-1]
        if num_coordinates == 2:
            offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
            sampling_locations = (
                reference_points[:, :, None, :, None, :]
                + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
            )
        elif num_coordinates == 4:
            sampling_locations = (
                reference_points[:, :, None, :, None, :2]
                + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
            )
        else:
            raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")

        output = self.attn(
            value,
            spatial_shapes,
            spatial_shapes_list,
            level_start_index,
            sampling_locations,
            attention_weights,
            self.im2col_step,
        )

        output = self.output_proj(output)

        return output, attention_weights


class RfDetrMLP(nn.Module):
    def __init__(self, config: RfDetrConfig):
        super().__init__()
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.decoder_activation_function]
        self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        return hidden_states


class RfDetrDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: RfDetrConfig, layer_idx: int):
        nn.Module.__init__(self)

        # self-attention
        self.self_attn = RfDetrAttention(config, layer_idx=layer_idx)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.decoder_activation_function]
        self.activation_dropout = config.activation_dropout
        self.self_attn_layer_norm = nn.LayerNorm(config.d_model)

        # cross-attention
        self.cross_attn = RfDetrMultiscaleDeformableAttention(
            config,
            num_heads=config.decoder_cross_attention_heads,
            n_points=config.decoder_n_points,
        )
        self.cross_attn_layer_norm = nn.LayerNorm(config.d_model)

        # mlp
        self.mlp = RfDetrMLP(config)
        self.layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: torch.Tensor | None = None,
        reference_points=None,
        spatial_shapes=None,
        spatial_shapes_list=None,
        level_start_index=None,
        encoder_hidden_states: torch.Tensor | None = None,
        encoder_attention_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ):
        self_attention_output, self_attn_weights = self.self_attn(
            hidden_states, position_embeddings=position_embeddings, **kwargs
        )

        self_attention_output = nn.functional.dropout(self_attention_output, p=self.dropout, training=self.training)
        hidden_states = hidden_states + self_attention_output
        hidden_states = self.self_attn_layer_norm(hidden_states)

        cross_attention_output, cross_attn_weights = self.cross_attn(
            hidden_states=hidden_states,
            attention_mask=encoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            position_embeddings=position_embeddings,
            reference_points=reference_points,
            spatial_shapes=spatial_shapes,
            spatial_shapes_list=spatial_shapes_list,
            level_start_index=level_start_index,
            **kwargs,
        )
        cross_attention_output = nn.functional.dropout(cross_attention_output, p=self.dropout, training=self.training)
        hidden_states = hidden_states + cross_attention_output
        hidden_states = self.cross_attn_layer_norm(hidden_states)

        hidden_states = self.mlp(hidden_states)
        hidden_states = self.layer_norm(hidden_states)

        return hidden_states


@auto_docstring
class RfDetrPreTrainedModel(PreTrainedModel):
    config: RfDetrConfig
    base_model_prefix = "model"
    main_input_name = "pixel_values"
    input_modalities = ("image",)
    _no_split_modules = [
        r"RfDetrConvEncoder",
        r"RfDetrDecoderLayer",
    ]
    _supports_sdpa = True
    _supports_flash_attn = True
    _supports_flex_attn = True
    _supports_attention_backend = True
    _can_record_outputs = {
        "attentions": [RfDetrAttention, RfDetrMultiscaleDeformableAttention],
        "hidden_states": [RfDetrDecoderLayer],
    }
    # Roboflow checkpoints use bare keys with no top-level prefix
    _checkpoint_conversion_prefix_free = True

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)

        if isinstance(module, RfDetrMultiscaleDeformableAttention):
            init.constant_(module.sampling_offsets.weight, 0.0)
            thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads)
            grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
            grid_init = (
                (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
                .view(module.n_heads, 1, 1, 2)
                .repeat(1, module.n_levels, module.n_points, 1)
            )
            for i in range(module.n_points):
                grid_init[:, :, i, :] *= i + 1

            init.copy_(module.sampling_offsets.bias, grid_init.view(-1))
            init.constant_(module.attention_weights.weight, 0.0)
            init.constant_(module.attention_weights.bias, 0.0)
            init.xavier_uniform_(module.value_proj.weight)
            init.constant_(module.value_proj.bias, 0.0)
            init.xavier_uniform_(module.output_proj.weight)
            init.constant_(module.output_proj.bias, 0.0)
        if hasattr(module, "level_embed"):
            init.normal_(module.level_embed)
        if hasattr(module, "refpoint_embed") and module.refpoint_embed is not None:
            init.constant_(module.refpoint_embed.weight, 0)
        if hasattr(module, "class_embed") and module.class_embed is not None:
            prior_prob = 0.01
            bias_value = -math.log((1 - prior_prob) / prior_prob)
            init.constant_(module.class_embed.bias, bias_value)
        if hasattr(module, "bbox_embed") and module.bbox_embed is not None:
            init.constant_(module.bbox_embed.layers[-1].weight, 0)
            init.constant_(module.bbox_embed.layers[-1].bias, 0)
        if hasattr(module, "segmentation_bias") and isinstance(module.segmentation_bias, nn.Parameter):
            nn.init.constant_(module.segmentation_bias, 0.0)


@auto_docstring(
    custom_intro="""
    Base class for outputs of the RfDetr backbone-decoder model.
    """
)
@dataclass
class RfDetrModelOutput(ModelOutput):
    r"""
    init_reference_points (`torch.FloatTensor` of shape  `(batch_size, num_queries, 4)`):
        Initial reference points sent through the Transformer decoder.
    intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, d_model)`):
        Stacked intermediate hidden states (output of each layer of the decoder).
    intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
        Stacked intermediate reference points (reference points of each layer of the decoder).
    enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
        Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
        picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
        foreground and background).
    enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
        Logits of predicted bounding boxes coordinates in the first stage.
    backbone_features (list of `torch.FloatTensor` of shape `(batch_size, config.num_channels, config.image_size, config.image_size)`):
        Features from the backbone.
    """

    last_hidden_state: torch.FloatTensor | None = None
    init_reference_points: torch.FloatTensor | None = None
    intermediate_hidden_states: torch.FloatTensor | None = None
    intermediate_reference_points: torch.FloatTensor | None = None
    enc_outputs_class: torch.FloatTensor | None = None
    enc_outputs_coord_logits: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None
    cross_attentions: tuple[torch.FloatTensor, ...] | None = None
    backbone_features: list[torch.Tensor] = None


@auto_docstring(
    custom_intro="""
    Base class for outputs of the RfDetrDecoder. This class adds two attributes to
    BaseModelOutputWithCrossAttentions, namely:
    - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
    - a stacked tensor of intermediate reference points.
    """
)
@dataclass
class RfDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
    r"""
    cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
        used to compute the weighted average in the cross-attention heads.
    intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
        Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
        layernorm.
    intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
        Stacked intermediate reference points (reference points of each layer of the decoder).
    """

    intermediate_hidden_states: torch.FloatTensor | None = None

    intermediate_reference_points: torch.FloatTensor | None = None


def encode_sinusoidal_position_embedding(
    pos_tensor: torch.Tensor,
    num_pos_feats: int = 128,
    temperature: int = 10000,
) -> torch.Tensor:
    """Sinusoidal position embeddings from normalized anchor coordinates.

    Each coordinate in `pos_tensor` is independently encoded with ``num_pos_feats``
    interleaved sin/cos components; per-coordinate embeddings are concatenated.
    Handles 2-D ``(x, y)`` and N-D ``(x, y, w, h)`` inputs. For 2-D+ inputs the
    x and y embeddings are swapped to follow the DETR ``[pos_y, pos_x, ...]`` convention.

    Args:
        pos_tensor: Normalized coordinates in ``[0, 1]``, shape ``(..., n_coords)``.
        num_pos_feats: Embedding dimension per coordinate.
        temperature: Base for the frequency decay.

    Returns:
        Tensor of shape ``(..., n_coords * num_pos_feats)``, same dtype as input.
    """
    scale = 2 * math.pi
    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
    dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)

    coords = pos_tensor.unbind(-1)  # list of (...,) tensors
    embeddings = [coord[..., None] * scale / dim_t for coord in coords]  # each (..., num_pos_feats)
    embeddings = [
        torch.stack((e[..., 0::2].sin(), e[..., 1::2].cos()), dim=-1).flatten(-2) for e in embeddings
    ]  # each (..., num_pos_feats)

    if len(embeddings) >= 2:
        embeddings[0], embeddings[1] = embeddings[1], embeddings[0]

    return torch.cat(embeddings, dim=-1).to(pos_tensor.dtype)


class RfDetrMLPPredictionHead(nn.Module):
    """
    Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
    height and width of a bounding box w.r.t. an image.

    """

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


class RfDetrDecoder(RfDetrPreTrainedModel):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeformableDetrDecoderLayer`].

    The decoder updates the query embeddings through multiple self-attention and deformable cross-attention layers.

    Some tweaks for RfDetr:

    - it uses group detr technique at training for faster convergence.

    Args:
        config: RfDetrConfig
    """

    _can_record_outputs = {
        "hidden_states": RfDetrDecoderLayer,
        "attentions": OutputRecorder(RfDetrAttention, layer_name="self_attn", index=1),
        "cross_attentions": OutputRecorder(RfDetrMultiscaleDeformableAttention, layer_name="cross_attn", index=1),
    }

    def __init__(self, config: RfDetrConfig):
        super().__init__(config)
        self.dropout = config.dropout
        self.layers = nn.ModuleList([RfDetrDecoderLayer(config, i) for i in range(config.decoder_layers)])
        self.layernorm = nn.LayerNorm(config.d_model)

        self.gradient_checkpointing = False

        self.ref_point_head = RfDetrMLPPredictionHead(2 * config.d_model, config.d_model, config.d_model, num_layers=2)

        self.post_init()

    def get_reference(self, reference_points, valid_ratios):
        # batch_size, num_queries, batch_size, 4
        obj_center = reference_points[..., :4]

        # batch_size, num_queries, num_levels, 4
        reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]

        # batch_size, num_queries, d_model * 2
        query_sine_embed = encode_sinusoidal_position_embedding(
            reference_points_inputs[:, :, 0, :], num_pos_feats=self.config.d_model // 2
        )

        # batch_size, num_queries, d_model
        query_pos = self.ref_point_head(query_sine_embed)
        return reference_points_inputs, query_pos

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        inputs_embeds: torch.Tensor | None = None,
        reference_points: torch.Tensor | None = None,
        spatial_shapes: torch.Tensor | None = None,
        spatial_shapes_list: torch.Tensor | None = None,
        level_start_index: torch.Tensor | None = None,
        valid_ratios: torch.Tensor | None = None,
        encoder_hidden_states: torch.Tensor | None = None,
        encoder_attention_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ):
        intermediate = ()
        intermediate_reference_points = (reference_points,)

        if inputs_embeds is not None:
            hidden_states = inputs_embeds

        reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios)

        for idx, decoder_layer in enumerate(self.layers):
            hidden_states = decoder_layer(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                position_embeddings=query_pos,
                reference_points=reference_points_inputs,
                spatial_shapes=spatial_shapes,
                spatial_shapes_list=spatial_shapes_list,
                level_start_index=level_start_index,
                **kwargs,
            )
            intermediate_hidden_states = self.layernorm(hidden_states)
            intermediate += (intermediate_hidden_states,)

        intermediate = torch.stack(intermediate)
        last_hidden_state = intermediate[-1]
        intermediate_reference_points = torch.stack(intermediate_reference_points)

        return RfDetrDecoderOutput(
            last_hidden_state=last_hidden_state,
            intermediate_hidden_states=intermediate,
            intermediate_reference_points=intermediate_reference_points,
        )


def refine_bboxes(reference_points, deltas):
    reference_points = reference_points.to(deltas.device)
    new_reference_points_cxcy = deltas[..., :2] * reference_points[..., 2:] + reference_points[..., :2]
    new_reference_points_wh = deltas[..., 2:].exp() * reference_points[..., 2:]
    new_reference_points = torch.cat((new_reference_points_cxcy, new_reference_points_wh), -1)
    return new_reference_points


@auto_docstring(
    custom_intro="""
    The bare LW Detr Model (consisting of a backbone and decoder Transformer) outputting raw
    hidden-states without any specific head on top.
    """
)
class RfDetrModel(RfDetrPreTrainedModel):
    def __init__(self, config: RfDetrConfig):
        super().__init__(config)

        # Create backbone + positional encoding
        self.backbone = RfDetrConvEncoder(config)

        self.group_detr = config.group_detr
        self.num_queries = config.num_queries
        hidden_dim = config.d_model
        self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 4)
        self.query_feat = nn.Embedding(self.num_queries * self.group_detr, hidden_dim)

        self.decoder = RfDetrDecoder(config)

        self.enc_output = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(self.group_detr)])
        self.enc_output_norm = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(self.group_detr)])
        # Should normally be None and then instantiated in the ForObjectDetection class
        self.enc_out_bbox_embed = nn.ModuleList(
            [RfDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) for _ in range(self.group_detr)]
        )
        self.enc_out_class_embed = nn.ModuleList(
            [nn.Linear(config.d_model, config.num_labels) for _ in range(self.group_detr)]
        )
        self.d_model = config.d_model

        self.post_init()

    def freeze_backbone(self):
        for name, param in self.backbone.model.named_parameters():
            param.requires_grad_(False)

    def unfreeze_backbone(self):
        for name, param in self.backbone.model.named_parameters():
            param.requires_grad_(True)

    def get_valid_ratio(self, mask, dtype=torch.float32):
        """Get the valid ratio of all feature maps."""

        _, height, width = mask.shape
        valid_height = torch.sum(mask[:, :, 0], 1)
        valid_width = torch.sum(mask[:, 0, :], 1)
        valid_ratio_height = valid_height.to(dtype) / height
        valid_ratio_width = valid_width.to(dtype) / width
        valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
        return valid_ratio

    def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
        """Generate the encoder output proposals from encoded enc_output.

        Args:
            enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder.
            padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`.
            spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps.

        Returns:
            `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction.
                - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to
                  directly predict a bounding box. (without the need of a decoder)
                - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals in [0, 1] space.
                  Invalid positions (padding or out-of-bounds) are filled with 0.
                - invalid_mask (Tensor[batch_size, sequence_length, 1]): Boolean mask that is True for invalid positions
                  (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)).
        """
        batch_size = enc_output.shape[0]
        proposals = []
        _cur = 0
        for level, (height, width) in enumerate(spatial_shapes):
            mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1)
            valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
            valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)

            grid_y, grid_x = torch.meshgrid(
                torch.linspace(
                    0,
                    height - 1,
                    height,
                    dtype=enc_output.dtype,
                    device=enc_output.device,
                ),
                torch.linspace(
                    0,
                    width - 1,
                    width,
                    dtype=enc_output.dtype,
                    device=enc_output.device,
                ),
                indexing="ij",
            )
            grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)

            scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2)
            grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale
            width_height = torch.ones_like(grid) * 0.05 * (2.0**level)
            proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4)
            proposals.append(proposal)
            _cur += height * width
        output_proposals = torch.cat(proposals, 1)
        output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
        invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid
        output_proposals = output_proposals.masked_fill(invalid_mask, float(0))

        # assign each pixel as an object query
        object_query = enc_output
        object_query = object_query.masked_fill(invalid_mask, float(0))
        return object_query, output_proposals, invalid_mask

    @can_return_tuple
    @auto_docstring(
        custom_intro="""
    Forward pass of the RF-DETR model. The pipeline proceeds as follows:

        1. Generate an initial set of object query embeddings and spatial location proposals from the
           backbone's flattened output.
        2. Initialize storage for refined encoder-stage predictions (accommodating multi-group query
           structures) and iteratively refine object queries and their coordinates for each query group
           to capture the highest-confidence candidates from the encoder stage.
        3. Initialize learnable query features and spatial reference points (restricting to the primary
           group during inference for efficiency).
        4. Project the base reference points across the batch, refine them with the predicted coordinate
           refinements (shifting attention to the discovered object locations before decoding), and expand
           the target query features to match the batch dimensions.
        5. Pass the refined queries and updated reference points through the transformer decoder to
            aggregate detailed spatial context from the multi-scale features.
    """
    )
    def forward(
        self,
        pixel_values: torch.FloatTensor,
        pixel_mask: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> RfDetrModelOutput:
        r"""
        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, RfDetrModel
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("Roboflow/rf-detr-base")
        >>> model = RfDetrModel.from_pretrained("Roboflow/rf-detr-base")

        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)

        >>> last_hidden_states = outputs.last_hidden_state
        >>> list(last_hidden_states.shape)
        [1, 200, 256]
        ```"""
        batch_size, num_channels, height, width = pixel_values.shape
        device = pixel_values.device

        if pixel_mask is None:
            pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)

        features, mask = self.backbone(pixel_values, pixel_mask)

        source_flatten = features.flatten(2).transpose(1, 2)
        mask_flatten = mask.flatten(1)
        spatial_shapes_list = [features.shape[2:]]
        spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
        level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
        valid_ratios = self.get_valid_ratio(mask, dtype=source_flatten.dtype).unsqueeze(1)

        # Step 1.
        object_query_embedding, output_proposals, invalid_mask = self.gen_encoder_output_proposals(
            source_flatten, ~mask_flatten, spatial_shapes_list
        )

        # Step 2.
        group_detr = self.group_detr if self.training else 1
        topk = self.num_queries
        topk_coords_logits = torch.empty(
            (batch_size, topk * group_detr, 4), device=self.device, dtype=output_proposals.dtype
        )
        enc_outputs_coord_logits = torch.empty(
            (batch_size, topk * group_detr, 4), device=self.device, dtype=output_proposals.dtype
        )
        enc_outputs_class = torch.empty(
            (batch_size, topk * group_detr, self.config.d_model), device=self.device, dtype=output_proposals.dtype
        )
        for group_id in range(group_detr):
            object_query_undetach, group_topk_coords_logits, topk_coords_logits_undetach = (
                self.generate_topk_proposals(group_id, object_query_embedding, output_proposals, invalid_mask, topk)
            )
            topk_coords_logits[:, group_id * topk : (group_id + 1) * topk] = group_topk_coords_logits
            enc_outputs_coord_logits[:, group_id * topk : (group_id + 1) * topk] = topk_coords_logits_undetach
            enc_outputs_class[:, group_id * topk : (group_id + 1) * topk] = object_query_undetach

        # Step 3.
        if self.training:
            reference_points = self.reference_point_embed.weight
            query_feat = self.query_feat.weight
        else:
            reference_points = self.reference_point_embed.weight[: self.num_queries]
            query_feat = self.query_feat.weight[: self.num_queries]

        # Step 4.
        reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1)
        two_stage_len = enc_outputs_coord_logits.shape[-2]
        reference_points_two_stage_subset = reference_points[..., :two_stage_len, :]
        reference_points_subset = reference_points[..., two_stage_len:, :]
        reference_points_two_stage_subset = refine_bboxes(topk_coords_logits, reference_points_two_stage_subset)
        reference_points = torch.cat([reference_points_two_stage_subset, reference_points_subset], dim=-2)
        init_reference_points = reference_points
        target = query_feat.unsqueeze(0).expand(batch_size, -1, -1)

        # Step 5.
        decoder_outputs = self.decoder(
            inputs_embeds=target,
            reference_points=reference_points,
            spatial_shapes=spatial_shapes,
            spatial_shapes_list=spatial_shapes_list,
            level_start_index=level_start_index,
            valid_ratios=valid_ratios,
            encoder_hidden_states=source_flatten,
            encoder_attention_mask=mask_flatten,
            **kwargs,
        )

        return RfDetrModelOutput(
            init_reference_points=init_reference_points,
            last_hidden_state=decoder_outputs.last_hidden_state,
            intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
            intermediate_reference_points=decoder_outputs.intermediate_reference_points,
            backbone_features=features,
            enc_outputs_class=enc_outputs_class,
            enc_outputs_coord_logits=enc_outputs_coord_logits,
            hidden_states=decoder_outputs.hidden_states,
            attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
        )

    def generate_topk_proposals(
        self, group_id: int, object_query_embedding: Tensor, output_proposals: Tensor, invalid_mask: Tensor, topk: int
    ) -> tuple[Tensor, Tensor, Tensor]:
        """
        Generates and selects the top-k object query embeddings and bounding box proposals for a specific query group.

        The pipeline proceeds as follows:

        1. Project and normalize the base query embeddings for the specific query group.
        2. Predict classification scores and bounding box refinements for the current query features.
        3. Apply the predicted deltas to the initial proposals to obtain refined spatial coordinates.
        4. Identify the indices of the highest-confidence predictions and gather the refined coordinates for these
           top-k candidates (detached to prevent gradient flow back to the proposal generation stage).
        5. Gather the associated query features to be used as starting points for the decoder stage.
        """
        # Step 1.
        object_query = self.enc_output[group_id](object_query_embedding)
        object_query = self.enc_output_norm[group_id](object_query)

        # Step 2.
        enc_outputs_class_proposals = self.enc_out_class_embed[group_id](object_query)
        delta_bbox = self.enc_out_bbox_embed[group_id](object_query)
        enc_outputs_class_proposals = enc_outputs_class_proposals.masked_fill(
            invalid_mask.to(enc_outputs_class_proposals.device), float("-inf")
        )

        # Step 3.
        enc_outputs_coord = refine_bboxes(output_proposals, delta_bbox)

        # Step 4.
        topk_proposals = torch.topk(enc_outputs_class_proposals.max(-1)[0], topk, dim=1)[1]
        topk_coords_logits_undetach = torch.gather(
            enc_outputs_coord,
            1,
            topk_proposals.unsqueeze(-1).expand(-1, -1, 4),
        )
        topk_coords_logits = topk_coords_logits_undetach.detach()

        # Step 5.
        object_query_undetach = torch.gather(
            object_query, 1, topk_proposals.unsqueeze(-1).expand(-1, -1, self.config.d_model)
        )
        return object_query_undetach, topk_coords_logits, topk_coords_logits_undetach


@auto_docstring(
    custom_intro="""
    Output type of [`RfDetrForObjectDetection`].
    """
)
@dataclass
class RfDetrObjectDetectionOutput(ModelOutput):
    r"""
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
        Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
        bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
        scale-invariant IoU loss.
    loss_dict (`Dict`, *optional*):
        A dictionary containing the individual losses. Useful for logging.
    logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
        Classification logits (including no-object) for all queries.
    pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
        possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the
        unnormalized bounding boxes.
    auxiliary_outputs (`list[Dict]`, *optional*):
        Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
        and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
        `pred_boxes`) for each decoder layer.
    init_reference_points (`torch.FloatTensor` of shape  `(batch_size, num_queries, 4)`):
        Initial reference points sent through the Transformer decoder.
    intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, d_model)`):
        Stacked intermediate hidden states (output of each layer of the decoder).
    intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
        Stacked intermediate reference points (reference points of each layer of the decoder).
    enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
        Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
        picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
        foreground and background).
    enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
        Logits of predicted bounding boxes coordinates in the first stage.
    backbone_features (list of `torch.FloatTensor` of shape `(batch_size, config.num_channels, config.image_size, config.image_size)`):
        Features from the backbone.
    """

    loss: torch.FloatTensor | None = None
    loss_dict: dict | None = None
    logits: torch.FloatTensor | None = None
    pred_boxes: torch.FloatTensor | None = None
    auxiliary_outputs: list[dict] | None = None
    init_reference_points: torch.FloatTensor | None = None
    last_hidden_state: torch.FloatTensor | None = None
    intermediate_hidden_states: torch.FloatTensor | None = None
    intermediate_reference_points: torch.FloatTensor | None = None
    enc_outputs_class: Any = None
    enc_outputs_coord_logits: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None
    cross_attentions: tuple[torch.FloatTensor, ...] | None = None

    backbone_features: list[torch.Tensor] = None


@auto_docstring(
    custom_intro="""
    LW DETR Model (consisting of a backbone and decoder Transformer) with object detection heads on
    top, for tasks such as COCO detection.
    """
)
class RfDetrForObjectDetection(RfDetrPreTrainedModel):
    # When using clones, all layers > 0 will be clones, but layer 0 *is* required
    # We can't initialize the model on meta device as some weights are modified during the initialization
    _no_split_modules = None
    _tied_weights_keys = None

    def __init__(self, config: RfDetrConfig):
        super().__init__(config)
        self.model = RfDetrModel(config)
        self.class_embed = nn.Linear(config.d_model, config.num_labels)
        self.bbox_embed = RfDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)

        self.post_init()

    @can_return_tuple
    @auto_docstring(
        custom_intro="""
    The forward pass proceeds as follows:

        1. Process the visual input through the base RF-DETR model to obtain the transformer's last hidden state and
           the final sequence of reference points.
        2. First stage: Generate classification logits from the encoder's proposed object query embeddings.
        3. Second stage: Predict the final classification labels and refined bounding boxes using the decoder's last hidden state
           and the most recent reference points.

    """
    )
    def forward(
        self,
        pixel_values: torch.FloatTensor = None,
        pixel_mask: torch.LongTensor | None = None,
        labels: list[dict] | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> RfDetrObjectDetectionOutput:
        r"""
        labels (`list[Dict]` of len `(batch_size,)`, *optional*):
            Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
            following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
            respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
            in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, RfDetrForObjectDetection
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("Roboflow/rf-detr-base")
        >>> model = RfDetrForObjectDetection.from_pretrained("Roboflow/rf-detr-base")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
        >>> target_sizes = torch.tensor([image.size[::-1]])
        >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[
        ...     0
        ... ]
        >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(
        ...         f"Detected {model.config.id2label[label.item()]} with confidence "
        ...         f"{round(score.item(), 3)} at location {box}"
        ...     )
        Detected cat with confidence 0.8 at location [16.5, 52.84, 318.25, 470.78]
        Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25]
        Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25]
        ```"""
        # Step 1.
        outputs = self.model(pixel_values, pixel_mask=pixel_mask, **kwargs)

        last_hidden_states = outputs.last_hidden_state
        intermediate_reference_points = outputs.intermediate_reference_points

        # Step 2.
        enc_outputs_class_logits = self.predict_encoder_class_logits(outputs.enc_outputs_class)

        # Step 3.
        logits, pred_boxes = self.predict_class_and_boxes(last_hidden_states, intermediate_reference_points[-1])

        loss, loss_dict, auxiliary_outputs = None, None, None
        if labels is not None:
            outputs_class, outputs_coord = None, None
            if self.config.auxiliary_loss:
                outputs_class, outputs_coord = self.predict_class_and_boxes(
                    outputs.intermediate_hidden_states, intermediate_reference_points
                )

            loss, loss_dict, auxiliary_outputs = self.loss_function(
                logits,
                labels,
                self.device,
                pred_boxes,
                self.config,
                outputs_class,
                outputs_coord,
                enc_outputs_class_logits,
                outputs.enc_outputs_coord_logits,
            )

        return RfDetrObjectDetectionOutput(
            loss=loss,
            loss_dict=loss_dict,
            logits=logits,
            pred_boxes=pred_boxes,
            auxiliary_outputs=auxiliary_outputs,
            last_hidden_state=outputs.last_hidden_state,
            intermediate_hidden_states=outputs.intermediate_hidden_states,
            intermediate_reference_points=outputs.intermediate_reference_points,
            init_reference_points=outputs.init_reference_points,
            enc_outputs_class=enc_outputs_class_logits,
            enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
            backbone_features=outputs.backbone_features,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

    def predict_encoder_class_logits(self, enc_outputs_class: torch.Tensor) -> Tensor:
        """
        Predicts classification logits from encoder hidden states for each query group.
        """
        enc_outputs_class_list = enc_outputs_class.split(self.config.num_queries, dim=1)
        group_detr = self.config.group_detr if self.training else 1
        pred_class = [
            self.model.enc_out_class_embed[group_index](enc_outputs_class_list[group_index])
            for group_index in range(group_detr)
        ]
        return torch.cat(pred_class, dim=1)

    def predict_class_and_boxes(
        self, hidden_states: torch.Tensor, reference_points: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Predicts classification logits and refined bounding boxes from transformer hidden states and reference points.
        """
        logits = self.class_embed(hidden_states)
        boxes_delta = self.bbox_embed(hidden_states)
        boxes = refine_bboxes(reference_points, boxes_delta)
        return logits, boxes


@auto_docstring(
    custom_intro="""
    Output type of [`RfDetrForInstanceSegmentation`].
    """
)
@dataclass
class RfDetrInstanceSegmentationOutput(ModelOutput):
    r"""
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
        Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
        bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
        scale-invariant IoU loss.
    loss_dict (`Dict`, *optional*):
        A dictionary containing the individual losses. Useful for logging.
    logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
        Classification logits (including no-object) for all queries.
    pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
        possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the
        unnormalized bounding boxes.
    pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
        Segmentation masks logits for all queries. See also
        [`~RfDetrImageProcessor.post_process_instance_segmentation`] to obtain instance segmentation maps.
    auxiliary_outputs (`list[Dict]`, *optional*):
        Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
        and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
        `pred_boxes`) for each decoder layer.
    init_reference_points (`torch.FloatTensor` of shape  `(batch_size, num_queries, 4)`):
        Initial reference points sent through the Transformer decoder.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, d_model)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the decoder of the model.
    intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, d_model)`):
        Stacked intermediate hidden states (output of each layer of the decoder).
    intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
        Stacked intermediate reference points (reference points of each layer of the decoder).
    enc_outputs_mask_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, width, height)`, *optional*):
        Mask logits from the encoder for all queries.
    """

    loss: torch.FloatTensor | None = None
    loss_dict: dict | None = None
    logits: torch.FloatTensor | None = None
    pred_boxes: torch.FloatTensor | None = None
    pred_masks: torch.FloatTensor = None
    auxiliary_outputs: list[dict] | None = None
    init_reference_points: torch.FloatTensor | None = None
    last_hidden_state: torch.FloatTensor | None = None
    intermediate_hidden_states: torch.FloatTensor | None = None
    intermediate_reference_points: torch.FloatTensor | None = None
    enc_outputs_mask_logits: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None
    cross_attentions: tuple[torch.FloatTensor, ...] | None = None


class RfDetrSegmentationBlock(nn.Module):
    """This corresponds to the `Block` class in the original implementation.

    There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
    H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back

    The authors used (2) as they find it slightly faster in PyTorch.

    Args:
        config ([`RfDetrConfig`]): Model configuration class.
        dim (`int`): Number of input channels.
        drop_path (`float`): Stochastic depth rate. Default: 0.0.
    """

    def __init__(self, config: RfDetrConfig):
        super().__init__()
        dim = config.d_model
        self.layernorm = RfDetrLayerNorm(dim, eps=1e-6)
        self.act = ACT2FN[config.segmentation_head_activation_function]
        self.depthwise_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
        self.pointwise_conv = nn.Linear(dim, dim)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.depthwise_conv(hidden_states)
        hidden_states = hidden_states.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        hidden_states = self.layernorm(hidden_states)
        hidden_states = self.pointwise_conv(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = hidden_states.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
        hidden_states = hidden_states + residual
        return hidden_states


class RfDetrSegmentationMLP(nn.Module):
    def __init__(self, config: RfDetrConfig):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.segmentation_head_activation_function]
        self.fc1 = nn.Linear(config.d_model, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.d_model)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class RfDetrSegmentationMLPBlock(nn.Module):
    def __init__(self, config: RfDetrConfig):
        super().__init__()
        dim = config.d_model
        self.norm = nn.LayerNorm(dim)
        self.mlp = RfDetrSegmentationMLP(config)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = hidden_states + residual
        return hidden_states


class RfDetrForInstanceSegmentation(RfDetrPreTrainedModel):
    # When using clones, all layers > 0 will be clones, but layer 0 *is* required
    # We can't initialize the model on meta device as some weights are modified during the initialization
    _no_split_modules = None

    def __init__(self, config: RfDetrConfig):
        super().__init__(config)

        self.model = RfDetrForObjectDetection(config)

        num_blocks = config.decoder_layers
        self.downsample_ratio = config.mask_downsample_ratio
        self.blocks = nn.ModuleList([RfDetrSegmentationBlock(config) for _ in range(num_blocks)])
        self.spatial_features_proj = nn.Conv2d(config.d_model, config.d_model, kernel_size=1)

        self.query_features_block = RfDetrSegmentationMLPBlock(config)
        self.query_features_proj = nn.Linear(config.d_model, config.d_model)

        self.segmentation_bias = nn.Parameter(torch.zeros(1), requires_grad=True)

        self.post_init()

    def get_mask_logits(self, query_features: Tensor, spatial_features: Tensor) -> Tensor:
        """
        Compute the per-query mask logits.
        Query features are projected to the same dimension as the spatial features and then multiplied
        with the spatial features to get the mask logits.
        The mask logits are then reshaped to the spatial dimensions and broadcast with a segmentation bias
        parameter.

        Args:
            query_features (`torch.Tensor`): Query features of shape (batch_size, num_queries, d_model).
            spatial_features (`torch.Tensor`): Spatial features of shape (batch_size, hidden_dim, height, width).

        Returns:
            `torch.Tensor`: Mask logits of shape (batch_size, num_queries, height, width).
        """
        batch_size, num_queries, _ = query_features.shape
        height, width = spatial_features.shape[2], spatial_features.shape[3]

        query_features = self.query_features_block(query_features)
        query_features = self.query_features_proj(query_features)
        mask_logits = torch.matmul(query_features, spatial_features.flatten(2))
        mask_logits = mask_logits.view(batch_size, num_queries, height, width)
        mask_logits = mask_logits + self.segmentation_bias
        return mask_logits

    def segmentation_head(
        self, spatial_features, list_query_features, image_size: torch.Size, skip_blocks: bool = False
    ) -> list[torch.Tensor] | torch.Tensor:
        """
        Compute mask logits from spatial features and query features.

        Args:
            spatial_features: Multi-scale spatial features of shape
                (batch_size, num_channels, feature_height, feature_width).
            list_query_features: When `skip_blocks` is False, a list of query feature tensors of shape
                (batch_size, num_queries, d_model) for each decoder layer. When `skip_blocks` is True,
                a single tensor of shape (batch_size, num_queries, d_model).
            image_size: Original image spatial dimensions (height, width).
            skip_blocks: If True, skip the convolutional blocks and compute mask logits directly.

        Returns:
            When `skip_blocks` is False: list of mask logit tensors of shape
            (batch_size, num_queries, mask_height, mask_width), where mask size is image size divided
            by `downsample_ratio`. When `skip_blocks` is True: a single such tensor.
        """
        target_size = (image_size[0] // self.downsample_ratio, image_size[1] // self.downsample_ratio)
        spatial_features = F.interpolate(spatial_features, size=target_size, mode="bilinear", align_corners=False)

        if not skip_blocks:
            list_mask_logits = []
            for block, query_features in zip(self.blocks, list_query_features):
                spatial_features = block(spatial_features)
                spatial_features_proj = self.spatial_features_proj(spatial_features)
                mask_logits = self.get_mask_logits(query_features, spatial_features_proj)
                list_mask_logits.append(mask_logits)
        else:
            list_mask_logits = self.get_mask_logits(list_query_features, spatial_features)

        return list_mask_logits

    @can_return_tuple
    @auto_docstring(
        custom_intro="""
    Forward pass of the RF-DETR model for instance segmentation. The pipeline proceeds as follows:

        1. Process the visual input through the base RF-DETR model to obtain multi-scale spatial features,
           query embeddings, and their transformation history.
        2. Generate classification logits and initial segmentation masks from the encoder's proposed
           object query embeddings (first stage).
        3. Predict the final classification labels and refined bounding boxes using the decoder's last
           hidden state (second stage).
        4. Pass the high-resolution spatial features and query hidden states through the segmentation
           head to produce the final, detailed instance masks.
    """
    )
    def forward(
        self,
        pixel_values: torch.FloatTensor = None,
        pixel_mask: torch.LongTensor | None = None,
        labels: list[dict] | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> dict[str, torch.Tensor]:
        r"""
        labels (`list[Dict]` of len `(batch_size,)`, *optional*):
            Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
            following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
            respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
            in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
        """
        image_size = pixel_values.shape[-2:]

        # Step 1.
        outputs = self.model.model(pixel_values, pixel_mask=pixel_mask, **kwargs)

        spatial_features = outputs.backbone_features
        last_hidden_states = outputs.last_hidden_state
        intermediate_reference_points = outputs.intermediate_reference_points
        enc_outputs_class = outputs.enc_outputs_class

        # Step 2.
        enc_outputs_class_logits = self.model.predict_encoder_class_logits(enc_outputs_class)
        enc_outputs_masks = self.segmentation_head(spatial_features, enc_outputs_class, image_size, skip_blocks=True)

        # Step 3.
        logits, pred_boxes = self.model.predict_class_and_boxes(last_hidden_states, intermediate_reference_points[-1])

        # Step 4.
        outputs_masks = self.segmentation_head(spatial_features, outputs.intermediate_hidden_states, image_size)
        pred_masks = outputs_masks[-1]

        loss, loss_dict, auxiliary_outputs = None, None, None
        if labels is not None:
            outputs_class, outputs_coord = None, None
            if self.config.auxiliary_loss:
                outputs_class, outputs_coord = self.model.predict_class_and_boxes(
                    outputs.intermediate_hidden_states, intermediate_reference_points
                )
            loss, loss_dict, auxiliary_outputs = self.loss_function(
                logits,
                labels,
                self.device,
                pred_boxes,
                pred_masks,
                self.config,
                outputs_class,
                outputs_coord,
                outputs_masks,
                enc_outputs_class_logits,
                outputs.enc_outputs_coord_logits,
                enc_outputs_masks,
            )

        return RfDetrInstanceSegmentationOutput(
            loss=loss,
            loss_dict=loss_dict,
            logits=logits,
            pred_boxes=pred_boxes,
            pred_masks=pred_masks,
            auxiliary_outputs=auxiliary_outputs,
            last_hidden_state=outputs.last_hidden_state,
            intermediate_hidden_states=outputs.intermediate_hidden_states,
            intermediate_reference_points=outputs.intermediate_reference_points,
            init_reference_points=outputs.init_reference_points,
            enc_outputs_mask_logits=enc_outputs_masks,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )


__all__ = [
    "RfDetrModel",
    "RfDetrForObjectDetection",
    "RfDetrForInstanceSegmentation",
    "RfDetrPreTrainedModel",
    "RfDetrDinov2Backbone",
    "RfDetrDinov2PreTrainedModel",
]
