_ops.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import torch._C
  2. import contextlib
  3. import ctypes
  4. import sys
  5. import types
  6. import torch.jit
  7. import torch._utils_internal
  8. # Query `hasattr` only once.
  9. _SET_GLOBAL_FLAGS = hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')
  10. @contextlib.contextmanager
  11. def dl_open_guard():
  12. """
  13. Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
  14. shared library to load custom operators.
  15. """
  16. if _SET_GLOBAL_FLAGS:
  17. old_flags = sys.getdlopenflags()
  18. sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
  19. yield
  20. if _SET_GLOBAL_FLAGS:
  21. sys.setdlopenflags(old_flags)
  22. # Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
  23. # You can obtain an OpOverload object through attribute query on OpOverloadPacket.
  24. class OpOverload:
  25. def __init__(self, overloadpacket, op, schema):
  26. self._op = op
  27. self._schema = schema
  28. self._overloadpacket = overloadpacket
  29. self._overloadname = 'default' if schema.overload_name == '' else schema.overload_name
  30. self.__name__ = "{}.{}".format(self._schema.name.split("::")[1], self._overloadname)
  31. self.__module__ = overloadpacket.__module__
  32. op.__module__ = overloadpacket.__module__
  33. # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
  34. def __deepcopy__(self, memo=None):
  35. return self
  36. def __repr__(self):
  37. return "<OpOverload(op='{}.{}', overload='{}')>".format(*self._schema.name.split("::"), self._overloadname)
  38. def __call__(self, *args, **kwargs):
  39. return self._op(*args, **kwargs or {})
  40. def __getattr__(self, key):
  41. return getattr(self._op, key)
  42. def __hash__(self):
  43. return hash(self._op)
  44. # `my_namespace.my_op_name.overload_name`
  45. def __str__(self):
  46. return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
  47. @property
  48. def overloadpacket(self):
  49. return self._overloadpacket
  50. @property
  51. def op(self):
  52. return self._op
  53. # TODO: add more methods to expose information about input and output arguments
  54. # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
  55. # You can obtain an OpOverload object through attribute query.
  56. class OpOverloadPacket:
  57. def __init__(self, qualified_op_name, op_name, op, overload_names):
  58. # These attributes are accessible on the object through the properties
  59. # defined below but are immutable
  60. self._qualified_op_name = qualified_op_name
  61. self.__name__ = op_name
  62. self._op = op
  63. self._overload_names = overload_names
  64. # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
  65. def __deepcopy__(self, memo=None):
  66. return self
  67. def __repr__(self):
  68. return "<OpOverloadPacket(op='{}.{}')>".format(*self._qualified_op_name.split("::"))
  69. def __hash__(self):
  70. return hash(self._op)
  71. def __str__(self):
  72. return "{}.{}".format(*self._qualified_op_name.split("::"))
  73. @property
  74. def op(self):
  75. return self._op
  76. def __getattr__(self, key):
  77. # It is not a valid op_name when __file__ is passed in
  78. if key == '__file__':
  79. return 'torch.ops'
  80. # ensure that query for dunder attributes that does not exist on
  81. # opoverloadpacket but instead exists on the self._op object does not unnecessarily call
  82. # `_get_operation_overload` (which is an expensive operation).
  83. # This is done to prevent any potential slowdown. This list can be extended
  84. # if there exists other attributes like `__name__` that only exist on self._op and not on the
  85. # opoverloadpacket.
  86. # This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
  87. try:
  88. if key.startswith('__'):
  89. return getattr(self._op, key)
  90. except AttributeError:
  91. # for consistency because it seems weird to
  92. # throw an attribute error with a message containing
  93. # an object name different from the one the attribute
  94. # query was performed on.
  95. raise AttributeError("'{}' can't have an overload name beginning with '__' and the "
  96. "underlying op {} has no attribute {} either."
  97. .format(str(self), str(self._op), key)) from None
  98. try:
  99. # This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
  100. use_key = '' if key == 'default' else key
  101. # TODO: disallow access to overloads registered by JIT
  102. op_ = torch._C._get_operation_overload(
  103. self._qualified_op_name, use_key)
  104. schema = torch._C._get_schema(self._qualified_op_name, use_key)
  105. overload = OpOverload(self, op_, schema)
  106. # cache the overload object
  107. setattr(self, key, overload)
  108. return overload
  109. except RuntimeError:
  110. raise AttributeError(
  111. "The underlying op of '{}' has no overload name '{}'".format(str(self), key)
  112. ) from None
  113. def __call__(self, *args, **kwargs):
  114. # overloading __call__ to ensure torch.ops.foo.bar()
  115. # is still callable from JIT
  116. # We save the function ptr as the `op` attribute on
  117. # OpOverloadPacket to access it here.
  118. return self._op(*args, **kwargs or {})
  119. # TODO: use this to make a __dir__
  120. def overloads(self):
  121. return [n if n else "default" for n in self._overload_names]
  122. # Resolution of torch.fn is different from torch.ops.aten.fn
  123. # torch.fn uses the Python argparser, matches with the
  124. # appropriate schema, and calls into the unboxed version of the method
  125. # torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
  126. # JIT creates a stack of all the overloads and then tries to match the
  127. # correct one at runtime and always calls into the boxed version of the method
  128. # Autograd codegen creates VariableType, TracerType,
  129. # inplace or view type and python bindings.
  130. # Aten codegen generates tensor methods for the the tensor class.
  131. # _OpNamespace is a subclass of ModuleType because the torch script
  132. # allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
  133. # to work from script, we need to ensure ops and foo are modules
  134. class _OpNamespace(types.ModuleType):
  135. """
  136. An op namespace to dynamically bind Operators into Python.
  137. Say a user has created a custom Operator called "my_namespace::my_op". To
  138. call this op, the user will write torch.ops.my_namespace.my_op(...).
  139. At startup, this operation will not yet be bound into Python. Instead, the
  140. following sequence of magic tricks will occur:
  141. 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
  142. on the `torch.ops` object, which will create a new `_OpNamespace`
  143. object called `my_namespace` and set it as an attribute on the `ops`
  144. object.
  145. 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
  146. the `my_namespace` object, which will retrieve the operation via
  147. `torch.get_operation`, a function bound from C++, and then in a similar
  148. fashion bind this new object onto the `my_namespace` object.
  149. 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
  150. and subsequent accesses will incur no further lookup (the namespace and
  151. operation will already exist).
  152. """
  153. def __init__(self, name):
  154. super(_OpNamespace, self).__init__('torch.ops.' + name)
  155. self.name = name
  156. def __getattr__(self, op_name):
  157. # It is not a valid op_name when __file__ is passed in
  158. if op_name == '__file__':
  159. return 'torch.ops'
  160. # Get the op `my_namespace::my_op` if available. This will also check
  161. # for overloads and raise an exception if there are more than one.
  162. namespace_name = self.name
  163. qualified_op_name = '{}::{}'.format(namespace_name, op_name)
  164. try:
  165. op, overload_names = torch._C._jit_get_operation(qualified_op_name)
  166. except RuntimeError as e:
  167. # Turn this into AttributeError so getattr(obj, key, default)
  168. # works (this is called by TorchScript with __origin__)
  169. raise AttributeError(f"'_OpNamespace' object has no attribute '{op_name}'") from e
  170. # let the script frontend know that op is identical to the builtin op
  171. # with qualified_op_name
  172. torch.jit._builtins._register_builtin(op, qualified_op_name)
  173. op.__module__ = self.__module__ + "." + namespace_name
  174. opoverloadpacket = OpOverloadPacket(qualified_op_name, op_name, op, overload_names)
  175. opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
  176. # cache the opoverloadpacket to ensure that each op corresponds to
  177. # a unique OpOverloadPacket object
  178. setattr(self, op_name, opoverloadpacket)
  179. return opoverloadpacket
  180. class _Ops(types.ModuleType):
  181. __file__ = '_ops.py'
  182. def __init__(self):
  183. super(_Ops, self).__init__('torch.ops')
  184. self.loaded_libraries = set()
  185. def __getattr__(self, name):
  186. # Here we are creating `torch.ops.my_namespace`
  187. namespace = _OpNamespace(name)
  188. setattr(self, name, namespace)
  189. return namespace
  190. def load_library(self, path):
  191. """
  192. Loads a shared library from the given path into the current process.
  193. The library being loaded may run global initialization code to register
  194. custom operators with the PyTorch JIT runtime. This allows dynamically
  195. loading custom operators. For this, you should compile your operator
  196. and the static registration code into a shared library object, and then
  197. call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
  198. shared object.
  199. After the library is loaded, it is added to the
  200. ``torch.ops.loaded_libraries`` attribute, a set that may be inspected
  201. for the paths of all libraries loaded using this function.
  202. Args:
  203. path (str): A path to a shared library to load.
  204. """
  205. if sys.executable == "torch_deploy":
  206. return
  207. path = torch._utils_internal.resolve_library_path(path)
  208. with dl_open_guard():
  209. # Import the shared library into the process, thus running its
  210. # static (global) initialization code in order to register custom
  211. # operators with the JIT.
  212. ctypes.CDLL(path)
  213. self.loaded_libraries.add(path)
  214. # The ops "namespace"
  215. ops = _Ops()