Feature Extraction
Transformers
Safetensors
English
usad
automatic-speech-recognition
audio-classification
audio
speech
music
custom_code
Instructions to use MIT-SLS/USAD-Base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use MIT-SLS/USAD-Base with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="MIT-SLS/USAD-Base", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("MIT-SLS/USAD-Base", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2021, Soohwan Kim. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import contextlib | |
| import math | |
| from collections import defaultdict | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| class SamePad(nn.Module): | |
| def __init__(self, kernel_size, causal=False): | |
| super().__init__() | |
| if causal: | |
| self.remove = kernel_size - 1 | |
| else: | |
| self.remove = 1 if kernel_size % 2 == 0 else 0 | |
| def forward(self, x): | |
| if self.remove > 0: | |
| x = x[:, :, : -self.remove] | |
| return x | |
| class TransposeLast(nn.Module): | |
| def __init__(self, deconstruct_idx=None, tranpose_dim=-2): | |
| super().__init__() | |
| self.deconstruct_idx = deconstruct_idx | |
| self.tranpose_dim = tranpose_dim | |
| def forward(self, x): | |
| if self.deconstruct_idx is not None: | |
| x = x[self.deconstruct_idx] | |
| return x.transpose(self.tranpose_dim, -1) | |
| class Swish(nn.Module): | |
| def __init__(self): | |
| super(Swish, self).__init__() | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| return inputs * inputs.sigmoid() | |
| class GLU(nn.Module): | |
| def __init__(self, dim: int) -> None: | |
| super(GLU, self).__init__() | |
| self.dim = dim | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| outputs, gate = inputs.chunk(2, dim=self.dim) | |
| return outputs * gate.sigmoid() | |
| class ResidualConnectionModule(nn.Module): | |
| def __init__( | |
| self, | |
| module: nn.Module, | |
| module_factor: float = 1.0, | |
| input_factor: float = 1.0, | |
| ): | |
| super(ResidualConnectionModule, self).__init__() | |
| self.module = module | |
| self.module_factor = module_factor | |
| self.input_factor = input_factor | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) | |
| class Linear(nn.Module): | |
| def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: | |
| super(Linear, self).__init__() | |
| self.linear = nn.Linear(in_features, out_features, bias=bias) | |
| nn.init.xavier_uniform_(self.linear.weight) | |
| if bias: | |
| nn.init.zeros_(self.linear.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.linear(x) | |
| class View(nn.Module): | |
| def __init__(self, shape: tuple, contiguous: bool = False): | |
| super(View, self).__init__() | |
| self.shape = shape | |
| self.contiguous = contiguous | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.contiguous: | |
| x = x.contiguous() | |
| return x.view(*self.shape) | |
| class Transpose(nn.Module): | |
| def __init__(self, shape: tuple): | |
| super(Transpose, self).__init__() | |
| self.shape = shape | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x.transpose(*self.shape) | |
| class FeedForwardModule(nn.Module): | |
| def __init__( | |
| self, | |
| encoder_dim: int = 512, | |
| expansion_factor: int = 4, | |
| dropout_p: float = 0.1, | |
| ) -> None: | |
| super(FeedForwardModule, self).__init__() | |
| self.sequential = nn.Sequential( | |
| nn.LayerNorm(encoder_dim), | |
| Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), | |
| Swish(), | |
| nn.Dropout(p=dropout_p), | |
| Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), | |
| nn.Dropout(p=dropout_p), | |
| ) | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| return self.sequential(inputs) | |
| class DepthwiseConv1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| bias: bool = False, | |
| ) -> None: | |
| super(DepthwiseConv1d, self).__init__() | |
| assert ( | |
| out_channels % in_channels == 0 | |
| ), "out_channels should be constant multiple of in_channels" | |
| self.conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| groups=in_channels, | |
| stride=stride, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| return self.conv(inputs) | |
| class PointwiseConv1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| bias: bool = True, | |
| ) -> None: | |
| super(PointwiseConv1d, self).__init__() | |
| self.conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=stride, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| return self.conv(inputs) | |
| class ConformerConvModule(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| kernel_size: int = 31, | |
| expansion_factor: int = 2, | |
| dropout_p: float = 0.1, | |
| ) -> None: | |
| super(ConformerConvModule, self).__init__() | |
| assert ( | |
| kernel_size - 1 | |
| ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" | |
| assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" | |
| self.sequential = nn.Sequential( | |
| nn.LayerNorm(in_channels), | |
| Transpose(shape=(1, 2)), | |
| PointwiseConv1d( | |
| in_channels, | |
| in_channels * expansion_factor, | |
| stride=1, | |
| padding=0, | |
| bias=True, | |
| ), | |
| GLU(dim=1), | |
| DepthwiseConv1d( | |
| in_channels, | |
| in_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=(kernel_size - 1) // 2, | |
| ), | |
| nn.BatchNorm1d(in_channels), | |
| Swish(), | |
| PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), | |
| nn.Dropout(p=dropout_p), | |
| ) | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| return self.sequential(inputs).transpose(1, 2) | |
| class FramewiseConv2dSubampling(nn.Module): | |
| def __init__(self, out_channels: int, subsample_rate: int = 2) -> None: | |
| super(FramewiseConv2dSubampling, self).__init__() | |
| assert subsample_rate in {2, 4}, "subsample_rate should be 2 or 4" | |
| self.subsample_rate = subsample_rate | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(1, out_channels, kernel_size=3, stride=2), | |
| nn.ReLU(), | |
| nn.Conv2d( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=(2 if subsample_rate == 4 else 1, 2), | |
| padding=(0 if subsample_rate == 4 else 1, 0), | |
| ), | |
| nn.ReLU(), | |
| ) | |
| def forward( | |
| self, inputs: torch.Tensor, input_lengths: torch.LongTensor | |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: | |
| # inputs: (B, T, C) -> (B, 1, T, C) | |
| if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0: | |
| inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0) | |
| outputs = self.cnn(inputs.unsqueeze(1)) | |
| batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size() | |
| outputs = outputs.permute(0, 2, 1, 3) | |
| outputs = outputs.contiguous().view( | |
| batch_size, subsampled_lengths, channels * sumsampled_dim | |
| ) | |
| if self.subsample_rate == 4: | |
| output_lengths = (((input_lengths - 1) >> 1) - 1) >> 1 | |
| else: | |
| output_lengths = input_lengths >> 1 | |
| return outputs, output_lengths | |
| class PatchwiseConv2dSubampling(nn.Module): | |
| def __init__( | |
| self, | |
| mel_dim: int, | |
| out_channels: int, | |
| patch_size_time: int = 16, | |
| patch_size_freq: int = 16, | |
| ) -> None: | |
| super(PatchwiseConv2dSubampling, self).__init__() | |
| self.mel_dim = mel_dim | |
| self.patch_size_time = patch_size_time | |
| self.patch_size_freq = patch_size_freq | |
| self.proj = nn.Conv2d( | |
| 1, | |
| out_channels, | |
| kernel_size=(patch_size_time, patch_size_freq), | |
| stride=(patch_size_time, patch_size_freq), | |
| padding=0, | |
| ) | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(), | |
| ) | |
| def subsample_rate(self) -> int: | |
| return self.patch_size_time * self.patch_size_freq // self.mel_dim | |
| def forward( | |
| self, inputs: torch.Tensor, input_lengths: torch.LongTensor | |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: | |
| assert ( | |
| inputs.shape[2] == self.mel_dim | |
| ), "inputs.shape[2] should be equal to mel_dim" | |
| # inputs: (B, Time, Freq) -> (B, 1, Time, Freq) | |
| outputs = self.proj(inputs.unsqueeze(1)) | |
| outputs = self.cnn(outputs) | |
| # (B, channels, Time // patch_size_time, Freq // patch_size_freq) | |
| outputs = outputs.flatten(2, 3).transpose(1, 2) | |
| # (B, (Time // patch_size_time) * (Freq // patch_size_freq), channels) | |
| output_lengths = ( | |
| input_lengths | |
| // self.patch_size_time | |
| * (self.mel_dim // self.patch_size_freq) | |
| ) | |
| return outputs, output_lengths | |
| class RelPositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, max_len: int = 10000) -> None: | |
| super(RelPositionalEncoding, self).__init__() | |
| self.d_model = d_model | |
| self.pe = None | |
| self.extend_pe(torch.tensor(0.0).expand(1, max_len)) | |
| def extend_pe(self, x: torch.Tensor) -> None: | |
| if self.pe is not None: | |
| if self.pe.size(1) >= x.size(1) * 2 - 1: | |
| if self.pe.dtype != x.dtype or self.pe.device != x.device: | |
| self.pe = self.pe.to(dtype=x.dtype, device=x.device) | |
| return | |
| pe_positive = torch.zeros(x.size(1), self.d_model) | |
| pe_negative = torch.zeros(x.size(1), self.d_model) | |
| position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, self.d_model, 2, dtype=torch.float32) | |
| * -(math.log(10000.0) / self.d_model) | |
| ) | |
| pe_positive[:, 0::2] = torch.sin(position * div_term) | |
| pe_positive[:, 1::2] = torch.cos(position * div_term) | |
| pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) | |
| pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) | |
| pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) | |
| pe_negative = pe_negative[1:].unsqueeze(0) | |
| pe = torch.cat([pe_positive, pe_negative], dim=1) | |
| self.pe = pe.to(device=x.device, dtype=x.dtype) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: (B, T, C) | |
| self.extend_pe(x) | |
| pos_emb = self.pe[ | |
| :, | |
| self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1), | |
| ] | |
| return pos_emb | |
| class RelativeMultiHeadAttention(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int = 512, | |
| num_heads: int = 16, | |
| dropout_p: float = 0.1, | |
| ): | |
| super(RelativeMultiHeadAttention, self).__init__() | |
| assert d_model % num_heads == 0, "d_model % num_heads should be zero." | |
| self.d_model = d_model | |
| self.d_head = int(d_model / num_heads) | |
| self.num_heads = num_heads | |
| self.sqrt_dim = math.sqrt(self.d_head) | |
| self.query_proj = Linear(d_model, d_model) | |
| self.key_proj = Linear(d_model, d_model) | |
| self.value_proj = Linear(d_model, d_model) | |
| self.pos_proj = Linear(d_model, d_model, bias=False) | |
| self.dropout = nn.Dropout(p=dropout_p) | |
| self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) | |
| self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) | |
| torch.nn.init.xavier_uniform_(self.u_bias) | |
| torch.nn.init.xavier_uniform_(self.v_bias) | |
| self.out_proj = Linear(d_model, d_model) | |
| def forward( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| pos_embedding: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| batch_size = value.size(0) | |
| query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) | |
| key = ( | |
| self.key_proj(key) | |
| .view(batch_size, -1, self.num_heads, self.d_head) | |
| .permute(0, 2, 1, 3) | |
| ) | |
| value = ( | |
| self.value_proj(value) | |
| .view(batch_size, -1, self.num_heads, self.d_head) | |
| .permute(0, 2, 1, 3) | |
| ) | |
| pos_embedding = self.pos_proj(pos_embedding).view( | |
| batch_size, -1, self.num_heads, self.d_head | |
| ) | |
| content_score = torch.matmul( | |
| (query + self.u_bias).transpose(1, 2), key.transpose(2, 3) | |
| ) | |
| pos_score = torch.matmul( | |
| (query + self.v_bias).transpose(1, 2), | |
| pos_embedding.permute(0, 2, 3, 1), | |
| ) | |
| pos_score = self._relative_shift(pos_score) | |
| score = (content_score + pos_score) / self.sqrt_dim | |
| if mask is not None: | |
| mask = mask.unsqueeze(1) | |
| score.masked_fill_(mask, -1e9) | |
| attn = F.softmax(score, -1) | |
| attn = self.dropout(attn) | |
| context = torch.matmul(attn, value).transpose(1, 2) | |
| context = context.contiguous().view(batch_size, -1, self.d_model) | |
| return self.out_proj(context), attn | |
| def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor: | |
| batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() | |
| zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) | |
| padded_pos_score = torch.cat([zeros, pos_score], dim=-1) | |
| padded_pos_score = padded_pos_score.view( | |
| batch_size, num_heads, seq_length2 + 1, seq_length1 | |
| ) | |
| pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)[ | |
| :, :, :, : seq_length2 // 2 + 1 | |
| ] | |
| return pos_score | |
| class MultiHeadedSelfAttentionModule(nn.Module): | |
| def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1): | |
| super(MultiHeadedSelfAttentionModule, self).__init__() | |
| self.positional_encoding = RelPositionalEncoding(d_model) | |
| self.layer_norm = nn.LayerNorm(d_model) | |
| self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) | |
| self.dropout = nn.Dropout(p=dropout_p) | |
| def forward( | |
| self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| batch_size = inputs.size(0) | |
| pos_embedding = self.positional_encoding(inputs) | |
| pos_embedding = pos_embedding.repeat(batch_size, 1, 1) | |
| inputs = self.layer_norm(inputs) | |
| outputs, attn = self.attention( | |
| inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask | |
| ) | |
| return self.dropout(outputs), attn | |
| class ConformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| encoder_dim: int = 512, | |
| attention_type: str = "mhsa", | |
| num_attention_heads: int = 8, | |
| mamba_d_state: int = 16, | |
| mamba_d_conv: int = 4, | |
| mamba_expand: int = 2, | |
| mamba_bidirectional: bool = True, | |
| feed_forward_expansion_factor: int = 4, | |
| conv_expansion_factor: int = 2, | |
| feed_forward_dropout_p: float = 0.1, | |
| attention_dropout_p: float = 0.1, | |
| conv_dropout_p: float = 0.1, | |
| conv_kernel_size: int = 31, | |
| half_step_residual: bool = True, | |
| transformer_style: bool = False, | |
| ): | |
| super(ConformerBlock, self).__init__() | |
| self.transformer_style = transformer_style | |
| self.attention_type = attention_type | |
| if half_step_residual and not transformer_style: | |
| self.feed_forward_residual_factor = 0.5 | |
| else: | |
| self.feed_forward_residual_factor = 1 | |
| assert attention_type in ["mhsa", "mamba"] | |
| if attention_type == "mhsa": | |
| attention = MultiHeadedSelfAttentionModule( | |
| d_model=encoder_dim, | |
| num_heads=num_attention_heads, | |
| dropout_p=attention_dropout_p, | |
| ) | |
| self.ffn_1 = FeedForwardModule( | |
| encoder_dim=encoder_dim, | |
| expansion_factor=feed_forward_expansion_factor, | |
| dropout_p=feed_forward_dropout_p, | |
| ) | |
| self.attention = attention | |
| if not transformer_style: | |
| self.conv = ConformerConvModule( | |
| in_channels=encoder_dim, | |
| kernel_size=conv_kernel_size, | |
| expansion_factor=conv_expansion_factor, | |
| dropout_p=conv_dropout_p, | |
| ) | |
| self.ffn_2 = FeedForwardModule( | |
| encoder_dim=encoder_dim, | |
| expansion_factor=feed_forward_expansion_factor, | |
| dropout_p=feed_forward_dropout_p, | |
| ) | |
| self.layernorm = nn.LayerNorm(encoder_dim) | |
| def forward( | |
| self, x: torch.Tensor | |
| ) -> Tuple[torch.Tensor, Dict[str, Union[torch.Tensor, None]]]: | |
| # FFN 1 | |
| ffn_1_out = self.ffn_1(x) | |
| x = ffn_1_out * self.feed_forward_residual_factor + x | |
| # Attention | |
| if not isinstance(self.attention, MultiHeadedSelfAttentionModule): | |
| # MAMBA | |
| attn_out = self.attention(x) | |
| attn = None | |
| else: | |
| attn_out, attn = self.attention(x) | |
| x = attn_out + x | |
| if self.transformer_style: | |
| x = self.layernorm(x) | |
| return x, { | |
| "ffn_1": ffn_1_out, | |
| "attn": attn, | |
| "conv": None, | |
| "ffn_2": None, | |
| } | |
| # Convolution | |
| conv_out = self.conv(x) | |
| x = conv_out + x | |
| # FFN 2 | |
| ffn_2_out = self.ffn_2(x) | |
| x = ffn_2_out * self.feed_forward_residual_factor + x | |
| x = self.layernorm(x) | |
| other = { | |
| "ffn_1": ffn_1_out, | |
| "attn": attn, | |
| "conv": conv_out, | |
| "ffn_2": ffn_2_out, | |
| } | |
| return x, other | |
| class ConformerEncoder(nn.Module): | |
| def __init__(self, cfg): | |
| super(ConformerEncoder, self).__init__() | |
| self.cfg = cfg | |
| self.framewise_subsample = None | |
| self.patchwise_subsample = None | |
| self.framewise_in_proj = None | |
| self.patchwise_in_proj = None | |
| assert ( | |
| cfg.use_framewise_subsample or cfg.use_patchwise_subsample | |
| ), "At least one subsampling method should be used" | |
| if cfg.use_framewise_subsample: | |
| self.framewise_subsample = FramewiseConv2dSubampling( | |
| out_channels=cfg.conv_subsample_channels, | |
| subsample_rate=cfg.conv_subsample_rate, | |
| ) | |
| self.framewise_in_proj = nn.Sequential( | |
| Linear( | |
| cfg.conv_subsample_channels * (((cfg.input_dim - 1) // 2 - 1) // 2), | |
| cfg.encoder_dim, | |
| ), | |
| nn.Dropout(p=cfg.input_dropout_p), | |
| ) | |
| if cfg.use_patchwise_subsample: | |
| self.patchwise_subsample = PatchwiseConv2dSubampling( | |
| mel_dim=cfg.input_dim, | |
| out_channels=cfg.conv_subsample_channels, | |
| patch_size_time=cfg.patch_size_time, | |
| patch_size_freq=cfg.patch_size_freq, | |
| ) | |
| self.patchwise_in_proj = nn.Sequential( | |
| Linear( | |
| cfg.conv_subsample_channels, | |
| cfg.encoder_dim, | |
| ), | |
| nn.Dropout(p=cfg.input_dropout_p), | |
| ) | |
| assert not cfg.use_framewise_subsample or ( | |
| cfg.conv_subsample_rate == self.patchwise_subsample.subsample_rate | |
| ), ( | |
| f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate" | |
| f"({self.patchwise_subsample.subsample_rate})" | |
| ) | |
| self.framewise_norm, self.patchwise_norm = None, None | |
| if getattr(cfg, "subsample_normalization", False): | |
| if cfg.use_framewise_subsample: | |
| self.framewise_norm = nn.LayerNorm(cfg.encoder_dim) | |
| if cfg.use_patchwise_subsample: | |
| self.patchwise_norm = nn.LayerNorm(cfg.encoder_dim) | |
| self.conv_pos = None | |
| if getattr(cfg, "conv_pos", False): | |
| num_pos_layers = cfg.conv_pos_depth | |
| k = max(3, cfg.conv_pos_width // num_pos_layers) | |
| self.conv_pos = nn.Sequential( | |
| TransposeLast(), | |
| *[ | |
| nn.Sequential( | |
| nn.Conv1d( | |
| cfg.encoder_dim, | |
| cfg.encoder_dim, | |
| kernel_size=k, | |
| padding=k // 2, | |
| groups=cfg.conv_pos_groups, | |
| ), | |
| SamePad(k), | |
| TransposeLast(), | |
| nn.LayerNorm(cfg.encoder_dim, elementwise_affine=False), | |
| TransposeLast(), | |
| nn.GELU(), | |
| ) | |
| for _ in range(num_pos_layers) | |
| ], | |
| TransposeLast(), | |
| ) | |
| self.conv_pos_post_ln = nn.LayerNorm(cfg.encoder_dim) | |
| self.layers = nn.ModuleList( | |
| [ | |
| ConformerBlock( | |
| encoder_dim=cfg.encoder_dim, | |
| attention_type=cfg.attention_type, | |
| num_attention_heads=cfg.num_attention_heads, | |
| mamba_d_state=cfg.mamba_d_state, | |
| mamba_d_conv=cfg.mamba_d_conv, | |
| mamba_expand=cfg.mamba_expand, | |
| mamba_bidirectional=cfg.mamba_bidirectional, | |
| feed_forward_expansion_factor=cfg.feed_forward_expansion_factor, | |
| conv_expansion_factor=cfg.conv_expansion_factor, | |
| feed_forward_dropout_p=cfg.feed_forward_dropout_p, | |
| attention_dropout_p=cfg.attention_dropout_p, | |
| conv_dropout_p=cfg.conv_dropout_p, | |
| conv_kernel_size=cfg.conv_kernel_size, | |
| half_step_residual=cfg.half_step_residual, | |
| transformer_style=getattr(cfg, "transformer_style", False), | |
| ) | |
| for _ in range(cfg.num_layers) | |
| ] | |
| ) | |
| def count_parameters(self) -> int: | |
| """Count parameters of encoder""" | |
| return sum([p.numel() for p in self.parameters() if p.requires_grad]) | |
| def update_dropout(self, dropout_p: float) -> None: | |
| """Update dropout probability of encoder""" | |
| for name, child in self.named_children(): | |
| if isinstance(child, nn.Dropout): | |
| child.p = dropout_p | |
| def forward( | |
| self, | |
| inputs: torch.Tensor, | |
| input_lengths: Optional[torch.Tensor] = None, | |
| return_hidden: bool = False, | |
| freeze_input_layers: bool = False, | |
| target_layer: Optional[int] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, List[torch.Tensor]]]: | |
| if input_lengths is None: | |
| input_lengths = torch.full( | |
| (inputs.size(0),), | |
| inputs.size(1), | |
| dtype=torch.long, | |
| device=inputs.device, | |
| ) | |
| with torch.no_grad() if freeze_input_layers else contextlib.ExitStack(): | |
| frame_feat, patch_feat = None, None | |
| if self.framewise_subsample is not None: | |
| frame_feat, frame_lengths = self.framewise_subsample( | |
| inputs, input_lengths | |
| ) | |
| frame_feat = self.framewise_in_proj(frame_feat) | |
| if self.framewise_norm is not None: | |
| frame_feat = self.framewise_norm(frame_feat) | |
| if self.patchwise_subsample is not None: | |
| patch_feat, patch_lengths = self.patchwise_subsample( | |
| inputs, input_lengths | |
| ) | |
| patch_feat = self.patchwise_in_proj(patch_feat) | |
| if self.patchwise_norm is not None: | |
| patch_feat = self.patchwise_norm(patch_feat) | |
| if frame_feat is not None and patch_feat is not None: | |
| min_len = min(frame_feat.size(1), patch_feat.size(1)) | |
| frame_feat = frame_feat[:, :min_len] | |
| patch_feat = patch_feat[:, :min_len] | |
| features = frame_feat + patch_feat | |
| output_lengths = ( | |
| frame_lengths | |
| if frame_lengths.max().item() < patch_lengths.max().item() | |
| else patch_lengths | |
| ) | |
| elif frame_feat is not None: | |
| features = frame_feat | |
| output_lengths = frame_lengths | |
| else: | |
| features = patch_feat | |
| output_lengths = patch_lengths | |
| if self.conv_pos is not None: | |
| features = features + self.conv_pos(features) | |
| features = self.conv_pos_post_ln(features) | |
| layer_results = defaultdict(list) | |
| outputs = features | |
| for i, layer in enumerate(self.layers): | |
| outputs, other = layer(outputs) | |
| if return_hidden: | |
| layer_results["hidden_states"].append(outputs) | |
| for k, v in other.items(): | |
| layer_results[k].append(v) | |
| if target_layer is not None and i == target_layer: | |
| break | |
| return outputs, output_lengths, layer_results | |