arg_scope.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import contextlib
  2. import copy
  3. import threading
  4. _threadlocal_scope = threading.local()
  5. @contextlib.contextmanager
  6. def arg_scope(single_helper_or_list, **kwargs):
  7. global _threadlocal_scope
  8. if not isinstance(single_helper_or_list, list):
  9. assert callable(single_helper_or_list), \
  10. "arg_scope is only supporting single or a list of helper functions."
  11. single_helper_or_list = [single_helper_or_list]
  12. old_scope = copy.deepcopy(get_current_scope())
  13. for helper in single_helper_or_list:
  14. assert callable(helper), \
  15. "arg_scope is only supporting a list of callable helper functions."
  16. helper_key = helper.__name__
  17. if helper_key not in old_scope:
  18. _threadlocal_scope.current_scope[helper_key] = {}
  19. _threadlocal_scope.current_scope[helper_key].update(kwargs)
  20. yield
  21. _threadlocal_scope.current_scope = old_scope
  22. def get_current_scope():
  23. global _threadlocal_scope
  24. if not hasattr(_threadlocal_scope, "current_scope"):
  25. _threadlocal_scope.current_scope = {}
  26. return _threadlocal_scope.current_scope