Source code for diaparser.modules.scalar_mix

# -*- coding: utf-8 -*-

import torch
import torch.nn as nn


[docs]class ScalarMix(nn.Module): r""" Computes a parameterised scalar mixture of :math:`N` tensors, :math:`mixture = \gamma * \sum_{k}(s_k * tensor_k)` where :math:`s = \mathrm{softmax}(w)`, with :math:`w` and :math:`\gamma` scalar parameters. Args: n_layers (int): The number of layers to be mixed, i.e., :math:`N`. dropout (float): The dropout ratio of the layer weights. If dropout > 0, then for each scalar weight, adjust its softmax weight mass to 0 with the dropout probability (i.e., setting the unnormalized weight to -inf). This effectively redistributes the dropped probability mass to all other weights. Default: 0. """ def __init__(self, n_layers: int, dropout: float = 0.0): super().__init__() self.n_layers = n_layers self.weights = nn.Parameter(torch.zeros(n_layers)) self.gamma = nn.Parameter(torch.tensor([1.0])) self.dropout = nn.Dropout(dropout) def __repr__(self): s = f"n_layers={self.n_layers}" if self.dropout.p > 0: s += f", dropout={self.dropout.p}" return f"{self.__class__.__name__}({s})"
[docs] def forward(self, tensors): r""" Args: tensors (list[~torch.Tensor]): :math:`N` tensors to be mixed. Returns: The mixture of :math:`N` tensors. """ normed_weights = self.dropout(self.weights.softmax(-1)) weighted_sum = sum(w * h for w, h in zip(normed_weights, tensors)) return self.gamma * weighted_sum