library.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from ._ops import OpOverload
  2. from typing import Set
  3. import traceback
  4. import torch
  5. __all__ = ['Library', 'impl', 'define']
  6. # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
  7. # The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
  8. # This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
  9. # libraries calling into kernels not intended to be called.
  10. _impls: Set[str] = set()
  11. class Library:
  12. """
  13. A class to create libraries that can be used to register new operators or
  14. override operators in existing libraries from Python.
  15. A user can optionally pass in a dispatch keyname if they only want to register
  16. kernels corresponding to only one specific dispatch key.
  17. Args:
  18. ns: library name
  19. kind: "DEF", "IMPL" (default: "IMPL")
  20. dispatch_key: PyTorch dispatch key (default: "")
  21. """
  22. def __init__(self, ns, kind, dispatch_key=""):
  23. if kind != "IMPL" and kind != "DEF":
  24. raise ValueError("Unsupported kind: ", kind)
  25. frame = traceback.extract_stack(limit=3)[0]
  26. filename, lineno = frame.filename, frame.lineno
  27. self.m = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
  28. self.ns = ns
  29. self._op_impls = set()
  30. self.kind = kind
  31. self.dispatch_key = dispatch_key
  32. def __repr__(self):
  33. return "Library(kind={}, ns={}, dispatch_key={})>".format(self.kind, self.ns, self.dispatch_key)
  34. def impl(self, op_name, fn, dispatch_key=''):
  35. if dispatch_key == '':
  36. dispatch_key = self.dispatch_key
  37. if isinstance(op_name, str):
  38. name = op_name
  39. elif isinstance(op_name, OpOverload):
  40. name = op_name._schema.name
  41. overload_name = op_name._schema.overload_name
  42. if overload_name != '':
  43. name = name + '.' + overload_name
  44. else:
  45. raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
  46. key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
  47. if key in _impls:
  48. # TODO: in future, add more info about where the existing function is registered (this info is
  49. # today already returned by the C++ warning when impl is called but we error out before that)
  50. raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
  51. "'s behavior for {} dispatch key and {} namespace.".
  52. format(name.split("::")[-1], dispatch_key, self.ns))
  53. self.m.impl(name, dispatch_key, fn)
  54. _impls.add(key)
  55. self._op_impls.add(key)
  56. def define(self, schema, alias_analysis=""):
  57. '''
  58. Takes a schema to define a new operator.
  59. Also, optionally takes `alias_analysis` argument to indicate if the aliasing properties of the arguments
  60. can be inferred from the schema (default behavior) or not ("CONSERVATIVE").
  61. Returns the name of the operator as inferred from the schema.
  62. '''
  63. # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
  64. # AliasAnalysis type in C++
  65. if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
  66. raise RuntimeError("Invalid alias_analysis type")
  67. return self.m.define(schema, alias_analysis)
  68. def __del__(self):
  69. for key in self._op_impls:
  70. _impls.remove(key)
  71. del self.m
  72. # decorator to register python functions for library ops
  73. # Note: this decorator API should remain consistent with `Library.impl` API
  74. def impl(lib, name, dispatch_key=""):
  75. def wrap(f):
  76. lib.impl(name, f, dispatch_key)
  77. return wrap
  78. def define(lib, schema, alias_analysis=""):
  79. def wrap(f):
  80. name = lib.define(schema, alias_analysis)
  81. lib.impl(name, f)
  82. return wrap