context.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. ## @package context
  2. # Module caffe2.python.context
  3. import inspect
  4. import threading
  5. import functools
  6. class _ContextInfo(object):
  7. def __init__(self, cls, allow_default):
  8. self.cls = cls
  9. self.allow_default = allow_default
  10. self._local_stack = threading.local()
  11. @property
  12. def _stack(self):
  13. if not hasattr(self._local_stack, 'obj'):
  14. self._local_stack.obj = []
  15. return self._local_stack.obj
  16. def enter(self, value):
  17. self._stack.append(value)
  18. def exit(self, value):
  19. assert len(self._stack) > 0, 'Context %s is empty.' % self.cls
  20. assert self._stack.pop() == value
  21. def get_active(self, required=True):
  22. if len(self._stack) == 0:
  23. if not required:
  24. return None
  25. assert self.allow_default, (
  26. 'Context %s is required but none is active.' % self.cls)
  27. self.enter(self.cls())
  28. return self._stack[-1]
  29. class _ContextRegistry(object):
  30. def __init__(self):
  31. self._ctxs = {}
  32. def get(self, cls):
  33. if cls not in self._ctxs:
  34. assert issubclass(cls, Managed), "must be a context managed class, got {}".format(cls)
  35. self._ctxs[cls] = _ContextInfo(cls, allow_default=issubclass(cls, DefaultManaged))
  36. return self._ctxs[cls]
  37. _CONTEXT_REGISTRY = _ContextRegistry()
  38. def _context_registry():
  39. global _CONTEXT_REGISTRY
  40. return _CONTEXT_REGISTRY
  41. def _get_managed_classes(obj):
  42. return [
  43. cls for cls in inspect.getmro(obj.__class__)
  44. if issubclass(cls, Managed) and cls != Managed and cls != DefaultManaged
  45. ]
  46. class Managed(object):
  47. """
  48. Managed makes the inheritted class a context managed class.
  49. class Foo(Managed): ...
  50. with Foo() as f:
  51. assert f == Foo.current()
  52. """
  53. @classmethod
  54. def current(cls, value=None, required=True):
  55. ctx_info = _context_registry().get(cls)
  56. if value is not None:
  57. assert isinstance(value, cls), (
  58. 'Wrong context type. Expected: %s, got %s.' % (cls, type(value)))
  59. return value
  60. return ctx_info.get_active(required=required)
  61. def __enter__(self):
  62. for cls in _get_managed_classes(self):
  63. _context_registry().get(cls).enter(self)
  64. return self
  65. def __exit__(self, *args):
  66. for cls in _get_managed_classes(self):
  67. _context_registry().get(cls).exit(self)
  68. def __call__(self, func):
  69. @functools.wraps(func)
  70. def wrapper(*args, **kwargs):
  71. with self:
  72. return func(*args, **kwargs)
  73. return wrapper
  74. class DefaultManaged(Managed):
  75. """
  76. DefaultManaged is similar to Managed but if there is no parent when
  77. current() is called it makes a new one.
  78. """
  79. pass