cointegrated's picture
Upload model
71eeb37 verified
import copy
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers.models.m2m_100.modeling_m2m_100 import (
M2M100Config,
M2M100ScaledWordEmbedding,
M2M100Decoder,
M2M100PreTrainedModel,
GenerationMixin,
Seq2SeqLMOutput,
BaseModelOutput,
shift_tokens_right,
Cache,
CrossEntropyLoss,
)
# override model type to register AutoModels
class SonarDecoderConfig(M2M100Config):
model_type = "SonarDecoderModel"
class SonarDecoderModel(M2M100PreTrainedModel, GenerationMixin):
# override confing class to register AutoModels
config_class = SonarDecoderConfig
_tied_weights_keys = {
"decoder.embed_tokens.weight": "shared.weight",
"lm_head.weight": "shared.weight",
}
_keys_to_ignore_on_load_unexpected = [r"encoder"]
def __init__(self, config: M2M100Config):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.d_model)
decoder_config = copy.deepcopy(config)
decoder_config.use_cache = False
decoder_config.is_encoder_decoder = False
self.decoder = M2M100Decoder(decoder_config)
self.lm_head = nn.Linear(config.d_model, self.shared.num_embeddings, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
self.decoder.embed_tokens = self.shared
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# in SONAR models, input and output projections are tied (ideally, this should be configurable)
self._tie_or_clone_weights(self.lm_head, self.shared)
def get_decoder(self):
return self.decoder
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
decoder_input_ids: torch.LongTensor | None = None,
decoder_attention_mask: torch.LongTensor | None = None,
encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
decoder_inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor] | Seq2SeqLMOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
if encoder_outputs is None:
raise ValueError("M2M100DecoderModel expects the `encoder_outputs` to be always present.")
if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
lm_logits = self.lm_head(decoder_outputs[0])
masked_lm_loss = None
if labels is not None:
# move labels to the correct device to enable PP
labels = labels.to(lm_logits.device)
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
@classmethod
def _can_set_experts_implementation(cls) -> bool:
return False