torch.ao.nn.quantizable.modules.rnn 的源代码

"""
We will recreate all the RNN modules as we require the modules to be decomposed
into its building blocks to be able to observe.
"""

# mypy: allow-untyped-defs

import numbers
import warnings
from typing import Optional, Tuple

import torch
from torch import Tensor


__all__ = ["LSTMCell", "LSTM"]


class LSTMCell(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM) cell.

    For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`

    Examples::

        >>> import torch.ao.nn.quantizable as nnqa
        >>> rnn = nnqa.LSTMCell(10, 20)
        >>> input = torch.randn(6, 10)
        >>> hx = torch.randn(3, 20)
        >>> cx = torch.randn(3, 20)
        >>> output = []
        >>> for i in range(6):
        ...     hx, cx = rnn(input[i], (hx, cx))
        ...     output.append(hx)
    """
    _FLOAT_MODULE = torch.nn.LSTMCell

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        bias: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.input_size = input_dim
        self.hidden_size = hidden_dim
        self.bias = bias

        self.igates = torch.nn.Linear(
            input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
        )
        self.hgates = torch.nn.Linear(
            hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
        )
        self.gates = torch.ao.nn.quantized.FloatFunctional()

        self.input_gate = torch.nn.Sigmoid()
        self.forget_gate = torch.nn.Sigmoid()
        self.cell_gate = torch.nn.Tanh()
        self.output_gate = torch.nn.Sigmoid()

        self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
        self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
        self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()

        self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()

        self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0)
        self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0)
        self.hidden_state_dtype: torch.dtype = torch.quint8
        self.cell_state_dtype: torch.dtype = torch.quint8

    def forward(
        self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None
    ) -> Tuple[Tensor, Tensor]:
        if hidden is None or hidden[0] is None or hidden[1] is None:
            hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
        hx, cx = hidden

        igates = self.igates(x)
        hgates = self.hgates(hx)
        gates = self.gates.add(igates, hgates)

        input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)

        input_gate = self.input_gate(input_gate)
        forget_gate = self.forget_gate(forget_gate)
        cell_gate = self.cell_gate(cell_gate)
        out_gate = self.output_gate(out_gate)

        fgate_cx = self.fgate_cx.mul(forget_gate, cx)
        igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
        fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
        cy = fgate_cx_igate_cgate

        # TODO: make this tanh a member of the module so its qparams can be configured
        tanh_cy = torch.tanh(cy)
        hy = self.ogate_cy.mul(out_gate, tanh_cy)
        return hy, cy

    def initialize_hidden(
        self, batch_size: int, is_quantized: bool = False
    ) -> Tuple[Tensor, Tensor]:
        h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros(
            (batch_size, self.hidden_size)
        )
        if is_quantized:
            (h_scale, h_zp) = self.initial_hidden_state_qparams
            (c_scale, c_zp) = self.initial_cell_state_qparams
            h = torch.quantize_per_tensor(
                h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype
            )
            c = torch.quantize_per_tensor(
                c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype
            )
        return h, c

    def _get_name(self):
        return "QuantizableLSTMCell"

    @classmethod
    def from_params(cls, wi, wh, bi=None, bh=None):
"""Uses the weights and biases to create a new LSTM cell.

        Args:
            wi, wh: Weights for the input and hidden layers
            bi, bh: Biases for the input and hidden layers
        """
        assert (bi is None) == (bh is None)  # Either both None or both have values
        input_size = wi.shape[1]
        hidden_size = wh.shape[1]
        cell = cls(input_dim=input_size, hidden_dim=hidden_size, bias=(bi is not None))
        cell.igates.weight = torch.nn.Parameter(wi)
        if bi is not None:
            cell.igates.bias = torch.nn.Parameter(bi)
        cell.hgates.weight = torch.nn.Parameter(wh)
        if bh is not None:
            cell.hgates.bias = torch.nn.Parameter(bh)
        return cell

    @classmethod
    def from_float(cls, other, use_precomputed_fake_quant=False):
        assert type(other) == cls._FLOAT_MODULE
        assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
        observed = cls.from_params(
            other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh
        )
        observed.qconfig = other.qconfig
        observed.igates.qconfig = other.qconfig
        observed.hgates.qconfig = other.qconfig
        return observed


class _LSTMSingleLayer(torch.nn.Module):
r"""A single one-directional LSTM layer.

    The difference between a layer and a cell is that the layer can process a
    sequence, while the cell only expects an instantaneous value.
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        bias: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)

    def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
        result = []
        seq_len = x.shape[0]
        for i in range(seq_len):
            hidden = self.cell(x[i], hidden)
            result.append(hidden[0])  # type: ignore[index]
        result_tensor = torch.stack(result, 0)
        return result_tensor, hidden

    @classmethod
    def from_params(cls, *args, **kwargs):
        cell = LSTMCell.from_params(*args, **kwargs)
        layer = cls(cell.input_size, cell.hidden_size, cell.bias)
        layer.cell = cell
        return layer


class _LSTMLayer(torch.nn.Module):
r"""A single bi-directional LSTM layer."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        bias: bool = True,
        batch_first: bool = False,
        bidirectional: bool = False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.batch_first = batch_first
        self.bidirectional = bidirectional
        self.layer_fw = _LSTMSingleLayer(
            input_dim, hidden_dim, bias=bias, **factory_kwargs
        )
        if self.bidirectional:
            self.layer_bw = _LSTMSingleLayer(
                input_dim, hidden_dim, bias=bias, **factory_kwargs
            )

    def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
        if self.batch_first:
            x = x.transpose(0, 1)
        if hidden is None:
            hx_fw, cx_fw = (None, None)
        else:
            hx_fw, cx_fw = hidden
        hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
        if self.bidirectional:
            if hx_fw is None:
                hx_bw = None
            else:
                hx_bw = hx_fw[1]
                hx_fw = hx_fw[0]
            if cx_fw is None:
                cx_bw = None
            else:
                cx_bw = cx_fw[1]
                cx_fw = cx_fw[0]
            if hx_bw is not None and cx_bw is not None:
                hidden_bw = hx_bw, cx_bw
        if hx_fw is None and cx_fw is None:
            hidden_fw = None
        else:
            hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(
                cx_fw
            )
        result_fw, hidden_fw = self.layer_fw(x, hidden_fw)

        if hasattr(self, "layer_bw") and self.bidirectional:
            x_reversed = x.flip(0)
            result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
            result_bw = result_bw.flip(0)

            result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
            if hidden_fw is None and hidden_bw is None:
                h = None
                c = None
            elif hidden_fw is None:
                (h, c) = torch.jit._unwrap_optional(hidden_bw)
            elif hidden_bw is None:
                (h, c) = torch.jit._unwrap_optional(hidden_fw)
            else:
                h = torch.stack([hidden_fw[0], hidden_bw[0]], 0)  # type: ignore[list-item]
                c = torch.stack([hidden_fw[1], hidden_bw[1]], 0)  # type: ignore[list-item]
        else:
            result = result_fw
            h, c = torch.jit._unwrap_optional(hidden_fw)  # type: ignore[assignment]

        if self.batch_first:
            result.transpose_(0, 1)

        return result, (h, c)

    @classmethod
    def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
r"""
        There is no FP equivalent of this class. This function is here just to
        mimic the behavior of the `prepare` within the `torch.ao.quantization`
        flow.
        """
        assert hasattr(other, "qconfig") or (qconfig is not None)

        input_size = kwargs.get("input_size", other.input_size)
        hidden_size = kwargs.get("hidden_size", other.hidden_size)
        bias = kwargs.get("bias", other.bias)
        batch_first = kwargs.get("batch_first", other.batch_first)
        bidirectional = kwargs.get("bidirectional", other.bidirectional)

        layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
        layer.qconfig = getattr(other, "qconfig", qconfig)
        wi = getattr(other, f"weight_ih_l{layer_idx}")
        wh = getattr(other, f"weight_hh_l{layer_idx}")
        bi = getattr(other, f"bias_ih_l{layer_idx}", None)
        bh = getattr(other, f"bias_hh_l{layer_idx}", None)

        layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)

        if other.bidirectional:
            wi = getattr(other, f"weight_ih_l{layer_idx}_reverse")
            wh = getattr(other, f"weight_hh_l{layer_idx}_reverse")
            bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None)
            bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None)
            layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
        return layer


[docs]class LSTM(torch.nn.Module): r"""A quantizable long short-term memory (LSTM). For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` Attributes: layers : instances of the `_LSTMLayer` .. note:: To access the weights and biases, you need to access them per layer. See examples below. Examples:: >>> import torch.ao.nn.quantizable as nnqa >>> rnn = nnqa.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> c0 = torch.randn(2, 3, 20) >>> output, (hn, cn) = rnn(input, (h0, c0)) >>> # To get the weights: >>> # xdoctest: +SKIP >>> print(rnn.layers[0].weight_ih) tensor([[...]]) >>> print(rnn.layers[0].weight_hh) AssertionError: There is no reverse path in the non-bidirectional layer """ _FLOAT_MODULE = torch.nn.LSTM def __init__( self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.batch_first = batch_first self.dropout = float(dropout) self.bidirectional = bidirectional self.training = False # Default to eval mode. If we want to train, we will explicitly set to training. num_directions = 2 if bidirectional else 1 if ( not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or isinstance(dropout, bool) ): raise ValueError( "dropout should be a number in range [0, 1] " "representing the probability of an element being " "zeroed" ) if dropout > 0: warnings.warn( "dropout option for quantizable LSTM is ignored. " "If you are training, please, use nn.LSTM version " "followed by `prepare` step." ) if num_layers == 1: warnings.warn( "dropout option adds dropout after all but last " "recurrent layer, so non-zero dropout expects " f"num_layers greater than 1, but got dropout={dropout} " f"and num_layers={num_layers}" ) layers = [ _LSTMLayer( self.input_size, self.hidden_size, self.bias, batch_first=False, bidirectional=self.bidirectional, **factory_kwargs, ) ] for layer in range(1, num_layers): layers.append( _LSTMLayer( self.hidden_size, self.hidden_size, self.bias, batch_first=False, bidirectional=self.bidirectional, **factory_kwargs, ) ) self.layers = torch.nn.ModuleList(layers) def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): if self.batch_first: x = x.transpose(0, 1) max_batch_size = x.size(1) num_directions = 2 if self.bidirectional else 1 if hidden is None: zeros = torch.zeros( num_directions, max_batch_size, self.hidden_size, dtype=torch.float, device=x.device, ) zeros.squeeze_(0) if x.is_quantized: zeros = torch.quantize_per_tensor( zeros, scale=1.0, zero_point=0, dtype=x.dtype ) hxcx = [(zeros, zeros) for _ in range(self.num_layers)] else: hidden_non_opt = torch.jit._unwrap_optional(hidden) if isinstance(hidden_non_opt[0], Tensor): hx = hidden_non_opt[0].reshape( self.num_layers, num_directions, max_batch_size, self.hidden_size ) cx = hidden_non_opt[1].reshape( self.num_layers, num_directions, max_batch_size, self.hidden_size ) hxcx = [ (hx[idx].squeeze(0), cx[idx].squeeze(0)) for idx in range(self.num_layers) ] else: hxcx = hidden_non_opt hx_list = [] cx_list = [] for idx, layer in enumerate(self.layers): x, (h, c) = layer(x, hxcx[idx]) hx_list.append(torch.jit._unwrap_optional(h)) cx_list.append(torch.jit._unwrap_optional(c)) hx_tensor = torch.stack(hx_list) cx_tensor = torch.stack(cx_list) # We are creating another dimension for bidirectional case # need to collapse it hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1]) cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1]) if self.batch_first: x = x.transpose(0, 1) return x, (hx_tensor, cx_tensor) def _get_name(self): return "QuantizableLSTM" @classmethod def from_float(cls, other, qconfig=None): assert isinstance(other, cls._FLOAT_MODULE) assert hasattr(other, "qconfig") or qconfig observed = cls( other.input_size, other.hidden_size, other.num_layers, other.bias, other.batch_first, other.dropout, other.bidirectional, ) observed.qconfig = getattr(other, "qconfig", qconfig) for idx in range(other.num_layers): observed.layers[idx] = _LSTMLayer.from_float( other, idx, qconfig, batch_first=False ) # Prepare the model if other.training: observed.train() observed = torch.ao.quantization.prepare_qat(observed, inplace=True) else: observed.eval() observed = torch.ao.quantization.prepare(observed, inplace=True) return observed @classmethod def from_observed(cls, other): # The whole flow is float -> observed -> quantized # This class does float -> observed only raise NotImplementedError( "It looks like you are trying to convert a " "non-quantizable LSTM module. Please, see " "the examples on quantizable LSTMs." )
本页目录