| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975 |
- from collections import OrderedDict, namedtuple
- import itertools
- import warnings
- import functools
- import torch
- from ..parameter import Parameter
- import torch.utils.hooks as hooks
- from torch import Tensor, device, dtype
- from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
- from ...utils.hooks import RemovableHandle
- _grad_t = Union[Tuple[Tensor, ...], Tensor]
- # See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
- # of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
- # the type of the subclass, not the looser type of `Module`.
- T = TypeVar('T', bound='Module')
- class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):
- def __repr__(self):
- if not self.missing_keys and not self.unexpected_keys:
- return '<All keys matched successfully>'
- return super(_IncompatibleKeys, self).__repr__()
- __str__ = __repr__
- def _addindent(s_, numSpaces):
- s = s_.split('\n')
- # don't do anything for single-line stuff
- if len(s) == 1:
- return s_
- first = s.pop(0)
- s = [(numSpaces * ' ') + line for line in s]
- s = '\n'.join(s)
- s = first + '\n' + s
- return s
- r"""This tracks hooks common to all modules that are executed before/after
- calling forward and backward. This is global state used for debugging/profiling
- purposes"""
- _global_backward_hooks: Dict[int, Callable] = OrderedDict()
- _global_is_full_backward_hook: Optional[bool] = None
- _global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
- _global_forward_hooks: Dict[int, Callable] = OrderedDict()
- _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
- def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
- r"""Registers a forward pre-hook common to all modules.
- .. warning ::
- This adds global state to the `nn.module` module
- and it is only intended for debugging/profiling purposes.
- The hook will be called every time before :func:`forward` is invoked.
- It should have the following signature::
- hook(module, input) -> None or modified input
- The input contains only the positional arguments given to the module.
- Keyword arguments won't be passed to the hooks and only to the ``forward``.
- The hook can modify the input. User can either return a tuple or a
- single modified value in the hook. We will wrap the value into a tuple
- if a single value is returned(unless that value is already a tuple).
- This hook has precedence over the specific module hooks registered with
- ``register_forward_pre_hook``.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- handle = hooks.RemovableHandle(_global_forward_pre_hooks)
- _global_forward_pre_hooks[handle.id] = hook
- return handle
- def register_module_forward_hook(hook: Callable[..., None]) -> RemovableHandle:
- r"""Registers a global forward hook for all the modules
- .. warning ::
- This adds global state to the `nn.module` module
- and it is only intended for debugging/profiling purposes.
- The hook will be called every time after :func:`forward` has computed an output.
- It should have the following signature::
- hook(module, input, output) -> None or modified output
- The input contains only the positional arguments given to the module.
- Keyword arguments won't be passed to the hooks and only to the ``forward``.
- The hook can modify the output. It can modify the input inplace but
- it will not have effect on forward since this is called after
- :func:`forward` is called.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- This hook will be executed before specific module hooks registered with
- ``register_forward_hook``.
- """
- handle = hooks.RemovableHandle(_global_forward_hooks)
- _global_forward_hooks[handle.id] = hook
- return handle
- def register_module_backward_hook(
- hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
- ) -> RemovableHandle:
- r"""Registers a backward hook common to all the modules.
- This function is deprecated in favor of
- :func:`torch.nn.modules.module.register_module_full_backward_hook`
- and the behavior of this function will change in future versions.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- global _global_is_full_backward_hook
- if _global_is_full_backward_hook is True:
- raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a "
- "global Module hook. Please use only one of them.")
- _global_is_full_backward_hook = False
- handle = hooks.RemovableHandle(_global_backward_hooks)
- _global_backward_hooks[handle.id] = hook
- return handle
- def register_module_full_backward_hook(
- hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
- ) -> RemovableHandle:
- r"""Registers a backward hook common to all the modules.
- .. warning ::
- This adds global state to the `nn.module` module
- and it is only intended for debugging/profiling purposes.
- The hook will be called every time the gradients with respect to module
- inputs are computed. The hook should have the following signature::
- hook(module, grad_input, grad_output) -> Tensor or None
- The :attr:`grad_input` and :attr:`grad_output` are tuples. The hook should
- not modify its arguments, but it can optionally return a new gradient with
- respect to the input that will be used in place of :attr:`grad_input` in
- subsequent computations. :attr:`grad_input` will only correspond to the inputs given
- as positional arguments and all kwarg arguments will not appear in the hook. Entries
- in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor
- arguments.
- For technical reasons, when this hook is applied to a Module, its forward function will
- receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
- of each Tensor returned by the Module's forward function.
- Global hooks are called before hooks registered with `register_backward_hook`
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- global _global_is_full_backward_hook
- if _global_is_full_backward_hook is False:
- raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a "
- "global Module hook. Please use only one of them.")
- _global_is_full_backward_hook = True
- handle = hooks.RemovableHandle(_global_backward_hooks)
- _global_backward_hooks[handle.id] = hook
- return handle
- # Trick mypy into not applying contravariance rules to inputs by defining
- # forward as a value, rather than a function. See also
- # https://github.com/python/mypy/issues/8795
- def _forward_unimplemented(self, *input: Any) -> None:
- r"""Defines the computation performed at every call.
- Should be overridden by all subclasses.
- .. note::
- Although the recipe for forward pass needs to be defined within
- this function, one should call the :class:`Module` instance afterwards
- instead of this since the former takes care of running the
- registered hooks while the latter silently ignores them.
- """
- raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
- class Module:
- r"""Base class for all neural network modules.
- Your models should also subclass this class.
- Modules can also contain other Modules, allowing to nest them in
- a tree structure. You can assign the submodules as regular attributes::
- import torch.nn as nn
- import torch.nn.functional as F
- class Model(nn.Module):
- def __init__(self):
- super().__init__()
- self.conv1 = nn.Conv2d(1, 20, 5)
- self.conv2 = nn.Conv2d(20, 20, 5)
- def forward(self, x):
- x = F.relu(self.conv1(x))
- return F.relu(self.conv2(x))
- Submodules assigned in this way will be registered, and will have their
- parameters converted too when you call :meth:`to`, etc.
- .. note::
- As per the example above, an ``__init__()`` call to the parent class
- must be made before assignment on the child.
- :ivar training: Boolean represents whether this module is in training or
- evaluation mode.
- :vartype training: bool
- """
- dump_patches: bool = False
- _version: int = 1
- r"""This allows better BC support for :meth:`load_state_dict`. In
- :meth:`state_dict`, the version number will be saved as in the attribute
- `_metadata` of the returned state dict, and thus pickled. `_metadata` is a
- dictionary with keys that follow the naming convention of state dict. See
- ``_load_from_state_dict`` on how to use this information in loading.
- If new parameters/buffers are added/removed from a module, this number shall
- be bumped, and the module's `_load_from_state_dict` method can compare the
- version number and do appropriate changes if the state dict is from before
- the change."""
- training: bool
- _is_full_backward_hook: Optional[bool]
- def __init__(self) -> None:
- """
- Initializes internal Module state, shared by both nn.Module and ScriptModule.
- """
- torch._C._log_api_usage_once("python.nn_module")
- self.training = True
- self._parameters: Dict[str, Optional[Parameter]] = OrderedDict()
- self._buffers: Dict[str, Optional[Tensor]] = OrderedDict()
- self._non_persistent_buffers_set: Set[str] = set()
- self._backward_hooks: Dict[int, Callable] = OrderedDict()
- self._is_full_backward_hook = None
- self._forward_hooks: Dict[int, Callable] = OrderedDict()
- self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
- self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
- self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
- self._load_state_dict_post_hooks: Dict[int, Callable] = OrderedDict()
- self._modules: Dict[str, Optional['Module']] = OrderedDict()
- forward: Callable[..., Any] = _forward_unimplemented
- def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
- r"""Adds a buffer to the module.
- This is typically used to register a buffer that should not to be
- considered a model parameter. For example, BatchNorm's ``running_mean``
- is not a parameter, but is part of the module's state. Buffers, by
- default, are persistent and will be saved alongside parameters. This
- behavior can be changed by setting :attr:`persistent` to ``False``. The
- only difference between a persistent buffer and a non-persistent buffer
- is that the latter will not be a part of this module's
- :attr:`state_dict`.
- Buffers can be accessed as attributes using given names.
- Args:
- name (string): name of the buffer. The buffer can be accessed
- from this module using the given name
- tensor (Tensor or None): buffer to be registered. If ``None``, then operations
- that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,
- the buffer is **not** included in the module's :attr:`state_dict`.
- persistent (bool): whether the buffer is part of this module's
- :attr:`state_dict`.
- Example::
- >>> self.register_buffer('running_mean', torch.zeros(num_features))
- """
- if persistent is False and isinstance(self, torch.jit.ScriptModule):
- raise RuntimeError("ScriptModule does not support non-persistent buffers")
- if '_buffers' not in self.__dict__:
- raise AttributeError(
- "cannot assign buffer before Module.__init__() call")
- elif not isinstance(name, torch._six.string_classes):
- raise TypeError("buffer name should be a string. "
- "Got {}".format(torch.typename(name)))
- elif '.' in name:
- raise KeyError("buffer name can't contain \".\"")
- elif name == '':
- raise KeyError("buffer name can't be empty string \"\"")
- elif hasattr(self, name) and name not in self._buffers:
- raise KeyError("attribute '{}' already exists".format(name))
- elif tensor is not None and not isinstance(tensor, torch.Tensor):
- raise TypeError("cannot assign '{}' object to buffer '{}' "
- "(torch Tensor or None required)"
- .format(torch.typename(tensor), name))
- else:
- self._buffers[name] = tensor
- if persistent:
- self._non_persistent_buffers_set.discard(name)
- else:
- self._non_persistent_buffers_set.add(name)
- def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
- r"""Adds a parameter to the module.
- The parameter can be accessed as an attribute using given name.
- Args:
- name (string): name of the parameter. The parameter can be accessed
- from this module using the given name
- param (Parameter or None): parameter to be added to the module. If
- ``None``, then operations that run on parameters, such as :attr:`cuda`,
- are ignored. If ``None``, the parameter is **not** included in the
- module's :attr:`state_dict`.
- """
- if '_parameters' not in self.__dict__:
- raise AttributeError(
- "cannot assign parameter before Module.__init__() call")
- elif not isinstance(name, torch._six.string_classes):
- raise TypeError("parameter name should be a string. "
- "Got {}".format(torch.typename(name)))
- elif '.' in name:
- raise KeyError("parameter name can't contain \".\"")
- elif name == '':
- raise KeyError("parameter name can't be empty string \"\"")
- elif hasattr(self, name) and name not in self._parameters:
- raise KeyError("attribute '{}' already exists".format(name))
- if param is None:
- self._parameters[name] = None
- elif not isinstance(param, Parameter):
- raise TypeError("cannot assign '{}' object to parameter '{}' "
- "(torch.nn.Parameter or None required)"
- .format(torch.typename(param), name))
- elif param.grad_fn:
- raise ValueError(
- "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
- "parameters must be created explicitly. To express '{0}' "
- "as a function of another Tensor, compute the value in "
- "the forward() method.".format(name))
- else:
- self._parameters[name] = param
- def add_module(self, name: str, module: Optional['Module']) -> None:
- r"""Adds a child module to the current module.
- The module can be accessed as an attribute using the given name.
- Args:
- name (string): name of the child module. The child module can be
- accessed from this module using the given name
- module (Module): child module to be added to the module.
- """
- if not isinstance(module, Module) and module is not None:
- raise TypeError("{} is not a Module subclass".format(
- torch.typename(module)))
- elif not isinstance(name, torch._six.string_classes):
- raise TypeError("module name should be a string. Got {}".format(
- torch.typename(name)))
- elif hasattr(self, name) and name not in self._modules:
- raise KeyError("attribute '{}' already exists".format(name))
- elif '.' in name:
- raise KeyError("module name can't contain \".\", got: {}".format(name))
- elif name == '':
- raise KeyError("module name can't be empty string \"\"")
- self._modules[name] = module
- def register_module(self, name: str, module: Optional['Module']) -> None:
- r"""Alias for :func:`add_module`."""
- self.add_module(name, module)
- def get_submodule(self, target: str) -> "Module":
- """
- Returns the submodule given by ``target`` if it exists,
- otherwise throws an error.
- For example, let's say you have an ``nn.Module`` ``A`` that
- looks like this:
- .. code-block:: text
- A(
- (net_b): Module(
- (net_c): Module(
- (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
- )
- (linear): Linear(in_features=100, out_features=200, bias=True)
- )
- )
- (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
- submodule ``net_b``, which itself has two submodules ``net_c``
- and ``linear``. ``net_c`` then has a submodule ``conv``.)
- To check whether or not we have the ``linear`` submodule, we
- would call ``get_submodule("net_b.linear")``. To check whether
- we have the ``conv`` submodule, we would call
- ``get_submodule("net_b.net_c.conv")``.
- The runtime of ``get_submodule`` is bounded by the degree
- of module nesting in ``target``. A query against
- ``named_modules`` achieves the same result, but it is O(N) in
- the number of transitive modules. So, for a simple check to see
- if some submodule exists, ``get_submodule`` should always be
- used.
- Args:
- target: The fully-qualified string name of the submodule
- to look for. (See above example for how to specify a
- fully-qualified string.)
- Returns:
- torch.nn.Module: The submodule referenced by ``target``
- Raises:
- AttributeError: If the target string references an invalid
- path or resolves to something that is not an
- ``nn.Module``
- """
- if target == "":
- return self
- atoms: List[str] = target.split(".")
- mod: torch.nn.Module = self
- for item in atoms:
- if not hasattr(mod, item):
- raise AttributeError(mod._get_name() + " has no "
- "attribute `" + item + "`")
- mod = getattr(mod, item)
- if not isinstance(mod, torch.nn.Module):
- raise AttributeError("`" + item + "` is not "
- "an nn.Module")
- return mod
- def get_parameter(self, target: str) -> "Parameter":
- """
- Returns the parameter given by ``target`` if it exists,
- otherwise throws an error.
- See the docstring for ``get_submodule`` for a more detailed
- explanation of this method's functionality as well as how to
- correctly specify ``target``.
- Args:
- target: The fully-qualified string name of the Parameter
- to look for. (See ``get_submodule`` for how to specify a
- fully-qualified string.)
- Returns:
- torch.nn.Parameter: The Parameter referenced by ``target``
- Raises:
- AttributeError: If the target string references an invalid
- path or resolves to something that is not an
- ``nn.Parameter``
- """
- module_path, _, param_name = target.rpartition(".")
- mod: torch.nn.Module = self.get_submodule(module_path)
- if not hasattr(mod, param_name):
- raise AttributeError(mod._get_name() + " has no attribute `"
- + param_name + "`")
- param: torch.nn.Parameter = getattr(mod, param_name)
- if not isinstance(param, torch.nn.Parameter):
- raise AttributeError("`" + param_name + "` is not an "
- "nn.Parameter")
- return param
- def get_buffer(self, target: str) -> "Tensor":
- """
- Returns the buffer given by ``target`` if it exists,
- otherwise throws an error.
- See the docstring for ``get_submodule`` for a more detailed
- explanation of this method's functionality as well as how to
- correctly specify ``target``.
- Args:
- target: The fully-qualified string name of the buffer
- to look for. (See ``get_submodule`` for how to specify a
- fully-qualified string.)
- Returns:
- torch.Tensor: The buffer referenced by ``target``
- Raises:
- AttributeError: If the target string references an invalid
- path or resolves to something that is not a
- buffer
- """
- module_path, _, buffer_name = target.rpartition(".")
- mod: torch.nn.Module = self.get_submodule(module_path)
- if not hasattr(mod, buffer_name):
- raise AttributeError(mod._get_name() + " has no attribute `"
- + buffer_name + "`")
- buffer: torch.Tensor = getattr(mod, buffer_name)
- if buffer_name not in mod._buffers:
- raise AttributeError("`" + buffer_name + "` is not a buffer")
- return buffer
- def get_extra_state(self) -> Any:
- """
- Returns any extra state to include in the module's state_dict.
- Implement this and a corresponding :func:`set_extra_state` for your module
- if you need to store extra state. This function is called when building the
- module's `state_dict()`.
- Note that extra state should be pickleable to ensure working serialization
- of the state_dict. We only provide provide backwards compatibility guarantees
- for serializing Tensors; other objects may break backwards compatibility if
- their serialized pickled form changes.
- Returns:
- object: Any extra state to store in the module's state_dict
- """
- raise RuntimeError(
- "Reached a code path in Module.get_extra_state() that should never be called. "
- "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
- "to report this bug.")
- def set_extra_state(self, state: Any):
- """
- This function is called from :func:`load_state_dict` to handle any extra state
- found within the `state_dict`. Implement this function and a corresponding
- :func:`get_extra_state` for your module if you need to store extra state within its
- `state_dict`.
- Args:
- state (dict): Extra state from the `state_dict`
- """
- raise RuntimeError(
- "Reached a code path in Module.set_extra_state() that should never be called. "
- "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
- "to report this bug.")
- def _apply(self, fn):
- for module in self.children():
- module._apply(fn)
- def compute_should_use_set_data(tensor, tensor_applied):
- if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
- # If the new tensor has compatible tensor type as the existing tensor,
- # the current behavior is to change the tensor in-place using `.data =`,
- # and the future behavior is to overwrite the existing tensor. However,
- # changing the current behavior is a BC-breaking change, and we want it
- # to happen in future releases. So for now we introduce the
- # `torch.__future__.get_overwrite_module_params_on_conversion()`
- # global flag to let the user control whether they want the future
- # behavior of overwriting the existing tensor or not.
- return not torch.__future__.get_overwrite_module_params_on_conversion()
- else:
- return False
- for key, param in self._parameters.items():
- if param is None:
- continue
- # Tensors stored in modules are graph leaves, and we don't want to
- # track autograd history of `param_applied`, so we have to use
- # `with torch.no_grad():`
- with torch.no_grad():
- param_applied = fn(param)
- should_use_set_data = compute_should_use_set_data(param, param_applied)
- if should_use_set_data:
- param.data = param_applied
- out_param = param
- else:
- assert isinstance(param, Parameter)
- assert param.is_leaf
- out_param = Parameter(param_applied, param.requires_grad)
- self._parameters[key] = out_param
- if param.grad is not None:
- with torch.no_grad():
- grad_applied = fn(param.grad)
- should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
- if should_use_set_data:
- out_param.grad.data = grad_applied
- else:
- assert param.grad.is_leaf
- out_param.grad = grad_applied.requires_grad_(param.grad.requires_grad)
- for key, buf in self._buffers.items():
- if buf is not None:
- self._buffers[key] = fn(buf)
- return self
- def apply(self: T, fn: Callable[['Module'], None]) -> T:
- r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
- as well as self. Typical use includes initializing the parameters of a model
- (see also :ref:`nn-init-doc`).
- Args:
- fn (:class:`Module` -> None): function to be applied to each submodule
- Returns:
- Module: self
- Example::
- >>> @torch.no_grad()
- >>> def init_weights(m):
- >>> print(m)
- >>> if type(m) == nn.Linear:
- >>> m.weight.fill_(1.0)
- >>> print(m.weight)
- >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
- >>> net.apply(init_weights)
- Linear(in_features=2, out_features=2, bias=True)
- Parameter containing:
- tensor([[ 1., 1.],
- [ 1., 1.]])
- Linear(in_features=2, out_features=2, bias=True)
- Parameter containing:
- tensor([[ 1., 1.],
- [ 1., 1.]])
- Sequential(
- (0): Linear(in_features=2, out_features=2, bias=True)
- (1): Linear(in_features=2, out_features=2, bias=True)
- )
- Sequential(
- (0): Linear(in_features=2, out_features=2, bias=True)
- (1): Linear(in_features=2, out_features=2, bias=True)
- )
- """
- for module in self.children():
- module.apply(fn)
- fn(self)
- return self
- def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
- r"""Moves all model parameters and buffers to the GPU.
- This also makes associated parameters and buffers different objects. So
- it should be called before constructing optimizer if the module will
- live on GPU while being optimized.
- .. note::
- This method modifies the module in-place.
- Args:
- device (int, optional): if specified, all parameters will be
- copied to that device
- Returns:
- Module: self
- """
- return self._apply(lambda t: t.cuda(device))
- def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:
- r"""Moves all model parameters and buffers to the IPU.
- This also makes associated parameters and buffers different objects. So
- it should be called before constructing optimizer if the module will
- live on IPU while being optimized.
- .. note::
- This method modifies the module in-place.
- Arguments:
- device (int, optional): if specified, all parameters will be
- copied to that device
- Returns:
- Module: self
- """
- return self._apply(lambda t: t.ipu(device))
- def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:
- r"""Moves all model parameters and buffers to the XPU.
- This also makes associated parameters and buffers different objects. So
- it should be called before constructing optimizer if the module will
- live on XPU while being optimized.
- .. note::
- This method modifies the module in-place.
- Arguments:
- device (int, optional): if specified, all parameters will be
- copied to that device
- Returns:
- Module: self
- """
- return self._apply(lambda t: t.xpu(device))
- def cpu(self: T) -> T:
- r"""Moves all model parameters and buffers to the CPU.
- .. note::
- This method modifies the module in-place.
- Returns:
- Module: self
- """
- return self._apply(lambda t: t.cpu())
- def type(self: T, dst_type: Union[dtype, str]) -> T:
- r"""Casts all parameters and buffers to :attr:`dst_type`.
- .. note::
- This method modifies the module in-place.
- Args:
- dst_type (type or string): the desired type
- Returns:
- Module: self
- """
- return self._apply(lambda t: t.type(dst_type))
- def float(self: T) -> T:
- r"""Casts all floating point parameters and buffers to ``float`` datatype.
- .. note::
- This method modifies the module in-place.
- Returns:
- Module: self
- """
- return self._apply(lambda t: t.float() if t.is_floating_point() else t)
- def double(self: T) -> T:
- r"""Casts all floating point parameters and buffers to ``double`` datatype.
- .. note::
- This method modifies the module in-place.
- Returns:
- Module: self
- """
- return self._apply(lambda t: t.double() if t.is_floating_point() else t)
- def half(self: T) -> T:
- r"""Casts all floating point parameters and buffers to ``half`` datatype.
- .. note::
- This method modifies the module in-place.
- Returns:
- Module: self
- """
- return self._apply(lambda t: t.half() if t.is_floating_point() else t)
- def bfloat16(self: T) -> T:
- r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype.
- .. note::
- This method modifies the module in-place.
- Returns:
- Module: self
- """
- return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)
- def to_empty(self: T, *, device: Union[str, device]) -> T:
- r"""Moves the parameters and buffers to the specified device without copying storage.
- Args:
- device (:class:`torch.device`): The desired device of the parameters
- and buffers in this module.
- Returns:
- Module: self
- """
- return self._apply(lambda t: torch.empty_like(t, device=device))
- @overload
- def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
- non_blocking: bool = ...) -> T:
- ...
- @overload
- def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
- ...
- @overload
- def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
- ...
- def to(self, *args, **kwargs):
- r"""Moves and/or casts the parameters and buffers.
- This can be called as
- .. function:: to(device=None, dtype=None, non_blocking=False)
- :noindex:
- .. function:: to(dtype, non_blocking=False)
- :noindex:
- .. function:: to(tensor, non_blocking=False)
- :noindex:
- .. function:: to(memory_format=torch.channels_last)
- :noindex:
- Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
- floating point or complex :attr:`dtype`\ s. In addition, this method will
- only cast the floating point or complex parameters and buffers to :attr:`dtype`
- (if given). The integral parameters and buffers will be moved
- :attr:`device`, if that is given, but with dtypes unchanged. When
- :attr:`non_blocking` is set, it tries to convert/move asynchronously
- with respect to the host if possible, e.g., moving CPU Tensors with
- pinned memory to CUDA devices.
- See below for examples.
- .. note::
- This method modifies the module in-place.
- Args:
- device (:class:`torch.device`): the desired device of the parameters
- and buffers in this module
- dtype (:class:`torch.dtype`): the desired floating point or complex dtype of
- the parameters and buffers in this module
- tensor (torch.Tensor): Tensor whose dtype and device are the desired
- dtype and device for all parameters and buffers in this module
- memory_format (:class:`torch.memory_format`): the desired memory
- format for 4D parameters and buffers in this module (keyword
- only argument)
- Returns:
- Module: self
- Examples::
- >>> linear = nn.Linear(2, 2)
- >>> linear.weight
- Parameter containing:
- tensor([[ 0.1913, -0.3420],
- [-0.5113, -0.2325]])
- >>> linear.to(torch.double)
- Linear(in_features=2, out_features=2, bias=True)
- >>> linear.weight
- Parameter containing:
- tensor([[ 0.1913, -0.3420],
- [-0.5113, -0.2325]], dtype=torch.float64)
- >>> gpu1 = torch.device("cuda:1")
- >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
- Linear(in_features=2, out_features=2, bias=True)
- >>> linear.weight
- Parameter containing:
- tensor([[ 0.1914, -0.3420],
- [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
- >>> cpu = torch.device("cpu")
- >>> linear.to(cpu)
- Linear(in_features=2, out_features=2, bias=True)
- >>> linear.weight
- Parameter containing:
- tensor([[ 0.1914, -0.3420],
- [-0.5112, -0.2324]], dtype=torch.float16)
- >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
- >>> linear.weight
- Parameter containing:
- tensor([[ 0.3741+0.j, 0.2382+0.j],
- [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
- >>> linear(torch.ones(3, 2, dtype=torch.cdouble))
- tensor([[0.6122+0.j, 0.1150+0.j],
- [0.6122+0.j, 0.1150+0.j],
- [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
- """
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
- if dtype is not None:
- if not (dtype.is_floating_point or dtype.is_complex):
- raise TypeError('nn.Module.to only accepts floating point or complex '
- 'dtypes, but got desired dtype={}'.format(dtype))
- if dtype.is_complex:
- warnings.warn(
- "Complex modules are a new feature under active development whose design may change, "
- "and some modules might not work as expected when using complex tensors as parameters or buffers. "
- "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
- "if a complex module does not work as expected.")
- def convert(t):
- if convert_to_format is not None and t.dim() in (4, 5):
- return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
- non_blocking, memory_format=convert_to_format)
- return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
- return self._apply(convert)
- def register_backward_hook(
- self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
- ) -> RemovableHandle:
- r"""Registers a backward hook on the module.
- This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and
- the behavior of this function will change in future versions.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- if self._is_full_backward_hook is True:
- raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "
- "single Module. Please use only one of them.")
- self._is_full_backward_hook = False
- handle = hooks.RemovableHandle(self._backward_hooks)
- self._backward_hooks[handle.id] = hook
- return handle
- def register_full_backward_hook(
- self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
- ) -> RemovableHandle:
- r"""Registers a backward hook on the module.
- The hook will be called every time the gradients with respect to module
- inputs are computed. The hook should have the following signature::
- hook(module, grad_input, grad_output) -> tuple(Tensor) or None
- The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients
- with respect to the inputs and outputs respectively. The hook should
- not modify its arguments, but it can optionally return a new gradient with
- respect to the input that will be used in place of :attr:`grad_input` in
- subsequent computations. :attr:`grad_input` will only correspond to the inputs given
- as positional arguments and all kwarg arguments are ignored. Entries
- in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor
- arguments.
- For technical reasons, when this hook is applied to a Module, its forward function will
- receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
- of each Tensor returned by the Module's forward function.
- .. warning ::
- Modifying inputs or outputs inplace is not allowed when using backward hooks and
- will raise an error.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- if self._is_full_backward_hook is False:
- raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "
- "single Module. Please use only one of them.")
- self._is_full_backward_hook = True
- handle = hooks.RemovableHandle(self._backward_hooks)
- self._backward_hooks[handle.id] = hook
- return handle
- def _get_backward_hooks(self):
- r"""Returns the backward hooks for use in the call function.
- It returns two lists, one with the full backward hooks and one with the non-full
- backward hooks.
- """
- full_backward_hooks: List[Callable] = []
- if (_global_is_full_backward_hook is True):
- full_backward_hooks += _global_backward_hooks.values()
- if (self._is_full_backward_hook is True):
- full_backward_hooks += self._backward_hooks.values()
- non_full_backward_hooks: List[Callable] = []
- if (_global_is_full_backward_hook is False):
- non_full_backward_hooks += _global_backward_hooks.values()
- if (self._is_full_backward_hook is False):
- non_full_backward_hooks += self._backward_hooks.values()
- return full_backward_hooks, non_full_backward_hooks
- def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn):
- if not isinstance(result, torch.Tensor):
- if not (isinstance(result, tuple) and all([isinstance(r, torch.Tensor) for r in result])):
- warnings.warn("Using non-full backward hooks on a Module that does not return a "
- "single Tensor or a tuple of Tensors is deprecated and will be removed "
- "in future versions. This hook will be missing some of the grad_output. "
- "Please use register_full_backward_hook to get the documented behavior.")
- return
- else:
- result = (result,)
- if not isinstance(inputs, torch.Tensor):
- if not (isinstance(inputs, tuple) and all([isinstance(i, torch.Tensor) for i in inputs])):
- warnings.warn("Using non-full backward hooks on a Module that does not take as input a "
- "single Tensor or a tuple of Tensors is deprecated and will be removed "
- "in future versions. This hook will be missing some of the grad_input. "
- "Please use register_full_backward_hook to get the documented behavior.")
- return
- else:
- inputs = (inputs,)
- # At this point we are sure that inputs and result are tuple of Tensors
- out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None}
- if len(out_grad_fn) == 0 or (len(out_grad_fn) == 1 and grad_fn not in out_grad_fn):
- warnings.warn("Using a non-full backward hook when outputs are nested in python data structure "
- "is deprecated and will be removed in future versions. This hook will be missing "
- "some grad_output.")
- elif len(out_grad_fn) > 1:
- warnings.warn("Using a non-full backward hook when outputs are generated by different autograd Nodes "
- "is deprecated and will be removed in future versions. This hook will be missing "
- "some grad_output. Please use register_full_backward_hook to get the documented behavior.")
- else:
- # At this point the grad_ouput part of the hook will most likely be correct
- inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None}
- next_functions = {n[0] for n in grad_fn.next_functions}
- if inputs_grad_fn != next_functions:
- warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
- "is deprecated and will be removed in future versions. This hook will be missing "
- "some grad_input. Please use register_full_backward_hook to get the documented "
- "behavior.")
- def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
- r"""Registers a forward pre-hook on the module.
- The hook will be called every time before :func:`forward` is invoked.
- It should have the following signature::
- hook(module, input) -> None or modified input
- The input contains only the positional arguments given to the module.
- Keyword arguments won't be passed to the hooks and only to the ``forward``.
- The hook can modify the input. User can either return a tuple or a
- single modified value in the hook. We will wrap the value into a tuple
- if a single value is returned(unless that value is already a tuple).
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- handle = hooks.RemovableHandle(self._forward_pre_hooks)
- self._forward_pre_hooks[handle.id] = hook
- return handle
- def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
- r"""Registers a forward hook on the module.
- The hook will be called every time after :func:`forward` has computed an output.
- It should have the following signature::
- hook(module, input, output) -> None or modified output
- The input contains only the positional arguments given to the module.
- Keyword arguments won't be passed to the hooks and only to the ``forward``.
- The hook can modify the output. It can modify the input inplace but
- it will not have effect on forward since this is called after
- :func:`forward` is called.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- handle = hooks.RemovableHandle(self._forward_hooks)
- self._forward_hooks[handle.id] = hook
- return handle
- def _slow_forward(self, *input, **kwargs):
- tracing_state = torch._C._get_tracing_state()
- if not tracing_state or isinstance(self.forward, torch._C.ScriptMethod):
- return self.forward(*input, **kwargs)
- recording_scopes = torch.jit._trace._trace_module_map is not None
- if recording_scopes:
- # type ignore was added because at this point one knows that
- # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
- name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950
- if name:
- tracing_state.push_scope(name)
- else:
- recording_scopes = False
- try:
- result = self.forward(*input, **kwargs)
- finally:
- if recording_scopes:
- tracing_state.pop_scope()
- return result
- def _call_impl(self, *input, **kwargs):
- forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
- # If we don't have any hooks, we want to skip the rest of the logic in
- # this function, and just call forward.
- if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
- or _global_forward_hooks or _global_forward_pre_hooks):
- return forward_call(*input, **kwargs)
- # Do not call functions when jit is used
- full_backward_hooks, non_full_backward_hooks = [], []
- if self._backward_hooks or _global_backward_hooks:
- full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
- if _global_forward_pre_hooks or self._forward_pre_hooks:
- for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
- result = hook(self, input)
- if result is not None:
- if not isinstance(result, tuple):
- result = (result,)
- input = result
- bw_hook = None
- if full_backward_hooks:
- bw_hook = hooks.BackwardHook(self, full_backward_hooks)
- input = bw_hook.setup_input_hook(input)
- result = forward_call(*input, **kwargs)
- if _global_forward_hooks or self._forward_hooks:
- for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
- hook_result = hook(self, input, result)
- if hook_result is not None:
- result = hook_result
- if bw_hook:
- result = bw_hook.setup_output_hook(result)
- # Handle the non-full backward hooks
- if non_full_backward_hooks:
- var = result
- while not isinstance(var, torch.Tensor):
- if isinstance(var, dict):
- var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
- else:
- var = var[0]
- grad_fn = var.grad_fn
- if grad_fn is not None:
- for hook in non_full_backward_hooks:
- wrapper = functools.partial(hook, self)
- functools.update_wrapper(wrapper, hook)
- grad_fn.register_hook(wrapper)
- self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
- return result
- __call__ : Callable[..., Any] = _call_impl
- def __setstate__(self, state):
- self.__dict__.update(state)
- # Support loading old checkpoints that don't have the following attrs:
- if '_forward_pre_hooks' not in self.__dict__:
- self._forward_pre_hooks = OrderedDict()
- if '_state_dict_hooks' not in self.__dict__:
- self._state_dict_hooks = OrderedDict()
- if '_load_state_dict_pre_hooks' not in self.__dict__:
- self._load_state_dict_pre_hooks = OrderedDict()
- if '_load_state_dict_post_hooks' not in self.__dict__:
- self._load_state_dict_post_hooks = OrderedDict()
- if '_non_persistent_buffers_set' not in self.__dict__:
- self._non_persistent_buffers_set = set()
- if '_is_full_backward_hook' not in self.__dict__:
- self._is_full_backward_hook = None
- def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
- if '_parameters' in self.__dict__:
- _parameters = self.__dict__['_parameters']
- if name in _parameters:
- return _parameters[name]
- if '_buffers' in self.__dict__:
- _buffers = self.__dict__['_buffers']
- if name in _buffers:
- return _buffers[name]
- if '_modules' in self.__dict__:
- modules = self.__dict__['_modules']
- if name in modules:
- return modules[name]
- raise AttributeError("'{}' object has no attribute '{}'".format(
- type(self).__name__, name))
- def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
- def remove_from(*dicts_or_sets):
- for d in dicts_or_sets:
- if name in d:
- if isinstance(d, dict):
- del d[name]
- else:
- d.discard(name)
- params = self.__dict__.get('_parameters')
- if isinstance(value, Parameter):
- if params is None:
- raise AttributeError(
- "cannot assign parameters before Module.__init__() call")
- remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
- self.register_parameter(name, value)
- elif params is not None and name in params:
- if value is not None:
- raise TypeError("cannot assign '{}' as parameter '{}' "
- "(torch.nn.Parameter or None expected)"
- .format(torch.typename(value), name))
- self.register_parameter(name, value)
- else:
- modules = self.__dict__.get('_modules')
- if isinstance(value, Module):
- if modules is None:
- raise AttributeError(
- "cannot assign module before Module.__init__() call")
- remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
- modules[name] = value
- elif modules is not None and name in modules:
- if value is not None:
- raise TypeError("cannot assign '{}' as child module '{}' "
- "(torch.nn.Module or None expected)"
- .format(torch.typename(value), name))
- modules[name] = value
- else:
- buffers = self.__dict__.get('_buffers')
- if buffers is not None and name in buffers:
- if value is not None and not isinstance(value, torch.Tensor):
- raise TypeError("cannot assign '{}' as buffer '{}' "
- "(torch.Tensor or None expected)"
- .format(torch.typename(value), name))
- buffers[name] = value
- else:
- object.__setattr__(self, name, value)
- def __delattr__(self, name):
- if name in self._parameters:
- del self._parameters[name]
- elif name in self._buffers:
- del self._buffers[name]
- self._non_persistent_buffers_set.discard(name)
- elif name in self._modules:
- del self._modules[name]
- else:
- object.__delattr__(self, name)
- def _register_state_dict_hook(self, hook):
- r"""These hooks will be called with arguments: `self`, `state_dict`,
- `prefix`, `local_metadata`, after the `state_dict` of `self` is set.
- Note that only parameters and buffers of `self` or its children are
- guaranteed to exist in `state_dict`. The hooks may modify `state_dict`
- inplace or return a new one.
- """
- handle = hooks.RemovableHandle(self._state_dict_hooks)
- self._state_dict_hooks[handle.id] = hook
- return handle
- def _save_to_state_dict(self, destination, prefix, keep_vars):
- r"""Saves module state to `destination` dictionary, containing a state
- of the module, but not its descendants. This is called on every
- submodule in :meth:`~torch.nn.Module.state_dict`.
- In rare cases, subclasses can achieve class-specific behavior by
- overriding this method with custom logic.
- Args:
- destination (dict): a dict where state will be stored
- prefix (str): the prefix for parameters and buffers used in this
- module
- """
- for name, param in self._parameters.items():
- if param is not None:
- destination[prefix + name] = param if keep_vars else param.detach()
- for name, buf in self._buffers.items():
- if buf is not None and name not in self._non_persistent_buffers_set:
- destination[prefix + name] = buf if keep_vars else buf.detach()
- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
- if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
- destination[extra_state_key] = self.get_extra_state()
- # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
- # back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
- T_destination = TypeVar('T_destination', bound=Dict[str, Any])
- @overload
- def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
- ...
- @overload
- def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
- ...
- # TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows.
- # Also remove the logic for arg parsing together.
- def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
- r"""Returns a dictionary containing a whole state of the module.
- Both parameters and persistent buffers (e.g. running averages) are
- included. Keys are corresponding parameter and buffer names.
- Parameters and buffers set to ``None`` are not included.
- .. warning::
- Currently ``state_dict()`` also accepts positional arguments for
- ``destination``, ``prefix`` and ``keep_vars`` in order. However,
- this is being deprecated and keyword arguments will be enforced in
- future releases.
- .. warning::
- Please avoid the use of argument ``destination`` as it is not
- designed for end-users.
- Args:
- destination (dict, optional): If provided, the state of module will
- be updated into the dict and the same object is returned.
- Otherwise, an ``OrderedDict`` will be created and returned.
- Default: ``None``.
- prefix (str, optional): a prefix added to parameter and buffer
- names to compose the keys in state_dict. Default: ``''``.
- keep_vars (bool, optional): by default the :class:`~torch.Tensor` s
- returned in the state dict are detached from autograd. If it's
- set to ``True``, detaching will not be performed.
- Default: ``False``.
- Returns:
- dict:
- a dictionary containing a whole state of the module
- Example::
- >>> module.state_dict().keys()
- ['bias', 'weight']
- """
- # TODO: Remove `args` and the parsing logic when BC allows.
- if len(args) > 0:
- if destination is None:
- destination = args[0]
- if len(args) > 1 and prefix == '':
- prefix = args[1]
- if len(args) > 2 and keep_vars is False:
- keep_vars = args[2]
- # DeprecationWarning is ignored by default
- warnings.warn(
- "Positional args are being deprecated, use kwargs instead. Refer to "
- "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
- " for details.")
- if destination is None:
- destination = OrderedDict()
- destination._metadata = OrderedDict()
- local_metadata = dict(version=self._version)
- if hasattr(destination, "_metadata"):
- destination._metadata[prefix[:-1]] = local_metadata
- self._save_to_state_dict(destination, prefix, keep_vars)
- for name, module in self._modules.items():
- if module is not None:
- module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
- for hook in self._state_dict_hooks.values():
- hook_result = hook(self, destination, prefix, local_metadata)
- if hook_result is not None:
- destination = hook_result
- return destination
- def _register_load_state_dict_pre_hook(self, hook, with_module=False):
- r"""These hooks will be called with arguments: `state_dict`, `prefix`,
- `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`,
- `error_msgs`, before loading `state_dict` into `self`. These arguments
- are exactly the same as those of `_load_from_state_dict`.
- If ``with_module`` is ``True``, then the first argument to the hook is
- an instance of the module.
- Arguments:
- hook (Callable): Callable hook that will be invoked before
- loading the state dict.
- with_module (bool, optional): Whether or not to pass the module
- instance to the hook as the first parameter.
- """
- handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
- if with_module:
- hook = functools.partial(hook, self)
- self._load_state_dict_pre_hooks[handle.id] = hook
- return handle
- def register_load_state_dict_post_hook(self, hook):
- r"""Registers a post hook to be run after module's ``load_state_dict``
- is called.
- It should have the following signature::
- hook(module, incompatible_keys) -> None
- The ``module`` argument is the current module that this hook is registered
- on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
- of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
- is a ``list`` of ``str`` containing the missing keys and
- ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
- The given incompatible_keys can be modified inplace if needed.
- Note that the checks performed when calling :func:`load_state_dict` with
- ``strict=True`` are affected by modifications the hook makes to
- ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
- set of keys will result in an error being thrown when ``strict=True``, and
- clearning out both missing and unexpected keys will avoid an error.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- handle = hooks.RemovableHandle(self._load_state_dict_post_hooks)
- self._load_state_dict_post_hooks[handle.id] = hook
- return handle
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- r"""Copies parameters and buffers from :attr:`state_dict` into only
- this module, but not its descendants. This is called on every submodule
- in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
- module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
- For state dicts without metadata, :attr:`local_metadata` is empty.
- Subclasses can achieve class-specific backward compatible loading using
- the version number at `local_metadata.get("version", None)`.
- .. note::
- :attr:`state_dict` is not the same object as the input
- :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
- it can be modified.
- Args:
- state_dict (dict): a dict containing parameters and
- persistent buffers.
- prefix (str): the prefix for parameters and buffers used in this
- module
- local_metadata (dict): a dict containing the metadata for this module.
- See
- strict (bool): whether to strictly enforce that the keys in
- :attr:`state_dict` with :attr:`prefix` match the names of
- parameters and buffers in this module
- missing_keys (list of str): if ``strict=True``, add missing keys to
- this list
- unexpected_keys (list of str): if ``strict=True``, add unexpected
- keys to this list
- error_msgs (list of str): error messages should be added to this
- list, and will be reported together in
- :meth:`~torch.nn.Module.load_state_dict`
- """
- for hook in self._load_state_dict_pre_hooks.values():
- hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
- persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
- local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
- local_state = {k: v for k, v in local_name_params if v is not None}
- for name, param in local_state.items():
- key = prefix + name
- if key in state_dict:
- input_param = state_dict[key]
- if not torch.overrides.is_tensor_like(input_param):
- error_msgs.append('While copying the parameter named "{}", '
- 'expected torch.Tensor or Tensor-like object from checkpoint but '
- 'received {}'
- .format(key, type(input_param)))
- continue
- # This is used to avoid copying uninitialized parameters into
- # non-lazy modules, since they dont have the hook to do the checks
- # in such case, it will error when accessing the .shape attribute.
- is_param_lazy = torch.nn.parameter.is_lazy(param)
- # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
- if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
- input_param = input_param[0]
- if not is_param_lazy and input_param.shape != param.shape:
- # local shape should match the one in checkpoint
- error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
- 'the shape in current model is {}.'
- .format(key, input_param.shape, param.shape))
- continue
- try:
- with torch.no_grad():
- param.copy_(input_param)
- except Exception as ex:
- error_msgs.append('While copying the parameter named "{}", '
- 'whose dimensions in the model are {} and '
- 'whose dimensions in the checkpoint are {}, '
- 'an exception occurred : {}.'
- .format(key, param.size(), input_param.size(), ex.args))
- elif strict:
- missing_keys.append(key)
- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
- if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
- if extra_state_key in state_dict:
- self.set_extra_state(state_dict[extra_state_key])
- elif strict:
- missing_keys.append(extra_state_key)
- elif strict and (extra_state_key in state_dict):
- unexpected_keys.append(extra_state_key)
- if strict:
- for key in state_dict.keys():
- if key.startswith(prefix) and key != extra_state_key:
- input_name = key[len(prefix):]
- input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
- if input_name not in self._modules and input_name not in local_state:
- unexpected_keys.append(key)
- def load_state_dict(self, state_dict: Mapping[str, Any],
- strict: bool = True):
- r"""Copies parameters and buffers from :attr:`state_dict` into
- this module and its descendants. If :attr:`strict` is ``True``, then
- the keys of :attr:`state_dict` must exactly match the keys returned
- by this module's :meth:`~torch.nn.Module.state_dict` function.
- Args:
- state_dict (dict): a dict containing parameters and
- persistent buffers.
- strict (bool, optional): whether to strictly enforce that the keys
- in :attr:`state_dict` match the keys returned by this module's
- :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
- Returns:
- ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
- * **missing_keys** is a list of str containing the missing keys
- * **unexpected_keys** is a list of str containing the unexpected keys
- Note:
- If a parameter or buffer is registered as ``None`` and its corresponding key
- exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
- ``RuntimeError``.
- """
- if not isinstance(state_dict, Mapping):
- raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
- missing_keys: List[str] = []
- unexpected_keys: List[str] = []
- error_msgs: List[str] = []
- # copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, '_metadata', None)
- state_dict = OrderedDict(state_dict)
- if metadata is not None:
- # mypy isn't aware that "_metadata" exists in state_dict
- state_dict._metadata = metadata # type: ignore[attr-defined]
- def load(module, prefix=''):
- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
- module._load_from_state_dict(
- state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + '.')
- # Note that the hook can modify missing_keys and unexpected_keys.
- incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
- for hook in module._load_state_dict_post_hooks.values():
- out = hook(module, incompatible_keys)
- assert out is None, (
- "Hooks registered with ``register_load_state_dict_post_hook`` are not"
- "expected to return new values, if incompatible_keys need to be modified,"
- "it should be done inplace."
- )
- load(self)
- del load
- if strict:
- if len(unexpected_keys) > 0:
- error_msgs.insert(
- 0, 'Unexpected key(s) in state_dict: {}. '.format(
- ', '.join('"{}"'.format(k) for k in unexpected_keys)))
- if len(missing_keys) > 0:
- error_msgs.insert(
- 0, 'Missing key(s) in state_dict: {}. '.format(
- ', '.join('"{}"'.format(k) for k in missing_keys)))
- if len(error_msgs) > 0:
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
- self.__class__.__name__, "\n\t".join(error_msgs)))
- return _IncompatibleKeys(missing_keys, unexpected_keys)
- def _named_members(self, get_members_fn, prefix='', recurse=True):
- r"""Helper method for yielding various names + members of modules."""
- memo = set()
- modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
- for module_prefix, module in modules:
- members = get_members_fn(module)
- for k, v in members:
- if v is None or v in memo:
- continue
- memo.add(v)
- name = module_prefix + ('.' if module_prefix else '') + k
- yield name, v
- def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
- r"""Returns an iterator over module parameters.
- This is typically passed to an optimizer.
- Args:
- recurse (bool): if True, then yields parameters of this module
- and all submodules. Otherwise, yields only parameters that
- are direct members of this module.
- Yields:
- Parameter: module parameter
- Example::
- >>> for param in model.parameters():
- >>> print(type(param), param.size())
- <class 'torch.Tensor'> (20L,)
- <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- """
- for name, param in self.named_parameters(recurse=recurse):
- yield param
- def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
- r"""Returns an iterator over module parameters, yielding both the
- name of the parameter as well as the parameter itself.
- Args:
- prefix (str): prefix to prepend to all parameter names.
- recurse (bool): if True, then yields parameters of this module
- and all submodules. Otherwise, yields only parameters that
- are direct members of this module.
- Yields:
- (string, Parameter): Tuple containing the name and parameter
- Example::
- >>> for name, param in self.named_parameters():
- >>> if name in ['bias']:
- >>> print(param.size())
- """
- gen = self._named_members(
- lambda module: module._parameters.items(),
- prefix=prefix, recurse=recurse)
- for elem in gen:
- yield elem
- def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
- r"""Returns an iterator over module buffers.
- Args:
- recurse (bool): if True, then yields buffers of this module
- and all submodules. Otherwise, yields only buffers that
- are direct members of this module.
- Yields:
- torch.Tensor: module buffer
- Example::
- >>> for buf in model.buffers():
- >>> print(type(buf), buf.size())
- <class 'torch.Tensor'> (20L,)
- <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- """
- for _, buf in self.named_buffers(recurse=recurse):
- yield buf
- def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
- r"""Returns an iterator over module buffers, yielding both the
- name of the buffer as well as the buffer itself.
- Args:
- prefix (str): prefix to prepend to all buffer names.
- recurse (bool): if True, then yields buffers of this module
- and all submodules. Otherwise, yields only buffers that
- are direct members of this module.
- Yields:
- (string, torch.Tensor): Tuple containing the name and buffer
- Example::
- >>> for name, buf in self.named_buffers():
- >>> if name in ['running_var']:
- >>> print(buf.size())
- """
- gen = self._named_members(
- lambda module: module._buffers.items(),
- prefix=prefix, recurse=recurse)
- for elem in gen:
- yield elem
- def children(self) -> Iterator['Module']:
- r"""Returns an iterator over immediate children modules.
- Yields:
- Module: a child module
- """
- for name, module in self.named_children():
- yield module
- def named_children(self) -> Iterator[Tuple[str, 'Module']]:
- r"""Returns an iterator over immediate children modules, yielding both
- the name of the module as well as the module itself.
- Yields:
- (string, Module): Tuple containing a name and child module
- Example::
- >>> for name, module in model.named_children():
- >>> if name in ['conv4', 'conv5']:
- >>> print(module)
- """
- memo = set()
- for name, module in self._modules.items():
- if module is not None and module not in memo:
- memo.add(module)
- yield name, module
- def modules(self) -> Iterator['Module']:
- r"""Returns an iterator over all modules in the network.
- Yields:
- Module: a module in the network
- Note:
- Duplicate modules are returned only once. In the following
- example, ``l`` will be returned only once.
- Example::
- >>> l = nn.Linear(2, 2)
- >>> net = nn.Sequential(l, l)
- >>> for idx, m in enumerate(net.modules()):
- print(idx, '->', m)
- 0 -> Sequential(
- (0): Linear(in_features=2, out_features=2, bias=True)
- (1): Linear(in_features=2, out_features=2, bias=True)
- )
- 1 -> Linear(in_features=2, out_features=2, bias=True)
- """
- for _, module in self.named_modules():
- yield module
- def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
- r"""Returns an iterator over all modules in the network, yielding
- both the name of the module as well as the module itself.
- Args:
- memo: a memo to store the set of modules already added to the result
- prefix: a prefix that will be added to the name of the module
- remove_duplicate: whether to remove the duplicated module instances in the result
- or not
- Yields:
- (string, Module): Tuple of name and module
- Note:
- Duplicate modules are returned only once. In the following
- example, ``l`` will be returned only once.
- Example::
- >>> l = nn.Linear(2, 2)
- >>> net = nn.Sequential(l, l)
- >>> for idx, m in enumerate(net.named_modules()):
- print(idx, '->', m)
- 0 -> ('', Sequential(
- (0): Linear(in_features=2, out_features=2, bias=True)
- (1): Linear(in_features=2, out_features=2, bias=True)
- ))
- 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
- """
- if memo is None:
- memo = set()
- if self not in memo:
- if remove_duplicate:
- memo.add(self)
- yield prefix, self
- for name, module in self._modules.items():
- if module is None:
- continue
- submodule_prefix = prefix + ('.' if prefix else '') + name
- for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
- yield m
- def train(self: T, mode: bool = True) -> T:
- r"""Sets the module in training mode.
- This has any effect only on certain modules. See documentations of
- particular modules for details of their behaviors in training/evaluation
- mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
- etc.
- Args:
- mode (bool): whether to set training mode (``True``) or evaluation
- mode (``False``). Default: ``True``.
- Returns:
- Module: self
- """
- if not isinstance(mode, bool):
- raise ValueError("training mode is expected to be boolean")
- self.training = mode
- for module in self.children():
- module.train(mode)
- return self
- def eval(self: T) -> T:
- r"""Sets the module in evaluation mode.
- This has any effect only on certain modules. See documentations of
- particular modules for details of their behaviors in training/evaluation
- mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
- etc.
- This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
- See :ref:`locally-disable-grad-doc` for a comparison between
- `.eval()` and several similar mechanisms that may be confused with it.
- Returns:
- Module: self
- """
- return self.train(False)
- def requires_grad_(self: T, requires_grad: bool = True) -> T:
- r"""Change if autograd should record operations on parameters in this
- module.
- This method sets the parameters' :attr:`requires_grad` attributes
- in-place.
- This method is helpful for freezing part of the module for finetuning
- or training parts of a model individually (e.g., GAN training).
- See :ref:`locally-disable-grad-doc` for a comparison between
- `.requires_grad_()` and several similar mechanisms that may be confused with it.
- Args:
- requires_grad (bool): whether autograd should record operations on
- parameters in this module. Default: ``True``.
- Returns:
- Module: self
- """
- for p in self.parameters():
- p.requires_grad_(requires_grad)
- return self
- def zero_grad(self, set_to_none: bool = False) -> None:
- r"""Sets gradients of all model parameters to zero. See similar function
- under :class:`torch.optim.Optimizer` for more context.
- Args:
- set_to_none (bool): instead of setting to zero, set the grads to None.
- See :meth:`torch.optim.Optimizer.zero_grad` for details.
- """
- if getattr(self, '_is_replica', False):
- warnings.warn(
- "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
- "The parameters are copied (in a differentiable manner) from the original module. "
- "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
- "If you need gradients in your forward method, consider using autograd.grad instead.")
- for p in self.parameters():
- if p.grad is not None:
- if set_to_none:
- p.grad = None
- else:
- if p.grad.grad_fn is not None:
- p.grad.detach_()
- else:
- p.grad.requires_grad_(False)
- p.grad.zero_()
- def share_memory(self: T) -> T:
- r"""See :meth:`torch.Tensor.share_memory_`"""
- return self._apply(lambda t: t.share_memory_())
- def _get_name(self):
- return self.__class__.__name__
- def extra_repr(self) -> str:
- r"""Set the extra representation of the module
- To print customized extra information, you should re-implement
- this method in your own modules. Both single-line and multi-line
- strings are acceptable.
- """
- return ''
- def __repr__(self):
- # We treat the extra repr like the sub-module, one item per line
- extra_lines = []
- extra_repr = self.extra_repr()
- # empty string will be split into list ['']
- if extra_repr:
- extra_lines = extra_repr.split('\n')
- child_lines = []
- for key, module in self._modules.items():
- mod_str = repr(module)
- mod_str = _addindent(mod_str, 2)
- child_lines.append('(' + key + '): ' + mod_str)
- lines = extra_lines + child_lines
- main_str = self._get_name() + '('
- if lines:
- # simple one-liner info, which most builtin Modules will use
- if len(extra_lines) == 1 and not child_lines:
- main_str += extra_lines[0]
- else:
- main_str += '\n ' + '\n '.join(lines) + '\n'
- main_str += ')'
- return main_str
- def __dir__(self):
- module_attrs = dir(self.__class__)
- attrs = list(self.__dict__.keys())
- parameters = list(self._parameters.keys())
- modules = list(self._modules.keys())
- buffers = list(self._buffers.keys())
- keys = module_attrs + attrs + parameters + modules + buffers
- # Eliminate attrs that are not legal Python variable names
- keys = [key for key in keys if not key[0].isdigit()]
- return sorted(keys)
- def _replicate_for_data_parallel(self):
- replica = self.__new__(type(self))
- replica.__dict__ = self.__dict__.copy()
- # replicas do not have parameters themselves, the replicas reference the original
- # module.
- replica._parameters = OrderedDict()
- replica._buffers = replica._buffers.copy()
- replica._modules = replica._modules.copy()
- replica._is_replica = True # type: ignore[assignment]
- return replica
|