local.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import threading
  2. from contextlib import contextmanager
  3. from typing import Optional, Iterator
  4. # Simple dynamic scoping implementation. The name "parametrize" comes
  5. # from Racket.
  6. #
  7. # WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about
  8. # why you need to add a toggle to the global behavior of code
  9. # generation. The parameters here should really only be used
  10. # for "temporary" situations, where we need to temporarily change
  11. # the codegen in some cases because we cannot conveniently update
  12. # all call sites, and are slated to be eliminated once all call
  13. # sites are eliminated. If you don't have a plan for how to get there,
  14. # DON'T add a new entry here.
  15. class Locals(threading.local):
  16. use_const_ref_for_mutable_tensors: Optional[bool] = None
  17. _locals = Locals()
  18. def use_const_ref_for_mutable_tensors() -> bool:
  19. assert _locals.use_const_ref_for_mutable_tensors is not None, (
  20. "need to initialize local.use_const_ref_for_mutable_tensors with "
  21. "local.parametrize"
  22. )
  23. return _locals.use_const_ref_for_mutable_tensors
  24. @contextmanager
  25. def parametrize(*, use_const_ref_for_mutable_tensors: bool) -> Iterator[None]:
  26. old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
  27. try:
  28. _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
  29. yield
  30. finally:
  31. _locals.use_const_ref_for_mutable_tensors = (
  32. old_use_const_ref_for_mutable_tensors
  33. )