dropout.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. from .module import Module
  2. from .. import functional as F
  3. from torch import Tensor
  4. class _DropoutNd(Module):
  5. __constants__ = ['p', 'inplace']
  6. p: float
  7. inplace: bool
  8. def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
  9. super(_DropoutNd, self).__init__()
  10. if p < 0 or p > 1:
  11. raise ValueError("dropout probability has to be between 0 and 1, "
  12. "but got {}".format(p))
  13. self.p = p
  14. self.inplace = inplace
  15. def extra_repr(self) -> str:
  16. return 'p={}, inplace={}'.format(self.p, self.inplace)
  17. class Dropout(_DropoutNd):
  18. r"""During training, randomly zeroes some of the elements of the input
  19. tensor with probability :attr:`p` using samples from a Bernoulli
  20. distribution. Each channel will be zeroed out independently on every forward
  21. call.
  22. This has proven to be an effective technique for regularization and
  23. preventing the co-adaptation of neurons as described in the paper
  24. `Improving neural networks by preventing co-adaptation of feature
  25. detectors`_ .
  26. Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
  27. training. This means that during evaluation the module simply computes an
  28. identity function.
  29. Args:
  30. p: probability of an element to be zeroed. Default: 0.5
  31. inplace: If set to ``True``, will do this operation in-place. Default: ``False``
  32. Shape:
  33. - Input: :math:`(*)`. Input can be of any shape
  34. - Output: :math:`(*)`. Output is of the same shape as input
  35. Examples::
  36. >>> m = nn.Dropout(p=0.2)
  37. >>> input = torch.randn(20, 16)
  38. >>> output = m(input)
  39. .. _Improving neural networks by preventing co-adaptation of feature
  40. detectors: https://arxiv.org/abs/1207.0580
  41. """
  42. def forward(self, input: Tensor) -> Tensor:
  43. return F.dropout(input, self.p, self.training, self.inplace)
  44. class Dropout1d(_DropoutNd):
  45. r"""Randomly zero out entire channels (a channel is a 1D feature map,
  46. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  47. batched input is a 1D tensor :math:`\text{input}[i, j]`).
  48. Each channel will be zeroed out independently on every forward call with
  49. probability :attr:`p` using samples from a Bernoulli distribution.
  50. Usually the input comes from :class:`nn.Conv1d` modules.
  51. As described in the paper
  52. `Efficient Object Localization Using Convolutional Networks`_ ,
  53. if adjacent pixels within feature maps are strongly correlated
  54. (as is normally the case in early convolution layers) then i.i.d. dropout
  55. will not regularize the activations and will otherwise just result
  56. in an effective learning rate decrease.
  57. In this case, :func:`nn.Dropout1d` will help promote independence between
  58. feature maps and should be used instead.
  59. Args:
  60. p (float, optional): probability of an element to be zero-ed.
  61. inplace (bool, optional): If set to ``True``, will do this operation
  62. in-place
  63. Shape:
  64. - Input: :math:`(N, C, L)` or :math:`(C, L)`.
  65. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
  66. Examples::
  67. >>> m = nn.Dropout1d(p=0.2)
  68. >>> input = torch.randn(20, 16, 32)
  69. >>> output = m(input)
  70. .. _Efficient Object Localization Using Convolutional Networks:
  71. https://arxiv.org/abs/1411.4280
  72. """
  73. def forward(self, input: Tensor) -> Tensor:
  74. return F.dropout1d(input, self.p, self.training, self.inplace)
  75. class Dropout2d(_DropoutNd):
  76. r"""Randomly zero out entire channels (a channel is a 2D feature map,
  77. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  78. batched input is a 2D tensor :math:`\text{input}[i, j]`).
  79. Each channel will be zeroed out independently on every forward call with
  80. probability :attr:`p` using samples from a Bernoulli distribution.
  81. Usually the input comes from :class:`nn.Conv2d` modules.
  82. As described in the paper
  83. `Efficient Object Localization Using Convolutional Networks`_ ,
  84. if adjacent pixels within feature maps are strongly correlated
  85. (as is normally the case in early convolution layers) then i.i.d. dropout
  86. will not regularize the activations and will otherwise just result
  87. in an effective learning rate decrease.
  88. In this case, :func:`nn.Dropout2d` will help promote independence between
  89. feature maps and should be used instead.
  90. Args:
  91. p (float, optional): probability of an element to be zero-ed.
  92. inplace (bool, optional): If set to ``True``, will do this operation
  93. in-place
  94. .. warning ::
  95. Due to historical reasons, this class will perform 1D channel-wise dropout
  96. for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT
  97. support inputs without a batch dimension of shape :math:`(C, H, W)`. This
  98. behavior will change in a future release to interpret 3D inputs as no-batch-dim
  99. inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`.
  100. Shape:
  101. - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
  102. - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
  103. Examples::
  104. >>> m = nn.Dropout2d(p=0.2)
  105. >>> input = torch.randn(20, 16, 32, 32)
  106. >>> output = m(input)
  107. .. _Efficient Object Localization Using Convolutional Networks:
  108. https://arxiv.org/abs/1411.4280
  109. """
  110. def forward(self, input: Tensor) -> Tensor:
  111. return F.dropout2d(input, self.p, self.training, self.inplace)
  112. class Dropout3d(_DropoutNd):
  113. r"""Randomly zero out entire channels (a channel is a 3D feature map,
  114. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  115. batched input is a 3D tensor :math:`\text{input}[i, j]`).
  116. Each channel will be zeroed out independently on every forward call with
  117. probability :attr:`p` using samples from a Bernoulli distribution.
  118. Usually the input comes from :class:`nn.Conv3d` modules.
  119. As described in the paper
  120. `Efficient Object Localization Using Convolutional Networks`_ ,
  121. if adjacent pixels within feature maps are strongly correlated
  122. (as is normally the case in early convolution layers) then i.i.d. dropout
  123. will not regularize the activations and will otherwise just result
  124. in an effective learning rate decrease.
  125. In this case, :func:`nn.Dropout3d` will help promote independence between
  126. feature maps and should be used instead.
  127. Args:
  128. p (float, optional): probability of an element to be zeroed.
  129. inplace (bool, optional): If set to ``True``, will do this operation
  130. in-place
  131. Shape:
  132. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
  133. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
  134. Examples::
  135. >>> m = nn.Dropout3d(p=0.2)
  136. >>> input = torch.randn(20, 16, 4, 32, 32)
  137. >>> output = m(input)
  138. .. _Efficient Object Localization Using Convolutional Networks:
  139. https://arxiv.org/abs/1411.4280
  140. """
  141. def forward(self, input: Tensor) -> Tensor:
  142. return F.dropout3d(input, self.p, self.training, self.inplace)
  143. class AlphaDropout(_DropoutNd):
  144. r"""Applies Alpha Dropout over the input.
  145. Alpha Dropout is a type of Dropout that maintains the self-normalizing
  146. property.
  147. For an input with zero mean and unit standard deviation, the output of
  148. Alpha Dropout maintains the original mean and standard deviation of the
  149. input.
  150. Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
  151. that the outputs have zero mean and unit standard deviation.
  152. During training, it randomly masks some of the elements of the input
  153. tensor with probability *p* using samples from a bernoulli distribution.
  154. The elements to masked are randomized on every forward call, and scaled
  155. and shifted to maintain zero mean and unit standard deviation.
  156. During evaluation the module simply computes an identity function.
  157. More details can be found in the paper `Self-Normalizing Neural Networks`_ .
  158. Args:
  159. p (float): probability of an element to be dropped. Default: 0.5
  160. inplace (bool, optional): If set to ``True``, will do this operation
  161. in-place
  162. Shape:
  163. - Input: :math:`(*)`. Input can be of any shape
  164. - Output: :math:`(*)`. Output is of the same shape as input
  165. Examples::
  166. >>> m = nn.AlphaDropout(p=0.2)
  167. >>> input = torch.randn(20, 16)
  168. >>> output = m(input)
  169. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  170. """
  171. def forward(self, input: Tensor) -> Tensor:
  172. return F.alpha_dropout(input, self.p, self.training)
  173. class FeatureAlphaDropout(_DropoutNd):
  174. r"""Randomly masks out entire channels (a channel is a feature map,
  175. e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
  176. is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
  177. setting activations to zero, as in regular Dropout, the activations are set
  178. to the negative saturation value of the SELU activation function. More details
  179. can be found in the paper `Self-Normalizing Neural Networks`_ .
  180. Each element will be masked independently for each sample on every forward
  181. call with probability :attr:`p` using samples from a Bernoulli distribution.
  182. The elements to be masked are randomized on every forward call, and scaled
  183. and shifted to maintain zero mean and unit variance.
  184. Usually the input comes from :class:`nn.AlphaDropout` modules.
  185. As described in the paper
  186. `Efficient Object Localization Using Convolutional Networks`_ ,
  187. if adjacent pixels within feature maps are strongly correlated
  188. (as is normally the case in early convolution layers) then i.i.d. dropout
  189. will not regularize the activations and will otherwise just result
  190. in an effective learning rate decrease.
  191. In this case, :func:`nn.AlphaDropout` will help promote independence between
  192. feature maps and should be used instead.
  193. Args:
  194. p (float, optional): probability of an element to be zeroed. Default: 0.5
  195. inplace (bool, optional): If set to ``True``, will do this operation
  196. in-place
  197. Shape:
  198. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
  199. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
  200. Examples::
  201. >>> m = nn.FeatureAlphaDropout(p=0.2)
  202. >>> input = torch.randn(20, 16, 4, 32, 32)
  203. >>> output = m(input)
  204. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  205. .. _Efficient Object Localization Using Convolutional Networks:
  206. https://arxiv.org/abs/1411.4280
  207. """
  208. def forward(self, input: Tensor) -> Tensor:
  209. return F.feature_alpha_dropout(input, self.p, self.training)