File size: 5,914 Bytes
71eeb37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import copy

from typing import Optional, Tuple, Union

import torch
from torch import nn

from transformers.models.m2m_100.modeling_m2m_100 import (
    M2M100Config,
    
    M2M100ScaledWordEmbedding,
    M2M100Decoder,
    M2M100PreTrainedModel,
    GenerationMixin,

    Seq2SeqLMOutput,
    BaseModelOutput,

    shift_tokens_right,
    Cache,
    CrossEntropyLoss,
)

# override model type to register AutoModels
class SonarDecoderConfig(M2M100Config):
    model_type = "SonarDecoderModel"


class SonarDecoderModel(M2M100PreTrainedModel, GenerationMixin):
    # override confing class to register AutoModels
    config_class = SonarDecoderConfig
    _tied_weights_keys = {
        "decoder.embed_tokens.weight": "shared.weight",
        "lm_head.weight": "shared.weight",
    }
    _keys_to_ignore_on_load_unexpected = [r"encoder"]

    def __init__(self, config: M2M100Config):
        super().__init__(config)
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        decoder_config = copy.deepcopy(config)
        decoder_config.use_cache = False
        decoder_config.is_encoder_decoder = False
        self.decoder = M2M100Decoder(decoder_config)
        self.lm_head = nn.Linear(config.d_model, self.shared.num_embeddings, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, value):
        self.shared = value
        self.decoder.embed_tokens = self.shared

    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
            # in SONAR models, input and output projections are tied (ideally, this should be configurable)
            self._tie_or_clone_weights(self.lm_head, self.shared)

    def get_decoder(self):
        return self.decoder

    def forward(

        self,

        input_ids: torch.LongTensor | None = None,

        attention_mask: torch.Tensor | None = None,

        decoder_input_ids: torch.LongTensor | None = None,

        decoder_attention_mask: torch.LongTensor | None = None,

        encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,

        past_key_values: Cache | None = None,

        inputs_embeds: torch.FloatTensor | None = None,

        decoder_inputs_embeds: torch.FloatTensor | None = None,

        labels: torch.LongTensor | None = None,

        use_cache: bool | None = None,

        output_attentions: bool | None = None,

        output_hidden_states: bool | None = None,

        return_dict: bool | None = None,

        cache_position: torch.Tensor | None = None,

        **kwargs,

    ) -> tuple[torch.Tensor] | Seq2SeqLMOutput:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        if encoder_outputs is None:
            raise ValueError("M2M100DecoderModel expects the `encoder_outputs` to be always present.")

        if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        lm_logits = self.lm_head(decoder_outputs[0])

        masked_lm_loss = None
        if labels is not None:
            # move labels to the correct device to enable PP
            labels = labels.to(lm_logits.device)
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past

    @classmethod
    def _can_set_experts_implementation(cls) -> bool:
        return False