cpp.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. """Functionality for Python <-> C++ frontend inter-op."""
  2. from torch import nn
  3. class OrderedDictWrapper(object):
  4. """
  5. A wrapper around a C++ OrderedDict that dynamically evaluates the
  6. OrderedDict getter on a bound C++ module, such that new changes on the C++
  7. side are picked up. Otherwise accessing e.g. ``cpp_module._parameters`` just
  8. once would get a frozen copy of the parameters at the time of access.
  9. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` so
  10. using properties does not work.
  11. """
  12. def __init__(self, cpp_module, attr):
  13. self.cpp_module = cpp_module
  14. self.attr = attr
  15. @property
  16. def cpp_dict(self):
  17. return getattr(self.cpp_module, self.attr)
  18. # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
  19. # must manually override them.
  20. def items(self):
  21. return self.cpp_dict.items()
  22. def keys(self):
  23. return self.cpp_dict.keys()
  24. def values(self):
  25. return self.cpp_dict.values()
  26. def __iter__(self):
  27. return self.cpp_dict.__iter__()
  28. def __len__(self):
  29. return self.cpp_dict.__len__()
  30. def __contains__(self, key):
  31. return self.cpp_dict.__contains__(key)
  32. def __getitem__(self, key):
  33. return self.cpp_dict.__getitem__(key)
  34. class ModuleWrapper(nn.Module):
  35. """
  36. A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and
  37. delegates all access.
  38. """
  39. def __init__(self, cpp_module):
  40. # Assign before the super class constructor so ``self.training`` can be
  41. # assigned to in the super class constructor.
  42. self.cpp_module = cpp_module
  43. super(ModuleWrapper, self).__init__()
  44. self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment]
  45. self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment]
  46. self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment]
  47. for attr in dir(cpp_module):
  48. # Skip magic methods and the three attributes above.
  49. if not attr.startswith("_"):
  50. setattr(self, attr, getattr(self.cpp_module, attr))
  51. def _apply(self, fn):
  52. for param in self.parameters():
  53. # Tensors stored in modules are graph leaves, and we don't
  54. # want to create copy nodes, so we have to unpack the data.
  55. param.data = fn(param.data)
  56. if param._grad is not None:
  57. param._grad.data = fn(param._grad.data)
  58. for buf in self.buffers():
  59. buf.data = fn(buf.data)
  60. return self
  61. # nn.Module defines training as a boolean
  62. @property # type: ignore[override]
  63. def training(self):
  64. return self.cpp_module.training
  65. @training.setter
  66. def training(self, mode):
  67. self.cpp_module.train(mode)
  68. def __repr__(self):
  69. return self.cpp_module.__repr__()