scope.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. ## @package scope
  2. # Module caffe2.python.scope
  3. import contextlib
  4. import threading
  5. from past.builtins import basestring
  6. from caffe2.proto import caffe2_pb2
  7. # The name scope and device scope when creating a new operator.
  8. _NAMESCOPE_SEPARATOR = '/'
  9. _threadlocal_scope = threading.local()
  10. def CurrentNameScope():
  11. global _threadlocal_scope
  12. if not hasattr(_threadlocal_scope, "namescope"):
  13. _threadlocal_scope.namescope = ''
  14. return _threadlocal_scope.namescope
  15. def CurrentDeviceScope():
  16. global _threadlocal_scope
  17. if not hasattr(_threadlocal_scope, "devicescope"):
  18. _threadlocal_scope.devicescope = None
  19. return _threadlocal_scope.devicescope
  20. @contextlib.contextmanager
  21. def NameScope(prefix, reset=False):
  22. global _threadlocal_scope
  23. assert isinstance(prefix, basestring) or prefix is None, \
  24. "NameScope takes in a string as its argument."
  25. old_scope = CurrentNameScope()
  26. prefix = prefix + _NAMESCOPE_SEPARATOR if prefix else ''
  27. if reset:
  28. _threadlocal_scope.namescope = prefix
  29. else:
  30. _threadlocal_scope.namescope = _threadlocal_scope.namescope + prefix
  31. try:
  32. yield
  33. finally:
  34. assert _threadlocal_scope.namescope.endswith(prefix), \
  35. "The namescope variable is changed from outside NameScope() calls."
  36. _threadlocal_scope.namescope = old_scope
  37. @contextlib.contextmanager
  38. def DeviceScope(scope, node_name=None):
  39. new_scope = caffe2_pb2.DeviceOption()
  40. if scope:
  41. assert isinstance(scope, caffe2_pb2.DeviceOption), \
  42. "DeviceScope takes in a caffe2_pb2.DeviceOption as its argument."
  43. new_scope.CopyFrom(scope)
  44. else:
  45. assert node_name, "At least one argument should be non-null in DeviceScope"
  46. # rewrite node_name if it is explicitly given
  47. if node_name:
  48. new_scope.node_name = node_name
  49. global _threadlocal_scope
  50. old_scope = CurrentDeviceScope()
  51. # nested scope should inherit the node_name if it is not explicitly set
  52. if old_scope and old_scope.HasField('node_name') and \
  53. not new_scope.HasField('node_name'):
  54. new_scope.node_name = old_scope.node_name
  55. # nested scope should inherit the extra_info and merged it with new extra_info
  56. if old_scope and hasattr(old_scope, 'extra_info'):
  57. new_scope.extra_info.extend(old_scope.extra_info)
  58. new_scope.extra_info.sort()
  59. _threadlocal_scope.devicescope = new_scope
  60. try:
  61. yield
  62. finally:
  63. assert _threadlocal_scope.devicescope == new_scope, \
  64. "The device scope is changed from outside DeviceScope() calls."
  65. _threadlocal_scope.devicescope = old_scope
  66. @contextlib.contextmanager
  67. def EmptyNameScope():
  68. """
  69. Allow users to 'disable' the name scope behaviour.
  70. This sets the CurrentNameScope() to None, so that the field is
  71. not set in CreateOperator(...), etc.
  72. """
  73. old_scope = CurrentNameScope()
  74. try:
  75. _threadlocal_scope.namescope = ''
  76. yield
  77. finally:
  78. _threadlocal_scope.namescope = old_scope
  79. return
  80. @contextlib.contextmanager
  81. def EmptyDeviceScope():
  82. """
  83. Allow users to 'disable' the device scope behaviour (so it can be
  84. controlled at a NetDef::DeviceOption level, not overridden at
  85. OperatorDef::DeviceOption level).
  86. This sets the CurrentDeviceScope() to None, so that the field is
  87. not set in CreateOperator(...), etc.
  88. """
  89. old_scope = CurrentDeviceScope()
  90. try:
  91. _threadlocal_scope.devicescope = None
  92. yield
  93. finally:
  94. _threadlocal_scope.devicescope = old_scope
  95. return