activation.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450
  1. import warnings
  2. from typing import Optional, Tuple
  3. import torch
  4. from torch import Tensor
  5. from .linear import NonDynamicallyQuantizableLinear
  6. from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
  7. from torch.nn.parameter import Parameter
  8. from .module import Module
  9. from .. import functional as F
  10. class Threshold(Module):
  11. r"""Thresholds each element of the input Tensor.
  12. Threshold is defined as:
  13. .. math::
  14. y =
  15. \begin{cases}
  16. x, &\text{ if } x > \text{threshold} \\
  17. \text{value}, &\text{ otherwise }
  18. \end{cases}
  19. Args:
  20. threshold: The value to threshold at
  21. value: The value to replace with
  22. inplace: can optionally do the operation in-place. Default: ``False``
  23. Shape:
  24. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  25. - Output: :math:`(*)`, same shape as the input.
  26. Examples::
  27. >>> m = nn.Threshold(0.1, 20)
  28. >>> input = torch.randn(2)
  29. >>> output = m(input)
  30. """
  31. __constants__ = ['threshold', 'value', 'inplace']
  32. threshold: float
  33. value: float
  34. inplace: bool
  35. def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
  36. super(Threshold, self).__init__()
  37. self.threshold = threshold
  38. self.value = value
  39. self.inplace = inplace
  40. # TODO: check in THNN (if inplace == True, then assert value <= threshold)
  41. def forward(self, input: Tensor) -> Tensor:
  42. return F.threshold(input, self.threshold, self.value, self.inplace)
  43. def extra_repr(self):
  44. inplace_str = ', inplace=True' if self.inplace else ''
  45. return 'threshold={}, value={}{}'.format(
  46. self.threshold, self.value, inplace_str
  47. )
  48. class ReLU(Module):
  49. r"""Applies the rectified linear unit function element-wise:
  50. :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
  51. Args:
  52. inplace: can optionally do the operation in-place. Default: ``False``
  53. Shape:
  54. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  55. - Output: :math:`(*)`, same shape as the input.
  56. .. image:: ../scripts/activation_images/ReLU.png
  57. Examples::
  58. >>> m = nn.ReLU()
  59. >>> input = torch.randn(2)
  60. >>> output = m(input)
  61. An implementation of CReLU - https://arxiv.org/abs/1603.05201
  62. >>> m = nn.ReLU()
  63. >>> input = torch.randn(2).unsqueeze(0)
  64. >>> output = torch.cat((m(input),m(-input)))
  65. """
  66. __constants__ = ['inplace']
  67. inplace: bool
  68. def __init__(self, inplace: bool = False):
  69. super(ReLU, self).__init__()
  70. self.inplace = inplace
  71. def forward(self, input: Tensor) -> Tensor:
  72. return F.relu(input, inplace=self.inplace)
  73. def extra_repr(self) -> str:
  74. inplace_str = 'inplace=True' if self.inplace else ''
  75. return inplace_str
  76. class RReLU(Module):
  77. r"""Applies the randomized leaky rectified liner unit function, element-wise,
  78. as described in the paper:
  79. `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
  80. The function is defined as:
  81. .. math::
  82. \text{RReLU}(x) =
  83. \begin{cases}
  84. x & \text{if } x \geq 0 \\
  85. ax & \text{ otherwise }
  86. \end{cases}
  87. where :math:`a` is randomly sampled from uniform distribution
  88. :math:`\mathcal{U}(\text{lower}, \text{upper})`.
  89. See: https://arxiv.org/pdf/1505.00853.pdf
  90. Args:
  91. lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
  92. upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
  93. inplace: can optionally do the operation in-place. Default: ``False``
  94. Shape:
  95. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  96. - Output: :math:`(*)`, same shape as the input.
  97. .. image:: ../scripts/activation_images/RReLU.png
  98. Examples::
  99. >>> m = nn.RReLU(0.1, 0.3)
  100. >>> input = torch.randn(2)
  101. >>> output = m(input)
  102. .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
  103. https://arxiv.org/abs/1505.00853
  104. """
  105. __constants__ = ['lower', 'upper', 'inplace']
  106. lower: float
  107. upper: float
  108. inplace: bool
  109. def __init__(
  110. self,
  111. lower: float = 1. / 8,
  112. upper: float = 1. / 3,
  113. inplace: bool = False
  114. ):
  115. super(RReLU, self).__init__()
  116. self.lower = lower
  117. self.upper = upper
  118. self.inplace = inplace
  119. def forward(self, input: Tensor) -> Tensor:
  120. return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
  121. def extra_repr(self):
  122. inplace_str = ', inplace=True' if self.inplace else ''
  123. return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
  124. class Hardtanh(Module):
  125. r"""Applies the HardTanh function element-wise.
  126. HardTanh is defined as:
  127. .. math::
  128. \text{HardTanh}(x) = \begin{cases}
  129. \text{max\_val} & \text{ if } x > \text{ max\_val } \\
  130. \text{min\_val} & \text{ if } x < \text{ min\_val } \\
  131. x & \text{ otherwise } \\
  132. \end{cases}
  133. Args:
  134. min_val: minimum value of the linear region range. Default: -1
  135. max_val: maximum value of the linear region range. Default: 1
  136. inplace: can optionally do the operation in-place. Default: ``False``
  137. Keyword arguments :attr:`min_value` and :attr:`max_value`
  138. have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
  139. Shape:
  140. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  141. - Output: :math:`(*)`, same shape as the input.
  142. .. image:: ../scripts/activation_images/Hardtanh.png
  143. Examples::
  144. >>> m = nn.Hardtanh(-2, 2)
  145. >>> input = torch.randn(2)
  146. >>> output = m(input)
  147. """
  148. __constants__ = ['min_val', 'max_val', 'inplace']
  149. min_val: float
  150. max_val: float
  151. inplace: bool
  152. def __init__(
  153. self,
  154. min_val: float = -1.,
  155. max_val: float = 1.,
  156. inplace: bool = False,
  157. min_value: Optional[float] = None,
  158. max_value: Optional[float] = None
  159. ) -> None:
  160. super(Hardtanh, self).__init__()
  161. if min_value is not None:
  162. warnings.warn("keyword argument min_value is deprecated and rename to min_val")
  163. min_val = min_value
  164. if max_value is not None:
  165. warnings.warn("keyword argument max_value is deprecated and rename to max_val")
  166. max_val = max_value
  167. self.min_val = min_val
  168. self.max_val = max_val
  169. self.inplace = inplace
  170. assert self.max_val > self.min_val
  171. def forward(self, input: Tensor) -> Tensor:
  172. return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
  173. def extra_repr(self) -> str:
  174. inplace_str = ', inplace=True' if self.inplace else ''
  175. return 'min_val={}, max_val={}{}'.format(
  176. self.min_val, self.max_val, inplace_str
  177. )
  178. class ReLU6(Hardtanh):
  179. r"""Applies the element-wise function:
  180. .. math::
  181. \text{ReLU6}(x) = \min(\max(0,x), 6)
  182. Args:
  183. inplace: can optionally do the operation in-place. Default: ``False``
  184. Shape:
  185. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  186. - Output: :math:`(*)`, same shape as the input.
  187. .. image:: ../scripts/activation_images/ReLU6.png
  188. Examples::
  189. >>> m = nn.ReLU6()
  190. >>> input = torch.randn(2)
  191. >>> output = m(input)
  192. """
  193. def __init__(self, inplace: bool = False):
  194. super(ReLU6, self).__init__(0., 6., inplace)
  195. def extra_repr(self) -> str:
  196. inplace_str = 'inplace=True' if self.inplace else ''
  197. return inplace_str
  198. class Sigmoid(Module):
  199. r"""Applies the element-wise function:
  200. .. math::
  201. \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
  202. Shape:
  203. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  204. - Output: :math:`(*)`, same shape as the input.
  205. .. image:: ../scripts/activation_images/Sigmoid.png
  206. Examples::
  207. >>> m = nn.Sigmoid()
  208. >>> input = torch.randn(2)
  209. >>> output = m(input)
  210. """
  211. def forward(self, input: Tensor) -> Tensor:
  212. return torch.sigmoid(input)
  213. class Hardsigmoid(Module):
  214. r"""Applies the Hardsigmoid function element-wise.
  215. Hardsigmoid is defined as:
  216. .. math::
  217. \text{Hardsigmoid}(x) = \begin{cases}
  218. 0 & \text{if~} x \le -3, \\
  219. 1 & \text{if~} x \ge +3, \\
  220. x / 6 + 1 / 2 & \text{otherwise}
  221. \end{cases}
  222. Args:
  223. inplace: can optionally do the operation in-place. Default: ``False``
  224. Shape:
  225. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  226. - Output: :math:`(*)`, same shape as the input.
  227. .. image:: ../scripts/activation_images/Hardsigmoid.png
  228. Examples::
  229. >>> m = nn.Hardsigmoid()
  230. >>> input = torch.randn(2)
  231. >>> output = m(input)
  232. """
  233. __constants__ = ['inplace']
  234. inplace: bool
  235. def __init__(self, inplace : bool = False) -> None:
  236. super(Hardsigmoid, self).__init__()
  237. self.inplace = inplace
  238. def forward(self, input: Tensor) -> Tensor:
  239. return F.hardsigmoid(input, self.inplace)
  240. class Tanh(Module):
  241. r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
  242. Tanh is defined as:
  243. .. math::
  244. \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
  245. Shape:
  246. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  247. - Output: :math:`(*)`, same shape as the input.
  248. .. image:: ../scripts/activation_images/Tanh.png
  249. Examples::
  250. >>> m = nn.Tanh()
  251. >>> input = torch.randn(2)
  252. >>> output = m(input)
  253. """
  254. def forward(self, input: Tensor) -> Tensor:
  255. return torch.tanh(input)
  256. class SiLU(Module):
  257. r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
  258. The SiLU function is also known as the swish function.
  259. .. math::
  260. \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
  261. .. note::
  262. See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
  263. where the SiLU (Sigmoid Linear Unit) was originally coined, and see
  264. `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
  265. in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
  266. a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
  267. where the SiLU was experimented with later.
  268. Shape:
  269. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  270. - Output: :math:`(*)`, same shape as the input.
  271. .. image:: ../scripts/activation_images/SiLU.png
  272. Examples::
  273. >>> m = nn.SiLU()
  274. >>> input = torch.randn(2)
  275. >>> output = m(input)
  276. """
  277. __constants__ = ['inplace']
  278. inplace: bool
  279. def __init__(self, inplace: bool = False):
  280. super(SiLU, self).__init__()
  281. self.inplace = inplace
  282. def forward(self, input: Tensor) -> Tensor:
  283. return F.silu(input, inplace=self.inplace)
  284. def extra_repr(self) -> str:
  285. inplace_str = 'inplace=True' if self.inplace else ''
  286. return inplace_str
  287. class Mish(Module):
  288. r"""Applies the Mish function, element-wise.
  289. Mish: A Self Regularized Non-Monotonic Neural Activation Function.
  290. .. math::
  291. \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
  292. .. note::
  293. See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
  294. Shape:
  295. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  296. - Output: :math:`(*)`, same shape as the input.
  297. .. image:: ../scripts/activation_images/Mish.png
  298. Examples::
  299. >>> m = nn.Mish()
  300. >>> input = torch.randn(2)
  301. >>> output = m(input)
  302. """
  303. __constants__ = ['inplace']
  304. inplace: bool
  305. def __init__(self, inplace: bool = False):
  306. super(Mish, self).__init__()
  307. self.inplace = inplace
  308. def forward(self, input: Tensor) -> Tensor:
  309. return F.mish(input, inplace=self.inplace)
  310. def extra_repr(self) -> str:
  311. inplace_str = 'inplace=True' if self.inplace else ''
  312. return inplace_str
  313. class Hardswish(Module):
  314. r"""Applies the hardswish function, element-wise, as described in the paper:
  315. `Searching for MobileNetV3`_.
  316. .. math::
  317. \text{Hardswish}(x) = \begin{cases}
  318. 0 & \text{if~} x \le -3, \\
  319. x & \text{if~} x \ge +3, \\
  320. x \cdot (x + 3) /6 & \text{otherwise}
  321. \end{cases}
  322. Args:
  323. inplace: can optionally do the operation in-place. Default: ``False``
  324. Shape:
  325. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  326. - Output: :math:`(*)`, same shape as the input.
  327. .. image:: ../scripts/activation_images/Hardswish.png
  328. Examples::
  329. >>> m = nn.Hardswish()
  330. >>> input = torch.randn(2)
  331. >>> output = m(input)
  332. .. _`Searching for MobileNetV3`:
  333. https://arxiv.org/abs/1905.02244
  334. """
  335. __constants__ = ['inplace']
  336. inplace: bool
  337. def __init__(self, inplace : bool = False) -> None:
  338. super(Hardswish, self).__init__()
  339. self.inplace = inplace
  340. def forward(self, input: Tensor) -> Tensor:
  341. return F.hardswish(input, self.inplace)
  342. class ELU(Module):
  343. r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
  344. in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
  345. Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
  346. ELU is defined as:
  347. .. math::
  348. \text{ELU}(x) = \begin{cases}
  349. x, & \text{ if } x > 0\\
  350. \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
  351. \end{cases}
  352. Args:
  353. alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
  354. inplace: can optionally do the operation in-place. Default: ``False``
  355. Shape:
  356. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  357. - Output: :math:`(*)`, same shape as the input.
  358. .. image:: ../scripts/activation_images/ELU.png
  359. Examples::
  360. >>> m = nn.ELU()
  361. >>> input = torch.randn(2)
  362. >>> output = m(input)
  363. """
  364. __constants__ = ['alpha', 'inplace']
  365. alpha: float
  366. inplace: bool
  367. def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
  368. super(ELU, self).__init__()
  369. self.alpha = alpha
  370. self.inplace = inplace
  371. def forward(self, input: Tensor) -> Tensor:
  372. return F.elu(input, self.alpha, self.inplace)
  373. def extra_repr(self) -> str:
  374. inplace_str = ', inplace=True' if self.inplace else ''
  375. return 'alpha={}{}'.format(self.alpha, inplace_str)
  376. class CELU(Module):
  377. r"""Applies the element-wise function:
  378. .. math::
  379. \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
  380. More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
  381. Args:
  382. alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
  383. inplace: can optionally do the operation in-place. Default: ``False``
  384. Shape:
  385. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  386. - Output: :math:`(*)`, same shape as the input.
  387. .. image:: ../scripts/activation_images/CELU.png
  388. Examples::
  389. >>> m = nn.CELU()
  390. >>> input = torch.randn(2)
  391. >>> output = m(input)
  392. .. _`Continuously Differentiable Exponential Linear Units`:
  393. https://arxiv.org/abs/1704.07483
  394. """
  395. __constants__ = ['alpha', 'inplace']
  396. alpha: float
  397. inplace: bool
  398. def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
  399. super(CELU, self).__init__()
  400. self.alpha = alpha
  401. self.inplace = inplace
  402. def forward(self, input: Tensor) -> Tensor:
  403. return F.celu(input, self.alpha, self.inplace)
  404. def extra_repr(self) -> str:
  405. inplace_str = ', inplace=True' if self.inplace else ''
  406. return 'alpha={}{}'.format(self.alpha, inplace_str)
  407. class SELU(Module):
  408. r"""Applied element-wise, as:
  409. .. math::
  410. \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
  411. with :math:`\alpha = 1.6732632423543772848170429916717` and
  412. :math:`\text{scale} = 1.0507009873554804934193349852946`.
  413. .. warning::
  414. When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
  415. ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
  416. in order to get `Self-Normalizing Neural Networks`_.
  417. See :func:`torch.nn.init.calculate_gain` for more information.
  418. More details can be found in the paper `Self-Normalizing Neural Networks`_ .
  419. Args:
  420. inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
  421. Shape:
  422. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  423. - Output: :math:`(*)`, same shape as the input.
  424. .. image:: ../scripts/activation_images/SELU.png
  425. Examples::
  426. >>> m = nn.SELU()
  427. >>> input = torch.randn(2)
  428. >>> output = m(input)
  429. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  430. """
  431. __constants__ = ['inplace']
  432. inplace: bool
  433. def __init__(self, inplace: bool = False) -> None:
  434. super(SELU, self).__init__()
  435. self.inplace = inplace
  436. def forward(self, input: Tensor) -> Tensor:
  437. return F.selu(input, self.inplace)
  438. def extra_repr(self) -> str:
  439. inplace_str = 'inplace=True' if self.inplace else ''
  440. return inplace_str
  441. class GLU(Module):
  442. r"""Applies the gated linear unit function
  443. :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
  444. of the input matrices and :math:`b` is the second half.
  445. Args:
  446. dim (int): the dimension on which to split the input. Default: -1
  447. Shape:
  448. - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
  449. dimensions
  450. - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
  451. Examples::
  452. >>> m = nn.GLU()
  453. >>> input = torch.randn(4, 2)
  454. >>> output = m(input)
  455. """
  456. __constants__ = ['dim']
  457. dim: int
  458. def __init__(self, dim: int = -1) -> None:
  459. super(GLU, self).__init__()
  460. self.dim = dim
  461. def forward(self, input: Tensor) -> Tensor:
  462. return F.glu(input, self.dim)
  463. def extra_repr(self) -> str:
  464. return 'dim={}'.format(self.dim)
  465. class GELU(Module):
  466. r"""Applies the Gaussian Error Linear Units function:
  467. .. math:: \text{GELU}(x) = x * \Phi(x)
  468. where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
  469. When the approximate argument is 'tanh', Gelu is estimated with:
  470. :math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
  471. Args:
  472. approximate (string, optional): the gelu approximation algorithm to use:
  473. ``'none'`` | ``'tanh'``. Default: ``'none'``
  474. Shape:
  475. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  476. - Output: :math:`(*)`, same shape as the input.
  477. .. image:: ../scripts/activation_images/GELU.png
  478. Examples::
  479. >>> m = nn.GELU()
  480. >>> input = torch.randn(2)
  481. >>> output = m(input)
  482. """
  483. __constants__ = ['approximate']
  484. approximate: str
  485. def __init__(self, approximate: str = 'none') -> None:
  486. super(GELU, self).__init__()
  487. self.approximate = approximate
  488. def forward(self, input: Tensor) -> Tensor:
  489. return F.gelu(input, approximate=self.approximate)
  490. def extra_repr(self) -> str:
  491. return 'approximate={}'.format(self.approximate)
  492. class Hardshrink(Module):
  493. r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
  494. Hardshrink is defined as:
  495. .. math::
  496. \text{HardShrink}(x) =
  497. \begin{cases}
  498. x, & \text{ if } x > \lambda \\
  499. x, & \text{ if } x < -\lambda \\
  500. 0, & \text{ otherwise }
  501. \end{cases}
  502. Args:
  503. lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
  504. Shape:
  505. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  506. - Output: :math:`(*)`, same shape as the input.
  507. .. image:: ../scripts/activation_images/Hardshrink.png
  508. Examples::
  509. >>> m = nn.Hardshrink()
  510. >>> input = torch.randn(2)
  511. >>> output = m(input)
  512. """
  513. __constants__ = ['lambd']
  514. lambd: float
  515. def __init__(self, lambd: float = 0.5) -> None:
  516. super(Hardshrink, self).__init__()
  517. self.lambd = lambd
  518. def forward(self, input: Tensor) -> Tensor:
  519. return F.hardshrink(input, self.lambd)
  520. def extra_repr(self) -> str:
  521. return '{}'.format(self.lambd)
  522. class LeakyReLU(Module):
  523. r"""Applies the element-wise function:
  524. .. math::
  525. \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
  526. or
  527. .. math::
  528. \text{LeakyRELU}(x) =
  529. \begin{cases}
  530. x, & \text{ if } x \geq 0 \\
  531. \text{negative\_slope} \times x, & \text{ otherwise }
  532. \end{cases}
  533. Args:
  534. negative_slope: Controls the angle of the negative slope. Default: 1e-2
  535. inplace: can optionally do the operation in-place. Default: ``False``
  536. Shape:
  537. - Input: :math:`(*)` where `*` means, any number of additional
  538. dimensions
  539. - Output: :math:`(*)`, same shape as the input
  540. .. image:: ../scripts/activation_images/LeakyReLU.png
  541. Examples::
  542. >>> m = nn.LeakyReLU(0.1)
  543. >>> input = torch.randn(2)
  544. >>> output = m(input)
  545. """
  546. __constants__ = ['inplace', 'negative_slope']
  547. inplace: bool
  548. negative_slope: float
  549. def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
  550. super(LeakyReLU, self).__init__()
  551. self.negative_slope = negative_slope
  552. self.inplace = inplace
  553. def forward(self, input: Tensor) -> Tensor:
  554. return F.leaky_relu(input, self.negative_slope, self.inplace)
  555. def extra_repr(self) -> str:
  556. inplace_str = ', inplace=True' if self.inplace else ''
  557. return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
  558. class LogSigmoid(Module):
  559. r"""Applies the element-wise function:
  560. .. math::
  561. \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
  562. Shape:
  563. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  564. - Output: :math:`(*)`, same shape as the input.
  565. .. image:: ../scripts/activation_images/LogSigmoid.png
  566. Examples::
  567. >>> m = nn.LogSigmoid()
  568. >>> input = torch.randn(2)
  569. >>> output = m(input)
  570. """
  571. def forward(self, input: Tensor) -> Tensor:
  572. return F.logsigmoid(input)
  573. class Softplus(Module):
  574. r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
  575. \log(1 + \exp(\beta * x))` element-wise.
  576. SoftPlus is a smooth approximation to the ReLU function and can be used
  577. to constrain the output of a machine to always be positive.
  578. For numerical stability the implementation reverts to the linear function
  579. when :math:`input \times \beta > threshold`.
  580. Args:
  581. beta: the :math:`\beta` value for the Softplus formulation. Default: 1
  582. threshold: values above this revert to a linear function. Default: 20
  583. Shape:
  584. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  585. - Output: :math:`(*)`, same shape as the input.
  586. .. image:: ../scripts/activation_images/Softplus.png
  587. Examples::
  588. >>> m = nn.Softplus()
  589. >>> input = torch.randn(2)
  590. >>> output = m(input)
  591. """
  592. __constants__ = ['beta', 'threshold']
  593. beta: int
  594. threshold: int
  595. def __init__(self, beta: int = 1, threshold: int = 20) -> None:
  596. super(Softplus, self).__init__()
  597. self.beta = beta
  598. self.threshold = threshold
  599. def forward(self, input: Tensor) -> Tensor:
  600. return F.softplus(input, self.beta, self.threshold)
  601. def extra_repr(self) -> str:
  602. return 'beta={}, threshold={}'.format(self.beta, self.threshold)
  603. class Softshrink(Module):
  604. r"""Applies the soft shrinkage function elementwise:
  605. .. math::
  606. \text{SoftShrinkage}(x) =
  607. \begin{cases}
  608. x - \lambda, & \text{ if } x > \lambda \\
  609. x + \lambda, & \text{ if } x < -\lambda \\
  610. 0, & \text{ otherwise }
  611. \end{cases}
  612. Args:
  613. lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
  614. Shape:
  615. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  616. - Output: :math:`(*)`, same shape as the input.
  617. .. image:: ../scripts/activation_images/Softshrink.png
  618. Examples::
  619. >>> m = nn.Softshrink()
  620. >>> input = torch.randn(2)
  621. >>> output = m(input)
  622. """
  623. __constants__ = ['lambd']
  624. lambd: float
  625. def __init__(self, lambd: float = 0.5) -> None:
  626. super(Softshrink, self).__init__()
  627. self.lambd = lambd
  628. def forward(self, input: Tensor) -> Tensor:
  629. return F.softshrink(input, self.lambd)
  630. def extra_repr(self) -> str:
  631. return str(self.lambd)
  632. class MultiheadAttention(Module):
  633. r"""Allows the model to jointly attend to information
  634. from different representation subspaces as described in the paper:
  635. `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
  636. Multi-Head Attention is defined as:
  637. .. math::
  638. \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
  639. where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
  640. ``forward()`` will use a special optimized implementation if all of the following
  641. conditions are met:
  642. - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
  643. restriction will be loosened in the future.)
  644. - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
  645. - training is disabled (using ``.eval()``)
  646. - dropout is 0
  647. - ``add_bias_kv`` is ``False``
  648. - ``add_zero_attn`` is ``False``
  649. - ``batch_first`` is ``True`` and the input is batched
  650. - ``kdim`` and ``vdim`` are equal to ``embed_dim``
  651. - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
  652. - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
  653. nor ``attn_mask`` is passed
  654. If the optimized implementation is in use, a
  655. `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
  656. ``query``/``key``/``value`` to represent padding more efficiently than using a
  657. padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
  658. will be returned, and an additional speedup proportional to the fraction of the input
  659. that is padding can be expected.
  660. Args:
  661. embed_dim: Total dimension of the model.
  662. num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
  663. across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
  664. dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
  665. bias: If specified, adds bias to input / output projection layers. Default: ``True``.
  666. add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
  667. add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
  668. Default: ``False``.
  669. kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
  670. vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
  671. batch_first: If ``True``, then the input and output tensors are provided
  672. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  673. Examples::
  674. >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
  675. >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
  676. """
  677. __constants__ = ['batch_first']
  678. bias_k: Optional[torch.Tensor]
  679. bias_v: Optional[torch.Tensor]
  680. def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
  681. kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
  682. factory_kwargs = {'device': device, 'dtype': dtype}
  683. super(MultiheadAttention, self).__init__()
  684. self.embed_dim = embed_dim
  685. self.kdim = kdim if kdim is not None else embed_dim
  686. self.vdim = vdim if vdim is not None else embed_dim
  687. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  688. self.num_heads = num_heads
  689. self.dropout = dropout
  690. self.batch_first = batch_first
  691. self.head_dim = embed_dim // num_heads
  692. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  693. if self._qkv_same_embed_dim is False:
  694. self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
  695. self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
  696. self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
  697. self.register_parameter('in_proj_weight', None)
  698. else:
  699. self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
  700. self.register_parameter('q_proj_weight', None)
  701. self.register_parameter('k_proj_weight', None)
  702. self.register_parameter('v_proj_weight', None)
  703. if bias:
  704. self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
  705. else:
  706. self.register_parameter('in_proj_bias', None)
  707. self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
  708. if add_bias_kv:
  709. self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  710. self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  711. else:
  712. self.bias_k = self.bias_v = None
  713. self.add_zero_attn = add_zero_attn
  714. self._reset_parameters()
  715. def _reset_parameters(self):
  716. if self._qkv_same_embed_dim:
  717. xavier_uniform_(self.in_proj_weight)
  718. else:
  719. xavier_uniform_(self.q_proj_weight)
  720. xavier_uniform_(self.k_proj_weight)
  721. xavier_uniform_(self.v_proj_weight)
  722. if self.in_proj_bias is not None:
  723. constant_(self.in_proj_bias, 0.)
  724. constant_(self.out_proj.bias, 0.)
  725. if self.bias_k is not None:
  726. xavier_normal_(self.bias_k)
  727. if self.bias_v is not None:
  728. xavier_normal_(self.bias_v)
  729. def __setstate__(self, state):
  730. # Support loading old MultiheadAttention checkpoints generated by v1.1.0
  731. if '_qkv_same_embed_dim' not in state:
  732. state['_qkv_same_embed_dim'] = True
  733. super(MultiheadAttention, self).__setstate__(state)
  734. def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
  735. need_weights: bool = True, attn_mask: Optional[Tensor] = None,
  736. average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
  737. r"""
  738. Args:
  739. query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
  740. or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
  741. :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
  742. Queries are compared against key-value pairs to produce the output.
  743. See "Attention Is All You Need" for more details.
  744. key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
  745. or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
  746. :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
  747. See "Attention Is All You Need" for more details.
  748. value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
  749. ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
  750. sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
  751. See "Attention Is All You Need" for more details.
  752. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
  753. to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
  754. Binary and byte masks are supported.
  755. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
  756. the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
  757. value will be ignored.
  758. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
  759. Default: ``True``.
  760. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
  761. :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
  762. :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
  763. broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
  764. Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
  765. corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
  766. corresponding position is not allowed to attend. For a float mask, the mask values will be added to
  767. the attention weight.
  768. average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
  769. heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
  770. effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
  771. Outputs:
  772. - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
  773. :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
  774. where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
  775. embedding dimension ``embed_dim``.
  776. - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
  777. returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
  778. :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
  779. :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per
  780. head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
  781. .. note::
  782. `batch_first` argument is ignored for unbatched inputs.
  783. """
  784. is_batched = query.dim() == 3
  785. why_not_fast_path = ''
  786. if not is_batched:
  787. why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
  788. elif query is not key or key is not value:
  789. # When lifting this restriction, don't forget to either
  790. # enforce that the dtypes all match or test cases where
  791. # they don't!
  792. why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
  793. elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
  794. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
  795. elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
  796. # this case will fail anyway, but at least they'll get a useful error message.
  797. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
  798. elif self.training:
  799. why_not_fast_path = "training is enabled"
  800. elif not self.batch_first:
  801. why_not_fast_path = "batch_first was not True"
  802. elif self.bias_k is not None:
  803. why_not_fast_path = "self.bias_k was not None"
  804. elif self.bias_v is not None:
  805. why_not_fast_path = "self.bias_v was not None"
  806. elif self.dropout:
  807. why_not_fast_path = f"dropout was {self.dropout}, required zero"
  808. elif self.add_zero_attn:
  809. why_not_fast_path = "add_zero_attn was enabled"
  810. elif not self._qkv_same_embed_dim:
  811. why_not_fast_path = "_qkv_same_embed_dim was not True"
  812. elif attn_mask is not None:
  813. why_not_fast_path = "attn_mask was not None"
  814. elif query.is_nested and key_padding_mask is not None:
  815. why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
  816. if not why_not_fast_path:
  817. tensor_args = (
  818. query,
  819. key,
  820. value,
  821. self.in_proj_weight,
  822. self.in_proj_bias,
  823. self.out_proj.weight,
  824. self.out_proj.bias,
  825. )
  826. # We have to use list comprehensions below because TorchScript does not support
  827. # generator expressions.
  828. if torch.overrides.has_torch_function(tensor_args):
  829. why_not_fast_path = "some Tensor argument has_torch_function"
  830. elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
  831. why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
  832. elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]):
  833. why_not_fast_path = ("grad is enabled and at least one of query or the "
  834. "input/output projection weights or biases requires_grad")
  835. if not why_not_fast_path:
  836. return torch._native_multi_head_attention(
  837. query,
  838. key,
  839. value,
  840. self.embed_dim,
  841. self.num_heads,
  842. self.in_proj_weight,
  843. self.in_proj_bias,
  844. self.out_proj.weight,
  845. self.out_proj.bias,
  846. key_padding_mask if key_padding_mask is not None else attn_mask,
  847. need_weights,
  848. average_attn_weights)
  849. any_nested = query.is_nested or key.is_nested or value.is_nested
  850. assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
  851. f"The fast path was not hit because {why_not_fast_path}")
  852. if self.batch_first and is_batched:
  853. # make sure that the transpose op does not affect the "is" property
  854. if key is value:
  855. if query is key:
  856. query = key = value = query.transpose(1, 0)
  857. else:
  858. query, key = [x.transpose(1, 0) for x in (query, key)]
  859. value = key
  860. else:
  861. query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
  862. if not self._qkv_same_embed_dim:
  863. attn_output, attn_output_weights = F.multi_head_attention_forward(
  864. query, key, value, self.embed_dim, self.num_heads,
  865. self.in_proj_weight, self.in_proj_bias,
  866. self.bias_k, self.bias_v, self.add_zero_attn,
  867. self.dropout, self.out_proj.weight, self.out_proj.bias,
  868. training=self.training,
  869. key_padding_mask=key_padding_mask, need_weights=need_weights,
  870. attn_mask=attn_mask, use_separate_proj_weight=True,
  871. q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
  872. v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights)
  873. else:
  874. attn_output, attn_output_weights = F.multi_head_attention_forward(
  875. query, key, value, self.embed_dim, self.num_heads,
  876. self.in_proj_weight, self.in_proj_bias,
  877. self.bias_k, self.bias_v, self.add_zero_attn,
  878. self.dropout, self.out_proj.weight, self.out_proj.bias,
  879. training=self.training,
  880. key_padding_mask=key_padding_mask, need_weights=need_weights,
  881. attn_mask=attn_mask, average_attn_weights=average_attn_weights)
  882. if self.batch_first and is_batched:
  883. return attn_output.transpose(1, 0), attn_output_weights
  884. else:
  885. return attn_output, attn_output_weights
  886. class PReLU(Module):
  887. r"""Applies the element-wise function:
  888. .. math::
  889. \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
  890. or
  891. .. math::
  892. \text{PReLU}(x) =
  893. \begin{cases}
  894. x, & \text{ if } x \geq 0 \\
  895. ax, & \text{ otherwise }
  896. \end{cases}
  897. Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
  898. parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
  899. a separate :math:`a` is used for each input channel.
  900. .. note::
  901. weight decay should not be used when learning :math:`a` for good performance.
  902. .. note::
  903. Channel dim is the 2nd dim of input. When input has dims < 2, then there is
  904. no channel dim and the number of channels = 1.
  905. Args:
  906. num_parameters (int): number of :math:`a` to learn.
  907. Although it takes an int as input, there is only two values are legitimate:
  908. 1, or the number of channels at input. Default: 1
  909. init (float): the initial value of :math:`a`. Default: 0.25
  910. Shape:
  911. - Input: :math:`( *)` where `*` means, any number of additional
  912. dimensions.
  913. - Output: :math:`(*)`, same shape as the input.
  914. Attributes:
  915. weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
  916. .. image:: ../scripts/activation_images/PReLU.png
  917. Examples::
  918. >>> m = nn.PReLU()
  919. >>> input = torch.randn(2)
  920. >>> output = m(input)
  921. """
  922. __constants__ = ['num_parameters']
  923. num_parameters: int
  924. def __init__(self, num_parameters: int = 1, init: float = 0.25,
  925. device=None, dtype=None) -> None:
  926. factory_kwargs = {'device': device, 'dtype': dtype}
  927. self.num_parameters = num_parameters
  928. super(PReLU, self).__init__()
  929. self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs).fill_(init))
  930. def forward(self, input: Tensor) -> Tensor:
  931. return F.prelu(input, self.weight)
  932. def extra_repr(self) -> str:
  933. return 'num_parameters={}'.format(self.num_parameters)
  934. class Softsign(Module):
  935. r"""Applies the element-wise function:
  936. .. math::
  937. \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
  938. Shape:
  939. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  940. - Output: :math:`(*)`, same shape as the input.
  941. .. image:: ../scripts/activation_images/Softsign.png
  942. Examples::
  943. >>> m = nn.Softsign()
  944. >>> input = torch.randn(2)
  945. >>> output = m(input)
  946. """
  947. def forward(self, input: Tensor) -> Tensor:
  948. return F.softsign(input)
  949. class Tanhshrink(Module):
  950. r"""Applies the element-wise function:
  951. .. math::
  952. \text{Tanhshrink}(x) = x - \tanh(x)
  953. Shape:
  954. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  955. - Output: :math:`(*)`, same shape as the input.
  956. .. image:: ../scripts/activation_images/Tanhshrink.png
  957. Examples::
  958. >>> m = nn.Tanhshrink()
  959. >>> input = torch.randn(2)
  960. >>> output = m(input)
  961. """
  962. def forward(self, input: Tensor) -> Tensor:
  963. return F.tanhshrink(input)
  964. class Softmin(Module):
  965. r"""Applies the Softmin function to an n-dimensional input Tensor
  966. rescaling them so that the elements of the n-dimensional output Tensor
  967. lie in the range `[0, 1]` and sum to 1.
  968. Softmin is defined as:
  969. .. math::
  970. \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
  971. Shape:
  972. - Input: :math:`(*)` where `*` means, any number of additional
  973. dimensions
  974. - Output: :math:`(*)`, same shape as the input
  975. Args:
  976. dim (int): A dimension along which Softmin will be computed (so every slice
  977. along dim will sum to 1).
  978. Returns:
  979. a Tensor of the same dimension and shape as the input, with
  980. values in the range [0, 1]
  981. Examples::
  982. >>> m = nn.Softmin()
  983. >>> input = torch.randn(2, 3)
  984. >>> output = m(input)
  985. """
  986. __constants__ = ['dim']
  987. dim: Optional[int]
  988. def __init__(self, dim: Optional[int] = None) -> None:
  989. super(Softmin, self).__init__()
  990. self.dim = dim
  991. def __setstate__(self, state):
  992. super().__setstate__(state)
  993. if not hasattr(self, 'dim'):
  994. self.dim = None
  995. def forward(self, input: Tensor) -> Tensor:
  996. return F.softmin(input, self.dim, _stacklevel=5)
  997. def extra_repr(self):
  998. return 'dim={dim}'.format(dim=self.dim)
  999. class Softmax(Module):
  1000. r"""Applies the Softmax function to an n-dimensional input Tensor
  1001. rescaling them so that the elements of the n-dimensional output Tensor
  1002. lie in the range [0,1] and sum to 1.
  1003. Softmax is defined as:
  1004. .. math::
  1005. \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
  1006. When the input Tensor is a sparse tensor then the unspecifed
  1007. values are treated as ``-inf``.
  1008. Shape:
  1009. - Input: :math:`(*)` where `*` means, any number of additional
  1010. dimensions
  1011. - Output: :math:`(*)`, same shape as the input
  1012. Returns:
  1013. a Tensor of the same dimension and shape as the input with
  1014. values in the range [0, 1]
  1015. Args:
  1016. dim (int): A dimension along which Softmax will be computed (so every slice
  1017. along dim will sum to 1).
  1018. .. note::
  1019. This module doesn't work directly with NLLLoss,
  1020. which expects the Log to be computed between the Softmax and itself.
  1021. Use `LogSoftmax` instead (it's faster and has better numerical properties).
  1022. Examples::
  1023. >>> m = nn.Softmax(dim=1)
  1024. >>> input = torch.randn(2, 3)
  1025. >>> output = m(input)
  1026. """
  1027. __constants__ = ['dim']
  1028. dim: Optional[int]
  1029. def __init__(self, dim: Optional[int] = None) -> None:
  1030. super(Softmax, self).__init__()
  1031. self.dim = dim
  1032. def __setstate__(self, state):
  1033. super().__setstate__(state)
  1034. if not hasattr(self, 'dim'):
  1035. self.dim = None
  1036. def forward(self, input: Tensor) -> Tensor:
  1037. return F.softmax(input, self.dim, _stacklevel=5)
  1038. def extra_repr(self) -> str:
  1039. return 'dim={dim}'.format(dim=self.dim)
  1040. class Softmax2d(Module):
  1041. r"""Applies SoftMax over features to each spatial location.
  1042. When given an image of ``Channels x Height x Width``, it will
  1043. apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
  1044. Shape:
  1045. - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
  1046. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
  1047. Returns:
  1048. a Tensor of the same dimension and shape as the input with
  1049. values in the range [0, 1]
  1050. Examples::
  1051. >>> m = nn.Softmax2d()
  1052. >>> # you softmax over the 2nd dimension
  1053. >>> input = torch.randn(2, 3, 12, 13)
  1054. >>> output = m(input)
  1055. """
  1056. def forward(self, input: Tensor) -> Tensor:
  1057. assert input.dim() == 4 or input.dim() == 3, 'Softmax2d requires a 3D or 4D tensor as input'
  1058. return F.softmax(input, -3, _stacklevel=5)
  1059. class LogSoftmax(Module):
  1060. r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
  1061. input Tensor. The LogSoftmax formulation can be simplified as:
  1062. .. math::
  1063. \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
  1064. Shape:
  1065. - Input: :math:`(*)` where `*` means, any number of additional
  1066. dimensions
  1067. - Output: :math:`(*)`, same shape as the input
  1068. Args:
  1069. dim (int): A dimension along which LogSoftmax will be computed.
  1070. Returns:
  1071. a Tensor of the same dimension and shape as the input with
  1072. values in the range [-inf, 0)
  1073. Examples::
  1074. >>> m = nn.LogSoftmax()
  1075. >>> input = torch.randn(2, 3)
  1076. >>> output = m(input)
  1077. """
  1078. __constants__ = ['dim']
  1079. dim: Optional[int]
  1080. def __init__(self, dim: Optional[int] = None) -> None:
  1081. super(LogSoftmax, self).__init__()
  1082. self.dim = dim
  1083. def __setstate__(self, state):
  1084. super().__setstate__(state)
  1085. if not hasattr(self, 'dim'):
  1086. self.dim = None
  1087. def forward(self, input: Tensor) -> Tensor:
  1088. return F.log_softmax(input, self.dim, _stacklevel=5)
  1089. def extra_repr(self):
  1090. return 'dim={dim}'.format(dim=self.dim)