| 123456789101112131415161718192021222324252627282930313233343536373839404142 |
- import threading
- from contextlib import contextmanager
- from typing import Optional, Iterator
- # Simple dynamic scoping implementation. The name "parametrize" comes
- # from Racket.
- #
- # WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about
- # why you need to add a toggle to the global behavior of code
- # generation. The parameters here should really only be used
- # for "temporary" situations, where we need to temporarily change
- # the codegen in some cases because we cannot conveniently update
- # all call sites, and are slated to be eliminated once all call
- # sites are eliminated. If you don't have a plan for how to get there,
- # DON'T add a new entry here.
- class Locals(threading.local):
- use_const_ref_for_mutable_tensors: Optional[bool] = None
- _locals = Locals()
- def use_const_ref_for_mutable_tensors() -> bool:
- assert _locals.use_const_ref_for_mutable_tensors is not None, (
- "need to initialize local.use_const_ref_for_mutable_tensors with "
- "local.parametrize"
- )
- return _locals.use_const_ref_for_mutable_tensors
- @contextmanager
- def parametrize(*, use_const_ref_for_mutable_tensors: bool) -> Iterator[None]:
- old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
- try:
- _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
- yield
- finally:
- _locals.use_const_ref_for_mutable_tensors = (
- old_use_const_ref_for_mutable_tensors
- )
|