# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from collections.abc import Callable

import torch
from huggingface_hub.dataclasses import strict
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import Unigram
from torch import nn

from ...audio_utils import AudioInput, make_list_of_audio
from ...masking_utils import create_bidirectional_mask
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...tokenization_utils_tokenizers import TokenizersBackend
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from ..llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward
from ..parakeet.configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
from ..parakeet.modeling_parakeet import (
    ParakeetEncoderBlock,
    ParakeetEncoderConvolutionModule,
    ParakeetEncoderModelOutput,
    ParakeetForCTC,
    ParakeetPreTrainedModel,
)
from ..t5.tokenization_t5 import T5Tokenizer


logger = logging.get_logger(__name__)


class LasrTokenizer(T5Tokenizer, TokenizersBackend):
    def __init__(
        self,
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="<pad>",
        _spm_precompiled_charsmap=None,
        extra_ids=100,
        additional_special_tokens=None,
        vocab=None,
        vocab_file=None,
        **kwargs,
    ):
        self._extra_ids = extra_ids

        # Handle extra_ids and additional_special_tokens
        if additional_special_tokens is not None:
            extra_tokens = [x for x in additional_special_tokens if "<extra_id_" in str(x)]
            if len(extra_tokens) < 1:
                additional_special_tokens += [f"<extra_id_{i}>" for i in range(extra_ids)]
            elif extra_ids > 0 and extra_ids != len(extra_tokens):
                raise ValueError(
                    f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
                    " provided to LasrTokenizer. In this case the additional_special_tokens must include the extra_ids"
                    " tokens"
                )
        else:
            extra_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
            additional_special_tokens = extra_tokens

        # LASR vocab structure: <pad>=0, </s>=1, <unk>=2, then regular vocab, then extra_ids in reverse
        if vocab is not None:
            self._vocab_scores = vocab
        else:
            self._vocab_scores = [
                (str(pad_token), 0.0),
                (str(eos_token), 0.0),
                (str(unk_token), 0.0),
                ("▁", -2.0),  # Space token
            ]
            for i in range(extra_ids - 1, -1, -1):
                self._vocab_scores.append((f"<extra_id_{i}>", 0.0))
        self._tokenizer = Tokenizer(
            Unigram(
                self._vocab_scores,
                unk_id=3,
                byte_fallback=False,
            )
        )

        if _spm_precompiled_charsmap is not None:
            self._tokenizer.normalizer = normalizers.Precompiled(_spm_precompiled_charsmap)

        self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
            [
                pre_tokenizers.WhitespaceSplit(),
                pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True),
            ]
        )
        self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)

        TokenizersBackend.__init__(
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            extra_ids=extra_ids,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )

        self._tokenizer.post_processor = processors.TemplateProcessing(
            single=["$A", "</s>"],
            pair=["$A", "</s>", "$B", "</s>"],
            special_tokens=[
                ("</s>", self.eos_token_id),
            ],
        )

    def _decode(
        self,
        token_ids: int | list[int],
        skip_special_tokens: bool = False,
        clean_up_tokenization_spaces: bool | None = None,
        group_tokens: bool = True,
        **kwargs,
    ) -> str:
        if isinstance(token_ids, int):
            token_ids = [token_ids]
        if group_tokens:
            token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)]

        # for CTC we filter out the blank token, which is the pad token
        token_ids = [token for token in token_ids if token != self.pad_token_id]

        return TokenizersBackend._decode(
            self,
            token_ids=token_ids,
            skip_special_tokens=skip_special_tokens,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            **kwargs,
        )


class LasrProcessorKwargs(ProcessingKwargs, total=False):
    _defaults = {
        "audio_kwargs": {
            "sampling_rate": 16000,
            "padding": "longest",
            "return_attention_mask": True,
        },
        "text_kwargs": {
            "padding": True,
            "padding_side": "right",
            "add_special_tokens": False,
        },
        "common_kwargs": {"return_tensors": "pt"},
    }


@auto_docstring
class LasrProcessor(ProcessorMixin):
    def __init__(self, feature_extractor, tokenizer):
        super().__init__(feature_extractor, tokenizer)

    @auto_docstring
    def __call__(
        self,
        audio: AudioInput,
        text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
        sampling_rate: int | None = None,
        **kwargs: Unpack[LasrProcessorKwargs],
    ):
        r"""
        sampling_rate (`int`, *optional*):
            The sampling rate of the input audio in Hz. This should match the sampling rate expected by the feature
            extractor (defaults to 16000 Hz). If provided, it will be validated against the processor's expected
            sampling rate, and an error will be raised if they don't match. If not provided, a warning will be
            issued and the default sampling rate will be assumed.
        """
        audio = make_list_of_audio(audio)

        output_kwargs = self._merge_kwargs(
            LasrProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
            **kwargs,
        )

        if sampling_rate is None:
            logger.warning_once(
                f"You've provided audio without specifying the sampling rate. It will be assumed to be {output_kwargs['audio_kwargs']['sampling_rate']}, which can result in silent errors."
            )
        elif sampling_rate != output_kwargs["audio_kwargs"]["sampling_rate"]:
            raise ValueError(
                f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({output_kwargs['audio_kwargs']['sampling_rate']}). Please provide resampled the audio to the expected sampling rate."
            )

        if audio is not None:
            inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
        if text is not None:
            encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])

        if text is None:
            return inputs
        else:
            inputs["labels"] = encodings["input_ids"]
            return inputs

    @property
    def model_input_names(self):
        feature_extractor_input_names = self.feature_extractor.model_input_names
        return feature_extractor_input_names + ["labels"]


