overrides.py 99 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915
  1. """
  2. Python implementation of ``__torch_function__``
  3. While most of the torch API and handling for ``__torch_function__`` happens
  4. at the C++ level, some of the torch API is written in Python so we need
  5. python-level handling for ``__torch_function__`` overrides as well. The main
  6. developer-facing functionality in this file are handle_torch_function and
  7. has_torch_function. See torch/functional.py and test/test_overrides.py
  8. for usage examples.
  9. Note
  10. ----
  11. heavily inspired by NumPy's ``__array_function__`` (see:
  12. https://github.com/pytorch/pytorch/issues/24015 and
  13. https://www.numpy.org/neps/nep-0018-array-function-protocol.html
  14. )
  15. If changing this file in a way that can affect ``__torch_function__`` overhead,
  16. please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
  17. instructions in the ``README.md`` in that directory.
  18. """
  19. import __future__
  20. import collections
  21. import functools
  22. import types
  23. import warnings
  24. from typing import Dict, Set, List, Any, Callable, Iterable, Type, Iterator, Tuple
  25. import contextlib
  26. import torch
  27. from torch._C import (
  28. _has_torch_function, _has_torch_function_unary,
  29. _has_torch_function_variadic, _add_docstr, _set_torch_function_mode, _get_torch_function_mode)
  30. from torch.utils._mode_utils import _enable_mode, _push_mode, _ModeInfo, _wrap_init, MetaInitErrorInfo
  31. __all__ = [
  32. "get_ignored_functions",
  33. "get_overridable_functions",
  34. "get_testing_overrides",
  35. "handle_torch_function",
  36. "has_torch_function",
  37. "resolve_name",
  38. "is_tensor_like",
  39. "is_tensor_method_or_property",
  40. "wrap_torch_function",
  41. "enable_reentrant_dispatch",
  42. ]
  43. @functools.lru_cache(None)
  44. def get_ignored_functions() -> Set[Callable]:
  45. """
  46. Return public functions that cannot be overridden by ``__torch_function__``.
  47. Returns
  48. -------
  49. Set[Callable]
  50. A tuple of functions that are publicly available in the torch API but cannot
  51. be overridden with ``__torch_function__``. Mostly this is because none of the
  52. arguments of these functions are tensors or tensor-likes.
  53. Examples
  54. --------
  55. >>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
  56. True
  57. >>> torch.add in torch.overrides.get_ignored_functions()
  58. False
  59. """
  60. Tensor = torch.Tensor
  61. return {
  62. torch.typename,
  63. torch.is_tensor,
  64. torch.is_storage,
  65. torch.set_default_tensor_type,
  66. torch.set_rng_state,
  67. torch.get_rng_state,
  68. torch.manual_seed,
  69. torch.initial_seed,
  70. torch.seed,
  71. torch.save,
  72. torch.load,
  73. torch.set_printoptions,
  74. torch.fork,
  75. torch.get_default_dtype,
  76. torch.get_num_interop_threads,
  77. torch.get_num_threads,
  78. torch.init_num_threads,
  79. torch.import_ir_module,
  80. torch.import_ir_module_from_buffer,
  81. torch.is_anomaly_enabled,
  82. torch.is_grad_enabled,
  83. torch.merge_type_from_type_comment,
  84. torch.parse_ir,
  85. torch.parse_schema,
  86. torch.parse_type_comment,
  87. torch.set_anomaly_enabled,
  88. torch.set_flush_denormal,
  89. torch.set_num_interop_threads,
  90. torch.set_num_threads,
  91. torch.wait,
  92. torch.as_tensor,
  93. torch.from_numpy,
  94. torch.get_device,
  95. torch.tensor,
  96. torch.default_generator,
  97. torch.has_cuda,
  98. torch.has_cudnn,
  99. torch.has_lapack,
  100. torch.device,
  101. torch.dtype,
  102. torch.finfo,
  103. torch.has_mkl,
  104. torch.has_mps,
  105. torch.has_mkldnn,
  106. torch.has_openmp,
  107. torch.iinfo,
  108. torch.memory_format,
  109. torch.qscheme,
  110. torch.set_grad_enabled,
  111. torch.no_grad,
  112. torch.enable_grad,
  113. torch.inference_mode,
  114. torch.is_inference_mode_enabled,
  115. torch.layout,
  116. torch.align_tensors,
  117. torch.arange,
  118. torch.as_strided,
  119. torch.bartlett_window,
  120. torch.blackman_window,
  121. torch.broadcast_shapes,
  122. torch.can_cast,
  123. torch.cudnn_affine_grid_generator,
  124. torch.cudnn_batch_norm,
  125. torch.cudnn_convolution,
  126. torch.cudnn_convolution_transpose,
  127. torch.cudnn_convolution_relu,
  128. torch.cudnn_convolution_add_relu,
  129. torch.cudnn_grid_sampler,
  130. torch.cudnn_is_acceptable,
  131. torch.empty,
  132. torch.empty_strided,
  133. torch.empty_quantized,
  134. torch.eye,
  135. torch.fft.fftfreq,
  136. torch.fft.rfftfreq,
  137. torch.from_file,
  138. torch.full,
  139. torch.fill,
  140. torch.hamming_window,
  141. torch.hann_window,
  142. torch.kaiser_window,
  143. torch.linspace,
  144. torch.logspace,
  145. torch.mkldnn_adaptive_avg_pool2d,
  146. torch.mkldnn_convolution,
  147. torch.mkldnn_max_pool2d,
  148. torch.mkldnn_max_pool3d,
  149. torch.mkldnn_linear_backward_weights,
  150. torch.nested_tensor,
  151. torch.normal,
  152. torch.ones,
  153. torch.promote_types,
  154. torch.rand,
  155. torch.randn,
  156. torch.randint,
  157. torch.randperm,
  158. torch.range,
  159. torch.result_type,
  160. torch.scalar_tensor,
  161. torch.sparse_coo_tensor,
  162. torch.sparse_compressed_tensor,
  163. torch.sparse_csr_tensor,
  164. torch.sparse_csc_tensor,
  165. torch.sparse_bsr_tensor,
  166. torch.sparse_bsc_tensor,
  167. torch.tril_indices,
  168. torch.triu_indices,
  169. torch.vander,
  170. torch.zeros,
  171. torch._jit_internal.boolean_dispatch,
  172. torch.nn.functional.assert_int_or_pair,
  173. torch.nn.functional.upsample,
  174. torch.nn.functional.upsample_bilinear,
  175. torch.nn.functional.upsample_nearest,
  176. torch.nn.functional.has_torch_function,
  177. torch.nn.functional.has_torch_function_unary,
  178. torch.nn.functional.has_torch_function_variadic,
  179. torch.nn.functional.handle_torch_function,
  180. torch.nn.functional.sigmoid,
  181. torch.nn.functional.hardsigmoid,
  182. torch.nn.functional.tanh,
  183. # Doesn't actually take or return tensor arguments
  184. torch.nn.init.calculate_gain,
  185. # These are deprecated; don't test them
  186. torch.nn.init.uniform,
  187. torch.nn.init.normal,
  188. torch.nn.init.constant,
  189. torch.nn.init.eye,
  190. torch.nn.init.dirac,
  191. torch.nn.init.xavier_uniform,
  192. torch.nn.init.xavier_normal,
  193. torch.nn.init.kaiming_uniform,
  194. torch.nn.init.kaiming_normal,
  195. torch.nn.init.orthogonal,
  196. torch.nn.init.sparse,
  197. has_torch_function,
  198. handle_torch_function,
  199. torch.set_autocast_enabled,
  200. torch.is_autocast_enabled,
  201. torch.clear_autocast_cache,
  202. torch.set_autocast_cpu_enabled,
  203. torch.is_autocast_cpu_enabled,
  204. torch.set_autocast_cpu_dtype,
  205. torch.get_autocast_cpu_dtype,
  206. torch.get_autocast_gpu_dtype,
  207. torch.set_autocast_gpu_dtype,
  208. torch.autocast_increment_nesting,
  209. torch.autocast_decrement_nesting,
  210. torch.is_autocast_cache_enabled,
  211. torch.set_autocast_cache_enabled,
  212. torch.nn.functional.hardswish,
  213. torch.is_vulkan_available,
  214. torch.are_deterministic_algorithms_enabled,
  215. torch.use_deterministic_algorithms,
  216. torch.is_deterministic_algorithms_warn_only_enabled,
  217. torch.set_deterministic_debug_mode,
  218. torch.get_deterministic_debug_mode,
  219. torch.set_float32_matmul_precision,
  220. torch.get_float32_matmul_precision,
  221. torch.unify_type_list,
  222. torch.is_warn_always_enabled,
  223. torch.set_warn_always,
  224. torch.vitals_enabled,
  225. torch.set_vital,
  226. torch.read_vitals,
  227. torch.frombuffer,
  228. torch.asarray,
  229. Tensor.__delitem__,
  230. Tensor.__dir__,
  231. Tensor.__getattribute__,
  232. Tensor.__init__,
  233. Tensor.__iter__,
  234. Tensor.__init_subclass__,
  235. Tensor.__delattr__,
  236. Tensor.__setattr__,
  237. Tensor.__torch_function__,
  238. Tensor.__torch_dispatch__,
  239. Tensor.__new__,
  240. Tensor.__class__,
  241. Tensor.__subclasshook__,
  242. Tensor.as_subclass,
  243. Tensor.reinforce,
  244. Tensor.new,
  245. Tensor.new_tensor,
  246. Tensor.new_empty,
  247. Tensor.new_empty_strided,
  248. Tensor.new_zeros,
  249. Tensor.new_ones,
  250. Tensor.new_full,
  251. Tensor._make_subclass,
  252. Tensor.solve,
  253. Tensor.stride,
  254. Tensor.unflatten,
  255. Tensor.to_sparse_coo,
  256. Tensor.to_sparse_csr,
  257. Tensor.to_sparse_csc,
  258. Tensor.to_sparse_bsr,
  259. Tensor.to_sparse_bsc,
  260. Tensor._reduce_ex_internal,
  261. Tensor._fix_weakref,
  262. Tensor._make_wrapper_subclass,
  263. Tensor._python_dispatch.__get__,
  264. Tensor._conj,
  265. Tensor._conj_physical,
  266. Tensor._neg_view,
  267. Tensor._is_zerotensor,
  268. Tensor._addmm_activation,
  269. Tensor._nested_tensor_layer_norm,
  270. Tensor.to_padded_tensor,
  271. }
  272. @functools.lru_cache(None)
  273. def get_default_nowrap_functions() -> Set[Callable]:
  274. """
  275. Return public functions that do not wrap in a subclass when invoked by
  276. the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
  277. these functions represent field accesses (i.e., retrieving a Tensor that
  278. is stored somewhere on the Tensor) as opposed to computation. Users of
  279. these functions expect object identity to be preserved over multiple accesses
  280. (e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
  281. the fly every time (furthermore, the tensor stored here might already be
  282. the subclass, in which case wrapping really ought not to happen).
  283. Not ALL property accessors have this property; for example ``Tensor.T`` actually
  284. just creates a new transposed tensor on the fly, and so we SHOULD interpose on
  285. these calls (you need to check the implementation of the function to see if
  286. this is the case or not). Additionally, if a property accessor doesn't return a Tensor,
  287. it doesn't have to be on this list (though it is harmless if it is).
  288. """
  289. Tensor = torch.Tensor
  290. return {
  291. Tensor._base.__get__,
  292. Tensor.grad.__get__,
  293. Tensor._grad.__get__,
  294. }
  295. @functools.lru_cache(None)
  296. def get_testing_overrides() -> Dict[Callable, Callable]:
  297. """Return a dict containing dummy overrides for all overridable functions
  298. Returns
  299. -------
  300. Dict[Callable, Callable]
  301. A dictionary that maps overridable functions in the PyTorch API to
  302. lambda functions that have the same signature as the real function
  303. and unconditionally return -1. These lambda functions are useful
  304. for testing API coverage for a type that defines ``__torch_function__``.
  305. Examples
  306. --------
  307. >>> import inspect
  308. >>> my_add = torch.overrides.get_testing_overrides()[torch.add]
  309. >>> inspect.signature(my_add)
  310. <Signature (input, other, out=None)>
  311. """
  312. # Every function in the PyTorchAPI that can be overriden needs an entry
  313. # in this dict.
  314. #
  315. # Optimally we would use inspect to get the function signature and define
  316. # the lambda function procedurally but that is blocked by generating
  317. # function signatures for native kernels that can be consumed by inspect.
  318. # See Issue #28233.
  319. Tensor = torch.Tensor
  320. ret: Dict[Callable, Callable] = {
  321. torch.abs: lambda input, out=None: -1,
  322. torch.absolute: lambda input, out=None: -1,
  323. torch.adaptive_avg_pool1d: lambda input, output_size: -1,
  324. torch.adaptive_max_pool1d: lambda inputs, output_size: -1,
  325. torch.acos: lambda input, out=None: -1,
  326. torch.adjoint: lambda input: -1,
  327. torch.arccos: lambda input, out=None: -1,
  328. torch.acosh: lambda input, out=None: -1,
  329. torch.arccosh: lambda input, out=None: -1,
  330. torch.add: lambda input, other, out=None: -1,
  331. torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
  332. torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,
  333. torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1,
  334. torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
  335. torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1,
  336. torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1,
  337. torch.affine_grid_generator: lambda theta, size, align_corners: -1,
  338. torch.all: lambda input, dim=None: -1,
  339. torch.allclose: lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1,
  340. torch.alpha_dropout: lambda input, p, train, inplace=False: -1,
  341. torch.amax: lambda input, dim=None: -1,
  342. torch.amin: lambda input, dim=None: -1,
  343. torch.aminmax: lambda input, dim=None, keepdim=False, out=None: -1,
  344. torch.angle: lambda input, out=None: -1,
  345. torch.any: lambda input, dim=None, keepdim=False, out=None: -1,
  346. torch.argmax: lambda input: -1,
  347. torch.argmin: lambda input: -1,
  348. torch.argsort: lambda input, dim=None: -1,
  349. torch.asin: lambda input, out=None: -1,
  350. torch._assert_async: lambda input: -1,
  351. torch.arcsin: lambda input, out=None: -1,
  352. torch.asinh: lambda input, out=None: -1,
  353. torch.arcsinh: lambda input, out=None: -1,
  354. torch.atan: lambda input, out=None: -1,
  355. torch.arctan: lambda input, out=None: -1,
  356. torch.atan2: lambda input, other, out=None: -1,
  357. torch.arctan2: lambda input, other, out=None: -1,
  358. torch.atanh: lambda input, out=None: -1,
  359. torch.arctanh: lambda input, out=None: -1,
  360. torch.atleast_1d: lambda *tensors: -1,
  361. torch.atleast_2d: lambda *tensors: -1,
  362. torch.atleast_3d: lambda *tensors: -1,
  363. torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1,
  364. torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
  365. torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1,
  366. torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor: -1,
  367. torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1,
  368. torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1,
  369. torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
  370. torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
  371. torch.batch_norm_stats: lambda input, eps: -1,
  372. torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
  373. torch.bernoulli: lambda input, generator=None, out=None: -1,
  374. torch.bilinear: lambda input1, input2, weight, bias: -1,
  375. torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None,
  376. reduction='mean', pos_weight=None: -1),
  377. torch.bincount: lambda input, weights=None, minlength=0: -1,
  378. torch.binomial: lambda count, prob, generator=None: -1,
  379. torch.bitwise_and: lambda input, other, out=None: -1,
  380. torch.bitwise_not: lambda input, out=None: -1,
  381. torch.bitwise_or: lambda input, other, out=None: -1,
  382. torch.bitwise_xor: lambda input, other, out=None: -1,
  383. torch.bitwise_left_shift: lambda input, other, out=None: -1,
  384. torch.bitwise_right_shift: lambda input, other, out=None: -1,
  385. torch.block_diag: lambda *tensors: -1,
  386. torch.bmm: lambda input, mat2, out=None: -1,
  387. torch.broadcast_tensors: lambda *tensors: -1,
  388. torch.broadcast_to: lambda self, size: -1,
  389. torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1,
  390. torch.cartesian_prod: lambda *tensors: -1,
  391. torch.cat: lambda tensors, dim=0, out=None: -1,
  392. torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
  393. torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
  394. torch.ceil: lambda input, out=None: -1,
  395. torch.celu: lambda input, alhpa=1., inplace=False: -1,
  396. torch.chain_matmul: lambda *matrices, out=None: -1,
  397. torch.channel_shuffle: lambda input, groups : -1,
  398. torch.cholesky: lambda input, upper=False, out=None: -1,
  399. torch.linalg.cholesky: lambda input, out=None: -1,
  400. torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
  401. torch.cholesky_inverse: lambda input, upper=False, out=None: -1,
  402. torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1,
  403. torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1,
  404. torch.chunk: lambda input, chunks, dim=0: -1,
  405. torch.clamp: lambda input, min=None, max=None, out=None: -1,
  406. torch.clip: lambda input, min=None, max=None, out=None: -1,
  407. torch.clamp_min: lambda input, min, out=None: -1,
  408. torch.clamp_max: lambda input, max, out=None: -1,
  409. torch.column_stack: lambda tensors, out=None: -1,
  410. torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1,
  411. torch.clone: lambda input: -1,
  412. torch.combinations: lambda input, r=2, with_replacement=False: -1,
  413. torch.complex: lambda real, imag: -1,
  414. torch.copysign: lambda input, other, out=None: -1,
  415. torch.polar: lambda abs, ang: -1,
  416. torch.linalg.cond: lambda input, ord=None: -1,
  417. torch.conj: lambda input, out=None: -1,
  418. torch.conj_physical: lambda input, out=None: -1,
  419. torch.resolve_conj: lambda input, out=None: -1,
  420. torch.resolve_neg: lambda input, out=None: -1,
  421. torch.constant_pad_nd: lambda input, pad, value=0: -1,
  422. torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
  423. torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
  424. torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
  425. torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1,
  426. torch.conv_tbc: lambda input, weight, bias, pad=0: -1,
  427. torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
  428. torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
  429. torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
  430. torch.corrcoef: lambda input: -1,
  431. torch.cos: lambda input, out=None: -1,
  432. torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
  433. torch.cosh: lambda input, out=None: -1,
  434. torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
  435. torch.count_nonzero: lambda input: -1,
  436. torch.cross: lambda input, other, dim=None, out=None: -1,
  437. torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
  438. torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean',
  439. zero_infinity=False: -1),
  440. torch.cummax: lambda input, dim, out=None: -1,
  441. torch.cummin: lambda input, dim, out=None: -1,
  442. torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
  443. torch.cumsum: lambda input, dim, out=None, dtype=None: -1,
  444. torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1,
  445. torch.logcumsumexp: lambda input, dim, out=None: -1,
  446. torch.deg2rad: lambda input, out=None: -1,
  447. torch.dequantize: lambda input: -1,
  448. torch.det: lambda input: -1,
  449. torch.linalg.det: lambda input: -1, # alias for torch.det # type: ignore[attr-defined]
  450. torch.detach: lambda input: -1,
  451. torch.diag: lambda input, diagonal=0, out=None: -1,
  452. torch.diag_embed: lambda input, diagonal=0, out=None: -1,
  453. torch.diagflat: lambda input, offset=0: -1,
  454. torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
  455. torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
  456. torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
  457. torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
  458. torch.digamma: lambda input, out=None: -1,
  459. torch.dist: lambda input, other, p=2: -1,
  460. torch.div: lambda input, other, rounding_mode=None, out=None: -1,
  461. torch.divide: lambda input, other, rounding_mode=None, out=None: -1,
  462. torch.dot: lambda input, other, out=None: -1,
  463. torch.dropout: lambda input, p, train, inplace=False: -1,
  464. torch.dsmm: lambda input, mat2: -1,
  465. torch.hsmm: lambda mat1, mat2: -1,
  466. torch.dsplit: lambda input, indices_or_sections: -1,
  467. torch.dstack: lambda tensors, out=None: -1,
  468. torch.eig: lambda input, eigenvectors=False, out=None: -1,
  469. torch.linalg.eig: lambda input, out=None: -1,
  470. torch.linalg.eigvals: lambda input, out=None: -1,
  471. torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
  472. torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
  473. torch.einsum: lambda equation, *operands: -1,
  474. torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
  475. sparse=False: -1),
  476. torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
  477. mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1),
  478. torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  479. torch.eq: lambda input, other, out=None: -1,
  480. torch.equal: lambda input, other: -1,
  481. torch.erf: lambda input, out=None: -1,
  482. torch.erfc: lambda input, out=None: -1,
  483. torch.erfinv: lambda input, out=None: -1,
  484. torch.exp: lambda input, out=None: -1,
  485. torch.exp2: lambda input, out=None: -1,
  486. torch.expm1: lambda input, out=None: -1,
  487. torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
  488. torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
  489. torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
  490. running_max, scale, zero_point, quant_min, quant_max, ch_axis,
  491. per_row_fake_quant=False, symmetric_quant=False: -1),
  492. torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
  493. torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
  494. torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
  495. torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
  496. weight_zero_point, bias: -1),
  497. torch.fbgemm_linear_quantize_weight: lambda input: -1,
  498. torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
  499. torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
  500. torch.feature_alpha_dropout: lambda input, p, train: -1,
  501. torch.feature_dropout: lambda input, p, train: -1,
  502. torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
  503. torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1,
  504. torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1,
  505. torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1,
  506. torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1,
  507. torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1,
  508. torch.fft.hfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  509. torch.fft.ihfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  510. torch.fft.hfftn: lambda input, s=None, dim=-1, norm=None: -1,
  511. torch.fft.ihfftn: lambda input, s=None, dim=-1, norm=None: -1,
  512. torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1,
  513. torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1,
  514. torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1,
  515. torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1,
  516. torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  517. torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  518. torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  519. torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  520. torch.fft.fftshift: lambda input, dim=None: -1,
  521. torch.fft.ifftshift: lambda input, dim=None: -1,
  522. torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
  523. torch.fix: lambda input, out=None: -1,
  524. torch.flatten: lambda input, start_dim=0, end_dim=-1: -1,
  525. torch.flip: lambda input, dims: -1,
  526. torch.fliplr: lambda input: -1,
  527. torch.flipud: lambda input: -1,
  528. torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
  529. torch.floor: lambda input, out=None: -1,
  530. torch.floor_divide: lambda input, other: -1,
  531. torch.float_power: lambda input, exponent, out=None: -1,
  532. torch.fmod: lambda input, other, out=None: -1,
  533. torch.frac: lambda input, out=None: -1,
  534. torch.frexp: lambda input, out=None: -1,
  535. torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
  536. torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
  537. torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
  538. torch.gcd: lambda input, other, out=None: -1,
  539. torch.ge: lambda input, other, out=None: -1,
  540. torch.greater_equal: lambda input, other, out=None: -1,
  541. torch.geqrf: lambda input, out=None: -1,
  542. torch.i0: lambda input, out=None: -1,
  543. torch.inner: lambda input, other, out=None: -1,
  544. torch.outer: lambda input, vec2, out=None: -1,
  545. torch.ger: lambda input, vec2, out=None: -1, # alias for torch.outer
  546. torch.gradient: lambda input, spacing=None, dim=None, edge_order=1: -1,
  547. torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
  548. torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
  549. torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
  550. torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1,
  551. torch.gru: lambda input, hx, params, has_biases, num_layers, gropout, train, bidirectional, batch_first: -1,
  552. torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  553. torch.gt: lambda input, other, out=None: -1,
  554. torch.greater: lambda input, other, out=None: -1,
  555. torch.hardshrink: lambda input, lambd=0.5: -1,
  556. torch.heaviside: lambda input, values, out=None: -1,
  557. torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
  558. torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
  559. torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
  560. torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
  561. torch.linalg.householder_product: lambda input, tau: -1,
  562. torch.hspmm: lambda mat1, mat2, out=None: -1,
  563. torch.hsplit: lambda input, indices_or_sections: -1,
  564. torch.hstack: lambda tensors, out=None: -1,
  565. torch.hypot: lambda input, other, out=None: -1,
  566. torch.igamma: lambda input, other, out=None: -1,
  567. torch.igammac: lambda input, other, out=None: -1,
  568. torch.imag: lambda input, out=None: -1,
  569. torch.index_add: lambda input, dim, index, source: -1,
  570. torch.index_copy: lambda input, dim, index, source: -1,
  571. torch.index_put: lambda input, indices, values, accumulate=False: -1,
  572. torch.index_select: lambda input, dim, index, out=None: -1,
  573. torch.index_fill: lambda input, dim, index, value: -1,
  574. torch.index_reduce: lambda input, dim, index, source, reduce, include_input=True: -1,
  575. torch.isfinite: lambda tensor: -1,
  576. torch.isin: lambda e, te, assume_unique=False, invert=False: -1,
  577. torch.isinf: lambda tensor: -1,
  578. torch.isreal: lambda tensor: -1,
  579. torch.isposinf: lambda input, out=None: -1,
  580. torch.isneginf: lambda input, out=None: -1,
  581. torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
  582. cudnn_enabled: -1),
  583. torch.int_repr: lambda input: -1,
  584. torch.inverse: lambda input, out=None: -1,
  585. torch.linalg.inv: lambda input, out=None: -1,
  586. torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
  587. torch.is_complex: lambda input: -1,
  588. torch.is_conj: lambda input: -1,
  589. torch.is_neg: lambda input: -1,
  590. torch.is_distributed: lambda input: -1,
  591. torch.is_inference: lambda input: -1,
  592. torch.is_floating_point: lambda input: -1,
  593. torch.is_nonzero: lambda input: -1,
  594. torch.is_same_size: lambda input, other: -1,
  595. torch.is_signed: lambda input: -1,
  596. torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
  597. torch.isnan: lambda input: -1,
  598. torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
  599. normalized=False, onesided=None, length=None, return_complex=False: -1),
  600. torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
  601. torch.kron: lambda input, other: -1,
  602. torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
  603. torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
  604. torch.linalg.ldl_factor: lambda input, hermitian=False, out=None: -1,
  605. torch.linalg.ldl_solve: lambda LD, pivots, B, hermitian=False, out=None: -1,
  606. torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1,
  607. torch.lcm: lambda input, other, out=None: -1,
  608. torch.ldexp: lambda input, other, out=None: -1,
  609. torch.le: lambda input, other, out=None: -1,
  610. torch.less_equal: lambda input, other, out=None: -1,
  611. torch.lerp: lambda input, end, weight, out=None: -1,
  612. torch.lgamma: lambda input, out=None: -1,
  613. torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None,
  614. tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
  615. torch.log: lambda input, out=None: -1,
  616. torch.log_softmax: lambda input, dim, dtype=None: -1,
  617. torch.log10: lambda input, out=None: -1,
  618. torch.log1p: lambda input, out=None: -1,
  619. torch.log2: lambda input, out=None: -1,
  620. torch.logaddexp: lambda input, other, out=None: -1,
  621. torch.logaddexp2: lambda input, other, out=None: -1,
  622. torch.logdet: lambda input: -1,
  623. torch.xlogy: lambda x, y, out=None: -1,
  624. torch.logical_and: lambda input, other, out=None: -1,
  625. torch.logical_not: lambda input, out=None: -1,
  626. torch.logical_or: lambda input, other, out=None: -1,
  627. torch.logical_xor: lambda input, other, out=None: -1,
  628. torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
  629. torch.logit: lambda input, eps=None: -1,
  630. torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
  631. torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1,
  632. torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  633. torch.lstsq: lambda input, A, out=None: -1,
  634. torch.lt: lambda input, other, out=None: -1,
  635. torch.less: lambda input, other, out=None: -1,
  636. torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
  637. torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
  638. torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, # type: ignore[attr-defined] # noqa: B950
  639. torch.masked_fill: lambda input, mask, value: -1,
  640. torch.masked_scatter: lambda input, mask, source: -1,
  641. torch.masked_select: lambda input, mask, out=None: -1,
  642. torch.matmul: lambda input, other, out=None: -1,
  643. torch.linalg.lu: lambda input, pivot=True, out=None: -1,
  644. torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1,
  645. torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1,
  646. torch.linalg.matmul: lambda input, other, out=None: -1, # alias for torch.matmul
  647. torch.matrix_power: lambda input, n: -1,
  648. torch.linalg.matrix_power: lambda input, n, out=None: -1,
  649. torch.matrix_rank: lambda input, tol=None, symmetric=False: -1,
  650. torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
  651. torch.linalg.multi_dot: lambda tensors, out=None: -1,
  652. torch.matrix_exp: lambda input: -1,
  653. torch.linalg.matrix_exp: lambda input: -1,
  654. torch.max: lambda input, out=None: -1,
  655. torch.maximum: lambda input, other, out=None: -1,
  656. torch.fmax: lambda input, other, out=None: -1,
  657. torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
  658. torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
  659. torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
  660. torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
  661. return_indices=False, ceil_mode=False: -1),
  662. torch.mean: lambda input, dim=None: -1,
  663. torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
  664. torch.median: lambda input, dim=None: -1,
  665. torch.nanmedian: lambda input, dim=None: -1,
  666. torch.meshgrid: lambda *tensors, **kwargs: -1,
  667. torch.min: lambda input, out=None: -1,
  668. torch.minimum: lambda input, other, out=None: -1,
  669. torch.fmin: lambda input, other, out=None: -1,
  670. torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training,
  671. exponential_average_factor, epsilon: -1),
  672. torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,
  673. torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation,
  674. groups, benchmark, deterministic: -1),
  675. torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark,
  676. deterministic: -1),
  677. torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first,
  678. dropout, train, bidirectional, batch_sizes, dropout_state: -1),
  679. torch.mm: lambda input, mat2, out=None: -1,
  680. torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
  681. torch.movedim: lambda input, source, destination: -1,
  682. torch.moveaxis: lambda input, source, destination: -1,
  683. torch.msort: lambda input, descending=False, out=None: -1,
  684. torch.mul: lambda input, other, out=None: -1,
  685. torch.multiply: lambda input, other, out=None: -1,
  686. torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
  687. torch.mv: lambda input, vec, out=None: -1,
  688. torch.mvlgamma: lambda input, p: -1,
  689. torch.narrow: lambda input, dim, start, length: -1,
  690. torch.narrow_copy: lambda input, dim, start, length: -1,
  691. torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
  692. torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
  693. torch.native_dropout: lambda input, p, train: -1,
  694. torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
  695. torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
  696. torch.native_norm: lambda input, p=2: -1,
  697. torch.native_norm: lambda input, p=2: -1,
  698. torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
  699. torch.native_channel_shuffle: lambda input, groups : -1,
  700. torch.ne: lambda input, other, out=None: -1,
  701. torch.not_equal: lambda input, other, out=None: -1,
  702. torch.neg: lambda input, out=None: -1,
  703. torch.negative: lambda input, out=None: -1,
  704. torch.nextafter: lambda input, other, out=None: -1,
  705. torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1,
  706. torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1,
  707. torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1,
  708. torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1,
  709. torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1,
  710. torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1,
  711. torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1,
  712. torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
  713. torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
  714. torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
  715. torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
  716. count_include_pad=True, divisor_override=None: -1),
  717. torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
  718. count_include_pad=True, divisor_override=None: -1),
  719. torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False,
  720. momentum=0.1, eps=1e-05: -1),
  721. torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
  722. torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None,
  723. reduction="mean": -1),
  724. torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None,
  725. reduce=None, reduction="mean", pos_weight=None: -1),
  726. torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
  727. torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None,
  728. reduce=None, reduction='mean': -1),
  729. torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
  730. reduce=None, reduction="mean", label_smoothing=0.0: -1),
  731. torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
  732. reduction='mean', zero_infinity=False: -1),
  733. torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
  734. torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
  735. torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
  736. torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
  737. torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
  738. torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0,
  739. scale_grad_by_freq=False, sparse=False: -1),
  740. torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2,
  741. scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None,
  742. include_last_offset=False, padding_idx=None: -1),
  743. torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
  744. torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
  745. torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None,
  746. return_indices=False, _random_samples=None: -1),
  747. torch.nn.functional.fractional_max_pool2d_with_indices: (
  748. lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
  749. _random_samples=None: -1),
  750. torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None,
  751. return_indices=False, _random_samples=None: -1),
  752. torch.nn.functional.fractional_max_pool3d_with_indices: (
  753. lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
  754. _random_samples=None: -1),
  755. torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
  756. torch.nn.functional.gelu: lambda input, approximate='none': -1,
  757. torch.nn.functional.glu: lambda input, dim=-1: -1,
  758. torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
  759. torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
  760. torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
  761. torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
  762. torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1,
  763. torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None,
  764. reduction='mean': -1),
  765. torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None,
  766. use_input_stats=True, momentum=0.1, eps=1e-05: -1),
  767. torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None,
  768. recompute_scale_factor=None, antialias=False: -1),
  769. torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
  770. torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
  771. torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
  772. torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
  773. torch.nn.functional.linear: lambda input, weight, bias=None: -1,
  774. torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1,
  775. torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
  776. torch.nn.functional.logsigmoid: lambda input: -1,
  777. torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
  778. torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
  779. torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
  780. reduce=None, reduction='mean': -1),
  781. torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
  782. ceil_mode=False, return_indices=False: -1),
  783. torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
  784. return_indices=False, ceil_mode=False: -1),
  785. torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
  786. ceil_mode=False, return_indices=False: -1),
  787. torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
  788. return_indices=False, ceil_mode=False: -1),
  789. torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
  790. return_indices=False, ceil_mode=False: -1),
  791. torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
  792. return_indices=False, ceil_mode=False: -1),
  793. torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
  794. torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
  795. torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
  796. torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
  797. torch.nn.functional.multi_head_attention_forward: (
  798. lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
  799. add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
  800. need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
  801. v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None: -1),
  802. torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
  803. reduce=None, reduction='mean': -1),
  804. torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
  805. reduction='mean': -1),
  806. torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None,
  807. reduce=None, reduction='mean': -1),
  808. torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
  809. reduce=None, reduction='mean': -1),
  810. torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
  811. torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
  812. torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1,
  813. torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
  814. torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None,
  815. eps=1e-08, reduce=None, reduction='mean': -1),
  816. torch.nn.functional.prelu: lambda input, weight: -1,
  817. torch.nn.functional.relu: lambda input, inplace=False: -1,
  818. torch.nn.functional.relu6: lambda input, inplace=False: -1,
  819. torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
  820. torch.nn.functional.selu: lambda input, inplace=False: -1,
  821. torch.nn.functional.silu: lambda input, inplace=False: -1,
  822. torch.nn.functional.mish: lambda input, inplace=False: -1,
  823. torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
  824. torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
  825. torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
  826. torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
  827. torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
  828. torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
  829. torch.nn.functional.softshrink: lambda input, lambd=0.5: -1,
  830. torch.nn.functional.softsign: lambda input: -1,
  831. torch.nn.functional.tanhshrink: lambda input: -1,
  832. torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
  833. torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
  834. swap=False, size_average=None, reduce=None, reduction='mean': -1),
  835. torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
  836. distance_function=None, margin=1.0,
  837. swap=False, reduction='mean': -1),
  838. torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
  839. torch.nn.init.uniform_: lambda tensor, a=0., b=1.: -1,
  840. torch.nn.init.constant_: lambda tensor, val: -1,
  841. torch.nn.init.normal_: lambda tensor, mean=0., std=1.: -1,
  842. torch.nn.init.constant_: lambda tensor, val: -1,
  843. torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu': -1,
  844. torch.nonzero: lambda input, as_tuple=False: -1,
  845. torch.argwhere: lambda input: -1,
  846. torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
  847. torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
  848. torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
  849. torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1,
  850. torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
  851. torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
  852. torch.numel: lambda input: -1,
  853. torch.orgqr: lambda input, tau: -1,
  854. torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
  855. torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
  856. torch.permute: lambda self, dim: -1,
  857. torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1,
  858. torch.pdist: lambda input, p=2: -1,
  859. torch.pinverse: lambda input, rcond=1e-15: -1,
  860. torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1,
  861. torch.pixel_shuffle: lambda input, upscale_factor: -1,
  862. torch.pixel_unshuffle: lambda input, downscale_factor: -1,
  863. torch.poisson: lambda input, generator=None: -1,
  864. torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1,
  865. torch.polygamma: lambda input, n, out=None: -1,
  866. torch.positive: lambda input, out=None: -1,
  867. torch.prelu: lambda input, weight: -1,
  868. torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  869. torch.pow: lambda input, exponent, out=None: -1,
  870. torch.prod: lambda input, dtype=None: -1,
  871. torch.put: lambda input, index, source, accumulate=False: -1,
  872. torch.q_per_channel_axis: lambda input: -1,
  873. torch.q_per_channel_scales: lambda input: -1,
  874. torch.q_per_channel_zero_points: lambda input: -1,
  875. torch.q_scale: lambda input: -1,
  876. torch.q_zero_point: lambda input: -1,
  877. torch.qr: lambda input, some=True, out=None: -1,
  878. torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
  879. torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
  880. torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
  881. torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
  882. torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
  883. torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
  884. torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
  885. torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
  886. col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
  887. torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
  888. col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
  889. torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
  890. dilation=(1,), ceil_mode=False: -1),
  891. torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
  892. dilation=(1, 1), ceil_mode=False: -1),
  893. torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
  894. col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
  895. torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
  896. col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
  897. torch.rad2deg: lambda input, out=None: -1,
  898. torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  899. torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
  900. torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  901. torch.ravel: lambda input: -1,
  902. torch.real: lambda input, out=None: -1,
  903. torch.vdot: lambda input, other, out=None: -1,
  904. torch.view_as_real: lambda input: -1,
  905. torch.view_as_complex: lambda input: -1,
  906. torch.reciprocal: lambda input, out=None: -1,
  907. torch.relu: lambda input, inplace=False: -1,
  908. torch.remainder: lambda input, other, out=None: -1,
  909. torch.renorm: lambda input, p, dim, maxnorm, out=None: -1,
  910. torch.repeat_interleave: lambda input, dim=None: -1,
  911. torch.reshape: lambda input, shape: -1,
  912. torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
  913. torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  914. torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
  915. torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  916. torch.roll: lambda input, shifts, dims=None: -1,
  917. torch.rot90: lambda input, k=1, dims=(0, 1): -1,
  918. torch.round: lambda input, out=None: -1,
  919. torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
  920. torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
  921. torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
  922. torch.rsqrt: lambda input, out=None: -1,
  923. torch.rsub: lambda input, other, alpha=1: -1,
  924. torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
  925. torch.scatter: lambda input, dim, index, src: -1,
  926. torch.scatter_add: lambda input, dim, index, src: -1,
  927. torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
  928. torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
  929. torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
  930. torch.select: lambda input, dim, index: -1,
  931. torch.select_scatter: lambda input, src, dim, index: -1,
  932. torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
  933. torch.selu: lambda input, inplace=False: -1,
  934. torch.sigmoid: lambda input, out=None: -1,
  935. torch.sign: lambda input, out=None: -1,
  936. torch.signbit: lambda input, out=None: -1,
  937. torch.sgn: lambda input, out=None: -1,
  938. torch.sin: lambda input, out=None: -1,
  939. torch.sinc: lambda input, out=None: -1,
  940. torch.sinh: lambda input, out=None: -1,
  941. torch.slogdet: lambda input: -1,
  942. torch.linalg.slogdet: lambda input: -1,
  943. torch.smm: lambda input, mat2: -1,
  944. torch.spmm: lambda input, mat2: -1,
  945. torch.softmax: lambda input, dim, dtype=None: -1,
  946. torch.linalg.solve: lambda input, other, out=None: -1,
  947. torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1,
  948. torch.split: lambda tensor, split_size_or_sections, dim=0: -1,
  949. torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
  950. torch.sqrt: lambda input, out=None: -1,
  951. torch.square: lambda input, out=None: -1,
  952. torch.squeeze: lambda input, dim=None, out=None: -1,
  953. torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
  954. torch.stack: lambda tensors, dim=0, out=None: -1,
  955. torch.std: lambda input, dim=None: -1,
  956. torch.std_mean: lambda input, dim=None: -1,
  957. torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
  958. pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1),
  959. torch.sub: lambda input, other, out=None: -1,
  960. torch.subtract: lambda input, other, out=None: -1,
  961. torch.sum: lambda input, dim=None: -1,
  962. torch.nansum: lambda input, dim=None: -1,
  963. torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,
  964. torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,
  965. torch.linalg.svd: lambda input, full_matrices=True, out=None: -1,
  966. torch.linalg.svdvals: lambda input, out=None: -1,
  967. torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1,
  968. torch.swapaxes: lambda input, dim0, dim1: -1,
  969. torch.swapdims: lambda input, axis0, axis1: -1,
  970. torch.special.entr: lambda input: -1,
  971. torch.special.erf: lambda input: -1,
  972. torch.special.erfc: lambda input: -1,
  973. torch.special.erfcx: lambda input: -1,
  974. torch.special.erfinv: lambda input: -1,
  975. torch.special.exp2: lambda input: -1,
  976. torch.special.expm1: lambda input: -1,
  977. torch.special.expit: lambda input: -1,
  978. torch.special.polygamma: lambda input, n, out=None: -1,
  979. torch.special.digamma: lambda input: -1,
  980. torch.special.psi: lambda input: -1,
  981. torch.special.gammainc: lambda input, other, out=None: -1,
  982. torch.special.gammaincc: lambda input, other, out=None: -1,
  983. torch.special.gammaln: lambda input: -1,
  984. torch.special.i0: lambda input: -1,
  985. torch.special.i0e: lambda input: -1,
  986. torch.special.i1: lambda input: -1,
  987. torch.special.i1e: lambda input: -1,
  988. torch.special.logit: lambda input: -1,
  989. torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1,
  990. torch.special.log1p: lambda input: -1,
  991. torch.special.log_softmax: lambda input, dim, dtype=None: -1,
  992. torch.special.round: lambda input: -1,
  993. torch.special.sinc: lambda input: -1,
  994. torch.special.softmax: lambda input, dim, dtype=None: -1,
  995. torch.special.multigammaln: lambda input, p: -1,
  996. torch.special.ndtri: lambda input: -1,
  997. torch.special.ndtr: lambda input: -1,
  998. torch.special.log_ndtr: lambda input: -1,
  999. torch.special.xlogy: lambda input, other, out=None: -1,
  1000. torch.special.xlog1py: lambda input, other, out=None: -1,
  1001. torch.special.zeta: lambda self, other, out=None: -1,
  1002. torch.t: lambda input: -1,
  1003. torch.take: lambda input, index: -1,
  1004. torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,
  1005. torch.tan: lambda input, out=None: -1,
  1006. torch.tanh: lambda input, out=None: -1,
  1007. torch.linalg.tensorinv: lambda a, ind=2: -1,
  1008. torch.linalg.tensorsolve: lambda a, b, dims=None: -1,
  1009. torch.tensordot: lambda a, b, dims=2, out=None: -1,
  1010. torch.tensor_split: lambda input, indices_or_sections, dim=0: -1,
  1011. torch.threshold: lambda input, threshold, value, inplace=False: -1,
  1012. torch.tile: lambda input, dims: -1,
  1013. torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1,
  1014. torch.trace: lambda input: -1,
  1015. torch.transpose: lambda input, dim0, dim1: -1,
  1016. torch.trapz: lambda y, x=None, dim=-1: -1,
  1017. torch.trapezoid: lambda y, x=None, dim=-1: -1,
  1018. torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
  1019. torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
  1020. torch.tril: lambda input, diagonal=0, out=None: -1,
  1021. torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
  1022. size_average=None, reduce=None, reduction='mean': -1),
  1023. torch.triu: lambda input, diagonal=0, out=None: -1,
  1024. torch.true_divide: lambda input, other: -1,
  1025. torch.trunc: lambda input, out=None: -1,
  1026. torch.unbind: lambda input, dim=0: -1,
  1027. torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
  1028. torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
  1029. torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
  1030. torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1,
  1031. torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
  1032. torch.unsqueeze: lambda input, dim, out=None: -1,
  1033. torch.linalg.vander: lambda x, N=None: -1,
  1034. torch.var: lambda input, dim=None: -1,
  1035. torch.var_mean: lambda input, dim=None: -1,
  1036. torch.vsplit: lambda input, indices_or_sections: -1,
  1037. torch.vstack: lambda tensors, out=None: -1,
  1038. torch.where: lambda condition, x=None, y=None: -1,
  1039. torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  1040. torch._fw_primal_copy: lambda self, level: -1,
  1041. torch._make_dual_copy: lambda primal, tangent, level: -1,
  1042. torch.view_as_real_copy: lambda self: -1,
  1043. torch.view_as_complex_copy: lambda self: -1,
  1044. torch._conj_copy: lambda self: -1,
  1045. torch._neg_view_copy: lambda self: -1,
  1046. torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1,
  1047. torch._sparse_broadcast_to_copy: lambda self, size: -1,
  1048. torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1,
  1049. torch.expand_copy: lambda self, size, *, implicit=False: -1,
  1050. torch.narrow_copy: lambda self, dim, start, length: -1,
  1051. torch.permute_copy: lambda self, dims: -1,
  1052. torch._reshape_alias_copy: lambda self, size, stride: -1,
  1053. torch.select_copy: lambda self, dim, index: -1,
  1054. torch.detach_copy: lambda self: -1,
  1055. torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1,
  1056. torch.split_copy: lambda self, split_size, dim=0: -1,
  1057. torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1,
  1058. torch.squeeze_copy: lambda self: -1,
  1059. torch.squeeze_copy: lambda self, dim: -1,
  1060. torch.t_copy: lambda self: -1,
  1061. torch.transpose_copy: lambda self, dim0, dim1: -1,
  1062. torch.unsqueeze_copy: lambda self, dim: -1,
  1063. torch._indices_copy: lambda self: -1,
  1064. torch._values_copy: lambda self: -1,
  1065. torch.indices_copy: lambda self: -1,
  1066. torch.values_copy: lambda self: -1,
  1067. torch.crow_indices_copy: lambda self: -1,
  1068. torch.col_indices_copy: lambda self: -1,
  1069. torch.ccol_indices_copy: lambda self: -1,
  1070. torch.row_indices_copy: lambda self: -1,
  1071. torch.unbind_copy: lambda self, dim=0: -1,
  1072. torch.view_copy: lambda self, size: -1,
  1073. torch.view_copy: lambda self, dtype: -1,
  1074. torch.unfold_copy: lambda self, dimension, size, step: -1,
  1075. torch.alias_copy: lambda self: -1,
  1076. Tensor.__floordiv__: lambda self, other: -1,
  1077. Tensor.__rfloordiv__: lambda self, other: -1,
  1078. Tensor.__ifloordiv__: lambda self, other: -1,
  1079. Tensor.__truediv__: lambda self, other: -1,
  1080. Tensor.__rtruediv__: lambda self, other: -1,
  1081. Tensor.__itruediv__: lambda self, other: -1,
  1082. Tensor.__lshift__: lambda self, other: -1,
  1083. Tensor.__rlshift__: lambda self, other: -1,
  1084. Tensor.__ilshift__: lambda self, other: -1,
  1085. Tensor.__rshift__: lambda self, other: -1,
  1086. Tensor.__rrshift__: lambda self, other: -1,
  1087. Tensor.__irshift__: lambda self, other: -1,
  1088. Tensor.__and__: lambda self, other: -1,
  1089. Tensor.__or__: lambda self, other: -1,
  1090. Tensor.__xor__: lambda self, other: -1,
  1091. Tensor.__float__: lambda self: -1,
  1092. Tensor.__complex__: lambda self: -1,
  1093. Tensor.__array__: lambda self, dtype: -1,
  1094. Tensor.__bool__: lambda self: -1,
  1095. Tensor.__contains__: lambda self, other: -1,
  1096. Tensor.__neg__: lambda self: -1,
  1097. Tensor.__invert__: lambda self: -1,
  1098. Tensor.__mod__: lambda self, other: -1,
  1099. Tensor.__rmod__: lambda self, other: -1,
  1100. Tensor.__imod__: lambda self, other: -1,
  1101. Tensor.__array_wrap__: lambda self, array: -1,
  1102. Tensor.__getitem__: lambda self, idx: -1,
  1103. Tensor.__deepcopy__: lambda self, memo: -1,
  1104. Tensor.__int__: lambda self: -1,
  1105. Tensor.__long__: lambda self: -1,
  1106. Tensor.__hash__: lambda self: -1,
  1107. Tensor.__index__: lambda self: -1,
  1108. Tensor.__len__: lambda self: -1,
  1109. Tensor.__format__: lambda self, format_spec: -1,
  1110. Tensor.__reduce_ex__: lambda self, proto: -1,
  1111. Tensor.__reversed__: lambda self: -1,
  1112. Tensor.__repr__: lambda self, *, tensor_contents=None: -1,
  1113. Tensor.__setitem__: lambda self, k, v: -1,
  1114. Tensor.__setstate__: lambda self, d: -1,
  1115. Tensor.T.__get__: lambda self: -1,
  1116. Tensor.H.__get__: lambda self: -1,
  1117. Tensor.mT.__get__: lambda self: -1,
  1118. Tensor.mH.__get__: lambda self: -1,
  1119. Tensor._backward_hooks.__get__: lambda self: -1,
  1120. Tensor._base.__get__: lambda self: -1,
  1121. Tensor._cdata.__get__: lambda self: -1,
  1122. Tensor.grad.__get__: lambda self: -1,
  1123. Tensor._grad.__get__: lambda self: -1,
  1124. Tensor._grad_fn.__get__: lambda self: -1,
  1125. Tensor.grad_fn.__get__: lambda self: -1,
  1126. Tensor._version.__get__: lambda self: -1,
  1127. Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1,
  1128. Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1,
  1129. Tensor.data.__get__: lambda self: -1,
  1130. Tensor.device.__get__: lambda self: -1,
  1131. Tensor.dtype.__get__: lambda self: -1,
  1132. Tensor.is_cuda.__get__: lambda self: -1,
  1133. Tensor.is_xpu.__get__: lambda self: -1,
  1134. Tensor.is_ipu.__get__: lambda self: -1,
  1135. Tensor.is_leaf.__get__: lambda self: -1,
  1136. Tensor.retains_grad.__get__: lambda self: -1,
  1137. Tensor.is_meta.__get__: lambda self: -1,
  1138. Tensor.is_mps.__get__: lambda self: -1,
  1139. Tensor.is_nested.__get__: lambda self: -1,
  1140. Tensor.is_ort.__get__: lambda self: -1,
  1141. Tensor.is_mkldnn.__get__: lambda self: -1,
  1142. Tensor.is_quantized.__get__: lambda self: -1,
  1143. Tensor.is_sparse.__get__: lambda self: -1,
  1144. Tensor.is_sparse_csr.__get__: lambda self: -1,
  1145. Tensor.is_vulkan.__get__: lambda self: -1,
  1146. Tensor.layout.__get__: lambda self: -1,
  1147. Tensor.name.__get__: lambda self: -1,
  1148. Tensor.names.__get__: lambda self: -1,
  1149. Tensor.ndim.__get__: lambda self: -1,
  1150. Tensor.output_nr.__get__: lambda self: -1,
  1151. Tensor.requires_grad.__get__: lambda self: -1,
  1152. Tensor.shape.__get__: lambda self: -1,
  1153. Tensor.volatile.__get__: lambda self: -1,
  1154. Tensor.real.__get__: lambda self: -1,
  1155. Tensor.imag.__get__: lambda self: -1,
  1156. Tensor.__cuda_array_interface__.__get__: lambda self: -1,
  1157. Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1,
  1158. Tensor._coalesced_: lambda self: -1,
  1159. Tensor._dimI: lambda self: -1,
  1160. Tensor._dimV: lambda self: -1,
  1161. Tensor._indices: lambda self: -1,
  1162. Tensor._is_view: lambda self: -1,
  1163. Tensor._nnz: lambda self: -1,
  1164. Tensor.crow_indices: lambda self: -1,
  1165. Tensor.col_indices: lambda self: -1,
  1166. Tensor.ccol_indices: lambda self: -1,
  1167. Tensor.row_indices: lambda self: -1,
  1168. Tensor._update_names: lambda self, names, inplace: -1,
  1169. Tensor._values: lambda self: -1,
  1170. Tensor.adjoint: lambda self: -1,
  1171. Tensor.align_as: lambda self, other: -1,
  1172. Tensor.align_to: lambda self, order, ellipsis_idx: -1,
  1173. Tensor.apply_: lambda self, callable: -1,
  1174. Tensor.as_strided: lambda self, size, stride: -1,
  1175. Tensor.as_strided_: lambda self, size, stride: -1,
  1176. Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1,
  1177. Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1,
  1178. Tensor.bool: lambda self, memory_format=torch.preserve_format: -1,
  1179. Tensor.byte: lambda self, memory_format=torch.preserve_format: -1,
  1180. Tensor.char: lambda self, memory_format=torch.preserve_format: -1,
  1181. Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1,
  1182. Tensor.coalesce: lambda self: -1,
  1183. Tensor._coalesced_: lambda self, coalesced: -1,
  1184. Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1,
  1185. Tensor.copy_: lambda self, src, non_blocking=False: -1,
  1186. Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1,
  1187. Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1,
  1188. Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1,
  1189. Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1,
  1190. Tensor.data_ptr: lambda self: -1,
  1191. Tensor.dense_dim: lambda self: -1,
  1192. Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,
  1193. Tensor.dim: lambda self: -1,
  1194. Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
  1195. Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
  1196. Tensor.element_size: lambda self: -1,
  1197. Tensor.expand: lambda self, size: -1,
  1198. Tensor.expand_as: lambda self, other: -1,
  1199. Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1,
  1200. Tensor.fill_: lambda self, value: -1,
  1201. Tensor.fill_diagonal_: lambda self, value: -1,
  1202. Tensor.float: lambda self, memory_format=torch.preserve_format: -1,
  1203. Tensor.cfloat: lambda self, memory_format=torch.preserve_format: -1,
  1204. Tensor.geometric_: lambda self, p, *, generator=None: -1,
  1205. Tensor.get_device: lambda self: -1,
  1206. Tensor.half: lambda self, memory_format=torch.preserve_format: -1,
  1207. Tensor.chalf: lambda self, memory_format=torch.preserve_format: -1,
  1208. Tensor.has_names: lambda self: -1,
  1209. Tensor.indices: lambda self: -1,
  1210. Tensor.int: lambda self, memory_format=torch.preserve_format: -1,
  1211. Tensor.is_coalesced: lambda self: -1,
  1212. Tensor.is_contiguous: lambda self: -1,
  1213. Tensor.is_inference: lambda self: -1,
  1214. Tensor.is_pinned: lambda self: -1,
  1215. Tensor.is_set_to: lambda self, tensor: -1,
  1216. Tensor.is_shared: lambda self: -1,
  1217. Tensor.item: lambda self: -1,
  1218. Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1,
  1219. Tensor.log_softmax: lambda self, dim: -1,
  1220. Tensor.long: lambda self, memory_format=torch.preserve_format: -1,
  1221. Tensor.map_: lambda self, tensor, callable: -1,
  1222. Tensor.map2_: lambda self, x, y, callable: -1,
  1223. Tensor.mm: lambda self, mat2: -1,
  1224. Tensor.narrow_copy: lambda self, dimension, start, length: -1,
  1225. Tensor.ndimension: lambda self: -1,
  1226. Tensor.nelement: lambda self: -1,
  1227. Tensor.normal_: lambda self: -1,
  1228. Tensor.numpy: lambda self: -1,
  1229. Tensor.permute: lambda self, dim: -1,
  1230. Tensor.pin_memory: lambda self: -1,
  1231. Tensor.put_: lambda self, indices, tensor, accumulate=False: -1,
  1232. Tensor.qscheme: lambda self: -1,
  1233. Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1,
  1234. Tensor.record_stream: lambda self, stream: -1,
  1235. Tensor.refine_names: lambda self, names: -1,
  1236. Tensor.register_hook: lambda self, hook: -1,
  1237. Tensor.rename: lambda self, name: -1,
  1238. Tensor.repeat: lambda self, *size: -1,
  1239. Tensor.requires_grad_: lambda self, requires_grad=True: -1,
  1240. Tensor.reshape_as: lambda self, other: -1,
  1241. Tensor.resize: lambda self, *size: -1,
  1242. Tensor.resize_: lambda self, size: -1,
  1243. Tensor.resize_as: lambda self, other: -1,
  1244. Tensor.resize_as_sparse_: lambda self, other: -1,
  1245. Tensor.retain_grad: lambda self: -1,
  1246. Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
  1247. Tensor.select_scatter: lambda self, src, dim, index: -1,
  1248. Tensor.share_memory_: lambda self: -1,
  1249. Tensor.short: lambda self, memory_format=torch.preserve_format: -1,
  1250. Tensor.size: lambda self: -1,
  1251. Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1,
  1252. Tensor.sparse_dim: lambda self: -1,
  1253. Tensor.sparse_mask: lambda self, mask: -1,
  1254. Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
  1255. Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
  1256. Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
  1257. Tensor.storage: lambda self: -1,
  1258. Tensor._storage: lambda self: -1,
  1259. Tensor.storage_offset: lambda self: -1,
  1260. Tensor.storage_type: lambda self: -1,
  1261. Tensor.sum_to_size: lambda self, size: -1,
  1262. Tensor.tile: lambda self, *reps: -1,
  1263. Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
  1264. Tensor.to_dense: lambda self, dtype=None: -1,
  1265. Tensor._to_dense: lambda self, dtype=None: -1,
  1266. Tensor.to_sparse: lambda self: -1,
  1267. Tensor.tolist: lambda self: -1,
  1268. Tensor.to_mkldnn: lambda self: -1,
  1269. Tensor.type_as: lambda self, other: -1,
  1270. Tensor.unfold: lambda self, dimension, size, step: -1,
  1271. Tensor.uniform_: lambda self, from_=0, to=1: -1,
  1272. Tensor.values: lambda self: -1,
  1273. Tensor.view: lambda self, shape: -1,
  1274. Tensor.view_as: lambda self, other: -1,
  1275. Tensor.zero_: lambda self: -1,
  1276. Tensor.__dlpack__: lambda self, stream=None: -1,
  1277. Tensor.__dlpack_device__: lambda self: -1,
  1278. torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
  1279. }
  1280. ret2 = {}
  1281. ignored = get_ignored_functions()
  1282. for k, v in ret.items():
  1283. # Generate methods like __add__ and add_ by default from add
  1284. names = [
  1285. k.__name__, # Default method
  1286. k.__name__ + "_", # Inplace variant
  1287. "__" + k.__name__ + "__", # Dunder method
  1288. "__i" + k.__name__ + "__", # Inplace dunder method
  1289. "__r" + k.__name__ + "__", # Reverse dunder method
  1290. ]
  1291. if k.__name__.startswith("bitwise_"):
  1292. # bitwise_<op> have dunder methods of the form __<op>__
  1293. # And so on.
  1294. subname = k.__name__[len("bitwise_"):]
  1295. names.extend([
  1296. "__" + subname + "__",
  1297. "__i" + subname + "__",
  1298. "__r" + subname + "__"
  1299. ])
  1300. for name in names:
  1301. func = getattr(Tensor, name, None)
  1302. if callable(func) and func not in ret and func not in ignored:
  1303. ret2[func] = v
  1304. ret.update(ret2)
  1305. return ret
  1306. def wrap_torch_function(dispatcher: Callable):
  1307. """Wraps a given function with ``__torch_function__`` -related functionality.
  1308. Parameters
  1309. ----------
  1310. dispatcher: Callable
  1311. A callable that returns an iterable of Tensor-likes passed into the function.
  1312. Note
  1313. ----
  1314. This decorator may reduce the performance of your code. Generally, it's enough to express
  1315. your code as a series of functions that, themselves, support __torch_function__. If you
  1316. find yourself in the rare situation where this is not the case, e.g. if you're wrapping a
  1317. low-level library and you also need it to work for Tensor-likes, then this function is available.
  1318. Examples
  1319. --------
  1320. >>> def dispatcher(a): # Must have the same signature as func
  1321. ... return (a,)
  1322. >>> @torch.overrides.wrap_torch_function(dispatcher)
  1323. >>> def func(a): # This will make func dispatchable by __torch_function__
  1324. ... return a + 0
  1325. """
  1326. def inner(func):
  1327. @functools.wraps(func)
  1328. def wrapped(*args, **kwargs):
  1329. relevant_args = dispatcher(*args, **kwargs)
  1330. if has_torch_function(relevant_args):
  1331. return handle_torch_function(wrapped, relevant_args, *args, **kwargs)
  1332. return func(*args, **kwargs)
  1333. return wrapped
  1334. return inner
  1335. def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
  1336. """Returns a list of arguments on which to call __torch_function__.
  1337. Checks arguments in relevant_args for __torch_function__ implementations,
  1338. storing references to the arguments and their types in overloaded_args and
  1339. overloaded_types in order of calling precedence. Only distinct types are
  1340. considered. If a type is a subclass of another type it will have higher
  1341. precedence, otherwise the precedence order is the same as the order of
  1342. arguments in relevant_args, that is, from left-to-right in the argument list.
  1343. The precedence-determining algorithm implemented in this function is
  1344. described in `NEP-0018`_.
  1345. See torch::append_overloaded_arg for the equivalent function in the C++
  1346. implementation.
  1347. Parameters
  1348. ----------
  1349. relevant_args : iterable of array-like
  1350. Iterable of array-like arguments to check for __torch_function__
  1351. methods.
  1352. Returns
  1353. -------
  1354. overloaded_args : list
  1355. Arguments from relevant_args on which to call __torch_function__
  1356. methods, in the order in which they should be called.
  1357. .. _NEP-0018:
  1358. https://numpy.org/neps/nep-0018-array-function-protocol.html
  1359. """
  1360. # If torch function is not enabled, there are no overloaded types
  1361. if not torch._C._is_torch_function_enabled():
  1362. return []
  1363. # Runtime is O(num_arguments * num_unique_types)
  1364. overloaded_types: Set[Type] = set()
  1365. overloaded_args: List[Any] = []
  1366. for arg in relevant_args:
  1367. arg_type = type(arg)
  1368. # We only collect arguments if they have a unique type, which ensures
  1369. # reasonable performance even with a long list of possibly overloaded
  1370. # arguments.
  1371. #
  1372. # NB: Important to exclude _disabled_torch_function_impl, otherwise
  1373. # https://github.com/pytorch/pytorch/issues/64687
  1374. if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
  1375. arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
  1376. # Create lists explicitly for the first type (usually the only one
  1377. # done) to avoid setting up the iterator for overloaded_args.
  1378. if overloaded_types:
  1379. overloaded_types.add(arg_type)
  1380. # By default, insert argument at the end, but if it is
  1381. # subclass of another argument, insert it before that argument.
  1382. # This ensures "subclasses before superclasses".
  1383. index = len(overloaded_args)
  1384. for i, old_arg in enumerate(overloaded_args):
  1385. if issubclass(arg_type, type(old_arg)):
  1386. index = i
  1387. break
  1388. overloaded_args.insert(index, arg)
  1389. else:
  1390. overloaded_types = {arg_type}
  1391. overloaded_args = [arg]
  1392. return overloaded_args
  1393. def handle_torch_function(
  1394. public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
  1395. """Implement a function with checks for ``__torch_function__`` overrides.
  1396. See torch::autograd::handle_torch_function for the equivalent of this
  1397. function in the C++ implementation.
  1398. Arguments
  1399. ---------
  1400. public_api : function
  1401. Function exposed by the public torch API originally called like
  1402. ``public_api(*args, **kwargs)`` on which arguments are now being
  1403. checked.
  1404. relevant_args : iterable
  1405. Iterable of arguments to check for __torch_function__ methods.
  1406. args : tuple
  1407. Arbitrary positional arguments originally passed into ``public_api``.
  1408. kwargs : tuple
  1409. Arbitrary keyword arguments originally passed into ``public_api``.
  1410. Returns
  1411. -------
  1412. object
  1413. Result from calling ``implementation`` or an ``__torch_function__``
  1414. method, as appropriate.
  1415. Raises
  1416. ------
  1417. TypeError : if no implementation is found.
  1418. Example
  1419. -------
  1420. >>> def func(a):
  1421. ... if has_torch_function_unary(a):
  1422. ... return handle_torch_function(func, (a,), a)
  1423. ... return a + 0
  1424. """
  1425. # Check for __torch_function__ methods.
  1426. overloaded_args = _get_overloaded_args(relevant_args)
  1427. # overloaded_args already have unique types.
  1428. types = tuple(map(type, overloaded_args))
  1429. # Check for __torch_function__ mode.
  1430. mode = _get_torch_function_mode()
  1431. if mode is not None:
  1432. # NB: unlike on tensors, modes are instances
  1433. with _no_torch_function_mode():
  1434. result = mode.__torch_function__(public_api, types, args, kwargs)
  1435. if result is not NotImplemented:
  1436. return result
  1437. # Call overrides
  1438. for overloaded_arg in overloaded_args:
  1439. # This call needs to become a classmethod call in the future.
  1440. # See https://github.com/pytorch/pytorch/issues/63767
  1441. torch_func_method = overloaded_arg.__torch_function__
  1442. if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \
  1443. torch_func_method is not torch._C._disabled_torch_function_impl:
  1444. warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
  1445. "will be an error in future, please define it as a classmethod.",
  1446. DeprecationWarning)
  1447. # Use `public_api` instead of `implementation` so __torch_function__
  1448. # implementations can do equality/identity comparisons.
  1449. result = torch_func_method(public_api, types, args, kwargs)
  1450. if result is not NotImplemented:
  1451. return result
  1452. func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
  1453. msg = (
  1454. "no implementation found for '{}' on types that implement "
  1455. '__torch_function__: {}'
  1456. ).format(func_name, [type(arg) for arg in overloaded_args])
  1457. if mode is not None:
  1458. msg += f" nor in mode {mode}"
  1459. raise TypeError(msg)
  1460. has_torch_function = _add_docstr(
  1461. _has_torch_function,
  1462. r"""Check for __torch_function__ implementations in the elements of an iterable
  1463. or if a __torch_function__ mode is enabled. Considers exact ``Tensor`` s
  1464. and ``Parameter`` s non-dispatchable. Use this to guard a call to
  1465. :func:`handle_torch_function`; don't use it to test if something
  1466. is Tensor-like, use :func:`is_tensor_like` instead.
  1467. Arguments
  1468. ---------
  1469. relevant_args : iterable
  1470. Iterable or aguments to check for __torch_function__ methods.
  1471. Returns
  1472. -------
  1473. bool
  1474. True if any of the elements of relevant_args have __torch_function__
  1475. implementations, False otherwise.
  1476. See Also
  1477. ________
  1478. torch.is_tensor_like
  1479. Checks if something is a Tensor-like, including an exact ``Tensor``.
  1480. """
  1481. )
  1482. has_torch_function_unary = _add_docstr(
  1483. _has_torch_function_unary,
  1484. r"""Special case of `has_torch_function` for single inputs.
  1485. Instead of:
  1486. `has_torch_function((t,))`
  1487. call:
  1488. `has_torch_function_unary(t)`
  1489. which skips unnecessary packing and unpacking work.
  1490. """
  1491. )
  1492. has_torch_function_variadic = _add_docstr(
  1493. _has_torch_function_variadic,
  1494. r"""Special case of `has_torch_function` that skips tuple creation.
  1495. This uses the METH_FASTCALL protocol introduced in Python 3.7
  1496. Instead of:
  1497. `has_torch_function((a, b))`
  1498. call:
  1499. `has_torch_function_variadic(a, b)`
  1500. which skips unnecessary packing and unpacking work.
  1501. """
  1502. )
  1503. @functools.lru_cache(None)
  1504. def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
  1505. overridable_funcs = collections.defaultdict(list)
  1506. index = {}
  1507. tested_namespaces = [
  1508. ("torch", torch, torch.__all__ + dir(torch._C._VariableFunctions)),
  1509. ("torch.functional", torch.functional, torch.functional.__all__),
  1510. ("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
  1511. ("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
  1512. ("torch.Tensor", torch.Tensor, dir(torch.Tensor)),
  1513. ("torch.linalg", torch.linalg, dir(torch.linalg)),
  1514. ("torch.fft", torch.fft, dir(torch.fft)),
  1515. ("torch.special", torch.special, dir(torch.special)),
  1516. ]
  1517. for namespace_str, namespace, ns_funcs in tested_namespaces:
  1518. for func_name in ns_funcs:
  1519. ignore = False
  1520. # ignore private functions or functions that are deleted in torch.__init__
  1521. if namespace is not torch.Tensor:
  1522. if func_name.startswith('__'):
  1523. continue
  1524. elif func_name.startswith('_'):
  1525. ignore = True
  1526. elif func_name.endswith('_'):
  1527. ignore = True
  1528. elif not func_name[0].islower():
  1529. ignore = True
  1530. elif func_name == 'unique_dim':
  1531. continue
  1532. else:
  1533. func = getattr(namespace, func_name)
  1534. if getattr(object, func_name, None) == func:
  1535. continue
  1536. if func_name == '__weakref__':
  1537. continue
  1538. func = getattr(namespace, func_name)
  1539. if namespace is torch.Tensor and getattr(object, func_name, None) == func:
  1540. continue
  1541. # ignore re-exported modules
  1542. if isinstance(func, types.ModuleType):
  1543. continue
  1544. # ignore __future__ imports
  1545. if isinstance(func, __future__._Feature):
  1546. continue
  1547. if not callable(func) and hasattr(func, "__get__"):
  1548. index[func.__get__] = f"{namespace_str}.{func_name}.__get__"
  1549. index[func.__set__] = f"{namespace_str}.{func_name}.__set__"
  1550. if ignore:
  1551. continue
  1552. if func.__get__ in get_ignored_functions():
  1553. msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
  1554. "but still has an explicit override")
  1555. assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__)
  1556. continue
  1557. else:
  1558. overridable_funcs[func].append(func.__get__)
  1559. continue
  1560. if not callable(func):
  1561. continue
  1562. index[func] = f"{namespace_str}.{func_name}"
  1563. if ignore:
  1564. continue
  1565. # cannot be overriden by __torch_function__
  1566. if func in get_ignored_functions():
  1567. msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
  1568. "but still has an explicit override")
  1569. assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
  1570. continue
  1571. overridable_funcs[namespace].append(func)
  1572. return overridable_funcs, index
  1573. def get_overridable_functions() -> Dict[Any, List[Callable]]:
  1574. """List functions that are overridable via __torch_function__
  1575. Returns
  1576. -------
  1577. Dict[Any, List[Callable]]
  1578. A dictionary that maps namespaces that contain overridable functions
  1579. to functions in that namespace that can be overridden.
  1580. """
  1581. return _get_overridable_functions()[0]
  1582. def resolve_name(f):
  1583. """Get a human readable string name for a function passed to
  1584. __torch_function__
  1585. Arguments
  1586. ---------
  1587. callable : Callable
  1588. Function to resolve the name of.
  1589. Returns
  1590. -------
  1591. str
  1592. Name of the function; if eval'ed it should give back the input
  1593. function.
  1594. """
  1595. if isinstance(f, torch._ops.OpOverload):
  1596. return str(f)
  1597. return _get_overridable_functions()[1].get(f)
  1598. @functools.lru_cache(None)
  1599. def _get_tensor_methods() -> Set[Callable]:
  1600. """ Returns a set of the overridable methods on ``torch.Tensor`` """
  1601. overridable_funcs = get_overridable_functions()
  1602. methods = set(overridable_funcs[torch.Tensor])
  1603. return methods
  1604. def is_tensor_method_or_property(func: Callable) -> bool:
  1605. """
  1606. Returns True if the function passed in is a handler for a
  1607. method or property belonging to ``torch.Tensor``, as passed
  1608. into ``__torch_function__``.
  1609. .. note::
  1610. For properties, their ``__get__`` method must be passed in.
  1611. This may be needed, in particular, for the following reasons:
  1612. 1. Methods/properties sometimes don't contain a `__module__` slot.
  1613. 2. They require that the first passed-in argument is an instance
  1614. of ``torch.Tensor``.
  1615. Examples
  1616. --------
  1617. >>> is_tensor_method_or_property(torch.Tensor.add)
  1618. True
  1619. >>> is_tensor_method_or_property(torch.add)
  1620. False
  1621. """
  1622. return func in _get_tensor_methods() or func.__name__ == "__get__"
  1623. def is_tensor_like(inp):
  1624. """
  1625. Returns ``True`` if the passed-in input is a Tensor-like.
  1626. Currently, this occurs whenever there's a ``__torch_function__``
  1627. attribute on the type of the input.
  1628. Examples
  1629. --------
  1630. A subclass of tensor is generally a Tensor-like.
  1631. >>> class SubTensor(torch.Tensor): ...
  1632. >>> is_tensor_like(SubTensor([0]))
  1633. True
  1634. Built-in or user types aren't usually Tensor-like.
  1635. >>> is_tensor_like(6)
  1636. False
  1637. >>> is_tensor_like(None)
  1638. False
  1639. >>> class NotATensor: ...
  1640. >>> is_tensor_like(NotATensor())
  1641. False
  1642. But, they can be made Tensor-like by implementing __torch_function__.
  1643. >>> class TensorLike:
  1644. ... @classmethod
  1645. ... def __torch_function__(cls, func, types, args, kwargs):
  1646. ... return -1
  1647. >>> is_tensor_like(TensorLike())
  1648. True
  1649. """
  1650. return type(inp) is torch.Tensor or hasattr(type(inp), "__torch_function__")
  1651. def _wrap_torch_function(f):
  1652. @functools.wraps(f)
  1653. def wrapped(self, *args, **kwargs):
  1654. with enable_torch_function_mode(self.inner):
  1655. return f(self, *args, **kwargs)
  1656. return wrapped
  1657. # Implementation note: I had a choice about how much of mode stacks
  1658. # to implement in Python versus in C++. At time of writing, I did not care
  1659. # too much about implementation efficiency; however, I do care about making it
  1660. # hard for users to implement modes in the wrong way. In the end, it turned
  1661. # out to be possible to implement mode stacks entirely from userland, with the
  1662. # C++ API providing only _get_torch_function_mode() and
  1663. # _set_torch_function_mode(), so I opted to provide some unsafe C++ bindings and
  1664. # have the bulk of the logic for managing the stack in Python, which helped
  1665. # simplify the C++ API surface. It would also have been valid to build in the
  1666. # notion of mode stack directly into C++ but in this design it's substantially
  1667. # more difficult to interact with TorchFunctionModeMeta.
  1668. class _TorchFunctionMetaInitErrorInfo(MetaInitErrorInfo):
  1669. def __init__(self):
  1670. super().__init__(mode_class_name="TorchDispatchMode", mode_name="torch_dispatch")
  1671. class TorchFunctionModeMeta(type):
  1672. """
  1673. Metaclass for :class:`TorchFunctionMode`; it does two things:
  1674. * Adds an implicit ``inner`` kwarg to ``__init__``, to
  1675. allow the modes to be chained together to form a stack.
  1676. * Reenables the inner mode, so that by default PyTorch API calls
  1677. will compositionally proceed to the next mode on the stack.
  1678. The default behavior for the second bullet is important, as it is easy to
  1679. accidentally write ``__torch_function__`` implementations that are not
  1680. compositional, and the wrapping here makes the obvious code do the
  1681. right thing (aka, this is why there is a metaclass).
  1682. """
  1683. def __new__(metacls, name, bases, dct):
  1684. if '__init__' in dct:
  1685. dct['__init__'] = _wrap_init(dct['__init__'], _TorchFunctionMetaInitErrorInfo())
  1686. if '__torch_function__' in dct:
  1687. dct['__torch_function__'] = _wrap_torch_function(dct['__torch_function__'])
  1688. return super().__new__(metacls, name, bases, dct)
  1689. class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
  1690. """
  1691. A ``TorchFunctionMode`` allows you to override the meaning of all
  1692. ``__torch_function__`` overrideable functions within a dynamic scope,
  1693. without having to actually create a tensor subclass or manually
  1694. monkey-patch functions in the PyTorch API. Some common situations
  1695. where you should use a mode:
  1696. * You want to override the meaning of factory functions, or other
  1697. functions that do not otherwise take a tensor as an argument
  1698. (these cannot be overridden with tensor subclasses).
  1699. * You want to override the behavior of all functions without needing
  1700. to wrap your inputs in tensor subclasses; e.g., if you are just
  1701. interested in logging intermediate computations.
  1702. * You want to control the order of execution of various tensor
  1703. subclasses explicitly, rather than implicitly via the return of
  1704. ``NotImplemented``.
  1705. Independent subclasses of :class:`TorchFunctionMode` are compositional:
  1706. modes can be pushed onto a stack with :func:`push_torch_function_mode`.
  1707. When you call functions in the PyTorch API inside your
  1708. ``__torch_function__`` implementation, by default, they will forward on to
  1709. the next mode on the mode stack. If you want recursively call back into
  1710. your current ``__torch_function__`` implementation, either explicitly
  1711. invoke ``self.__torch_function__(...)``, or use the context manager
  1712. ``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
  1713. API self-referential (beware of infinite loops, in this case!)
  1714. """
  1715. inner: "TorchFunctionMode"
  1716. # Force metaclass to generate constructor at the base of the hierarchy
  1717. def __init__(self):
  1718. pass
  1719. def __torch_function__(self, func, types, args=(), kwargs=None):
  1720. raise NotImplementedError()
  1721. @classmethod
  1722. def push(cls, *args, **kwargs):
  1723. return push_torch_function_mode(functools.partial(cls, *args, **kwargs))
  1724. class BaseTorchFunctionMode(TorchFunctionMode):
  1725. def __torch_function__(self, func, types, args=(), kwargs=None):
  1726. if kwargs is None:
  1727. kwargs = {}
  1728. return func(*args, **kwargs)
  1729. # This is private API as I'm not sure it's possible for users to use this
  1730. # compositionally (easy to discard too many modes). It is useful for
  1731. # library code though, e.g., in handle_torch_function
  1732. @contextlib.contextmanager
  1733. def _no_torch_function_mode() -> Iterator[None]:
  1734. old = _get_torch_function_mode()
  1735. _set_torch_function_mode(None)
  1736. try:
  1737. yield
  1738. finally:
  1739. _set_torch_function_mode(old)
  1740. class _TorchFunctionModeInfo(_ModeInfo):
  1741. def __init__(self):
  1742. super().__init__(mode_name="torch_function", mode_class=TorchFunctionMode,
  1743. base_mode_class=BaseTorchFunctionMode)
  1744. def get_mode(self):
  1745. return _get_torch_function_mode()
  1746. def set_mode(self, mode):
  1747. return _set_torch_function_mode(mode)
  1748. @contextlib.contextmanager
  1749. def enable_torch_function_mode(mode, *, replace=None, ignore_preexisting=False) -> Iterator[None]:
  1750. """
  1751. Context manager that sets the current :class:`TorchFunctionMode`; see the
  1752. class for more information on what modes are. This function is
  1753. non-compositional; if there is already an existing mode, it will raise an
  1754. error; prefer using :func:`push_torch_function_mode` if your
  1755. ``__torch_function__`` implementation can defer to an inner mode.
  1756. This function is safe to use inside a ``__torch_function__`` mode handler,
  1757. as the mode is guaranteed to be disabled in this context. You can use
  1758. this context manager to reinstate the mode so that calls to overridable
  1759. APIs recursively call back into your mode handler (this can easily cause
  1760. infinite loops, so use with care!)
  1761. Args:
  1762. mode (:class:`TorchFunctionMode`, Tensor-like class or None): the
  1763. mode to set as current mode. If you pass a Tensor-like class,
  1764. it will be treated as a non-compositional mode with no state,
  1765. which is convenient if you have an existing tensor subclass
  1766. that you'd like to apply globally in a quick and dirty way.
  1767. Passing None will disable the current mode.
  1768. replace (:class:`TorchFunctionMode` or Tensor-like class): the
  1769. mode to replace. You can use this argument to change the mode in
  1770. a situation where you know what the current mode is (and you are
  1771. intentionally overwriting it.) If you don't know what the current
  1772. mode is, use ``ignore_preexisting`` instead.
  1773. ignore_preexisting (bool): if True, ignore any preexisting mode
  1774. and overwrite it with the passed mode.
  1775. """
  1776. return _enable_mode(mode, _TorchFunctionModeInfo(), replace=replace, ignore_preexisting=ignore_preexisting)
  1777. @contextlib.contextmanager
  1778. def push_torch_function_mode(ctor) -> Iterator[TorchFunctionMode]:
  1779. """
  1780. Context manager that pushes a :class:`TorchFunctionMode` onto the current
  1781. mode stack; see the class for more information on what modes are. Stacked
  1782. modes can delegate to each other by invoking the ``__torch_function__``
  1783. method for the ``inner`` mode.
  1784. Args:
  1785. ctor: a function that when invoked as ``ctor(inner=...)`` produces
  1786. a :class:`TorchFunctionMode`. If your :class:`TorchFunctionMode`
  1787. has no ``__init__`` implementation, you can simply pass the class
  1788. itself (e.g., ``push_torch_function_mode(MyMode)``); otherwise,
  1789. use ``functools.partial`` to partially apply the constructor with all
  1790. non-inner arguments (e.g.,
  1791. ``push_torch_function_mode(partial(MyMode, arg))``)
  1792. """
  1793. return _push_mode(ctor, _TorchFunctionModeInfo())
  1794. class enable_reentrant_dispatch():
  1795. def __enter__(self):
  1796. self._raii_guard = torch._C._RestorePythonTLSSnapshot()
  1797. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  1798. del self._raii_guard