rnn.py 59 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288
  1. import math
  2. import warnings
  3. import numbers
  4. from typing import List, Tuple, Optional, overload
  5. import torch
  6. from torch import Tensor
  7. from .module import Module
  8. from ..parameter import Parameter
  9. from ..utils.rnn import PackedSequence
  10. from .. import init
  11. from ... import _VF
  12. _rnn_impls = {
  13. 'RNN_TANH': _VF.rnn_tanh,
  14. 'RNN_RELU': _VF.rnn_relu,
  15. }
  16. def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  17. return tensor.index_select(dim, permutation)
  18. class RNNBase(Module):
  19. __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
  20. 'batch_first', 'dropout', 'bidirectional', 'proj_size']
  21. __jit_unused_properties__ = ['all_weights']
  22. mode: str
  23. input_size: int
  24. hidden_size: int
  25. num_layers: int
  26. bias: bool
  27. batch_first: bool
  28. dropout: float
  29. bidirectional: bool
  30. proj_size: int
  31. def __init__(self, mode: str, input_size: int, hidden_size: int,
  32. num_layers: int = 1, bias: bool = True, batch_first: bool = False,
  33. dropout: float = 0., bidirectional: bool = False, proj_size: int = 0,
  34. device=None, dtype=None) -> None:
  35. factory_kwargs = {'device': device, 'dtype': dtype}
  36. super(RNNBase, self).__init__()
  37. self.mode = mode
  38. self.input_size = input_size
  39. self.hidden_size = hidden_size
  40. self.num_layers = num_layers
  41. self.bias = bias
  42. self.batch_first = batch_first
  43. self.dropout = float(dropout)
  44. self.bidirectional = bidirectional
  45. self.proj_size = proj_size
  46. num_directions = 2 if bidirectional else 1
  47. if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
  48. isinstance(dropout, bool):
  49. raise ValueError("dropout should be a number in range [0, 1] "
  50. "representing the probability of an element being "
  51. "zeroed")
  52. if dropout > 0 and num_layers == 1:
  53. warnings.warn("dropout option adds dropout after all but last "
  54. "recurrent layer, so non-zero dropout expects "
  55. "num_layers greater than 1, but got dropout={} and "
  56. "num_layers={}".format(dropout, num_layers))
  57. if proj_size < 0:
  58. raise ValueError("proj_size should be a positive integer or zero to disable projections")
  59. if proj_size >= hidden_size:
  60. raise ValueError("proj_size has to be smaller than hidden_size")
  61. if mode == 'LSTM':
  62. gate_size = 4 * hidden_size
  63. elif mode == 'GRU':
  64. gate_size = 3 * hidden_size
  65. elif mode == 'RNN_TANH':
  66. gate_size = hidden_size
  67. elif mode == 'RNN_RELU':
  68. gate_size = hidden_size
  69. else:
  70. raise ValueError("Unrecognized RNN mode: " + mode)
  71. self._flat_weights_names = []
  72. self._all_weights = []
  73. for layer in range(num_layers):
  74. for direction in range(num_directions):
  75. real_hidden_size = proj_size if proj_size > 0 else hidden_size
  76. layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions
  77. w_ih = Parameter(torch.empty((gate_size, layer_input_size), **factory_kwargs))
  78. w_hh = Parameter(torch.empty((gate_size, real_hidden_size), **factory_kwargs))
  79. b_ih = Parameter(torch.empty(gate_size, **factory_kwargs))
  80. # Second bias vector included for CuDNN compatibility. Only one
  81. # bias vector is needed in standard definition.
  82. b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
  83. layer_params: Tuple[Tensor, ...] = ()
  84. if self.proj_size == 0:
  85. if bias:
  86. layer_params = (w_ih, w_hh, b_ih, b_hh)
  87. else:
  88. layer_params = (w_ih, w_hh)
  89. else:
  90. w_hr = Parameter(torch.empty((proj_size, hidden_size), **factory_kwargs))
  91. if bias:
  92. layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
  93. else:
  94. layer_params = (w_ih, w_hh, w_hr)
  95. suffix = '_reverse' if direction == 1 else ''
  96. param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
  97. if bias:
  98. param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
  99. if self.proj_size > 0:
  100. param_names += ['weight_hr_l{}{}']
  101. param_names = [x.format(layer, suffix) for x in param_names]
  102. for name, param in zip(param_names, layer_params):
  103. setattr(self, name, param)
  104. self._flat_weights_names.extend(param_names)
  105. self._all_weights.append(param_names)
  106. self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
  107. self.flatten_parameters()
  108. self.reset_parameters()
  109. def __setattr__(self, attr, value):
  110. if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names:
  111. # keep self._flat_weights up to date if you do self.weight = ...
  112. idx = self._flat_weights_names.index(attr)
  113. self._flat_weights[idx] = value
  114. super(RNNBase, self).__setattr__(attr, value)
  115. def flatten_parameters(self) -> None:
  116. """Resets parameter data pointer so that they can use faster code paths.
  117. Right now, this works only if the module is on the GPU and cuDNN is enabled.
  118. Otherwise, it's a no-op.
  119. """
  120. # Short-circuits if _flat_weights is only partially instantiated
  121. if len(self._flat_weights) != len(self._flat_weights_names):
  122. return
  123. for w in self._flat_weights:
  124. if not isinstance(w, Tensor):
  125. return
  126. # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
  127. # or the tensors in _flat_weights are of different dtypes
  128. first_fw = self._flat_weights[0]
  129. dtype = first_fw.dtype
  130. for fw in self._flat_weights:
  131. if (not isinstance(fw.data, Tensor) or not (fw.data.dtype == dtype) or
  132. not fw.data.is_cuda or
  133. not torch.backends.cudnn.is_acceptable(fw.data)):
  134. return
  135. # If any parameters alias, we fall back to the slower, copying code path. This is
  136. # a sufficient check, because overlapping parameter buffers that don't completely
  137. # alias would break the assumptions of the uniqueness check in
  138. # Module.named_parameters().
  139. unique_data_ptrs = set(p.data_ptr() for p in self._flat_weights)
  140. if len(unique_data_ptrs) != len(self._flat_weights):
  141. return
  142. with torch.cuda.device_of(first_fw):
  143. import torch.backends.cudnn.rnn as rnn
  144. # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
  145. # an inplace operation on self._flat_weights
  146. with torch.no_grad():
  147. if torch._use_cudnn_rnn_flatten_weight():
  148. num_weights = 4 if self.bias else 2
  149. if self.proj_size > 0:
  150. num_weights += 1
  151. torch._cudnn_rnn_flatten_weight(
  152. self._flat_weights, num_weights,
  153. self.input_size, rnn.get_cudnn_mode(self.mode),
  154. self.hidden_size, self.proj_size, self.num_layers,
  155. self.batch_first, bool(self.bidirectional))
  156. def _apply(self, fn):
  157. ret = super(RNNBase, self)._apply(fn)
  158. # Resets _flat_weights
  159. # Note: be v. careful before removing this, as 3rd party device types
  160. # likely rely on this behavior to properly .to() modules like LSTM.
  161. self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
  162. # Flattens params (on CUDA)
  163. self.flatten_parameters()
  164. return ret
  165. def reset_parameters(self) -> None:
  166. stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
  167. for weight in self.parameters():
  168. init.uniform_(weight, -stdv, stdv)
  169. def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
  170. expected_input_dim = 2 if batch_sizes is not None else 3
  171. if input.dim() != expected_input_dim:
  172. raise RuntimeError(
  173. 'input must have {} dimensions, got {}'.format(
  174. expected_input_dim, input.dim()))
  175. if self.input_size != input.size(-1):
  176. raise RuntimeError(
  177. 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
  178. self.input_size, input.size(-1)))
  179. def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
  180. if batch_sizes is not None:
  181. mini_batch = int(batch_sizes[0])
  182. else:
  183. mini_batch = input.size(0) if self.batch_first else input.size(1)
  184. num_directions = 2 if self.bidirectional else 1
  185. if self.proj_size > 0:
  186. expected_hidden_size = (self.num_layers * num_directions,
  187. mini_batch, self.proj_size)
  188. else:
  189. expected_hidden_size = (self.num_layers * num_directions,
  190. mini_batch, self.hidden_size)
  191. return expected_hidden_size
  192. def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
  193. msg: str = 'Expected hidden size {}, got {}') -> None:
  194. if hx.size() != expected_hidden_size:
  195. raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
  196. def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]):
  197. self.check_input(input, batch_sizes)
  198. expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
  199. self.check_hidden_size(hidden, expected_hidden_size)
  200. def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]):
  201. if permutation is None:
  202. return hx
  203. return apply_permutation(hx, permutation)
  204. def extra_repr(self) -> str:
  205. s = '{input_size}, {hidden_size}'
  206. if self.proj_size != 0:
  207. s += ', proj_size={proj_size}'
  208. if self.num_layers != 1:
  209. s += ', num_layers={num_layers}'
  210. if self.bias is not True:
  211. s += ', bias={bias}'
  212. if self.batch_first is not False:
  213. s += ', batch_first={batch_first}'
  214. if self.dropout != 0:
  215. s += ', dropout={dropout}'
  216. if self.bidirectional is not False:
  217. s += ', bidirectional={bidirectional}'
  218. return s.format(**self.__dict__)
  219. def __setstate__(self, d):
  220. super(RNNBase, self).__setstate__(d)
  221. if 'all_weights' in d:
  222. self._all_weights = d['all_weights']
  223. # In PyTorch 1.8 we added a proj_size member variable to LSTM.
  224. # LSTMs that were serialized via torch.save(module) before PyTorch 1.8
  225. # don't have it, so to preserve compatibility we set proj_size here.
  226. if 'proj_size' not in d:
  227. self.proj_size = 0
  228. if isinstance(self._all_weights[0][0], str):
  229. return
  230. num_layers = self.num_layers
  231. num_directions = 2 if self.bidirectional else 1
  232. self._flat_weights_names = []
  233. self._all_weights = []
  234. for layer in range(num_layers):
  235. for direction in range(num_directions):
  236. suffix = '_reverse' if direction == 1 else ''
  237. weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}',
  238. 'bias_hh_l{}{}', 'weight_hr_l{}{}']
  239. weights = [x.format(layer, suffix) for x in weights]
  240. if self.bias:
  241. if self.proj_size > 0:
  242. self._all_weights += [weights]
  243. self._flat_weights_names.extend(weights)
  244. else:
  245. self._all_weights += [weights[:4]]
  246. self._flat_weights_names.extend(weights[:4])
  247. else:
  248. if self.proj_size > 0:
  249. self._all_weights += [weights[:2]] + [weights[-1:]]
  250. self._flat_weights_names.extend(weights[:2] + [weights[-1:]])
  251. else:
  252. self._all_weights += [weights[:2]]
  253. self._flat_weights_names.extend(weights[:2])
  254. self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
  255. @property
  256. def all_weights(self) -> List[List[Parameter]]:
  257. return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
  258. def _replicate_for_data_parallel(self):
  259. replica = super(RNNBase, self)._replicate_for_data_parallel()
  260. # Need to copy these caches, otherwise the replica will share the same
  261. # flat weights list.
  262. replica._flat_weights = replica._flat_weights[:]
  263. replica._flat_weights_names = replica._flat_weights_names[:]
  264. return replica
  265. class RNN(RNNBase):
  266. r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
  267. input sequence.
  268. For each element in the input sequence, each layer computes the following
  269. function:
  270. .. math::
  271. h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
  272. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
  273. the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
  274. previous layer at time `t-1` or the initial hidden state at time `0`.
  275. If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
  276. Args:
  277. input_size: The number of expected features in the input `x`
  278. hidden_size: The number of features in the hidden state `h`
  279. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  280. would mean stacking two RNNs together to form a `stacked RNN`,
  281. with the second RNN taking in outputs of the first RNN and
  282. computing the final results. Default: 1
  283. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  284. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  285. Default: ``True``
  286. batch_first: If ``True``, then the input and output tensors are provided
  287. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  288. Note that this does not apply to hidden or cell states. See the
  289. Inputs/Outputs sections below for details. Default: ``False``
  290. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  291. RNN layer except the last layer, with dropout probability equal to
  292. :attr:`dropout`. Default: 0
  293. bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
  294. Inputs: input, h_0
  295. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  296. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  297. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  298. the input sequence. The input can also be a packed variable length sequence.
  299. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  300. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  301. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  302. :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden
  303. state for the input sequence batch. Defaults to zeros if not provided.
  304. where:
  305. .. math::
  306. \begin{aligned}
  307. N ={} & \text{batch size} \\
  308. L ={} & \text{sequence length} \\
  309. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  310. H_{in} ={} & \text{input\_size} \\
  311. H_{out} ={} & \text{hidden\_size}
  312. \end{aligned}
  313. Outputs: output, h_n
  314. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  315. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  316. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  317. `(h_t)` from the last layer of the RNN, for each `t`. If a
  318. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  319. will also be a packed sequence.
  320. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  321. :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
  322. for each element in the batch.
  323. Attributes:
  324. weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
  325. of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is
  326. `(hidden_size, num_directions * hidden_size)`
  327. weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
  328. of shape `(hidden_size, hidden_size)`
  329. bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
  330. of shape `(hidden_size)`
  331. bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
  332. of shape `(hidden_size)`
  333. .. note::
  334. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  335. where :math:`k = \frac{1}{\text{hidden\_size}}`
  336. .. note::
  337. For bidirectional RNNs, forward and backward are directions 0 and 1 respectively.
  338. Example of splitting the output layers when ``batch_first=False``:
  339. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  340. .. note::
  341. ``batch_first`` argument is ignored for unbatched inputs.
  342. .. include:: ../cudnn_rnn_determinism.rst
  343. .. include:: ../cudnn_persistent_rnn.rst
  344. Examples::
  345. >>> rnn = nn.RNN(10, 20, 2)
  346. >>> input = torch.randn(5, 3, 10)
  347. >>> h0 = torch.randn(2, 3, 20)
  348. >>> output, hn = rnn(input, h0)
  349. """
  350. def __init__(self, *args, **kwargs):
  351. if 'proj_size' in kwargs:
  352. raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
  353. self.nonlinearity = kwargs.pop('nonlinearity', 'tanh')
  354. if self.nonlinearity == 'tanh':
  355. mode = 'RNN_TANH'
  356. elif self.nonlinearity == 'relu':
  357. mode = 'RNN_RELU'
  358. else:
  359. raise ValueError("Unknown nonlinearity '{}'".format(self.nonlinearity))
  360. super(RNN, self).__init__(mode, *args, **kwargs)
  361. @overload
  362. @torch._jit_internal._overload_method # noqa: F811
  363. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  364. pass
  365. @overload
  366. @torch._jit_internal._overload_method # noqa: F811
  367. def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]:
  368. pass
  369. def forward(self, input, hx=None): # noqa: F811
  370. orig_input = input
  371. if isinstance(orig_input, PackedSequence):
  372. input, batch_sizes, sorted_indices, unsorted_indices = input
  373. max_batch_size = int(batch_sizes[0])
  374. else:
  375. batch_sizes = None
  376. is_batched = input.dim() == 3
  377. batch_dim = 0 if self.batch_first else 1
  378. if not is_batched:
  379. input = input.unsqueeze(batch_dim)
  380. if hx is not None:
  381. if hx.dim() != 2:
  382. raise RuntimeError(
  383. f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
  384. hx = hx.unsqueeze(1)
  385. else:
  386. if hx is not None and hx.dim() != 3:
  387. raise RuntimeError(
  388. f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
  389. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  390. sorted_indices = None
  391. unsorted_indices = None
  392. if hx is None:
  393. num_directions = 2 if self.bidirectional else 1
  394. hx = torch.zeros(self.num_layers * num_directions,
  395. max_batch_size, self.hidden_size,
  396. dtype=input.dtype, device=input.device)
  397. else:
  398. # Each batch of the hidden state should match the input sequence that
  399. # the user believes he/she is passing in.
  400. hx = self.permute_hidden(hx, sorted_indices)
  401. assert hx is not None
  402. self.check_forward_args(input, hx, batch_sizes)
  403. assert self.mode == 'RNN_TANH' or self.mode == 'RNN_RELU'
  404. if batch_sizes is None:
  405. if self.mode == 'RNN_TANH':
  406. result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers,
  407. self.dropout, self.training, self.bidirectional,
  408. self.batch_first)
  409. else:
  410. result = _VF.rnn_relu(input, hx, self._flat_weights, self.bias, self.num_layers,
  411. self.dropout, self.training, self.bidirectional,
  412. self.batch_first)
  413. else:
  414. if self.mode == 'RNN_TANH':
  415. result = _VF.rnn_tanh(input, batch_sizes, hx, self._flat_weights, self.bias,
  416. self.num_layers, self.dropout, self.training,
  417. self.bidirectional)
  418. else:
  419. result = _VF.rnn_relu(input, batch_sizes, hx, self._flat_weights, self.bias,
  420. self.num_layers, self.dropout, self.training,
  421. self.bidirectional)
  422. output = result[0]
  423. hidden = result[1]
  424. if isinstance(orig_input, PackedSequence):
  425. output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  426. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  427. if not is_batched:
  428. output = output.squeeze(batch_dim)
  429. hidden = hidden.squeeze(1)
  430. return output, self.permute_hidden(hidden, unsorted_indices)
  431. # XXX: LSTM and GRU implementation is different from RNNBase, this is because:
  432. # 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
  433. # its current state could not support the python Union Type or Any Type
  434. # 2. TorchScript static typing does not allow a Function or Callable type in
  435. # Dict values, so we have to separately call _VF instead of using _rnn_impls
  436. # 3. This is temporary only and in the transition state that we want to make it
  437. # on time for the release
  438. #
  439. # More discussion details in https://github.com/pytorch/pytorch/pull/23266
  440. #
  441. # TODO: remove the overriding implementations for LSTM and GRU when TorchScript
  442. # support expressing these two modules generally.
  443. class LSTM(RNNBase):
  444. r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
  445. sequence.
  446. For each element in the input sequence, each layer computes the following
  447. function:
  448. .. math::
  449. \begin{array}{ll} \\
  450. i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
  451. f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
  452. g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
  453. o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
  454. c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
  455. h_t = o_t \odot \tanh(c_t) \\
  456. \end{array}
  457. where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
  458. state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
  459. is the hidden state of the layer at time `t-1` or the initial hidden
  460. state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
  461. :math:`o_t` are the input, forget, cell, and output gates, respectively.
  462. :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  463. In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  464. (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  465. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  466. variable which is :math:`0` with probability :attr:`dropout`.
  467. If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes
  468. the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from
  469. ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly).
  470. Second, the output hidden state of each layer will be multiplied by a learnable projection
  471. matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output
  472. of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact
  473. dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.
  474. Args:
  475. input_size: The number of expected features in the input `x`
  476. hidden_size: The number of features in the hidden state `h`
  477. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  478. would mean stacking two LSTMs together to form a `stacked LSTM`,
  479. with the second LSTM taking in outputs of the first LSTM and
  480. computing the final results. Default: 1
  481. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  482. Default: ``True``
  483. batch_first: If ``True``, then the input and output tensors are provided
  484. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  485. Note that this does not apply to hidden or cell states. See the
  486. Inputs/Outputs sections below for details. Default: ``False``
  487. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  488. LSTM layer except the last layer, with dropout probability equal to
  489. :attr:`dropout`. Default: 0
  490. bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
  491. proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
  492. Inputs: input, (h_0, c_0)
  493. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  494. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  495. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  496. the input sequence. The input can also be a packed variable length sequence.
  497. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  498. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  499. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  500. :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  501. initial hidden state for each element in the input sequence.
  502. Defaults to zeros if (h_0, c_0) is not provided.
  503. * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
  504. :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  505. initial cell state for each element in the input sequence.
  506. Defaults to zeros if (h_0, c_0) is not provided.
  507. where:
  508. .. math::
  509. \begin{aligned}
  510. N ={} & \text{batch size} \\
  511. L ={} & \text{sequence length} \\
  512. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  513. H_{in} ={} & \text{input\_size} \\
  514. H_{cell} ={} & \text{hidden\_size} \\
  515. H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
  516. \end{aligned}
  517. Outputs: output, (h_n, c_n)
  518. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  519. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  520. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  521. `(h_t)` from the last layer of the LSTM, for each `t`. If a
  522. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  523. will also be a packed sequence. When ``bidirectional=True``, `output` will contain
  524. a concatenation of the forward and reverse hidden states at each time step in the sequence.
  525. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  526. :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  527. final hidden state for each element in the sequence. When ``bidirectional=True``,
  528. `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively.
  529. * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
  530. :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  531. final cell state for each element in the sequence. When ``bidirectional=True``,
  532. `c_n` will contain a concatenation of the final forward and reverse cell states, respectively.
  533. Attributes:
  534. weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
  535. `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
  536. Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If
  537. ``proj_size > 0`` was specified, the shape will be
  538. `(4*hidden_size, num_directions * proj_size)` for `k > 0`
  539. weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
  540. `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0``
  541. was specified, the shape will be `(4*hidden_size, proj_size)`.
  542. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
  543. `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
  544. bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
  545. `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
  546. weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer
  547. of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was
  548. specified.
  549. weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction.
  550. Only present when ``bidirectional=True``.
  551. weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction.
  552. Only present when ``bidirectional=True``.
  553. bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction.
  554. Only present when ``bidirectional=True``.
  555. bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction.
  556. Only present when ``bidirectional=True``.
  557. weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction.
  558. Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified.
  559. .. note::
  560. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  561. where :math:`k = \frac{1}{\text{hidden\_size}}`
  562. .. note::
  563. For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively.
  564. Example of splitting the output layers when ``batch_first=False``:
  565. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  566. .. note::
  567. For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the
  568. former contains the final forward and reverse hidden states, while the latter contains the
  569. final forward hidden state and the initial reverse hidden state.
  570. .. note::
  571. ``batch_first`` argument is ignored for unbatched inputs.
  572. .. include:: ../cudnn_rnn_determinism.rst
  573. .. include:: ../cudnn_persistent_rnn.rst
  574. Examples::
  575. >>> rnn = nn.LSTM(10, 20, 2)
  576. >>> input = torch.randn(5, 3, 10)
  577. >>> h0 = torch.randn(2, 3, 20)
  578. >>> c0 = torch.randn(2, 3, 20)
  579. >>> output, (hn, cn) = rnn(input, (h0, c0))
  580. """
  581. def __init__(self, *args, **kwargs):
  582. super(LSTM, self).__init__('LSTM', *args, **kwargs)
  583. def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
  584. if batch_sizes is not None:
  585. mini_batch = int(batch_sizes[0])
  586. else:
  587. mini_batch = input.size(0) if self.batch_first else input.size(1)
  588. num_directions = 2 if self.bidirectional else 1
  589. expected_hidden_size = (self.num_layers * num_directions,
  590. mini_batch, self.hidden_size)
  591. return expected_hidden_size
  592. # In the future, we should prevent mypy from applying contravariance rules here.
  593. # See torch/nn/modules/module.py::_forward_unimplemented
  594. def check_forward_args(self, # type: ignore[override]
  595. input: Tensor,
  596. hidden: Tuple[Tensor, Tensor],
  597. batch_sizes: Optional[Tensor],
  598. ):
  599. self.check_input(input, batch_sizes)
  600. self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
  601. 'Expected hidden[0] size {}, got {}')
  602. self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
  603. 'Expected hidden[1] size {}, got {}')
  604. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  605. def permute_hidden(self, # type: ignore[override]
  606. hx: Tuple[Tensor, Tensor],
  607. permutation: Optional[Tensor]
  608. ) -> Tuple[Tensor, Tensor]:
  609. if permutation is None:
  610. return hx
  611. return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
  612. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  613. @overload # type: ignore[override]
  614. @torch._jit_internal._overload_method # noqa: F811
  615. def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
  616. ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811
  617. pass
  618. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  619. @overload
  620. @torch._jit_internal._overload_method # noqa: F811
  621. def forward(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
  622. ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811
  623. pass
  624. def forward(self, input, hx=None): # noqa: F811
  625. orig_input = input
  626. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  627. batch_sizes = None
  628. if isinstance(orig_input, PackedSequence):
  629. input, batch_sizes, sorted_indices, unsorted_indices = input
  630. max_batch_size = batch_sizes[0]
  631. max_batch_size = int(max_batch_size)
  632. else:
  633. batch_sizes = None
  634. is_batched = input.dim() == 3
  635. batch_dim = 0 if self.batch_first else 1
  636. if not is_batched:
  637. input = input.unsqueeze(batch_dim)
  638. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  639. sorted_indices = None
  640. unsorted_indices = None
  641. if hx is None:
  642. num_directions = 2 if self.bidirectional else 1
  643. real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
  644. h_zeros = torch.zeros(self.num_layers * num_directions,
  645. max_batch_size, real_hidden_size,
  646. dtype=input.dtype, device=input.device)
  647. c_zeros = torch.zeros(self.num_layers * num_directions,
  648. max_batch_size, self.hidden_size,
  649. dtype=input.dtype, device=input.device)
  650. hx = (h_zeros, c_zeros)
  651. else:
  652. if batch_sizes is None: # If not PackedSequence input.
  653. if is_batched:
  654. if (hx[0].dim() != 3 or hx[1].dim() != 3):
  655. msg = ("For batched 3-D input, hx and cx should "
  656. f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
  657. raise RuntimeError(msg)
  658. else:
  659. if hx[0].dim() != 2 or hx[1].dim() != 2:
  660. msg = ("For unbatched 2-D input, hx and cx should "
  661. f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
  662. raise RuntimeError(msg)
  663. hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
  664. # Each batch of the hidden state should match the input sequence that
  665. # the user believes he/she is passing in.
  666. hx = self.permute_hidden(hx, sorted_indices)
  667. self.check_forward_args(input, hx, batch_sizes)
  668. if batch_sizes is None:
  669. result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  670. self.dropout, self.training, self.bidirectional, self.batch_first)
  671. else:
  672. result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
  673. self.num_layers, self.dropout, self.training, self.bidirectional)
  674. output = result[0]
  675. hidden = result[1:]
  676. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  677. if isinstance(orig_input, PackedSequence):
  678. output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  679. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  680. else:
  681. if not is_batched:
  682. output = output.squeeze(batch_dim)
  683. hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
  684. return output, self.permute_hidden(hidden, unsorted_indices)
  685. class GRU(RNNBase):
  686. r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
  687. For each element in the input sequence, each layer computes the following
  688. function:
  689. .. math::
  690. \begin{array}{ll}
  691. r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
  692. z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
  693. n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
  694. h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
  695. \end{array}
  696. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
  697. at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
  698. at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
  699. :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
  700. :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
  701. In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  702. (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  703. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  704. variable which is :math:`0` with probability :attr:`dropout`.
  705. Args:
  706. input_size: The number of expected features in the input `x`
  707. hidden_size: The number of features in the hidden state `h`
  708. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  709. would mean stacking two GRUs together to form a `stacked GRU`,
  710. with the second GRU taking in outputs of the first GRU and
  711. computing the final results. Default: 1
  712. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  713. Default: ``True``
  714. batch_first: If ``True``, then the input and output tensors are provided
  715. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  716. Note that this does not apply to hidden or cell states. See the
  717. Inputs/Outputs sections below for details. Default: ``False``
  718. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  719. GRU layer except the last layer, with dropout probability equal to
  720. :attr:`dropout`. Default: 0
  721. bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
  722. Inputs: input, h_0
  723. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  724. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  725. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  726. the input sequence. The input can also be a packed variable length sequence.
  727. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  728. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  729. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
  730. :math:`(D * \text{num\_layers}, N, H_{out})`
  731. containing the initial hidden state for the input sequence. Defaults to zeros if not provided.
  732. where:
  733. .. math::
  734. \begin{aligned}
  735. N ={} & \text{batch size} \\
  736. L ={} & \text{sequence length} \\
  737. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  738. H_{in} ={} & \text{input\_size} \\
  739. H_{out} ={} & \text{hidden\_size}
  740. \end{aligned}
  741. Outputs: output, h_n
  742. * **output**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  743. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  744. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  745. `(h_t)` from the last layer of the GRU, for each `t`. If a
  746. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  747. will also be a packed sequence.
  748. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
  749. :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
  750. for the input sequence.
  751. Attributes:
  752. weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
  753. (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
  754. Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
  755. weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
  756. (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
  757. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
  758. (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
  759. bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
  760. (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
  761. .. note::
  762. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  763. where :math:`k = \frac{1}{\text{hidden\_size}}`
  764. .. note::
  765. For bidirectional GRUs, forward and backward are directions 0 and 1 respectively.
  766. Example of splitting the output layers when ``batch_first=False``:
  767. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  768. .. note::
  769. ``batch_first`` argument is ignored for unbatched inputs.
  770. .. include:: ../cudnn_persistent_rnn.rst
  771. Examples::
  772. >>> rnn = nn.GRU(10, 20, 2)
  773. >>> input = torch.randn(5, 3, 10)
  774. >>> h0 = torch.randn(2, 3, 20)
  775. >>> output, hn = rnn(input, h0)
  776. """
  777. def __init__(self, *args, **kwargs):
  778. if 'proj_size' in kwargs:
  779. raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
  780. super(GRU, self).__init__('GRU', *args, **kwargs)
  781. @overload # type: ignore[override]
  782. @torch._jit_internal._overload_method # noqa: F811
  783. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: # noqa: F811
  784. pass
  785. @overload
  786. @torch._jit_internal._overload_method # noqa: F811
  787. def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: # noqa: F811
  788. pass
  789. def forward(self, input, hx=None): # noqa: F811
  790. orig_input = input
  791. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  792. if isinstance(orig_input, PackedSequence):
  793. input, batch_sizes, sorted_indices, unsorted_indices = input
  794. max_batch_size = batch_sizes[0]
  795. max_batch_size = int(max_batch_size)
  796. else:
  797. batch_sizes = None
  798. is_batched = input.dim() == 3
  799. batch_dim = 0 if self.batch_first else 1
  800. if not is_batched:
  801. input = input.unsqueeze(batch_dim)
  802. if hx is not None:
  803. if hx.dim() != 2:
  804. raise RuntimeError(
  805. f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
  806. hx = hx.unsqueeze(1)
  807. else:
  808. if hx is not None and hx.dim() != 3:
  809. raise RuntimeError(
  810. f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
  811. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  812. sorted_indices = None
  813. unsorted_indices = None
  814. if hx is None:
  815. num_directions = 2 if self.bidirectional else 1
  816. hx = torch.zeros(self.num_layers * num_directions,
  817. max_batch_size, self.hidden_size,
  818. dtype=input.dtype, device=input.device)
  819. else:
  820. # Each batch of the hidden state should match the input sequence that
  821. # the user believes he/she is passing in.
  822. hx = self.permute_hidden(hx, sorted_indices)
  823. self.check_forward_args(input, hx, batch_sizes)
  824. if batch_sizes is None:
  825. result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
  826. self.dropout, self.training, self.bidirectional, self.batch_first)
  827. else:
  828. result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,
  829. self.num_layers, self.dropout, self.training, self.bidirectional)
  830. output = result[0]
  831. hidden = result[1]
  832. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  833. if isinstance(orig_input, PackedSequence):
  834. output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  835. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  836. else:
  837. if not is_batched:
  838. output = output.squeeze(batch_dim)
  839. hidden = hidden.squeeze(1)
  840. return output, self.permute_hidden(hidden, unsorted_indices)
  841. class RNNCellBase(Module):
  842. __constants__ = ['input_size', 'hidden_size', 'bias']
  843. input_size: int
  844. hidden_size: int
  845. bias: bool
  846. weight_ih: Tensor
  847. weight_hh: Tensor
  848. # WARNING: bias_ih and bias_hh purposely not defined here.
  849. # See https://github.com/pytorch/pytorch/issues/39670
  850. def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
  851. device=None, dtype=None) -> None:
  852. factory_kwargs = {'device': device, 'dtype': dtype}
  853. super(RNNCellBase, self).__init__()
  854. self.input_size = input_size
  855. self.hidden_size = hidden_size
  856. self.bias = bias
  857. self.weight_ih = Parameter(torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs))
  858. self.weight_hh = Parameter(torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs))
  859. if bias:
  860. self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs))
  861. self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs))
  862. else:
  863. self.register_parameter('bias_ih', None)
  864. self.register_parameter('bias_hh', None)
  865. self.reset_parameters()
  866. def extra_repr(self) -> str:
  867. s = '{input_size}, {hidden_size}'
  868. if 'bias' in self.__dict__ and self.bias is not True:
  869. s += ', bias={bias}'
  870. if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
  871. s += ', nonlinearity={nonlinearity}'
  872. return s.format(**self.__dict__)
  873. def reset_parameters(self) -> None:
  874. stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
  875. for weight in self.parameters():
  876. init.uniform_(weight, -stdv, stdv)
  877. class RNNCell(RNNCellBase):
  878. r"""An Elman RNN cell with tanh or ReLU non-linearity.
  879. .. math::
  880. h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
  881. If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
  882. Args:
  883. input_size: The number of expected features in the input `x`
  884. hidden_size: The number of features in the hidden state `h`
  885. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  886. Default: ``True``
  887. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  888. Inputs: input, hidden
  889. - **input**: tensor containing input features
  890. - **hidden**: tensor containing the initial hidden state
  891. Defaults to zero if not provided.
  892. Outputs: h'
  893. - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
  894. for each element in the batch
  895. Shape:
  896. - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
  897. :math:`H_{in}` = `input_size`.
  898. - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
  899. state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
  900. - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
  901. Attributes:
  902. weight_ih: the learnable input-hidden weights, of shape
  903. `(hidden_size, input_size)`
  904. weight_hh: the learnable hidden-hidden weights, of shape
  905. `(hidden_size, hidden_size)`
  906. bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
  907. bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
  908. .. note::
  909. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  910. where :math:`k = \frac{1}{\text{hidden\_size}}`
  911. Examples::
  912. >>> rnn = nn.RNNCell(10, 20)
  913. >>> input = torch.randn(6, 3, 10)
  914. >>> hx = torch.randn(3, 20)
  915. >>> output = []
  916. >>> for i in range(6):
  917. hx = rnn(input[i], hx)
  918. output.append(hx)
  919. """
  920. __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
  921. nonlinearity: str
  922. def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
  923. device=None, dtype=None) -> None:
  924. factory_kwargs = {'device': device, 'dtype': dtype}
  925. super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
  926. self.nonlinearity = nonlinearity
  927. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  928. assert input.dim() in (1, 2), \
  929. f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
  930. is_batched = input.dim() == 2
  931. if not is_batched:
  932. input = input.unsqueeze(0)
  933. if hx is None:
  934. hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  935. else:
  936. hx = hx.unsqueeze(0) if not is_batched else hx
  937. if self.nonlinearity == "tanh":
  938. ret = _VF.rnn_tanh_cell(
  939. input, hx,
  940. self.weight_ih, self.weight_hh,
  941. self.bias_ih, self.bias_hh,
  942. )
  943. elif self.nonlinearity == "relu":
  944. ret = _VF.rnn_relu_cell(
  945. input, hx,
  946. self.weight_ih, self.weight_hh,
  947. self.bias_ih, self.bias_hh,
  948. )
  949. else:
  950. ret = input # TODO: remove when jit supports exception flow
  951. raise RuntimeError(
  952. "Unknown nonlinearity: {}".format(self.nonlinearity))
  953. if not is_batched:
  954. ret = ret.squeeze(0)
  955. return ret
  956. class LSTMCell(RNNCellBase):
  957. r"""A long short-term memory (LSTM) cell.
  958. .. math::
  959. \begin{array}{ll}
  960. i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
  961. f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
  962. g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
  963. o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
  964. c' = f * c + i * g \\
  965. h' = o * \tanh(c') \\
  966. \end{array}
  967. where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
  968. Args:
  969. input_size: The number of expected features in the input `x`
  970. hidden_size: The number of features in the hidden state `h`
  971. bias: If ``False``, then the layer does not use bias weights `b_ih` and
  972. `b_hh`. Default: ``True``
  973. Inputs: input, (h_0, c_0)
  974. - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features
  975. - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state
  976. - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state
  977. If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
  978. Outputs: (h_1, c_1)
  979. - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state
  980. - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state
  981. Attributes:
  982. weight_ih: the learnable input-hidden weights, of shape
  983. `(4*hidden_size, input_size)`
  984. weight_hh: the learnable hidden-hidden weights, of shape
  985. `(4*hidden_size, hidden_size)`
  986. bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
  987. bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
  988. .. note::
  989. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  990. where :math:`k = \frac{1}{\text{hidden\_size}}`
  991. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  992. Examples::
  993. >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
  994. >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
  995. >>> hx = torch.randn(3, 20) # (batch, hidden_size)
  996. >>> cx = torch.randn(3, 20)
  997. >>> output = []
  998. >>> for i in range(input.size()[0]):
  999. hx, cx = rnn(input[i], (hx, cx))
  1000. output.append(hx)
  1001. >>> output = torch.stack(output, dim=0)
  1002. """
  1003. def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
  1004. device=None, dtype=None) -> None:
  1005. factory_kwargs = {'device': device, 'dtype': dtype}
  1006. super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
  1007. def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
  1008. assert input.dim() in (1, 2), \
  1009. f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
  1010. is_batched = input.dim() == 2
  1011. if not is_batched:
  1012. input = input.unsqueeze(0)
  1013. if hx is None:
  1014. zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  1015. hx = (zeros, zeros)
  1016. else:
  1017. hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
  1018. ret = _VF.lstm_cell(
  1019. input, hx,
  1020. self.weight_ih, self.weight_hh,
  1021. self.bias_ih, self.bias_hh,
  1022. )
  1023. if not is_batched:
  1024. ret = (ret[0].squeeze(0), ret[1].squeeze(0))
  1025. return ret
  1026. class GRUCell(RNNCellBase):
  1027. r"""A gated recurrent unit (GRU) cell
  1028. .. math::
  1029. \begin{array}{ll}
  1030. r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
  1031. z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
  1032. n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
  1033. h' = (1 - z) * n + z * h
  1034. \end{array}
  1035. where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
  1036. Args:
  1037. input_size: The number of expected features in the input `x`
  1038. hidden_size: The number of features in the hidden state `h`
  1039. bias: If ``False``, then the layer does not use bias weights `b_ih` and
  1040. `b_hh`. Default: ``True``
  1041. Inputs: input, hidden
  1042. - **input** : tensor containing input features
  1043. - **hidden** : tensor containing the initial hidden
  1044. state for each element in the batch.
  1045. Defaults to zero if not provided.
  1046. Outputs: h'
  1047. - **h'** : tensor containing the next hidden state
  1048. for each element in the batch
  1049. Shape:
  1050. - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
  1051. :math:`H_{in}` = `input_size`.
  1052. - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
  1053. state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
  1054. - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
  1055. Attributes:
  1056. weight_ih: the learnable input-hidden weights, of shape
  1057. `(3*hidden_size, input_size)`
  1058. weight_hh: the learnable hidden-hidden weights, of shape
  1059. `(3*hidden_size, hidden_size)`
  1060. bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
  1061. bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
  1062. .. note::
  1063. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1064. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1065. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  1066. Examples::
  1067. >>> rnn = nn.GRUCell(10, 20)
  1068. >>> input = torch.randn(6, 3, 10)
  1069. >>> hx = torch.randn(3, 20)
  1070. >>> output = []
  1071. >>> for i in range(6):
  1072. hx = rnn(input[i], hx)
  1073. output.append(hx)
  1074. """
  1075. def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
  1076. device=None, dtype=None) -> None:
  1077. factory_kwargs = {'device': device, 'dtype': dtype}
  1078. super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
  1079. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  1080. assert input.dim() in (1, 2), \
  1081. f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
  1082. is_batched = input.dim() == 2
  1083. if not is_batched:
  1084. input = input.unsqueeze(0)
  1085. if hx is None:
  1086. hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  1087. else:
  1088. hx = hx.unsqueeze(0) if not is_batched else hx
  1089. ret = _VF.gru_cell(
  1090. input, hx,
  1091. self.weight_ih, self.weight_hh,
  1092. self.bias_ih, self.bias_hh,
  1093. )
  1094. if not is_batched:
  1095. ret = ret.squeeze(0)
  1096. return ret