functional.py 71 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682
  1. from typing import (
  2. List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
  3. )
  4. import torch
  5. from torch._C import _add_docstr
  6. import torch.nn.functional as F
  7. from ._lowrank import svd_lowrank, pca_lowrank
  8. from .overrides import (
  9. has_torch_function, has_torch_function_unary, has_torch_function_variadic,
  10. handle_torch_function)
  11. from ._jit_internal import boolean_dispatch
  12. from ._jit_internal import _overload as overload
  13. Tensor = torch.Tensor
  14. from torch import _VF
  15. __all__ = [
  16. 'atleast_1d',
  17. 'atleast_2d',
  18. 'atleast_3d',
  19. 'align_tensors',
  20. 'broadcast_shapes',
  21. 'broadcast_tensors',
  22. 'cartesian_prod',
  23. 'block_diag',
  24. 'cdist',
  25. 'chain_matmul',
  26. 'einsum',
  27. 'istft',
  28. 'lu',
  29. 'norm',
  30. 'meshgrid',
  31. 'pca_lowrank',
  32. 'split',
  33. 'stft',
  34. 'svd_lowrank',
  35. 'tensordot',
  36. 'unique',
  37. 'unique_consecutive',
  38. ]
  39. def broadcast_tensors(*tensors):
  40. r"""broadcast_tensors(*tensors) -> List of Tensors
  41. Broadcasts the given tensors according to :ref:`broadcasting-semantics`.
  42. Args:
  43. *tensors: any number of tensors of the same type
  44. .. warning::
  45. More than one element of a broadcasted tensor may refer to a single
  46. memory location. As a result, in-place operations (especially ones that
  47. are vectorized) may result in incorrect behavior. If you need to write
  48. to the tensors, please clone them first.
  49. Example::
  50. >>> x = torch.arange(3).view(1, 3)
  51. >>> y = torch.arange(2).view(2, 1)
  52. >>> a, b = torch.broadcast_tensors(x, y)
  53. >>> a.size()
  54. torch.Size([2, 3])
  55. >>> a
  56. tensor([[0, 1, 2],
  57. [0, 1, 2]])
  58. """
  59. # This wrapper exists to support variadic args.
  60. if has_torch_function(tensors):
  61. return handle_torch_function(broadcast_tensors, tensors, *tensors)
  62. return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined]
  63. def broadcast_shapes(*shapes):
  64. r"""broadcast_shapes(*shapes) -> Size
  65. Similar to :func:`broadcast_tensors` but for shapes.
  66. This is equivalent to
  67. ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
  68. but avoids the need create to intermediate tensors. This is useful for
  69. broadcasting tensors of common batch shape but different rightmost shape,
  70. e.g. to broadcast mean vectors with covariance matrices.
  71. Example::
  72. >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
  73. torch.Size([1, 3, 2])
  74. Args:
  75. \*shapes (torch.Size): Shapes of tensors.
  76. Returns:
  77. shape (torch.Size): A shape compatible with all input shapes.
  78. Raises:
  79. RuntimeError: If shapes are incompatible.
  80. """
  81. # This wrapper exists to support variadic args.
  82. # TODO Move this to C++ once the jit has better support for torch.Size.
  83. if not torch.jit.is_tracing():
  84. max_len = 0
  85. for shape in shapes:
  86. if isinstance(shape, int):
  87. if max_len < 1:
  88. max_len = 1
  89. elif isinstance(shape, tuple) or isinstance(shape, list):
  90. s = len(shape)
  91. if max_len < s:
  92. max_len = s
  93. result = [1] * max_len
  94. for shape in shapes:
  95. if isinstance(shape, int):
  96. shape = (shape,)
  97. if isinstance(shape, tuple) or isinstance(shape, list):
  98. for i in range(-1, -1 - len(shape), -1):
  99. if shape[i] < 0:
  100. raise RuntimeError("Trying to create tensor with negative dimension ({}): ({})"
  101. .format(shape[i], shape[i]))
  102. if shape[i] == 1 or shape[i] == result[i]:
  103. continue
  104. if result[i] != 1:
  105. raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
  106. result[i] = shape[i]
  107. else:
  108. raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape)
  109. return torch.Size(result)
  110. else:
  111. # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
  112. with torch.no_grad():
  113. scalar = torch.zeros((), device="cpu")
  114. tensors = [scalar.expand(shape) for shape in shapes]
  115. tensors = broadcast_tensors(*tensors)
  116. return tensors[0].shape
  117. def split(
  118. tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
  119. ) -> List[Tensor]:
  120. r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
  121. If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
  122. be split into equally sized chunks (if possible). Last chunk will be smaller if
  123. the tensor size along the given dimension :attr:`dim` is not divisible by
  124. :attr:`split_size`.
  125. If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
  126. into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
  127. to :attr:`split_size_or_sections`.
  128. Args:
  129. tensor (Tensor): tensor to split.
  130. split_size_or_sections (int) or (list(int)): size of a single chunk or
  131. list of sizes for each chunk
  132. dim (int): dimension along which to split the tensor.
  133. Example::
  134. >>> a = torch.arange(10).reshape(5,2)
  135. >>> a
  136. tensor([[0, 1],
  137. [2, 3],
  138. [4, 5],
  139. [6, 7],
  140. [8, 9]])
  141. >>> torch.split(a, 2)
  142. (tensor([[0, 1],
  143. [2, 3]]),
  144. tensor([[4, 5],
  145. [6, 7]]),
  146. tensor([[8, 9]]))
  147. >>> torch.split(a, [1,4])
  148. (tensor([[0, 1]]),
  149. tensor([[2, 3],
  150. [4, 5],
  151. [6, 7],
  152. [8, 9]]))
  153. """
  154. if has_torch_function_unary(tensor):
  155. return handle_torch_function(
  156. split, (tensor,), tensor, split_size_or_sections, dim=dim)
  157. # Overwriting reason:
  158. # This dispatches to two ATen functions depending on the type of
  159. # split_size_or_sections. The branching code is in _tensor.py, which we
  160. # call here.
  161. return tensor.split(split_size_or_sections, dim)
  162. def einsum(*args: Any) -> Tensor:
  163. r"""einsum(equation, *operands) -> Tensor
  164. Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
  165. based on the Einstein summation convention.
  166. Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
  167. in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
  168. this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
  169. with some subscript and define which subscripts are part of the output. The output is then computed by summing
  170. the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
  171. output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`.
  172. Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).
  173. Equation:
  174. The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of
  175. the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a
  176. comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript
  177. must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is
  178. repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand
  179. must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that
  180. appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.
  181. The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based
  182. on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.
  183. Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation
  184. followed by the subscripts for the output. For instance, the following equation computes the transpose of a
  185. matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and
  186. at most once for the output.
  187. Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.
  188. Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,
  189. e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth
  190. dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the
  191. 'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not
  192. explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),
  193. before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements
  194. batch matrix multiplication `'...ij,...jk'`.
  195. A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,
  196. arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.
  197. .. note::
  198. ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions
  199. covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.
  200. .. note::
  201. This function does not optimize the given expression, so a different formula for the same computation may
  202. run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/)
  203. can optimize the formula for you.
  204. .. note::
  205. As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format,
  206. subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists
  207. follow their operands, and an extra sublist can appear at the end of the input to specify the output's
  208. subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object
  209. may be provided in a sublist to enable broadcasting as described in the Equation section above.
  210. Args:
  211. equation (string): The subscripts for the Einstein summation.
  212. operands (List[Tensor]): The tensors to compute the Einstein summation of.
  213. Examples::
  214. # trace
  215. >>> torch.einsum('ii', torch.randn(4, 4))
  216. tensor(-1.2104)
  217. # diagonal
  218. >>> torch.einsum('ii->i', torch.randn(4, 4))
  219. tensor([-0.1034, 0.7952, -0.2433, 0.4545])
  220. # outer product
  221. >>> x = torch.randn(5)
  222. >>> y = torch.randn(4)
  223. >>> torch.einsum('i,j->ij', x, y)
  224. tensor([[ 0.1156, -0.2897, -0.3918, 0.4963],
  225. [-0.3744, 0.9381, 1.2685, -1.6070],
  226. [ 0.7208, -1.8058, -2.4419, 3.0936],
  227. [ 0.1713, -0.4291, -0.5802, 0.7350],
  228. [ 0.5704, -1.4290, -1.9323, 2.4480]])
  229. # batch matrix multiplication
  230. >>> As = torch.randn(3,2,5)
  231. >>> Bs = torch.randn(3,5,4)
  232. >>> torch.einsum('bij,bjk->bik', As, Bs)
  233. tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
  234. [-1.6706, -0.8097, -0.8025, -2.1183]],
  235. [[ 4.2239, 0.3107, -0.5756, -0.2354],
  236. [-1.4558, -0.3460, 1.5087, -0.8530]],
  237. [[ 2.8153, 1.8787, -4.3839, -1.2112],
  238. [ 0.3728, -2.1131, 0.0921, 0.8305]]])
  239. # with sublist format and ellipsis
  240. >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
  241. tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
  242. [-1.6706, -0.8097, -0.8025, -2.1183]],
  243. [[ 4.2239, 0.3107, -0.5756, -0.2354],
  244. [-1.4558, -0.3460, 1.5087, -0.8530]],
  245. [[ 2.8153, 1.8787, -4.3839, -1.2112],
  246. [ 0.3728, -2.1131, 0.0921, 0.8305]]])
  247. # batch permute
  248. >>> A = torch.randn(2, 3, 4, 5)
  249. >>> torch.einsum('...ij->...ji', A).shape
  250. torch.Size([2, 3, 5, 4])
  251. # equivalent to torch.nn.functional.bilinear
  252. >>> A = torch.randn(3,5,4)
  253. >>> l = torch.randn(2,5)
  254. >>> r = torch.randn(2,4)
  255. >>> torch.einsum('bn,anm,bm->ba', l, A, r)
  256. tensor([[-0.3430, -5.2405, 0.4494],
  257. [ 0.3311, 5.5201, -3.0356]])
  258. """
  259. # This wrapper exists to support variadic args.
  260. if len(args) < 2:
  261. raise ValueError('einsum(): must specify the equation string and at least one operand, '
  262. 'or at least one operand and its subscripts list')
  263. equation = None
  264. operands = None
  265. if isinstance(args[0], torch.Tensor):
  266. # Convert the subscript list format which is an interleaving of operand and its subscripts
  267. # list with an optional output subscripts list at the end (see documentation for more details on this)
  268. # to the equation string format by creating the equation string from the subscripts list and grouping the
  269. # input operands into a tensorlist (List[Tensor]).
  270. def parse_subscript(n: int) -> str:
  271. if n == Ellipsis:
  272. return '...'
  273. if n >= 0 and n < 26:
  274. return chr(ord('A') + n)
  275. if n >= 26 and n < 52:
  276. return chr(ord('a') + n - 26)
  277. raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)')
  278. # Parse subscripts for input operands
  279. equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2])
  280. # Parse optional output subscripts (provided when the number of arguments is odd)
  281. if len(args) % 2 == 1:
  282. equation += '->' + ''.join(parse_subscript(s) for s in args[-1])
  283. operands = args[:-1:2]
  284. else:
  285. operands = args[::2]
  286. else:
  287. equation = args[0]
  288. operands = args[1:]
  289. if has_torch_function(operands):
  290. return handle_torch_function(einsum, operands, equation, *operands)
  291. if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
  292. # the old interface of passing the operands as one list argument
  293. _operands = operands[0]
  294. # recurse incase operands contains value that has torch function
  295. # in the original implementation this line is omitted
  296. return einsum(equation, *_operands)
  297. return _VF.einsum(equation, operands) # type: ignore[attr-defined]
  298. # This wrapper exists to support variadic args.
  299. if TYPE_CHECKING:
  300. # The JIT doesn't understand Union, so only add type annotation for mypy
  301. def meshgrid(*tensors: Union[Tensor, List[Tensor]],
  302. indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
  303. return _meshgrid(*tensors, indexing=indexing)
  304. else:
  305. def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
  306. r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
  307. This is helpful when you want to visualize data over some
  308. range of inputs. See below for a plotting example.
  309. Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as
  310. inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`,
  311. this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots
  312. G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where
  313. the output :math:`G_i` is constructed by expanding :math:`T_i`
  314. to the result shape.
  315. .. note::
  316. 0D inputs are treated equivalently to 1D inputs of a
  317. single element.
  318. .. warning::
  319. `torch.meshgrid(*tensors)` currently has the same behavior
  320. as calling `numpy.meshgrid(*arrays, indexing='ij')`.
  321. In the future `torch.meshgrid` will transition to
  322. `indexing='xy'` as the default.
  323. https://github.com/pytorch/pytorch/issues/50276 tracks
  324. this issue with the goal of migrating to NumPy's behavior.
  325. .. seealso::
  326. :func:`torch.cartesian_prod` has the same effect but it
  327. collects the data in a tensor of vectors.
  328. Args:
  329. tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
  330. treated as tensors of size :math:`(1,)` automatically
  331. indexing: (str, optional): the indexing mode, either "xy"
  332. or "ij", defaults to "ij". See warning for future changes.
  333. If "xy" is selected, the first dimension corresponds
  334. to the cardinality of the second input and the second
  335. dimension corresponds to the cardinality of the first
  336. input.
  337. If "ij" is selected, the dimensions are in the same
  338. order as the cardinality of the inputs.
  339. Returns:
  340. seq (sequence of Tensors): If the input has :math:`N`
  341. tensors of size :math:`S_0 \ldots S_{N-1}``, then the
  342. output will also have :math:`N` tensors, where each tensor
  343. is of shape :math:`(S_0, ..., S_{N-1})`.
  344. Example::
  345. >>> x = torch.tensor([1, 2, 3])
  346. >>> y = torch.tensor([4, 5, 6])
  347. Observe the element-wise pairings across the grid, (1, 4),
  348. (1, 5), ..., (3, 6). This is the same thing as the
  349. cartesian product.
  350. >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
  351. >>> grid_x
  352. tensor([[1, 1, 1],
  353. [2, 2, 2],
  354. [3, 3, 3]])
  355. >>> grid_y
  356. tensor([[4, 5, 6],
  357. [4, 5, 6],
  358. [4, 5, 6]])
  359. This correspondence can be seen when these grids are
  360. stacked properly.
  361. >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))),
  362. ... torch.cartesian_prod(x, y))
  363. True
  364. `torch.meshgrid` is commonly used to produce a grid for
  365. plotting.
  366. >>> import matplotlib.pyplot as plt
  367. >>> xs = torch.linspace(-5, 5, steps=100)
  368. >>> ys = torch.linspace(-5, 5, steps=100)
  369. >>> x, y = torch.meshgrid(xs, ys, indexing='xy')
  370. >>> z = torch.sin(torch.sqrt(x * x + y * y))
  371. >>> ax = plt.axes(projection='3d')
  372. >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
  373. <mpl_toolkits.mplot3d.art3d.Poly3DCollection object at 0x7f8f30d40100>
  374. >>> plt.show()
  375. .. image:: ../_static/img/meshgrid.png
  376. :width: 512
  377. """
  378. return _meshgrid(*tensors, indexing=indexing)
  379. def _meshgrid(*tensors, indexing: Optional[str]):
  380. if has_torch_function(tensors):
  381. return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing)
  382. if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
  383. # the old interface of passing the operands as one list argument
  384. tensors = tensors[0] # type: ignore[assignment]
  385. # Continue allowing call of old method that takes no indexing
  386. # kwarg for forward compatibility reasons.
  387. #
  388. # Remove this two weeks after landing.
  389. kwargs = {} if indexing is None else {'indexing': indexing}
  390. return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
  391. def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
  392. win_length: Optional[int] = None, window: Optional[Tensor] = None,
  393. center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
  394. onesided: Optional[bool] = None,
  395. return_complex: Optional[bool] = None) -> Tensor:
  396. r"""Short-time Fourier transform (STFT).
  397. .. warning::
  398. From version 1.8.0, :attr:`return_complex` must always be given
  399. explicitly for real inputs and `return_complex=False` has been
  400. deprecated. Strongly prefer `return_complex=True` as in a future
  401. pytorch release, this function will only return complex tensors.
  402. Note that :func:`torch.view_as_real` can be used to recover a real
  403. tensor with an extra last dimension for real and imaginary components.
  404. The STFT computes the Fourier transform of short overlapping windows of the
  405. input. This giving frequency components of the signal as they change over
  406. time. The interface of this function is modeled after (but *not* a drop-in
  407. replacement for) librosa_ stft function.
  408. .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
  409. Ignoring the optional batch dimension, this method computes the following
  410. expression:
  411. .. math::
  412. X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
  413. \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
  414. \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right),
  415. where :math:`m` is the index of the sliding window, and :math:`\omega` is
  416. the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
  417. or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
  418. * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
  419. sequences.
  420. * If :attr:`hop_length` is ``None`` (default), it is treated as equal to
  421. ``floor(n_fft / 4)``.
  422. * If :attr:`win_length` is ``None`` (default), it is treated as equal to
  423. :attr:`n_fft`.
  424. * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
  425. :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
  426. treated as if having :math:`1` everywhere in the window. If
  427. :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
  428. both sides to length :attr:`n_fft` before being applied.
  429. * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
  430. both sides so that the :math:`t`-th frame is centered at time
  431. :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
  432. begins at time :math:`t \times \text{hop\_length}`.
  433. * :attr:`pad_mode` determines the padding method used on :attr:`input` when
  434. :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
  435. all available options. Default is ``"reflect"``.
  436. * If :attr:`onesided` is ``True`` (default for real input), only values for
  437. :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
  438. \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
  439. the real-to-complex Fourier transform satisfies the conjugate symmetry,
  440. i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
  441. Note if the input or window tensors are complex, then :attr:`onesided`
  442. output is not possible.
  443. * If :attr:`normalized` is ``True`` (default is ``False``), the function
  444. returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
  445. * If :attr:`return_complex` is ``True`` (default if input is complex), the
  446. return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
  447. the output is a ``input.dim() + 2`` dimensional real tensor where the last
  448. dimension represents the real and imaginary components.
  449. Returns either a complex tensor of size :math:`(* \times N \times T)` if
  450. :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
  451. \times T \times 2)`. Where :math:`*` is the optional batch size of
  452. :attr:`input`, :math:`N` is the number of frequencies where STFT is applied
  453. and :math:`T` is the total number of frames used.
  454. .. warning::
  455. This function changed signature at version 0.4.1. Calling with the
  456. previous signature may cause error or return incorrect result.
  457. Args:
  458. input (Tensor): the input tensor
  459. n_fft (int): size of Fourier transform
  460. hop_length (int, optional): the distance between neighboring sliding window
  461. frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
  462. win_length (int, optional): the size of window frame and STFT filter.
  463. Default: ``None`` (treated as equal to :attr:`n_fft`)
  464. window (Tensor, optional): the optional window function.
  465. Default: ``None`` (treated as window of all :math:`1` s)
  466. center (bool, optional): whether to pad :attr:`input` on both sides so
  467. that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
  468. Default: ``True``
  469. pad_mode (string, optional): controls the padding method used when
  470. :attr:`center` is ``True``. Default: ``"reflect"``
  471. normalized (bool, optional): controls whether to return the normalized STFT results
  472. Default: ``False``
  473. onesided (bool, optional): controls whether to return half of results to
  474. avoid redundancy for real inputs.
  475. Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
  476. return_complex (bool, optional): whether to return a complex tensor, or
  477. a real tensor with an extra last dimension for the real and
  478. imaginary components.
  479. Returns:
  480. Tensor: A tensor containing the STFT result with shape described above
  481. """
  482. if has_torch_function_unary(input):
  483. return handle_torch_function(
  484. stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
  485. window=window, center=center, pad_mode=pad_mode, normalized=normalized,
  486. onesided=onesided, return_complex=return_complex)
  487. # NOTE: Do not edit. This code will be removed once the forward-compatibility
  488. # period is over for PR #73432
  489. if center:
  490. signal_dim = input.dim()
  491. extended_shape = [1] * (3 - signal_dim) + list(input.size())
  492. pad = int(n_fft // 2)
  493. input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
  494. input = input.view(input.shape[-signal_dim:])
  495. return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
  496. normalized, onesided, return_complex)
  497. istft = _add_docstr(
  498. torch.istft,
  499. "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
  500. "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n"
  501. r"""
  502. Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.
  503. It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the
  504. least squares estimation of the original signal. The algorithm will check using the NOLA condition (
  505. nonzero overlap).
  506. Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelop
  507. created by the summation of all the windows is never zero at certain point in time. Specifically,
  508. :math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`.
  509. Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,
  510. ``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False
  511. since the signal isn't padded). If `length` is given in the arguments and is longer than expected,
  512. ``istft`` will pad zeros to the end of the returned signal.
  513. If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.
  514. Left padding can be trimmed off exactly because they can be calculated but right padding cannot be
  515. calculated without additional information.
  516. Example: Suppose the last window is:
  517. ``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``
  518. The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation
  519. of right padding. These additional values could be zeros or a reflection of the signal so providing
  520. :attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed
  521. (some loss of signal).
  522. [1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
  523. IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
  524. Args:
  525. input (Tensor): The input tensor. Expected to be output of :func:`~torch.stft`,
  526. can either be complex (``channel``, ``fft_size``, ``n_frame``), or real
  527. (``channel``, ``fft_size``, ``n_frame``, 2) where the ``channel``
  528. dimension is optional.
  529. .. deprecated:: 1.8.0
  530. Real input is deprecated, use complex inputs as returned by
  531. ``stft(..., return_complex=True)`` instead.
  532. n_fft (int): Size of Fourier transform
  533. hop_length (Optional[int]): The distance between neighboring sliding window frames.
  534. (Default: ``n_fft // 4``)
  535. win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
  536. window (Optional[torch.Tensor]): The optional window function.
  537. (Default: ``torch.ones(win_length)``)
  538. center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is
  539. centered at time :math:`t \times \text{hop\_length}`.
  540. (Default: ``True``)
  541. normalized (bool): Whether the STFT was normalized. (Default: ``False``)
  542. onesided (Optional[bool]): Whether the STFT was onesided.
  543. (Default: ``True`` if ``n_fft != fft_size`` in the input size)
  544. length (Optional[int]): The amount to trim the signal by (i.e. the
  545. original signal length). (Default: whole signal)
  546. return_complex (Optional[bool]):
  547. Whether the output should be complex, or if the input should be
  548. assumed to derive from a real signal and window.
  549. Note that this is incompatible with ``onesided=True``.
  550. (Default: ``False``)
  551. Returns:
  552. Tensor: Least squares estimation of the original signal of size (..., signal_length)
  553. """)
  554. if TYPE_CHECKING:
  555. # These _impl functions return a variable number of tensors as output with
  556. # __torch_function__; tuple unpacking is done already rather than being
  557. # done by the caller of the _impl function
  558. _unique_impl_out = Any
  559. else:
  560. _unique_impl_out = Tuple[Tensor, Tensor, Tensor]
  561. def _unique_impl(input: Tensor, sorted: bool = True,
  562. return_inverse: bool = False, return_counts: bool = False,
  563. dim: Optional[int] = None) -> _unique_impl_out:
  564. r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
  565. Returns the unique elements of the input tensor.
  566. .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that
  567. this function also eliminates non-consecutive duplicate values.
  568. .. note:: Currently in the CUDA implementation and the CPU implementation when dim is specified,
  569. `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.
  570. Sorting could be slow, so if your input tensor is already sorted, it is recommended to use
  571. :func:`torch.unique_consecutive` which avoids the sorting.
  572. Args:
  573. input (Tensor): the input tensor
  574. sorted (bool): Whether to sort the unique elements in ascending order
  575. before returning as output.
  576. return_inverse (bool): Whether to also return the indices for where
  577. elements in the original input ended up in the returned unique list.
  578. return_counts (bool): Whether to also return the counts for each unique
  579. element.
  580. dim (int): the dimension to apply unique. If ``None``, the unique of the
  581. flattened input is returned. default: ``None``
  582. Returns:
  583. (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
  584. - **output** (*Tensor*): the output list of unique scalar elements.
  585. - **inverse_indices** (*Tensor*): (optional) if
  586. :attr:`return_inverse` is True, there will be an additional
  587. returned tensor (same shape as input) representing the indices
  588. for where elements in the original input map to in the output;
  589. otherwise, this function will only return a single tensor.
  590. - **counts** (*Tensor*): (optional) if
  591. :attr:`return_counts` is True, there will be an additional
  592. returned tensor (same shape as output or output.size(dim),
  593. if dim was specified) representing the number of occurrences
  594. for each unique value or tensor.
  595. Example::
  596. >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))
  597. >>> output
  598. tensor([ 2, 3, 1])
  599. >>> output, inverse_indices = torch.unique(
  600. ... torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)
  601. >>> output
  602. tensor([ 1, 2, 3])
  603. >>> inverse_indices
  604. tensor([ 0, 2, 1, 2])
  605. >>> output, inverse_indices = torch.unique(
  606. ... torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)
  607. >>> output
  608. tensor([ 1, 2, 3])
  609. >>> inverse_indices
  610. tensor([[ 0, 2],
  611. [ 1, 2]])
  612. """
  613. if has_torch_function_unary(input):
  614. return handle_torch_function(
  615. unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
  616. return_counts=return_counts, dim=dim)
  617. if dim is not None:
  618. output, inverse_indices, counts = _VF.unique_dim(
  619. input,
  620. dim,
  621. sorted=sorted,
  622. return_inverse=return_inverse,
  623. return_counts=return_counts,
  624. )
  625. else:
  626. output, inverse_indices, counts = torch._unique2(
  627. input,
  628. sorted=sorted,
  629. return_inverse=return_inverse,
  630. return_counts=return_counts,
  631. )
  632. return output, inverse_indices, counts
  633. def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
  634. return_counts: bool = False,
  635. dim: Optional[int] = None) -> _unique_impl_out:
  636. r"""Eliminates all but the first element from every consecutive group of equivalent elements.
  637. .. note:: This function is different from :func:`torch.unique` in the sense that this function
  638. only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
  639. in C++.
  640. Args:
  641. input (Tensor): the input tensor
  642. return_inverse (bool): Whether to also return the indices for where
  643. elements in the original input ended up in the returned unique list.
  644. return_counts (bool): Whether to also return the counts for each unique
  645. element.
  646. dim (int): the dimension to apply unique. If ``None``, the unique of the
  647. flattened input is returned. default: ``None``
  648. Returns:
  649. (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
  650. - **output** (*Tensor*): the output list of unique scalar elements.
  651. - **inverse_indices** (*Tensor*): (optional) if
  652. :attr:`return_inverse` is True, there will be an additional
  653. returned tensor (same shape as input) representing the indices
  654. for where elements in the original input map to in the output;
  655. otherwise, this function will only return a single tensor.
  656. - **counts** (*Tensor*): (optional) if
  657. :attr:`return_counts` is True, there will be an additional
  658. returned tensor (same shape as output or output.size(dim),
  659. if dim was specified) representing the number of occurrences
  660. for each unique value or tensor.
  661. Example::
  662. >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
  663. >>> output = torch.unique_consecutive(x)
  664. >>> output
  665. tensor([1, 2, 3, 1, 2])
  666. >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
  667. >>> output
  668. tensor([1, 2, 3, 1, 2])
  669. >>> inverse_indices
  670. tensor([0, 0, 1, 1, 2, 3, 3, 4])
  671. >>> output, counts = torch.unique_consecutive(x, return_counts=True)
  672. >>> output
  673. tensor([1, 2, 3, 1, 2])
  674. >>> counts
  675. tensor([2, 2, 1, 2, 1])
  676. """
  677. if has_torch_function_unary(input):
  678. return handle_torch_function(
  679. unique_consecutive, (input,), input, return_inverse=return_inverse,
  680. return_counts=return_counts, dim=dim)
  681. output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined]
  682. input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
  683. return output, inverse_indices, counts
  684. def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
  685. # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
  686. if has_torch_function_unary(input):
  687. return _unique_impl(input, sorted, return_inverse, return_counts, dim)
  688. output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
  689. return output, counts
  690. def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
  691. # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
  692. if has_torch_function_unary(input):
  693. return _unique_impl(input, sorted, return_inverse, return_counts, dim)
  694. output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
  695. return output
  696. def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
  697. # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
  698. if has_torch_function_unary(input):
  699. return _unique_impl(input, sorted, return_inverse, return_counts, dim)
  700. output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
  701. return output, inverse_indices
  702. _return_inverse_false = boolean_dispatch(
  703. arg_name='return_counts',
  704. arg_index=3,
  705. default=False,
  706. if_true=_return_counts,
  707. if_false=_return_output,
  708. module_name=__name__,
  709. func_name='unique')
  710. _return_inverse_true = boolean_dispatch(
  711. arg_name='return_counts',
  712. arg_index=3,
  713. default=False,
  714. if_true=_unique_impl,
  715. if_false=_return_inverse,
  716. module_name=__name__,
  717. func_name='unique')
  718. # The return type of unique depends on `return_inverse`, and `return_counts` so in order to
  719. # resolve the output type in TorchScript we need to statically know the value of both parameters
  720. unique = boolean_dispatch(
  721. arg_name='return_inverse',
  722. arg_index=2,
  723. default=False,
  724. if_true=_return_inverse_true,
  725. if_false=_return_inverse_false,
  726. module_name=__name__,
  727. func_name='unique')
  728. unique.__doc__ = _unique_impl.__doc__
  729. def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):
  730. # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
  731. if has_torch_function_unary(input):
  732. return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  733. output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  734. return output, counts
  735. def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
  736. # type: (Tensor, bool, bool, Optional[int]) -> Tensor
  737. if has_torch_function_unary(input):
  738. return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  739. output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  740. return output
  741. def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
  742. # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
  743. if has_torch_function_unary(input):
  744. return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  745. output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  746. return output, inverse_indices
  747. _consecutive_return_inverse_false = boolean_dispatch(
  748. arg_name='return_counts',
  749. arg_index=1,
  750. default=False,
  751. if_true=_consecutive_return_counts,
  752. if_false=_consecutive_return_output,
  753. module_name=__name__,
  754. func_name='unique_consecutive')
  755. _consecutive_return_inverse_true = boolean_dispatch(
  756. arg_name='return_counts',
  757. arg_index=1,
  758. default=False,
  759. if_true=_unique_consecutive_impl,
  760. if_false=_consecutive_return_inverse,
  761. module_name=__name__,
  762. func_name='unique_consecutive')
  763. # The return type of unique depends on `return_inverse`, and `return_counts` so in order to
  764. # resolve the output type in TorchScript we need to statically know the value of both parameters
  765. unique_consecutive = boolean_dispatch(
  766. arg_name='return_inverse',
  767. arg_index=2,
  768. default=False,
  769. if_true=_consecutive_return_inverse_true,
  770. if_false=_consecutive_return_inverse_false,
  771. module_name=__name__,
  772. func_name='unique_consecutive')
  773. unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
  774. if TYPE_CHECKING:
  775. pass
  776. # There's no good way to use this type annotation without breaking JIT
  777. # overloads. So leave untyped for mypy for now.
  778. else:
  779. @overload
  780. def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):
  781. pass
  782. @overload # noqa: F811
  783. def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None): # noqa: F811
  784. pass
  785. @overload # noqa: F811
  786. def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None): # noqa: F811
  787. pass
  788. @overload # noqa: F811
  789. def tensordot(a, b, dims: torch.Tensor, out: Optional[torch.Tensor] = None): # noqa: F811
  790. pass
  791. def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
  792. r"""Returns a contraction of a and b over multiple dimensions.
  793. :attr:`tensordot` implements a generalized matrix product.
  794. Args:
  795. a (Tensor): Left tensor to contract
  796. b (Tensor): Right tensor to contract
  797. dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
  798. contract or explicit lists of dimensions for :attr:`a` and
  799. :attr:`b` respectively
  800. When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
  801. the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
  802. respectively, :func:`~torch.tensordot` computes
  803. .. math::
  804. r_{i_0,...,i_{m-d}, i_d,...,i_n}
  805. = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
  806. When called with :attr:`dims` of the list form, the given dimensions will be contracted
  807. in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
  808. in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted
  809. dimensions.
  810. Examples::
  811. >>> a = torch.arange(60.).reshape(3, 4, 5)
  812. >>> b = torch.arange(24.).reshape(4, 3, 2)
  813. >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
  814. tensor([[4400., 4730.],
  815. [4532., 4874.],
  816. [4664., 5018.],
  817. [4796., 5162.],
  818. [4928., 5306.]])
  819. >>> a = torch.randn(3, 4, 5, device='cuda')
  820. >>> b = torch.randn(4, 5, 6, device='cuda')
  821. >>> c = torch.tensordot(a, b, dims=2).cpu()
  822. tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741],
  823. [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744],
  824. [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]])
  825. >>> a = torch.randn(3, 5, 4, 6)
  826. >>> b = torch.randn(6, 4, 5, 3)
  827. >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0]))
  828. tensor([[ 7.7193, -2.4867, -10.3204],
  829. [ 1.5513, -14.4737, -6.5113],
  830. [ -0.2850, 4.2573, -3.5997]])
  831. """
  832. if has_torch_function_variadic(a, b):
  833. return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
  834. if not isinstance(dims, (tuple, list, torch.Tensor, int)):
  835. raise RuntimeError("tensordot expects dims to be int or "
  836. + "Tuple[List[int], List[int]] or "
  837. + "List[List[int]] containing two lists, but got "
  838. + f"dims={dims}")
  839. dims_a: List[int] = []
  840. dims_b: List[int] = []
  841. if isinstance(dims, (tuple, list)):
  842. dims_a, dims_b = dims
  843. if isinstance(dims, torch.Tensor):
  844. num_elements = dims.numel()
  845. if num_elements > 1:
  846. assert dims.size()[0] == 2
  847. dims_a = torch.jit.annotate(List[int], dims[0].tolist())
  848. dims_b = torch.jit.annotate(List[int], dims[1].tolist())
  849. else:
  850. dims_val = int(dims.item())
  851. if dims_val < 0:
  852. raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
  853. dims_a = list(range(-dims_val, 0))
  854. dims_b = list(range(dims_val))
  855. if isinstance(dims, int):
  856. if dims < 0:
  857. raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
  858. dims_a = list(range(-dims, 0))
  859. dims_b = list(range(dims))
  860. if out is None:
  861. return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore[attr-defined]
  862. else:
  863. return _VF.tensordot(a, b, dims_a, dims_b, out=out) # type: ignore[attr-defined]
  864. def cartesian_prod(*tensors):
  865. """Do cartesian product of the given sequence of tensors. The behavior is similar to
  866. python's `itertools.product`.
  867. Args:
  868. *tensors: any number of 1 dimensional tensors.
  869. Returns:
  870. Tensor: A tensor equivalent to converting all the input tensors into lists,
  871. do `itertools.product` on these lists, and finally convert the resulting list
  872. into tensor.
  873. Example::
  874. >>> a = [1, 2, 3]
  875. >>> b = [4, 5]
  876. >>> list(itertools.product(a, b))
  877. [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
  878. >>> tensor_a = torch.tensor(a)
  879. >>> tensor_b = torch.tensor(b)
  880. >>> torch.cartesian_prod(tensor_a, tensor_b)
  881. tensor([[1, 4],
  882. [1, 5],
  883. [2, 4],
  884. [2, 5],
  885. [3, 4],
  886. [3, 5]])
  887. """
  888. # This wrapper exists to support variadic args.
  889. if has_torch_function(tensors):
  890. return handle_torch_function(cartesian_prod, tensors, *tensors)
  891. return _VF.cartesian_prod(tensors) # type: ignore[attr-defined]
  892. def block_diag(*tensors):
  893. """Create a block diagonal matrix from provided tensors.
  894. Args:
  895. *tensors: One or more tensors with 0, 1, or 2 dimensions.
  896. Returns:
  897. Tensor: A 2 dimensional tensor with all the input tensors arranged in
  898. order such that their upper left and lower right corners are
  899. diagonally adjacent. All other elements are set to 0.
  900. Example::
  901. >>> import torch
  902. >>> A = torch.tensor([[0, 1], [1, 0]])
  903. >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])
  904. >>> C = torch.tensor(7)
  905. >>> D = torch.tensor([1, 2, 3])
  906. >>> E = torch.tensor([[4], [5], [6]])
  907. >>> torch.block_diag(A, B, C, D, E)
  908. tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
  909. [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  910. [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],
  911. [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],
  912. [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],
  913. [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],
  914. [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
  915. [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
  916. [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])
  917. """
  918. # This wrapper exists to support variadic args.
  919. if has_torch_function(tensors):
  920. return handle_torch_function(block_diag, tensors, *tensors)
  921. return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined]
  922. def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
  923. # type: (Tensor, Tensor, float, str) -> (Tensor)
  924. r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
  925. Args:
  926. x1 (Tensor): input tensor of shape :math:`B \times P \times M`.
  927. x2 (Tensor): input tensor of shape :math:`B \times R \times M`.
  928. p: p value for the p-norm distance to calculate between each vector pair
  929. :math:`\in [0, \infty]`.
  930. compute_mode:
  931. 'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate
  932. euclidean distance (p = 2) if P > 25 or R > 25
  933. 'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate
  934. euclidean distance (p = 2)
  935. 'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate
  936. euclidean distance (p = 2)
  937. Default: use_mm_for_euclid_dist_if_necessary.
  938. If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the
  939. output will have shape :math:`B \times P \times R`.
  940. This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`
  941. if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to
  942. `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest
  943. scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.
  944. Example:
  945. >>> a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
  946. >>> a
  947. tensor([[ 0.9041, 0.0196],
  948. [-0.3108, -2.4423],
  949. [-0.4821, 1.0590]])
  950. >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
  951. >>> b
  952. tensor([[-2.1763, -0.4713],
  953. [-0.6986, 1.3702]])
  954. >>> torch.cdist(a, b, p=2)
  955. tensor([[3.1193, 2.0959],
  956. [2.7138, 3.8322],
  957. [2.2830, 0.3791]])
  958. """
  959. if has_torch_function_variadic(x1, x2):
  960. return handle_torch_function(
  961. cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
  962. if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
  963. return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
  964. elif compute_mode == 'use_mm_for_euclid_dist':
  965. return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined]
  966. elif compute_mode == 'donot_use_mm_for_euclid_dist':
  967. return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined]
  968. else:
  969. raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
  970. def atleast_1d(*tensors):
  971. r"""
  972. Returns a 1-dimensional view of each input tensor with zero dimensions.
  973. Input tensors with one or more dimensions are returned as-is.
  974. Args:
  975. input (Tensor or list of Tensors)
  976. Returns:
  977. output (Tensor or tuple of Tensors)
  978. Example::
  979. >>> x = torch.randn(2)
  980. >>> x
  981. tensor([1.4584, 0.7583])
  982. >>> torch.atleast_1d(x)
  983. tensor([1.4584, 0.7583])
  984. >>> x = torch.tensor(1.)
  985. >>> x
  986. tensor(1.)
  987. >>> torch.atleast_1d(x)
  988. tensor([1.])
  989. >>> x = torch.tensor(0.5)
  990. >>> y = torch.tensor(1.)
  991. >>> torch.atleast_1d((x,y))
  992. (tensor([0.5000]), tensor([1.]))
  993. """
  994. # This wrapper exists to support variadic args.
  995. if has_torch_function(tensors):
  996. return handle_torch_function(atleast_1d, tensors, *tensors)
  997. if len(tensors) == 1:
  998. tensors = tensors[0]
  999. return _VF.atleast_1d(tensors) # type: ignore[attr-defined]
  1000. def atleast_2d(*tensors):
  1001. r"""
  1002. Returns a 2-dimensional view of each input tensor with zero dimensions.
  1003. Input tensors with two or more dimensions are returned as-is.
  1004. Args:
  1005. input (Tensor or list of Tensors)
  1006. Returns:
  1007. output (Tensor or tuple of Tensors)
  1008. Example::
  1009. >>> x = torch.tensor(1.)
  1010. >>> x
  1011. tensor(1.)
  1012. >>> torch.atleast_2d(x)
  1013. tensor([[1.]])
  1014. >>> x = torch.randn(2,2)
  1015. >>> x
  1016. tensor([[2.2086, 2.5165],
  1017. [0.1757, 0.5194]])
  1018. >>> torch.atleast_2d(x)
  1019. tensor([[2.2086, 2.5165],
  1020. [0.1757, 0.5194]])
  1021. >>> x = torch.tensor(0.5)
  1022. >>> y = torch.tensor(1.)
  1023. >>> torch.atleast_2d((x,y))
  1024. (tensor([[0.5000]]), tensor([[1.]]))
  1025. """
  1026. # This wrapper exists to support variadic args.
  1027. if has_torch_function(tensors):
  1028. return handle_torch_function(atleast_2d, tensors, *tensors)
  1029. if len(tensors) == 1:
  1030. tensors = tensors[0]
  1031. return _VF.atleast_2d(tensors) # type: ignore[attr-defined]
  1032. def atleast_3d(*tensors):
  1033. r"""
  1034. Returns a 3-dimensional view of each input tensor with zero dimensions.
  1035. Input tensors with three or more dimensions are returned as-is.
  1036. Args:
  1037. input (Tensor or list of Tensors)
  1038. Returns:
  1039. output (Tensor or tuple of Tensors)
  1040. Example:
  1041. >>> x = torch.tensor(0.5)
  1042. >>> x
  1043. tensor(0.5000)
  1044. >>> torch.atleast_3d(x)
  1045. tensor([[[0.5000]]])
  1046. >>> y = torch.randn(2,2)
  1047. >>> y
  1048. tensor([[-0.8079, 0.7460],
  1049. [-1.1647, 1.4734]])
  1050. >>> torch.atleast_3d(y)
  1051. tensor([[[-0.8079],
  1052. [ 0.7460]],
  1053. <BLANKLINE>
  1054. [[-1.1647],
  1055. [ 1.4734]]])
  1056. >>> x = torch.randn(1,1,1)
  1057. >>> x
  1058. tensor([[[-1.5689]]])
  1059. >>> torch.atleast_3d(x)
  1060. tensor([[[-1.5689]]])
  1061. >>> x = torch.tensor(0.5)
  1062. >>> y = torch.tensor(1.)
  1063. >>> torch.atleast_3d((x,y))
  1064. (tensor([[[0.5000]]]), tensor([[[1.]]]))
  1065. """
  1066. # This wrapper exists to support variadic args.
  1067. if has_torch_function(tensors):
  1068. return handle_torch_function(atleast_3d, tensors, *tensors)
  1069. if len(tensors) == 1:
  1070. tensors = tensors[0]
  1071. return _VF.atleast_3d(tensors) # type: ignore[attr-defined]
  1072. if TYPE_CHECKING:
  1073. pass
  1074. # There's no good way to use this type annotation; cannot rename norm() to
  1075. # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped
  1076. # for mypy for now.
  1077. # def norm(input: Tensor,
  1078. # p: Optional[Union[str, Number]] = "fro",
  1079. # dim: Optional[Union[int, List[int]]] = None,
  1080. # keepdim: bool = False,
  1081. # out: Optional[Tensor] = None,
  1082. # dtype: _dtype = None) -> Tensor:
  1083. # return _norm_impl(input, p, dim, keepdim, out, dtype)
  1084. else:
  1085. # TODO: type dim as BroadcastingList when
  1086. # https://github.com/pytorch/pytorch/issues/33782 is fixed
  1087. @overload
  1088. def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
  1089. # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
  1090. pass
  1091. @overload # noqa: F811
  1092. def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
  1093. # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
  1094. pass
  1095. @overload # noqa: F811
  1096. def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
  1097. # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
  1098. pass
  1099. @overload # noqa: F811
  1100. def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
  1101. # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
  1102. pass
  1103. def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
  1104. r"""Returns the matrix norm or vector norm of a given tensor.
  1105. .. warning::
  1106. torch.norm is deprecated and may be removed in a future PyTorch release.
  1107. Its documentation and behavior may be incorrect, and it is no longer
  1108. actively maintained.
  1109. Use :func:`torch.linalg.norm`, instead, or :func:`torch.linalg.vector_norm`
  1110. when computing vector norms and :func:`torch.linalg.matrix_norm` when
  1111. computing matrix norms. Note, however, the signature for these functions
  1112. is slightly different than the signature for torch.norm.
  1113. Args:
  1114. input (Tensor): The input tensor. Its data type must be either a floating
  1115. point or complex type. For complex inputs, the norm is calculated using the
  1116. absolute value of each element. If the input is complex and neither
  1117. :attr:`dtype` nor :attr:`out` is specified, the result's data type will
  1118. be the corresponding floating point type (e.g. float if :attr:`input` is
  1119. complexfloat).
  1120. p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``
  1121. The following norms can be calculated:
  1122. ====== ============== ==========================
  1123. ord matrix norm vector norm
  1124. ====== ============== ==========================
  1125. 'fro' Frobenius norm --
  1126. 'nuc' nuclear norm --
  1127. Number -- sum(abs(x)**ord)**(1./ord)
  1128. ====== ============== ==========================
  1129. The vector norm can be calculated across any number of dimensions.
  1130. The corresponding dimensions of :attr:`input` are flattened into
  1131. one dimension, and the norm is calculated on the flattened
  1132. dimension.
  1133. Frobenius norm produces the same result as ``p=2`` in all cases
  1134. except when :attr:`dim` is a list of three or more dims, in which
  1135. case Frobenius norm throws an error.
  1136. Nuclear norm can only be calculated across exactly two dimensions.
  1137. dim (int, tuple of ints, list of ints, optional):
  1138. Specifies which dimension or dimensions of :attr:`input` to
  1139. calculate the norm across. If :attr:`dim` is ``None``, the norm will
  1140. be calculated across all dimensions of :attr:`input`. If the norm
  1141. type indicated by :attr:`p` does not support the specified number of
  1142. dimensions, an error will occur.
  1143. keepdim (bool, optional): whether the output tensors have :attr:`dim`
  1144. retained or not. Ignored if :attr:`dim` = ``None`` and
  1145. :attr:`out` = ``None``. Default: ``False``
  1146. out (Tensor, optional): the output tensor. Ignored if
  1147. :attr:`dim` = ``None`` and :attr:`out` = ``None``.
  1148. dtype (:class:`torch.dtype`, optional): the desired data type of
  1149. returned tensor. If specified, the input tensor is casted to
  1150. :attr:`dtype` while performing the operation. Default: None.
  1151. .. note::
  1152. Even though ``p='fro'`` supports any number of dimensions, the true
  1153. mathematical definition of Frobenius norm only applies to tensors with
  1154. exactly two dimensions. :func:`torch.linalg.norm` with ``ord='fro'`` aligns
  1155. with the mathematical definition, since it can only be applied across
  1156. exactly two dimensions.
  1157. Example::
  1158. >>> import torch
  1159. >>> a = torch.arange(9, dtype= torch.float) - 4
  1160. >>> b = a.reshape((3, 3))
  1161. >>> torch.norm(a)
  1162. tensor(7.7460)
  1163. >>> torch.norm(b)
  1164. tensor(7.7460)
  1165. >>> torch.norm(a, float('inf'))
  1166. tensor(4.)
  1167. >>> torch.norm(b, float('inf'))
  1168. tensor(4.)
  1169. >>> c = torch.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= torch.float)
  1170. >>> torch.norm(c, dim=0)
  1171. tensor([1.4142, 2.2361, 5.0000])
  1172. >>> torch.norm(c, dim=1)
  1173. tensor([3.7417, 4.2426])
  1174. >>> torch.norm(c, p=1, dim=1)
  1175. tensor([6., 6.])
  1176. >>> d = torch.arange(8, dtype= torch.float).reshape(2,2,2)
  1177. >>> torch.norm(d, dim=(1,2))
  1178. tensor([ 3.7417, 11.2250])
  1179. >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
  1180. (tensor(3.7417), tensor(11.2250))
  1181. """
  1182. if has_torch_function_unary(input):
  1183. return handle_torch_function(
  1184. norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
  1185. ndim = input.dim()
  1186. # catch default case
  1187. if dim is None and out is None and dtype is None and p is not None:
  1188. if isinstance(p, str):
  1189. if p == "fro":
  1190. return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
  1191. if not isinstance(p, str):
  1192. _dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
  1193. return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
  1194. # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
  1195. # remove the overloads where dim is an int and replace with BraodcastingList1
  1196. # and remove next four lines, replace _dim with dim
  1197. if dim is not None:
  1198. if isinstance(dim, int):
  1199. _dim = [dim]
  1200. else:
  1201. _dim = dim
  1202. else:
  1203. _dim = None # type: ignore[assignment]
  1204. if isinstance(p, str):
  1205. if p == "fro":
  1206. if dtype is not None:
  1207. raise ValueError("dtype argument is not supported in frobenius norm")
  1208. if _dim is None:
  1209. _dim = list(range(ndim))
  1210. if out is None:
  1211. return _VF.frobenius_norm(input, _dim, keepdim=keepdim)
  1212. else:
  1213. return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)
  1214. elif p == "nuc":
  1215. if dtype is not None:
  1216. raise ValueError("dtype argument is not supported in nuclear norm")
  1217. if _dim is None:
  1218. if out is None:
  1219. return _VF.nuclear_norm(input, keepdim=keepdim)
  1220. else:
  1221. return _VF.nuclear_norm(input, keepdim=keepdim, out=out)
  1222. else:
  1223. if out is None:
  1224. return _VF.nuclear_norm(input, _dim, keepdim=keepdim)
  1225. else:
  1226. return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)
  1227. raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
  1228. else:
  1229. if _dim is None:
  1230. _dim = list(range(ndim))
  1231. if out is None:
  1232. if dtype is None:
  1233. return _VF.norm(input, p, _dim, keepdim=keepdim) # type: ignore[attr-defined]
  1234. else:
  1235. return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype) # type: ignore[attr-defined]
  1236. else:
  1237. if dtype is None:
  1238. return _VF.norm(input, p, _dim, keepdim=keepdim, out=out) # type: ignore[attr-defined]
  1239. else:
  1240. return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined]
  1241. def chain_matmul(*matrices, out=None):
  1242. r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
  1243. using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
  1244. of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
  1245. needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
  1246. If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
  1247. .. warning::
  1248. :func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release.
  1249. Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors
  1250. rather than multiple arguments.
  1251. Args:
  1252. matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
  1253. out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``.
  1254. Returns:
  1255. Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
  1256. would be of dimensions :math:`p_{1} \times p_{N + 1}`.
  1257. Example::
  1258. >>> a = torch.randn(3, 4)
  1259. >>> b = torch.randn(4, 5)
  1260. >>> c = torch.randn(5, 6)
  1261. >>> d = torch.randn(6, 7)
  1262. >>> torch.chain_matmul(a, b, c, d)
  1263. tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614],
  1264. [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163],
  1265. [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]])
  1266. .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
  1267. """
  1268. # This wrapper exists to support variadic args.
  1269. if has_torch_function(matrices):
  1270. return handle_torch_function(chain_matmul, matrices, *matrices)
  1271. if out is None:
  1272. return _VF.chain_matmul(matrices) # type: ignore[attr-defined]
  1273. else:
  1274. return _VF.chain_matmul(matrices, out=out) # type: ignore[attr-defined]
  1275. def _lu_impl(A, pivot=True, get_infos=False, out=None):
  1276. # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]
  1277. r"""Computes the LU factorization of a matrix or batches of matrices
  1278. :attr:`A`. Returns a tuple containing the LU factorization and
  1279. pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to
  1280. ``True``.
  1281. .. note::
  1282. * The returned permutation matrix for every matrix in the batch is
  1283. represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.
  1284. ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm,
  1285. the ``i``-th row was permuted with the ``j-1``-th row.
  1286. * LU factorization with :attr:`pivot` = ``False`` is not available
  1287. for CPU, and attempting to do so will throw an error. However,
  1288. LU factorization with :attr:`pivot` = ``False`` is available for
  1289. CUDA.
  1290. * This function does not check if the factorization was successful
  1291. or not if :attr:`get_infos` is ``True`` since the status of the
  1292. factorization is present in the third element of the return tuple.
  1293. * In the case of batches of square matrices with size less or equal
  1294. to 32 on a CUDA device, the LU factorization is repeated for
  1295. singular matrices due to the bug in the MAGMA library
  1296. (see magma issue 13).
  1297. * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
  1298. .. warning::
  1299. The gradients of this function will only be finite when :attr:`A` is full rank.
  1300. This is because the LU decomposition is just differentiable at full rank matrices.
  1301. Furthermore, if :attr:`A` is close to not being full rank,
  1302. the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`.
  1303. Args:
  1304. A (Tensor): the tensor to factor of size :math:`(*, m, n)`
  1305. pivot (bool, optional): controls whether pivoting is done. Default: ``True``
  1306. get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
  1307. Default: ``False``
  1308. out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
  1309. then the elements in the tuple are Tensor, IntTensor,
  1310. and IntTensor. If :attr:`get_infos` is ``False``, then the
  1311. elements in the tuple are Tensor, IntTensor. Default: ``None``
  1312. Returns:
  1313. (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
  1314. - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`
  1315. - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`.
  1316. ``pivots`` stores all the intermediate transpositions of rows.
  1317. The final permutation ``perm`` could be reconstructed by
  1318. applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``,
  1319. where ``perm`` is initially the identity permutation of :math:`m` elements
  1320. (essentially this is what :func:`torch.lu_unpack` is doing).
  1321. - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
  1322. size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
  1323. each minibatch has succeeded or failed
  1324. Example::
  1325. >>> A = torch.randn(2, 3, 3)
  1326. >>> A_LU, pivots = torch.lu(A)
  1327. >>> A_LU
  1328. tensor([[[ 1.3506, 2.5558, -0.0816],
  1329. [ 0.1684, 1.1551, 0.1940],
  1330. [ 0.1193, 0.6189, -0.5497]],
  1331. [[ 0.4526, 1.2526, -0.3285],
  1332. [-0.7988, 0.7175, -0.9701],
  1333. [ 0.2634, -0.9255, -0.3459]]])
  1334. >>> pivots
  1335. tensor([[ 3, 3, 3],
  1336. [ 3, 3, 3]], dtype=torch.int32)
  1337. >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
  1338. >>> if info.nonzero().size(0) == 0:
  1339. ... print('LU factorization succeeded for all samples!')
  1340. LU factorization succeeded for all samples!
  1341. """
  1342. # If get_infos is True, then we don't need to check for errors and vice versa
  1343. return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
  1344. if TYPE_CHECKING:
  1345. _ListOrSeq = Sequence[Tensor]
  1346. else:
  1347. _ListOrSeq = List[Tensor]
  1348. def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
  1349. get_infos_int = 1 if get_infos else 0
  1350. if out_len - get_infos_int != 2:
  1351. raise TypeError(f"expected tuple of {2 + int(get_infos)} elements but got {out_len}")
  1352. if not isinstance(out, (tuple, list)):
  1353. raise TypeError(f"argument 'out' must be tuple of Tensors, not {type(out).__name__}")
  1354. def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
  1355. # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
  1356. if has_torch_function_unary(A):
  1357. return handle_torch_function(
  1358. lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
  1359. result = _lu_impl(A, pivot, get_infos, out)
  1360. if out is not None:
  1361. _check_list_size(len(out), get_infos, out)
  1362. for i in range(len(out)):
  1363. out[i].resize_as_(result[i]).copy_(result[i])
  1364. return out
  1365. else:
  1366. return result # A_LU, pivots, infos
  1367. def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
  1368. # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
  1369. # need to check for torch_function here so that we exit if
  1370. if has_torch_function_unary(A):
  1371. return handle_torch_function(
  1372. lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
  1373. result = _lu_impl(A, pivot, get_infos, out)
  1374. if out is not None:
  1375. _check_list_size(len(out), get_infos, out)
  1376. for i in range(len(out)):
  1377. out[i].resize_as_(result[i]).copy_(result[i])
  1378. return out
  1379. else:
  1380. return result[0], result[1] # A_LU, pivots
  1381. # The return type of lu depends on `get_infos`, so in order to resolve the output type
  1382. # of lu in TorchScript we need to statically know the value of `get_infos`
  1383. lu = boolean_dispatch(
  1384. arg_name='get_infos',
  1385. arg_index=2,
  1386. default=False,
  1387. if_true=_lu_with_infos,
  1388. if_false=_lu_no_infos,
  1389. module_name=__name__,
  1390. func_name='lu')
  1391. lu.__doc__ = _lu_impl.__doc__
  1392. def align_tensors(*tensors):
  1393. raise RuntimeError('`align_tensors` not yet implemented.')