| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- ## @package scope
- # Module caffe2.python.scope
- import contextlib
- import threading
- from past.builtins import basestring
- from caffe2.proto import caffe2_pb2
- # The name scope and device scope when creating a new operator.
- _NAMESCOPE_SEPARATOR = '/'
- _threadlocal_scope = threading.local()
- def CurrentNameScope():
- global _threadlocal_scope
- if not hasattr(_threadlocal_scope, "namescope"):
- _threadlocal_scope.namescope = ''
- return _threadlocal_scope.namescope
- def CurrentDeviceScope():
- global _threadlocal_scope
- if not hasattr(_threadlocal_scope, "devicescope"):
- _threadlocal_scope.devicescope = None
- return _threadlocal_scope.devicescope
- @contextlib.contextmanager
- def NameScope(prefix, reset=False):
- global _threadlocal_scope
- assert isinstance(prefix, basestring) or prefix is None, \
- "NameScope takes in a string as its argument."
- old_scope = CurrentNameScope()
- prefix = prefix + _NAMESCOPE_SEPARATOR if prefix else ''
- if reset:
- _threadlocal_scope.namescope = prefix
- else:
- _threadlocal_scope.namescope = _threadlocal_scope.namescope + prefix
- try:
- yield
- finally:
- assert _threadlocal_scope.namescope.endswith(prefix), \
- "The namescope variable is changed from outside NameScope() calls."
- _threadlocal_scope.namescope = old_scope
- @contextlib.contextmanager
- def DeviceScope(scope, node_name=None):
- new_scope = caffe2_pb2.DeviceOption()
- if scope:
- assert isinstance(scope, caffe2_pb2.DeviceOption), \
- "DeviceScope takes in a caffe2_pb2.DeviceOption as its argument."
- new_scope.CopyFrom(scope)
- else:
- assert node_name, "At least one argument should be non-null in DeviceScope"
- # rewrite node_name if it is explicitly given
- if node_name:
- new_scope.node_name = node_name
- global _threadlocal_scope
- old_scope = CurrentDeviceScope()
- # nested scope should inherit the node_name if it is not explicitly set
- if old_scope and old_scope.HasField('node_name') and \
- not new_scope.HasField('node_name'):
- new_scope.node_name = old_scope.node_name
- # nested scope should inherit the extra_info and merged it with new extra_info
- if old_scope and hasattr(old_scope, 'extra_info'):
- new_scope.extra_info.extend(old_scope.extra_info)
- new_scope.extra_info.sort()
- _threadlocal_scope.devicescope = new_scope
- try:
- yield
- finally:
- assert _threadlocal_scope.devicescope == new_scope, \
- "The device scope is changed from outside DeviceScope() calls."
- _threadlocal_scope.devicescope = old_scope
- @contextlib.contextmanager
- def EmptyNameScope():
- """
- Allow users to 'disable' the name scope behaviour.
- This sets the CurrentNameScope() to None, so that the field is
- not set in CreateOperator(...), etc.
- """
- old_scope = CurrentNameScope()
- try:
- _threadlocal_scope.namescope = ''
- yield
- finally:
- _threadlocal_scope.namescope = old_scope
- return
- @contextlib.contextmanager
- def EmptyDeviceScope():
- """
- Allow users to 'disable' the device scope behaviour (so it can be
- controlled at a NetDef::DeviceOption level, not overridden at
- OperatorDef::DeviceOption level).
- This sets the CurrentDeviceScope() to None, so that the field is
- not set in CreateOperator(...), etc.
- """
- old_scope = CurrentDeviceScope()
- try:
- _threadlocal_scope.devicescope = None
- yield
- finally:
- _threadlocal_scope.devicescope = old_scope
- return
|