@auto_docstring(checkpoint="google/medasr")
@strict
class LasrEncoderConfig(ParakeetEncoderConfig):
    r"""
    convolution_bias (`bool`, *optional*, defaults to `False`):
        Whether to use bias in convolutions of the conformer's convolution module.
    conv_kernel_size (`int`, *optional*, defaults to 32):
        The kernel size of the convolution layers in the Conformer block.
    subsampling_conv_channels (`int`, *optional*, defaults to 256):
        The number of channels in the subsampling convolution layers.
    subsampling_conv_kernel_size (`int`, *optional*, defaults to 5):
        The kernel size of the subsampling convolution layers.
    subsampling_conv_stride (`int`, *optional*, defaults to 2):
        The stride of the subsampling convolution layers.
    dropout_positions (`float`, *optional*, defaults to 0.0):
        The dropout ratio for the positions in the input sequence.
    feed_forward_residual_weights (`tuple[float, float]`, *optional*, defaults to `[1.5, 0.5]`):
        The residual weights for the feed forward layers.
    conv_residual_weights (`tuple[float, float]`, *optional*, defaults to `[2.0, 1.0]`):
        The residual weights for the convolution layers.
    batch_norm_momentum (`float`, *optional*, defaults to 0.01):
        The momentum for the batch normalization layers

    Example:
    ```python
    >>> from transformers import LasrEncoderModel, LasrEncoderConfig

    >>> # Initializing a `LasrEncoder` configuration
    >>> configuration = LasrEncoderConfig()

    >>> # Initializing a model from the configuration
    >>> model = LasrEncoderModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```

    This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details
    and pre-trained models at [google/medasr](https://huggingface.co/google/medasr).
    """

    hidden_size: int = 512
    num_hidden_layers: int = 17
    intermediate_size: int = 2048
    attention_bias: bool = False
    convolution_bias: bool = False
    conv_kernel_size: int = 32
    subsampling_conv_kernel_size: int = 5
    num_mel_bins: int = 128
    max_position_embeddings: int = 10000
    layer_norm_eps: float = 1e-6
    feed_forward_residual_weights: list[float] | tuple[float, ...] = (1.5, 0.5)
    conv_residual_weights: list[float] | tuple[float, ...] = (2.0, 1.0)
    batch_norm_momentum: float = 0.01
    rope_parameters: dict | None = None

    subsampling_factor = AttributeError()
    scale_input = AttributeError()


@auto_docstring(checkpoint="google/medasr")
@strict
class LasrCTCConfig(ParakeetCTCConfig):
    r"""
    ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`):
        Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
        instance of [`LasrForCTC`].
    ctc_zero_infinity (`bool`, *optional*, defaults to `True`):
        Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
        occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
        of [`LasrForCTC`].

    Example:
    ```python
    >>> from transformers import LasrForCTC, LasrCTCConfig
    >>> # Initializing a Lasr configuration
    >>> configuration = LasrCTCConfig()
    >>> # Initializing a model from the configuration
    >>> model = LasrForCTC(configuration)
    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details
    and pre-trained models at [google/medasr](https://huggingface.co/google/medasr).
    """

    vocab_size: int = 512
    pad_token_id: int = 0

    @property
    def inputs_to_logits_ratio(self):
        return self.encoder_config.subsampling_conv_stride**2


class LasrEncoderSubsampling(nn.Module):
    def __init__(self, config: LasrEncoderConfig):
        super().__init__()
        self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size)
        self.conv_0 = nn.Conv1d(
            config.hidden_size,
            config.hidden_size,
            kernel_size=config.subsampling_conv_kernel_size,
            stride=config.subsampling_conv_stride,
        )
        self.conv_1 = nn.Conv1d(
            config.hidden_size,
            config.subsampling_conv_channels,
            kernel_size=config.subsampling_conv_kernel_size,
            stride=config.subsampling_conv_stride,
        )
        self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size)
        self.act_fn = nn.ReLU()

    def forward(self, input_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.act_fn(self.dense_0(input_features))
        hidden_states = hidden_states.transpose(1, 2)
        hidden_states = self.act_fn(self.conv_0(hidden_states))
        hidden_states = self.act_fn(self.conv_1(hidden_states))
        hidden_states = hidden_states.transpose(1, 2)
        return self.dense_1(hidden_states)


class LasrEncoderRotaryEmbedding(LlamaRotaryEmbedding): ...


class LasrEncoderAttention(LlamaAttention):
    def __init__(self, config: LasrEncoderConfig, layer_idx: int):
        super().__init__(config, layer_idx)
        self.is_causal = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
        attention_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class LasrEncoderConvolutionModule(ParakeetEncoderConvolutionModule):
    def __init__(self, config: LasrEncoderConfig, module_config=None):
        super().__init__(config, module_config)
        self.padding = "same"
        self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum)


