function.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. import torch
  2. import torch._C as _C
  3. from torch._C import _functions
  4. import torch.utils.hooks as hooks
  5. from torch._six import with_metaclass
  6. import functools
  7. import warnings
  8. from collections import OrderedDict
  9. from typing import Any, List, Optional
  10. # Formerly known as: _ContextMethodMixin
  11. class FunctionCtx(object):
  12. def save_for_backward(self, *tensors: torch.Tensor):
  13. r"""Saves given tensors for a future call to :func:`~Function.backward`.
  14. ``save_for_backward`` should be called at most once, only from inside the
  15. :func:`forward` method, and only with tensors.
  16. All tensors intended to be used in the backward pass should be saved
  17. with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent
  18. incorrect gradients and memory leaks, and enable the application of saved
  19. tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
  20. In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`
  21. attribute. Before returning them to the user, a check is made to ensure
  22. they weren't used in any in-place operation that modified their content.
  23. Arguments can also be ``None``. This is a no-op.
  24. See :ref:`extending-autograd` for more details on how to use this method.
  25. Example::
  26. >>> class Func(Function):
  27. >>> @staticmethod
  28. >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
  29. >>> w = x * y * z
  30. >>> out = x * y + y * z + w
  31. >>> ctx.save_for_backward(x, y, w, out)
  32. >>> ctx.z = z # z is not a tensor
  33. >>> return out
  34. >>>
  35. >>> @staticmethod
  36. >>> def backward(ctx, grad_out):
  37. >>> x, y, w, out = ctx.saved_tensors
  38. >>> z = ctx.z
  39. >>> gx = grad_out * (y + y * z)
  40. >>> gy = grad_out * (x + z + x * z)
  41. >>> gz = None
  42. >>> return gx, gy, gz
  43. >>>
  44. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
  45. >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
  46. >>> c = 4
  47. >>> d = Func.apply(a, b, c)
  48. """
  49. self.to_save = tensors
  50. def save_for_forward(self, *tensors: torch.Tensor):
  51. r"""Saves given tensors for a future call to :func:`~Function.jvp`.
  52. ``save_for_forward`` should be only called once, from inside the :func:`forward`
  53. method, and only be called with tensors.
  54. In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors`
  55. attribute.
  56. Arguments can also be ``None``. This is a no-op.
  57. See :ref:`extending-autograd` for more details on how to use this method.
  58. Example::
  59. >>> class Func(torch.autograd.Function):
  60. >>> @staticmethod
  61. >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
  62. >>> ctx.save_for_backward(x, y)
  63. >>> ctx.save_for_forward(x, y)
  64. >>> ctx.z = z
  65. >>> return x * y * z
  66. >>>
  67. >>> @staticmethod
  68. >>> def jvp(ctx, x_t, y_t, _):
  69. >>> x, y = ctx.saved_tensors
  70. >>> z = ctx.z
  71. >>> return z * (y * x_t + x * y_t)
  72. >>>
  73. >>> @staticmethod
  74. >>> def vjp(ctx, grad_out):
  75. >>> x, y = ctx.saved_tensors
  76. >>> z = ctx.z
  77. >>> return z * grad_out * y, z * grad_out * x, None
  78. >>>
  79. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
  80. >>> t = torch.tensor(1., dtype=torch.double)
  81. >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
  82. >>> c = 4
  83. >>>
  84. >>> with fwAD.dual_level():
  85. >>> a_dual = fwAD.make_dual(a, t)
  86. >>> d = Func.apply(a_dual, b, c)
  87. """
  88. for tensor in tensors:
  89. assert isinstance(tensor, torch.Tensor) or tensor is None, (
  90. "save_for_forward expects all arguments to be tensors; you should "
  91. "save non-tensors as attributes on ctx.")
  92. self.saved_for_forward = tensors
  93. def mark_dirty(self, *args: torch.Tensor):
  94. r"""Marks given tensors as modified in an in-place operation.
  95. **This should be called at most once, only from inside the**
  96. :func:`forward` **method, and all arguments should be inputs.**
  97. Every tensor that's been modified in-place in a call to :func:`forward`
  98. should be given to this function, to ensure correctness of our checks.
  99. It doesn't matter whether the function is called before or after
  100. modification.
  101. Examples::
  102. >>> class Inplace(Function):
  103. >>> @staticmethod
  104. >>> def forward(ctx, x):
  105. >>> x_npy = x.numpy() # x_npy shares storage with x
  106. >>> x_npy += 1
  107. >>> ctx.mark_dirty(x)
  108. >>> return x
  109. >>>
  110. >>> @staticmethod
  111. >>> @once_differentiable
  112. >>> def backward(ctx, grad_output):
  113. >>> return grad_output
  114. >>>
  115. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
  116. >>> b = a * a
  117. >>> Inplace.apply(a) # This would lead to wrong gradients!
  118. >>> # but the engine would not know unless we mark_dirty
  119. >>> b.backward() # RuntimeError: one of the variables needed for gradient
  120. >>> # computation has been modified by an inplace operation
  121. """
  122. self.dirty_tensors = args
  123. def mark_shared_storage(self, *pairs):
  124. warnings.warn(
  125. 'mark_shared_storage is deprecated. '
  126. 'Tensors with shared storages are automatically tracked. Note '
  127. 'that calls to `set_()` are not tracked')
  128. def mark_non_differentiable(self, *args: torch.Tensor):
  129. r"""Marks outputs as non-differentiable.
  130. **This should be called at most once, only from inside the**
  131. :func:`forward` **method, and all arguments should be tensor outputs.**
  132. This will mark outputs as not requiring gradients, increasing the
  133. efficiency of backward computation. You still need to accept a gradient
  134. for each output in :meth:`~Function.backward`, but it's always going to
  135. be a zero tensor with the same shape as the shape of a corresponding
  136. output.
  137. This is used e.g. for indices returned from a sort. See example::
  138. >>> class Func(Function):
  139. >>> @staticmethod
  140. >>> def forward(ctx, x):
  141. >>> sorted, idx = x.sort()
  142. >>> ctx.mark_non_differentiable(idx)
  143. >>> ctx.save_for_backward(x, idx)
  144. >>> return sorted, idx
  145. >>>
  146. >>> @staticmethod
  147. >>> @once_differentiable
  148. >>> def backward(ctx, g1, g2): # still need to accept g2
  149. >>> x, idx = ctx.saved_tensors
  150. >>> grad_input = torch.zeros_like(x)
  151. >>> grad_input.index_add_(0, idx, g1)
  152. >>> return grad_input
  153. """
  154. self.non_differentiable = args
  155. def set_materialize_grads(self, value: bool):
  156. r"""Sets whether to materialize output grad tensors. Default is ``True``.
  157. **This should be called only from inside the** :func:`forward` **method**
  158. If ``True``, undefined output grad tensors will be expanded to tensors full
  159. of zeros prior to calling the :func:`backward` method.
  160. Example::
  161. >>> class SimpleFunc(Function):
  162. >>> @staticmethod
  163. >>> def forward(ctx, x):
  164. >>> return x.clone(), x.clone()
  165. >>>
  166. >>> @staticmethod
  167. >>> @once_differentiable
  168. >>> def backward(ctx, g1, g2):
  169. >>> return g1 + g2 # No check for None necessary
  170. >>>
  171. >>> # We modify SimpleFunc to handle non-materialized grad outputs
  172. >>> class Func(Function):
  173. >>> @staticmethod
  174. >>> def forward(ctx, x):
  175. >>> ctx.set_materialize_grads(False)
  176. >>> ctx.save_for_backward(x)
  177. >>> return x.clone(), x.clone()
  178. >>>
  179. >>> @staticmethod
  180. >>> @once_differentiable
  181. >>> def backward(ctx, g1, g2):
  182. >>> x, = ctx.saved_tensors
  183. >>> grad_input = torch.zeros_like(x)
  184. >>> if g1 is not None: # We must check for None now
  185. >>> grad_input += g1
  186. >>> if g2 is not None:
  187. >>> grad_input += g2
  188. >>> return grad_input
  189. >>>
  190. >>> a = torch.tensor(1., requires_grad=True)
  191. >>> b, _ = Func.apply(a) # induces g2 to be undefined
  192. """
  193. self.materialize_grads = value
  194. # DO NOT USE: This is only defined to be able to load old serialized models
  195. _ContextMethodMixin = FunctionCtx
  196. class _HookMixin(object):
  197. @staticmethod
  198. def _register_hook(backward_hooks, hook):
  199. if backward_hooks is None:
  200. backward_hooks = OrderedDict()
  201. handle = hooks.RemovableHandle(backward_hooks)
  202. backward_hooks[handle.id] = hook
  203. return backward_hooks, handle
  204. class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
  205. def apply(self, *args):
  206. # _forward_cls is defined by derived class
  207. # The user should define either backward or vjp but never both.
  208. backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
  209. vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined]
  210. if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
  211. raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom "
  212. "Function is not allowed. You should only implement one "
  213. "of them.")
  214. user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
  215. return user_fn(self, *args)
  216. def apply_jvp(self, *args):
  217. # _forward_cls is defined by derived class
  218. return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined]
  219. class FunctionMeta(type):
  220. """Function metaclass.
  221. This metaclass sets up the following properties:
  222. _backward_cls: The Function class corresponding to the differentiated
  223. version of this function (which is generated on the fly by this
  224. metaclass).
  225. """
  226. def __init__(cls, name, bases, attrs):
  227. backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
  228. cls._backward_cls = backward_fn
  229. super(FunctionMeta, cls).__init__(name, bases, attrs)
  230. # mypy doesn't understand `with_metaclass` from torch._six
  231. class Function(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _HookMixin)): # type: ignore[misc]
  232. r"""Base class to create custom `autograd.Function`
  233. To create a custom `autograd.Function`, subclass this class and implement
  234. the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom
  235. op in the forward pass, call the class method ``apply``. Do not call
  236. :meth:`forward` directly.
  237. To ensure correctness and best performance, make sure you are calling the
  238. correct methods on ``ctx`` and validating your backward function using
  239. :func:`torch.autograd.gradcheck`.
  240. See :ref:`extending-autograd` for more details on how to use this class.
  241. Examples::
  242. >>> class Exp(Function):
  243. >>> @staticmethod
  244. >>> def forward(ctx, i):
  245. >>> result = i.exp()
  246. >>> ctx.save_for_backward(result)
  247. >>> return result
  248. >>>
  249. >>> @staticmethod
  250. >>> def backward(ctx, grad_output):
  251. >>> result, = ctx.saved_tensors
  252. >>> return grad_output * result
  253. >>>
  254. >>> # Use it by calling the apply method:
  255. >>> output = Exp.apply(input)
  256. """
  257. def __init__(self, *args, **kwargs):
  258. cls = self.__class__
  259. warnings.warn(f"{cls} should not be instantiated. Methods on autograd functions"
  260. "are all static, so you should invoke them on the class itself. "
  261. "Instantiating an autograd function will raise an "
  262. "error in a future version of PyTorch.", DeprecationWarning)
  263. def __call__(self, *args, **kwargs):
  264. raise RuntimeError(
  265. "Legacy autograd function with non-static forward method is deprecated. "
  266. "Please use new-style autograd function with static forward method. "
  267. "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)")
  268. # for the tracer
  269. is_traceable = False
  270. @staticmethod
  271. def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
  272. r"""Performs the operation.
  273. This function is to be overridden by all subclasses.
  274. It must accept a context ctx as the first argument, followed by any
  275. number of arguments (tensors or other types).
  276. The context can be used to store arbitrary data that can be then
  277. retrieved during the backward pass. Tensors should not be stored
  278. directly on `ctx` (though this is not currently enforced for
  279. backward compatibility). Instead, tensors should be saved either with
  280. :func:`ctx.save_for_backward` if they are intended to be used in
  281. ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
  282. if they are intended to be used for in ``jvp``.
  283. """
  284. raise NotImplementedError("You must implement the forward function for custom"
  285. " autograd.Function.")
  286. @staticmethod
  287. def backward(ctx: Any, *grad_outputs: Any) -> Any:
  288. r"""Defines a formula for differentiating the operation with backward mode
  289. automatic differentiation (alias to the vjp function).
  290. This function is to be overridden by all subclasses.
  291. It must accept a context :attr:`ctx` as the first argument, followed by
  292. as many outputs as the :func:`forward` returned (None will be passed in
  293. for non tensor outputs of the forward function),
  294. and it should return as many tensors, as there were inputs to
  295. :func:`forward`. Each argument is the gradient w.r.t the given output,
  296. and each returned value should be the gradient w.r.t. the
  297. corresponding input. If an input is not a Tensor or is a Tensor not
  298. requiring grads, you can just pass None as a gradient for that input.
  299. The context can be used to retrieve tensors saved during the forward
  300. pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
  301. of booleans representing whether each input needs gradient. E.g.,
  302. :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
  303. first input to :func:`forward` needs gradient computated w.r.t. the
  304. output.
  305. """
  306. raise NotImplementedError("You must implement either the backward or vjp method for "
  307. "your custom autograd.Function to use it with backward "
  308. "mode AD.")
  309. # vjp and backward are alias of each other
  310. vjp = backward
  311. @staticmethod
  312. def jvp(ctx: Any, *grad_inputs: Any) -> Any:
  313. r"""Defines a formula for differentiating the operation with forward mode
  314. automatic differentiation.
  315. This function is to be overridden by all subclasses.
  316. It must accept a context :attr:`ctx` as the first argument, followed by
  317. as many inputs as the :func:`forward` got (None will be passed in
  318. for non tensor inputs of the forward function),
  319. and it should return as many tensors as there were outputs to
  320. :func:`forward`. Each argument is the gradient w.r.t the given input,
  321. and each returned value should be the gradient w.r.t. the
  322. corresponding output. If an output is not a Tensor or the function is not
  323. differentiable with respect to that output, you can just pass None as a
  324. gradient for that input.
  325. You can use the :attr:`ctx` object to pass any value from the forward to this
  326. functions.
  327. """
  328. raise NotImplementedError("You must implement the jvp function for custom "
  329. "autograd.Function to use it with forward mode AD.")
  330. def once_differentiable(fn):
  331. @functools.wraps(fn)
  332. def wrapper(ctx, *args):
  333. with torch.no_grad():
  334. outputs = fn(ctx, *args)
  335. if not torch.is_grad_enabled():
  336. return outputs
  337. # If any of the inputs have requires_grad=True, we force the outputs
  338. # to have requires_grad=True but point to a grad_fn which throws an
  339. # error message during (double) back-propagation.
  340. # XXX: this is only an approximation of requires_grad - there's no way
  341. # to figure out if fn didn't use ctx.saved_tensors and as a result
  342. # some Tensors might require grad, even if no args do.
  343. # Unfortunately, this leads to unexpected error messages ("no nodes
  344. # require computing gradients"), but I don't have a better idea.
  345. # These functions would raise an error in backward anyway.
  346. requires_grad = any(isinstance(arg, torch.Tensor) and arg.requires_grad
  347. for arg in args)
  348. if not requires_grad:
  349. return outputs
  350. if not isinstance(outputs, tuple):
  351. outputs = (outputs,)
  352. err_fn = _functions.DelayedError(
  353. b"trying to differentiate twice a function that was marked "
  354. b"with @once_differentiable", len(outputs))
  355. # Create aliases of each output that has requires_grad=True. We need
  356. # at least one of the inputs to err_fn to require grad so that the
  357. # output will have a grad_fn.
  358. def fake_requires_grad(var):
  359. if var is not None:
  360. var = var.detach()
  361. var.requires_grad = True
  362. return var
  363. return err_fn(*[fake_requires_grad(v) for v in outputs])
  364. return wrapper
  365. def traceable(fn_cls):
  366. r"""Marks Function as traceable for the JIT.
  367. Traceable functions have additional restrictions - they can't pass any
  368. data-dependent values to backward (e.g. Prod passes the output, which makes
  369. it non-traceable), and their backward should be implemented entirely in terms
  370. of operations on autograd Tensors in all cases.
  371. DON'T USE THIS DECORATOR. IT IS FOR INTERNAL USE ONLY AND SHOULD BE HANDLED WITH
  372. CARE (or can give incorrect results otherwise).
  373. """
  374. fn_cls.is_traceable = True
  375. return fn_cls
  376. class InplaceFunction(Function):
  377. def __init__(self, inplace=False):
  378. super(InplaceFunction, self).__init__()
  379. self.inplace = inplace
  380. def _nested_map(condition, fn, condition_msg=None):
  381. def _map(obj):
  382. if condition(obj):
  383. return fn(obj)
  384. elif obj is None:
  385. return None
  386. elif isinstance(obj, (list, tuple)):
  387. mapped = (_map(x) for x in obj)
  388. if hasattr(obj, '_fields'):
  389. # obj is namedtuple
  390. return type(obj)(*mapped)
  391. return type(obj)(mapped)
  392. elif isinstance(obj, dict):
  393. return {x : _map(obj[x]) for x in obj}
  394. else:
  395. raise ValueError("Auto nesting doesn't know how to process "
  396. "an input object of type " + torch.typename(obj) +
  397. (". Accepted types: " + condition_msg +
  398. ", or lists/tuples of them"
  399. if condition_msg else ""))
  400. return _map
  401. def _jit_unwrap_structured(obj):
  402. if hasattr(obj, "_jit_unwrap"):
  403. return obj._jit_unwrap()
  404. return obj
  405. def _iter_filter(condition, allow_unknown=False, condition_msg=None,
  406. conversion=None):
  407. def _iter(obj):
  408. if conversion is not None:
  409. obj = conversion(obj)
  410. if condition(obj):
  411. yield obj
  412. elif obj is None:
  413. return
  414. elif isinstance(obj, (list, tuple)):
  415. for o in obj:
  416. for var in _iter(o):
  417. yield var
  418. elif isinstance(obj, dict):
  419. # We only accept primitive key types, so we needn't inspect them
  420. for o in obj.values():
  421. for var in _iter(o):
  422. yield var
  423. elif allow_unknown:
  424. yield obj
  425. else:
  426. raise ValueError("Auto nesting doesn't know how to process "
  427. "an input object of type " + torch.typename(obj) +
  428. (". Accepted types: " + condition_msg +
  429. ", or lists/tuples of them"
  430. if condition_msg else ""))
  431. return _iter
  432. def _unflatten(input, proto):
  433. # unflatten a list or tuple input into a nested list/tuple structure
  434. # specified by proto
  435. def unflatten_helper(input, proto):
  436. res: List[Optional[torch.Tensor]] = []
  437. if hasattr(proto, "_jit_wrap"):
  438. return proto._jit_wrap(input)
  439. if not isinstance(proto, (list, tuple)):
  440. return input[0], input[1:]
  441. for e in proto:
  442. if e is None:
  443. res.append(e)
  444. else:
  445. res_e, input = unflatten_helper(input, e)
  446. res.append(res_e)
  447. return type(proto)(res), input
  448. return unflatten_helper(input, proto)[0]
  449. _iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value),
  450. condition_msg="jit's Values or None")
  451. _iter_tensors = _iter_filter(lambda x: isinstance(x, torch.Tensor), condition_msg="Tensors",
  452. conversion=_jit_unwrap_structured)
  453. _iter_tensors_permissive = _iter_filter(lambda x: isinstance(x, torch.Tensor),
  454. allow_unknown=True,
  455. condition_msg="Tensors (permissive)")
  456. _iter_None_tensors = _iter_filter(lambda o: o is None or isinstance(o, torch.Tensor),
  457. condition_msg="Tensors or None")
  458. _map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o: o.data,
  459. condition_msg="Tensors")
  460. class NestedIOFunction(Function):
  461. # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
  462. # superclass (Function) but are instance methods here, which mypy reports as incompatible.
  463. def _do_forward(self, *input):
  464. self._nested_input = input
  465. flat_input = tuple(_iter_tensors(input))
  466. flat_output = super(NestedIOFunction, self)._do_forward(*flat_input)
  467. nested_output = self._nested_output
  468. nested_tensors = _unflatten(flat_output, self._nested_output)
  469. return nested_tensors
  470. def _do_backward(self, gradients, retain_variables):
  471. self.retain_variables = retain_variables
  472. result = super(NestedIOFunction, self)._do_backward(gradients, retain_variables)
  473. if not retain_variables:
  474. del self._nested_output
  475. del self._to_save_nested
  476. return result
  477. def backward(self, *gradients: Any) -> Any: # type: ignore[override]
  478. nested_gradients = _unflatten(gradients, self._nested_output)
  479. result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value]
  480. return tuple(_iter_None_tensors(result))
  481. __call__ = _do_forward
  482. def forward(self, *args: Any) -> Any: # type: ignore[override]
  483. nested_tensors = _map_tensor_data(self._nested_input)
  484. result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value]
  485. del self._nested_input
  486. self._nested_output = result
  487. return tuple(_iter_tensors(result))
  488. def save_for_backward(self, *args: Any) -> None:
  489. self.to_save = tuple(_iter_tensors(args))
  490. self._to_save_nested = args
  491. @property
  492. def saved_tensors(self):
  493. flat_tensors = super(NestedIOFunction, self).saved_tensors
  494. return _unflatten(flat_tensors, self._to_save_nested)
  495. def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
  496. self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
  497. def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
  498. self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
  499. def forward_extended(self, *input: Any) -> None:
  500. raise NotImplementedError
  501. def backward_extended(self, *grad_output: Any) -> None:
  502. raise NotImplementedError