functional.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. r""" Functional interface (quantized)."""
  2. from typing import List, Optional
  3. import warnings
  4. import torch
  5. from torch import Tensor
  6. from torch.nn.modules.utils import _pair, _triple
  7. from torch.nn.quantized.modules.utils import _pair_from_first
  8. from torch.jit.annotations import BroadcastingList2
  9. # Although some of the functions and docstrings are mirrored from the torch.nn,
  10. # we want to have them here for future changes.
  11. def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False,
  12. count_include_pad=True, divisor_override=None):
  13. r"""
  14. Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size
  15. :math:`sH \times sW` steps. The number of output features is equal to the number of
  16. input planes.
  17. .. note:: The input quantization parameters propagate to the output.
  18. See :class:`~torch.nn.quantized.AvgPool2d` for details and output shape.
  19. Args:
  20. input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
  21. kernel_size: size of the pooling region. Can be a single number or a
  22. tuple `(kH, kW)`
  23. stride: stride of the pooling operation. Can be a single number or a
  24. tuple `(sH, sW)`. Default: :attr:`kernel_size`
  25. padding: implicit zero paddings on both sides of the input. Can be a
  26. single number or a tuple `(padH, padW)`. Default: 0
  27. ceil_mode: when True, will use `ceil` instead of `floor` in the formula
  28. to compute the output shape. Default: ``False``
  29. count_include_pad: when True, will include the zero-padding in the
  30. averaging calculation. Default: ``True``
  31. divisor_override: if specified, it will be used as divisor, otherwise
  32. size of the pooling region will be used. Default: None
  33. """
  34. if not input.is_quantized:
  35. raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!")
  36. return torch.nn.functional.avg_pool2d(input, kernel_size, stride, padding,
  37. ceil_mode, count_include_pad,
  38. divisor_override)
  39. def avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False,
  40. count_include_pad=True, divisor_override=None):
  41. r"""
  42. Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size
  43. :math:`sD \times sH \times sW` steps. The number of output features is equal to the number of
  44. input planes.
  45. .. note:: The input quantization parameters propagate to the output.
  46. Args:
  47. input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
  48. kernel_size: size of the pooling region. Can be a single number or a
  49. tuple `(kD, kH, kW)`
  50. stride: stride of the pooling operation. Can be a single number or a
  51. tuple `(sD, sH, sW)`. Default: :attr:`kernel_size`
  52. padding: implicit zero paddings on both sides of the input. Can be a
  53. single number or a tuple `(padD, padH, padW)`. Default: 0
  54. ceil_mode: when True, will use `ceil` instead of `floor` in the formula
  55. to compute the output shape. Default: ``False``
  56. count_include_pad: when True, will include the zero-padding in the
  57. averaging calculation. Default: ``True``
  58. divisor_override: if specified, it will be used as divisor, otherwise
  59. size of the pooling region will be used. Default: None
  60. """
  61. if not input.is_quantized:
  62. raise ValueError("Input to 'quantized.avg_pool3d' must be quantized!")
  63. return torch.nn.functional.avg_pool3d(input, kernel_size, stride, padding,
  64. ceil_mode, count_include_pad,
  65. divisor_override)
  66. def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
  67. r"""
  68. Applies a 2D adaptive average pooling over a quantized input signal composed
  69. of several quantized input planes.
  70. .. note:: The input quantization parameters propagate to the output.
  71. See :class:`~torch.nn.quantized.AdaptiveAvgPool2d` for details and output shape.
  72. Args:
  73. output_size: the target output size (single integer or
  74. double-integer tuple)
  75. """
  76. if not input.is_quantized:
  77. raise ValueError("Input to 'quantized.functional.adaptive_avg_pool2d' must be quantized!")
  78. return torch.nn.functional.adaptive_avg_pool2d(input, output_size)
  79. def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
  80. r"""
  81. Applies a 3D adaptive average pooling over a quantized input signal composed
  82. of several quantized input planes.
  83. .. note:: The input quantization parameters propagate to the output.
  84. See :class:`~torch.nn.quantized.AdaptiveAvgPool3d` for details and output shape.
  85. Args:
  86. output_size: the target output size (single integer or
  87. double-integer tuple)
  88. """
  89. if not input.is_quantized:
  90. raise ValueError(
  91. "Input to 'quantized.functional.adaptive_avg_pool3d' must be quantized!")
  92. return torch.nn.functional.adaptive_avg_pool3d(input, output_size)
  93. def conv1d(input, weight, bias,
  94. stride=1, padding=0, dilation=1, groups=1,
  95. padding_mode='zeros',
  96. scale=1.0, zero_point=0,
  97. dtype=torch.quint8):
  98. r"""
  99. Applies a 1D convolution over a quantized 1D input composed of several input
  100. planes.
  101. See :class:`~torch.nn.quantized.Conv1d` for details and output shape.
  102. Args:
  103. input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
  104. weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , iW)`
  105. bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
  106. stride: the stride of the convolving kernel. Can be a single number or a
  107. tuple `(sW,)`. Default: 1
  108. padding: implicit paddings on both sides of the input. Can be a
  109. single number or a tuple `(padW,)`. Default: 0
  110. dilation: the spacing between kernel elements. Can be a single number or
  111. a tuple `(dW,)`. Default: 1
  112. groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
  113. number of groups. Default: 1
  114. padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
  115. scale: quantization scale for the output. Default: 1.0
  116. zero_point: quantization zero_point for the output. Default: 0
  117. dtype: quantization data type to use. Default: ``torch.quint8``
  118. Examples::
  119. >>> from torch.nn.quantized import functional as qF
  120. >>> filters = torch.randn(33, 16, 3, dtype=torch.float)
  121. >>> inputs = torch.randn(20, 16, 50, dtype=torch.float)
  122. >>> bias = torch.randn(33, dtype=torch.float)
  123. >>>
  124. >>> scale, zero_point = 1.0, 0
  125. >>> dtype_inputs = torch.quint8
  126. >>> dtype_filters = torch.qint8
  127. >>>
  128. >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
  129. >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
  130. >>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
  131. """ # noqa: E501
  132. if padding_mode != 'zeros':
  133. raise NotImplementedError("Only zero-padding is supported!")
  134. if input.dtype != torch.quint8:
  135. raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
  136. if weight.dtype != torch.qint8:
  137. raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
  138. if input.ndim != 3:
  139. raise ValueError("Input shape must be `(N, C, L)`!")
  140. stride = _pair_from_first(stride)
  141. padding = _pair_from_first(padding)
  142. dilation = _pair_from_first(dilation)
  143. packed_params = torch.ops.quantized.conv1d_prepack(
  144. weight, bias, stride, padding, dilation, groups)
  145. return torch.ops.quantized.conv1d(input, packed_params, scale, zero_point)
  146. def conv2d(input, weight, bias,
  147. stride=1, padding=0, dilation=1, groups=1,
  148. padding_mode='zeros',
  149. scale=1.0, zero_point=0,
  150. dtype=torch.quint8):
  151. r"""
  152. Applies a 2D convolution over a quantized 2D input composed of several input
  153. planes.
  154. See :class:`~torch.nn.quantized.Conv2d` for details and output shape.
  155. Args:
  156. input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
  157. weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
  158. bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
  159. stride: the stride of the convolving kernel. Can be a single number or a
  160. tuple `(sH, sW)`. Default: 1
  161. padding: implicit paddings on both sides of the input. Can be a
  162. single number or a tuple `(padH, padW)`. Default: 0
  163. dilation: the spacing between kernel elements. Can be a single number or
  164. a tuple `(dH, dW)`. Default: 1
  165. groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
  166. number of groups. Default: 1
  167. padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
  168. scale: quantization scale for the output. Default: 1.0
  169. zero_point: quantization zero_point for the output. Default: 0
  170. dtype: quantization data type to use. Default: ``torch.quint8``
  171. Examples::
  172. >>> from torch.nn.quantized import functional as qF
  173. >>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float)
  174. >>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float)
  175. >>> bias = torch.randn(8, dtype=torch.float)
  176. >>>
  177. >>> scale, zero_point = 1.0, 0
  178. >>> dtype_inputs = torch.quint8
  179. >>> dtype_filters = torch.qint8
  180. >>>
  181. >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
  182. >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
  183. >>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
  184. """ # noqa: E501
  185. if padding_mode != 'zeros':
  186. raise NotImplementedError("Only zero-padding is supported!")
  187. if input.dtype != torch.quint8:
  188. raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
  189. if weight.dtype != torch.qint8:
  190. raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
  191. if input.ndim != 4:
  192. raise ValueError("Input shape must be `(N, C, H, W)`!")
  193. stride = _pair(stride)
  194. padding = _pair(padding)
  195. dilation = _pair(dilation)
  196. packed_params = torch.ops.quantized.conv2d_prepack(
  197. weight, bias, stride, padding, dilation, groups)
  198. return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point)
  199. def conv3d(input, weight, bias, stride=1, padding=0, dilation=1, groups=1,
  200. padding_mode='zeros', scale=1.0, zero_point=0, dtype=torch.quint8):
  201. r"""
  202. Applies a 3D convolution over a quantized 3D input composed of several input
  203. planes.
  204. See :class:`~torch.nn.quantized.Conv3d` for details and output shape.
  205. Args:
  206. input: quantized input tensor of shape
  207. :math:`(\text{minibatch} , \text{in\_channels} , iD , iH , iW)`
  208. weight: quantized filters of shape
  209. :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kD , kH , kW)`
  210. bias: **non-quantized** bias tensor of shape
  211. :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
  212. stride: the stride of the convolving kernel. Can be a single number or a
  213. tuple `(sD, sH, sW)`. Default: 1
  214. padding: implicit paddings on both sides of the input. Can be a
  215. single number or a tuple `(padD, padH, padW)`. Default: 0
  216. dilation: the spacing between kernel elements. Can be a single number or
  217. a tuple `(dD, dH, dW)`. Default: 1
  218. groups: split input into groups, :math:`\text{in\_channels}` should be
  219. divisible by the number of groups. Default: 1
  220. padding_mode: the padding mode to use. Only "zeros" is supported for
  221. quantized convolution at the moment. Default: "zeros"
  222. scale: quantization scale for the output. Default: 1.0
  223. zero_point: quantization zero_point for the output. Default: 0
  224. dtype: quantization data type to use. Default: ``torch.quint8``
  225. Examples::
  226. >>> from torch.nn.quantized import functional as qF
  227. >>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float)
  228. >>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float)
  229. >>> bias = torch.randn(8, dtype=torch.float)
  230. >>>
  231. >>> scale, zero_point = 1.0, 0
  232. >>> dtype_inputs = torch.quint8
  233. >>> dtype_filters = torch.qint8
  234. >>>
  235. >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
  236. >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
  237. >>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
  238. """ # noqa: E501
  239. if padding_mode != 'zeros':
  240. raise NotImplementedError("Only zero-padding is supported!")
  241. if input.dtype != torch.quint8:
  242. raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
  243. if weight.dtype != torch.qint8:
  244. raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
  245. if input.ndim != 5:
  246. raise ValueError("Input shape must be `(N, C, D, H, W)`!")
  247. stride = _triple(stride)
  248. padding = _triple(padding)
  249. dilation = _triple(dilation)
  250. packed_params = torch.ops.quantized.conv3d_prepack(
  251. weight, bias, stride, padding, dilation, groups)
  252. return torch.ops.quantized.conv3d(input, packed_params, scale, zero_point)
  253. def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
  254. r"""Down/up samples the input to either the given :attr:`size` or the given
  255. :attr:`scale_factor`
  256. See :func:`torch.nn.functional.interpolate` for implementation details.
  257. The input dimensions are interpreted in the form:
  258. `mini-batch x channels x [optional depth] x [optional height] x width`.
  259. .. note:: The input quantization parameters propagate to the output.
  260. .. note:: Only 2D/3D input is supported for quantized inputs
  261. .. note:: Only the following modes are supported for the quantized inputs:
  262. - `bilinear`
  263. - `nearest`
  264. Args:
  265. input (Tensor): the input tensor
  266. size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
  267. output spatial size.
  268. scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
  269. mode (str): algorithm used for upsampling:
  270. ``'nearest'`` | ``'bilinear'``
  271. align_corners (bool, optional): Geometrically, we consider the pixels of the
  272. input and output as squares rather than points.
  273. If set to ``True``, the input and output tensors are aligned by the
  274. center points of their corner pixels, preserving the values at the corner pixels.
  275. If set to ``False``, the input and output tensors are aligned by the corner
  276. points of their corner pixels, and the interpolation uses edge value padding
  277. for out-of-boundary values, making this operation *independent* of input size
  278. when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
  279. is ``'bilinear'``.
  280. Default: ``False``
  281. """
  282. if not input.is_quantized:
  283. raise ValueError("Input to 'quantized.interpolate' must be quantized!")
  284. return torch.nn.functional.interpolate(input, size, scale_factor, mode,
  285. align_corners)
  286. def linear(
  287. input: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
  288. scale: Optional[float] = None, zero_point: Optional[int] = None
  289. ) -> Tensor:
  290. r"""
  291. Applies a linear transformation to the incoming quantized data:
  292. :math:`y = xA^T + b`.
  293. See :class:`~torch.nn.quantized.Linear`
  294. .. note::
  295. Current implementation packs weights on every call, which has penalty on performance.
  296. If you want to avoid the overhead, use :class:`~torch.nn.quantized.Linear`.
  297. Args:
  298. input (Tensor): Quantized input of type `torch.quint8`
  299. weight (Tensor): Quantized weight of type `torch.qint8`
  300. bias (Tensor): None or fp32 bias of type `torch.float`
  301. scale (double): output scale. If None, derived from the input scale
  302. zero_point (long): output zero point. If None, derived from the input zero_point
  303. Shape:
  304. - Input: :math:`(N, *, in\_features)` where `*` means any number of
  305. additional dimensions
  306. - Weight: :math:`(out\_features, in\_features)`
  307. - Bias: :math:`(out\_features)`
  308. - Output: :math:`(N, *, out\_features)`
  309. """
  310. if scale is None:
  311. scale = input.q_scale()
  312. if zero_point is None:
  313. zero_point = input.q_zero_point()
  314. _packed_params = torch.ops.quantized.linear_prepack(weight, bias)
  315. return torch.ops.quantized.linear(input, _packed_params, scale, zero_point)
  316. def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
  317. ceil_mode=False, return_indices=False):
  318. r"""Applies a 1D max pooling over a quantized input signal composed of
  319. several quantized input planes.
  320. .. note:: The input quantization parameters are propagated to the output.
  321. See :class:`~torch.nn.quantized.MaxPool1d` for details.
  322. """
  323. if return_indices:
  324. raise NotImplementedError("return_indices is not yet implemented!")
  325. if stride is None:
  326. stride = torch.jit.annotate(List[int], [])
  327. return torch.nn.functional.max_pool1d(input, kernel_size, stride, padding,
  328. dilation, ceil_mode=ceil_mode, return_indices=return_indices)
  329. def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
  330. ceil_mode=False, return_indices=False):
  331. r"""Applies a 2D max pooling over a quantized input signal composed of
  332. several quantized input planes.
  333. .. note:: The input quantization parameters are propagated to the output.
  334. See :class:`~torch.nn.quantized.MaxPool2d` for details.
  335. """
  336. if return_indices:
  337. raise NotImplementedError("return_indices is not yet implemented!")
  338. if stride is None:
  339. stride = torch.jit.annotate(List[int], [])
  340. return torch.nn.functional.max_pool2d(input, kernel_size, stride, padding,
  341. dilation, ceil_mode=ceil_mode, return_indices=return_indices)
  342. def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.) -> Tensor:
  343. r"""celu(input, scale, zero_point, alpha=1.) -> Tensor
  344. Applies the quantized CELU function element-wise.
  345. .. math::
  346. \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x / \alpha) - 1))
  347. Args:
  348. input: quantized input
  349. alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
  350. """
  351. if not input.is_quantized:
  352. raise ValueError("Input to 'quantized.celu' must be quantized!")
  353. return torch.ops.quantized.celu(input, scale, zero_point, alpha)
  354. def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False,
  355. scale: Optional[float] = None, zero_point: Optional[int] = None):
  356. r"""
  357. Quantized version of the.
  358. leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor
  359. Applies element-wise,
  360. :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)`
  361. Args:
  362. input: Quaintized input
  363. negative_slope: The slope of the negative input
  364. inplace: Inplace modification of the input tensor
  365. scale, zero_point: Scale and zero point of the output tensor.
  366. See :class:`~torch.nn.LeakyReLU` for more details.
  367. """
  368. if scale is not None and zero_point is not None:
  369. assert not inplace, "Cannot rescale with `inplace`"
  370. output = torch._empty_affine_quantized(
  371. input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype)
  372. torch._C._nn.leaky_relu(input, negative_slope, out=output)
  373. return output
  374. if inplace:
  375. result = torch._C._nn.leaky_relu_(input, negative_slope)
  376. else:
  377. result = torch._C._nn.leaky_relu(input, negative_slope)
  378. return result
  379. def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor:
  380. r"""This is the quantized version of :func:`~torch.nn.functional.hardtanh`.
  381. """
  382. if not input.is_quantized:
  383. raise ValueError("Input to 'quantized.hardtanh' must be quantized!")
  384. if inplace:
  385. return torch._C._nn.hardtanh_(input, min_val, max_val)
  386. return torch._C._nn.hardtanh(input, min_val, max_val)
  387. def hardswish(input: Tensor, scale: float, zero_point: int) -> Tensor:
  388. r"""This is the quantized version of :func:`~torch.nn.functional.hardswish`.
  389. Args:
  390. input: quantized input
  391. scale: quantization scale of the output tensor
  392. zero_point: quantization zero point of the output tensor
  393. """
  394. if not input.is_quantized:
  395. raise ValueError("Input to 'quantized.hardswish' must be quantized!")
  396. return torch._ops.ops.quantized.hardswish(input, scale, zero_point)
  397. def threshold(input: Tensor, threshold: float, value: float) -> Tensor:
  398. r"""Applies the quantized version of the threshold function element-wise:
  399. .. math::
  400. x = \begin{cases}
  401. x & \text{if~} x > \text{threshold} \\
  402. \text{value} & \text{otherwise}
  403. \end{cases}
  404. See :class:`~torch.nn.Threshold` for more details.
  405. """
  406. if not input.is_quantized:
  407. raise ValueError("Input to 'quantized.threshold' must be quantized!")
  408. if threshold is None:
  409. raise ValueError("Input to 'threshold' must be specified!")
  410. if value is None:
  411. raise ValueError("Input to 'value' must be specified!")
  412. return torch._ops.ops.quantized.threshold(input, threshold, value)
  413. def elu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.) -> Tensor:
  414. r"""This is the quantized version of :func:`~torch.nn.functional.elu`.
  415. Args:
  416. input: quantized input
  417. scale: quantization scale of the output tensor
  418. zero_point: quantization zero point of the output tensor
  419. alpha: the alpha constant
  420. """
  421. if not input.is_quantized:
  422. raise ValueError("Input to 'quantized.elu' must be quantized!")
  423. return torch.ops.quantized.elu(input, scale, zero_point, alpha)
  424. def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor:
  425. r"""This is the quantized version of :func:`~torch.nn.functional.hardsigmoid`.
  426. """
  427. if not input.is_quantized:
  428. raise ValueError("Input to 'quantized.hardsigmoid' must be quantized!")
  429. if inplace:
  430. return torch._C._nn.hardsigmoid_(input) # type: ignore[attr-defined]
  431. return torch._C._nn.hardsigmoid(input)
  432. def clamp(input: Tensor, min_: float, max_: float) -> Tensor:
  433. r"""float(input, min\_, max\_) -> Tensor
  434. Applies the clamp function element-wise.
  435. See :class:`~torch.nn.quantized.clamp` for more details.
  436. Args:
  437. input: quantized input
  438. min_: minimum value for clamping
  439. max_: maximum value for clamping
  440. """
  441. if not input.is_quantized:
  442. raise ValueError("Input to 'quantized.clamp' must be quantized!")
  443. return torch.clamp(input, min_, max_)
  444. def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
  445. r"""Upsamples the input to either the given :attr:`size` or the given
  446. :attr:`scale_factor`
  447. .. warning::
  448. This function is deprecated in favor of
  449. :func:`torch.nn.quantized.functional.interpolate`.
  450. This is equivalent with ``nn.quantized.functional.interpolate(...)``.
  451. See :func:`torch.nn.functional.interpolate` for implementation details.
  452. The input dimensions are interpreted in the form:
  453. `mini-batch x channels x [optional depth] x [optional height] x width`.
  454. .. note:: The input quantization parameters propagate to the output.
  455. .. note:: Only 2D input is supported for quantized inputs
  456. .. note:: Only the following modes are supported for the quantized inputs:
  457. - `bilinear`
  458. - `nearest`
  459. Args:
  460. input (Tensor): quantized input tensor
  461. size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
  462. output spatial size.
  463. scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer.
  464. mode (string): algorithm used for upsampling:
  465. ``'nearest'`` | ``'bilinear'``
  466. align_corners (bool, optional): Geometrically, we consider the pixels of the
  467. input and output as squares rather than points.
  468. If set to ``True``, the input and output tensors are aligned by the
  469. center points of their corner pixels, preserving the values at the corner pixels.
  470. If set to ``False``, the input and output tensors are aligned by the corner
  471. points of their corner pixels, and the interpolation uses edge value padding
  472. for out-of-boundary values, making this operation *independent* of input size
  473. when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
  474. is ``'bilinear'``.
  475. Default: ``False``
  476. .. warning::
  477. With ``align_corners = True``, the linearly interpolating modes
  478. (`bilinear`) don't proportionally align the
  479. output and input pixels, and thus the output values can depend on the
  480. input size. This was the default behavior for these modes up to version
  481. 0.3.1. Since then, the default behavior is ``align_corners = False``.
  482. See :class:`~torch.nn.Upsample` for concrete examples on how this
  483. affects the outputs.
  484. """
  485. warnings.warn("nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead.")
  486. return interpolate(input, size, scale_factor, mode, align_corners)
  487. def upsample_bilinear(input, size=None, scale_factor=None):
  488. r"""Upsamples the input, using bilinear upsampling.
  489. .. warning::
  490. This function is deprecated in favor of
  491. :func:`torch.nn.quantized.functional.interpolate`.
  492. This is equivalent with
  493. ``nn.quantized.functional.interpolate(..., mode='bilinear', align_corners=True)``.
  494. .. note:: The input quantization parameters propagate to the output.
  495. .. note:: Only 2D inputs are supported
  496. Args:
  497. input (Tensor): quantized input
  498. size (int or Tuple[int, int]): output spatial size.
  499. scale_factor (int or Tuple[int, int]): multiplier for spatial size
  500. """
  501. # DeprecationWarning is ignored by default
  502. warnings.warn("nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead.")
  503. return interpolate(input, size, scale_factor, mode='bilinear', align_corners=True)
  504. def upsample_nearest(input, size=None, scale_factor=None):
  505. r"""Upsamples the input, using nearest neighbours' pixel values.
  506. .. warning::
  507. This function is deprecated in favor of
  508. :func:`torch.nn.quantized.functional.interpolate`.
  509. This is equivalent with ``nn.quantized.functional.interpolate(..., mode='nearest')``.
  510. .. note:: The input quantization parameters propagate to the output.
  511. .. note:: Only 2D inputs are supported
  512. Args:
  513. input (Tensor): quantized input
  514. size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial
  515. size.
  516. scale_factor (int): multiplier for spatial size. Has to be an integer.
  517. """
  518. # DeprecationWarning is ignored by default
  519. warnings.warn("nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead.")
  520. return interpolate(input, size, scale_factor, mode='nearest')