| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- """Functionality for Python <-> C++ frontend inter-op."""
- from torch import nn
- class OrderedDictWrapper(object):
- """
- A wrapper around a C++ OrderedDict that dynamically evaluates the
- OrderedDict getter on a bound C++ module, such that new changes on the C++
- side are picked up. Otherwise accessing e.g. ``cpp_module._parameters`` just
- once would get a frozen copy of the parameters at the time of access.
- ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` so
- using properties does not work.
- """
- def __init__(self, cpp_module, attr):
- self.cpp_module = cpp_module
- self.attr = attr
- @property
- def cpp_dict(self):
- return getattr(self.cpp_module, self.attr)
- # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
- # must manually override them.
- def items(self):
- return self.cpp_dict.items()
- def keys(self):
- return self.cpp_dict.keys()
- def values(self):
- return self.cpp_dict.values()
- def __iter__(self):
- return self.cpp_dict.__iter__()
- def __len__(self):
- return self.cpp_dict.__len__()
- def __contains__(self, key):
- return self.cpp_dict.__contains__(key)
- def __getitem__(self, key):
- return self.cpp_dict.__getitem__(key)
- class ModuleWrapper(nn.Module):
- """
- A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and
- delegates all access.
- """
- def __init__(self, cpp_module):
- # Assign before the super class constructor so ``self.training`` can be
- # assigned to in the super class constructor.
- self.cpp_module = cpp_module
- super(ModuleWrapper, self).__init__()
- self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment]
- self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment]
- self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment]
- for attr in dir(cpp_module):
- # Skip magic methods and the three attributes above.
- if not attr.startswith("_"):
- setattr(self, attr, getattr(self.cpp_module, attr))
- def _apply(self, fn):
- for param in self.parameters():
- # Tensors stored in modules are graph leaves, and we don't
- # want to create copy nodes, so we have to unpack the data.
- param.data = fn(param.data)
- if param._grad is not None:
- param._grad.data = fn(param._grad.data)
- for buf in self.buffers():
- buf.data = fn(buf.data)
- return self
- # nn.Module defines training as a boolean
- @property # type: ignore[override]
- def training(self):
- return self.cpp_module.training
- @training.setter
- def training(self, mode):
- self.cpp_module.train(mode)
- def __repr__(self):
- return self.cpp_module.__repr__()
|