| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- ## @package context
- # Module caffe2.python.context
- import inspect
- import threading
- import functools
- class _ContextInfo(object):
- def __init__(self, cls, allow_default):
- self.cls = cls
- self.allow_default = allow_default
- self._local_stack = threading.local()
- @property
- def _stack(self):
- if not hasattr(self._local_stack, 'obj'):
- self._local_stack.obj = []
- return self._local_stack.obj
- def enter(self, value):
- self._stack.append(value)
- def exit(self, value):
- assert len(self._stack) > 0, 'Context %s is empty.' % self.cls
- assert self._stack.pop() == value
- def get_active(self, required=True):
- if len(self._stack) == 0:
- if not required:
- return None
- assert self.allow_default, (
- 'Context %s is required but none is active.' % self.cls)
- self.enter(self.cls())
- return self._stack[-1]
- class _ContextRegistry(object):
- def __init__(self):
- self._ctxs = {}
- def get(self, cls):
- if cls not in self._ctxs:
- assert issubclass(cls, Managed), "must be a context managed class, got {}".format(cls)
- self._ctxs[cls] = _ContextInfo(cls, allow_default=issubclass(cls, DefaultManaged))
- return self._ctxs[cls]
- _CONTEXT_REGISTRY = _ContextRegistry()
- def _context_registry():
- global _CONTEXT_REGISTRY
- return _CONTEXT_REGISTRY
- def _get_managed_classes(obj):
- return [
- cls for cls in inspect.getmro(obj.__class__)
- if issubclass(cls, Managed) and cls != Managed and cls != DefaultManaged
- ]
- class Managed(object):
- """
- Managed makes the inheritted class a context managed class.
- class Foo(Managed): ...
- with Foo() as f:
- assert f == Foo.current()
- """
- @classmethod
- def current(cls, value=None, required=True):
- ctx_info = _context_registry().get(cls)
- if value is not None:
- assert isinstance(value, cls), (
- 'Wrong context type. Expected: %s, got %s.' % (cls, type(value)))
- return value
- return ctx_info.get_active(required=required)
- def __enter__(self):
- for cls in _get_managed_classes(self):
- _context_registry().get(cls).enter(self)
- return self
- def __exit__(self, *args):
- for cls in _get_managed_classes(self):
- _context_registry().get(cls).exit(self)
- def __call__(self, func):
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- with self:
- return func(*args, **kwargs)
- return wrapper
- class DefaultManaged(Managed):
- """
- DefaultManaged is similar to Managed but if there is no parent when
- current() is called it makes a new one.
- """
- pass
|