class LasrEncoderBlock(ParakeetEncoderBlock):
    def __init__(self, config: LasrEncoderConfig, layer_idx: int):
        super().__init__(config, layer_idx)

        self.feed_forward_residual_weights = config.feed_forward_residual_weights
        self.conv_residual_weights = config.conv_residual_weights

        self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
        self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
        self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
        self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
        self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        position_embeddings: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
        hidden_states = (
            self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
        )

        normalized_hidden_states = self.norm_self_att(hidden_states)
        attn_output, _ = self.self_attn(
            hidden_states=normalized_hidden_states,
            attention_mask=attention_mask,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = hidden_states + attn_output

        conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
        hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output

        residual = hidden_states
        hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states))
        hidden_states = (
            self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
        )

        hidden_states = self.norm_out(hidden_states)

        return hidden_states


class LasrPreTrainedModel(ParakeetPreTrainedModel):
    # padding is incompatible with flex attention as the resulting mask cannot be used to apply padding
    _supports_flex_attn = False

    def _init_weights(self, module):
        PreTrainedModel._init_weights(module)

    def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
        encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config
        kernel_size = encoder_config.subsampling_conv_kernel_size
        stride = encoder_config.subsampling_conv_stride

        num_layers = 2
        for _ in range(num_layers):
            input_lengths = (input_lengths - kernel_size) // stride + 1

        return input_lengths


class LasrEncoderModelOutput(ParakeetEncoderModelOutput):
    pass


@auto_docstring(
    custom_intro="""
    The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
    """
)
class LasrEncoder(LasrPreTrainedModel):
    config: LasrEncoderConfig
    base_model_prefix = "encoder"

    def __init__(self, config: LasrEncoderConfig):
        super().__init__(config)
        self.gradient_checkpointing = False

        self.dropout = config.dropout
        self.dropout_positions = config.dropout_positions
        self.layerdrop = config.layerdrop

        self.subsampler = LasrEncoderSubsampling(config)
        self.rotary_emb = LasrEncoderRotaryEmbedding(config)
        self.layers = nn.ModuleList(
            [LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)

        self.post_init()

    @auto_docstring
    @merge_with_config_defaults
    @capture_outputs
    @can_return_tuple
    def forward(
        self,
        input_features: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        output_attention_mask: bool | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> LasrEncoderModelOutput:
        r"""
        output_attention_mask (`bool`, *optional*):
            Whether to return the output attention mask.

        Example:

        ```python
        >>> from transformers import AutoProcessor, LasrEncoder
        >>> from datasets import load_dataset, Audio

        >>> model_id = "google/medasr"
        >>> processor = AutoProcessor.from_pretrained(model_id)
        >>> encoder = ParakeetEncoder.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))

        >>> inputs = processor(ds[0]["audio"]["array"])
        >>> encoder_outputs = encoder(**inputs)

        >>> print(encoder_outputs.last_hidden_state.shape)
        ```
        """

        hidden_states = self.subsampler(input_features)
        cos, sin = self.rotary_emb(
            hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
        )

        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
        sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)

        output_mask = None
        if attention_mask is not None:
            output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
            attention_mask = output_mask

        attention_mask = create_bidirectional_mask(
            config=self.config,
            inputs_embeds=hidden_states,
            attention_mask=attention_mask,
        )

        for encoder_layer in self.layers:
            # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
            to_drop = False
            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:  # skip the layer
                    to_drop = True

            if not to_drop:
                hidden_states = encoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_embeddings=(cos, sin),
                    **kwargs,
                )

        hidden_states = self.out_norm(hidden_states)

        return LasrEncoderModelOutput(
            last_hidden_state=hidden_states,
            attention_mask=output_mask.int() if output_attention_mask and output_mask is not None else None,
        )


class LasrForCTC(ParakeetForCTC):
    def generate(**super_kwargs):
        r"""
        Example:

        ```python
        >>> from transformers import AutoProcessor, LasrForCTC
        >>> from datasets import load_dataset, Audio

        >>> model_id = "google/medasr"
        >>> processor = AutoProcessor.from_pretrained(model_id)
        >>> model = LasrForCTC.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))

        >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
        >>> predicted_ids = model.generate(**inputs)
        >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

        >>> print(transcription)
        ```
        """
        return super().generate(**super_kwargs)


__all__ = [
    "LasrForCTC",
    "LasrEncoder",
    "LasrPreTrainedModel",
    "LasrProcessor",
    "LasrEncoderConfig",
    "LasrCTCConfig",
    "LasrTokenizer",
]
