_pytree.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional
  2. from collections import namedtuple
  3. """
  4. Contains utility functions for working with nested python data structures.
  5. A *pytree* is Python nested data structure. It is a tree in the sense that
  6. nodes are Python collections (e.g., list, tuple, dict) and the leaves are
  7. Python values. Furthermore, a pytree should not contain reference cycles.
  8. pytrees are useful for working with nested collections of Tensors. For example,
  9. one can use `tree_map` to map a function over all Tensors inside some nested
  10. collection of Tensors and `tree_unflatten` to get a flat list of all Tensors
  11. inside some nested collection. pytrees are helpful for implementing nested
  12. collection support for PyTorch APIs.
  13. This pytree implementation is not very performant due to Python overhead
  14. To improve the performance we can move parts of the implementation to C++.
  15. """
  16. # A NodeDef holds two callables:
  17. # - flatten_fn should take the collection and return a flat list of values.
  18. # It can also return some context that is used in reconstructing the
  19. # collection.
  20. # - unflatten_fn should take a flat list of values and some context
  21. # (returned by flatten_fn). It returns the collection by reconstructing
  22. # it from the list and the context.
  23. Context = Any
  24. PyTree = Any
  25. FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
  26. UnflattenFunc = Callable[[List, Context], PyTree]
  27. class NodeDef(NamedTuple):
  28. flatten_fn: FlattenFunc
  29. unflatten_fn: UnflattenFunc
  30. SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
  31. def _register_pytree_node(typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc) -> None:
  32. SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn)
  33. def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
  34. return list(d.values()), list(d.keys())
  35. def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
  36. return {key: value for key, value in zip(context, values)}
  37. def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
  38. return d, None
  39. def _list_unflatten(values: List[Any], context: Context) -> List[Any]:
  40. return list(values)
  41. def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
  42. return list(d), None
  43. def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]:
  44. return tuple(values)
  45. def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
  46. return list(d), type(d)
  47. def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple:
  48. return cast(NamedTuple, context(*values))
  49. _register_pytree_node(dict, _dict_flatten, _dict_unflatten)
  50. _register_pytree_node(list, _list_flatten, _list_unflatten)
  51. _register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten)
  52. _register_pytree_node(namedtuple, _namedtuple_flatten, _namedtuple_unflatten)
  53. # h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
  54. def _is_namedtuple_instance(pytree: Any) -> bool:
  55. typ = type(pytree)
  56. bases = typ.__bases__
  57. if len(bases) != 1 or bases[0] != tuple:
  58. return False
  59. fields = getattr(typ, '_fields', None)
  60. if not isinstance(fields, tuple):
  61. return False
  62. return all(type(entry) == str for entry in fields)
  63. def _get_node_type(pytree: Any) -> Any:
  64. if _is_namedtuple_instance(pytree):
  65. return namedtuple
  66. return type(pytree)
  67. # A leaf is defined as anything that is not a Node.
  68. def _is_leaf(pytree: PyTree) -> bool:
  69. return _get_node_type(pytree) not in SUPPORTED_NODES.keys()
  70. # A TreeSpec represents the structure of a pytree. It holds:
  71. # "type": the type of root Node of the pytree
  72. # context: some context that is useful in unflattening the pytree
  73. # children_specs: specs for each child of the root Node
  74. # num_leaves: the number of leaves
  75. class TreeSpec:
  76. def __init__(self, typ: Any, context: Context, children_specs: List['TreeSpec']) -> None:
  77. self.type = typ
  78. self.context = context
  79. self.children_specs = children_specs
  80. self.num_leaves: int = sum([spec.num_leaves for spec in children_specs])
  81. def __repr__(self) -> str:
  82. return f'TreeSpec({self.type.__name__}, {self.context}, {self.children_specs})'
  83. def __eq__(self, other: Any) -> bool:
  84. result = self.type == other.type and self.context == other.context \
  85. and self.children_specs == other.children_specs \
  86. and self.num_leaves == other.num_leaves
  87. # This should really not be necessary, but mypy errors out without it.
  88. return cast(bool, result)
  89. def __ne__(self, other: Any) -> bool:
  90. return not self.__eq__(other)
  91. class LeafSpec(TreeSpec):
  92. def __init__(self) -> None:
  93. super().__init__(None, None, [])
  94. self.num_leaves = 1
  95. def __repr__(self) -> str:
  96. return '*'
  97. def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
  98. """Flattens a pytree into a list of values and a TreeSpec that can be used
  99. to reconstruct the pytree.
  100. """
  101. if _is_leaf(pytree):
  102. return [pytree], LeafSpec()
  103. node_type = _get_node_type(pytree)
  104. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  105. child_pytrees, context = flatten_fn(pytree)
  106. # Recursively flatten the children
  107. result : List[Any] = []
  108. children_specs : List['TreeSpec'] = []
  109. for child in child_pytrees:
  110. flat, child_spec = tree_flatten(child)
  111. result += flat
  112. children_specs.append(child_spec)
  113. return result, TreeSpec(node_type, context, children_specs)
  114. def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
  115. """Given a list of values and a TreeSpec, builds a pytree.
  116. This is the inverse operation of `tree_flatten`.
  117. """
  118. if not isinstance(spec, TreeSpec):
  119. raise ValueError(
  120. f'tree_unflatten(values, spec): Expected `spec` to be instance of '
  121. f'TreeSpec but got item of type {type(spec)}.')
  122. if len(values) != spec.num_leaves:
  123. raise ValueError(
  124. f'tree_unflatten(values, spec): `values` has length {len(values)} '
  125. f'but the spec refers to a pytree that holds {spec.num_leaves} '
  126. f'items ({spec}).')
  127. if isinstance(spec, LeafSpec):
  128. return values[0]
  129. unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
  130. # Recursively unflatten the children
  131. start = 0
  132. end = 0
  133. child_pytrees = []
  134. for child_spec in spec.children_specs:
  135. end += child_spec.num_leaves
  136. child_pytrees.append(tree_unflatten(values[start:end], child_spec))
  137. start = end
  138. return unflatten_fn(child_pytrees, spec.context)
  139. def tree_map(fn: Any, pytree: PyTree) -> PyTree:
  140. flat_args, spec = tree_flatten(pytree)
  141. return tree_unflatten([fn(i) for i in flat_args], spec)
  142. # Broadcasts a pytree to the provided TreeSpec and returns the flattened
  143. # values. If this is not possible, then this function returns None.
  144. #
  145. # For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
  146. # would return [0, 0]. This is useful for part of the vmap implementation:
  147. # a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
  148. # broadcastable to the tree structure of `inputs` and we use
  149. # _broadcast_to_and_flatten to check this.
  150. def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]:
  151. assert isinstance(spec, TreeSpec)
  152. if _is_leaf(pytree):
  153. return [pytree] * spec.num_leaves
  154. if isinstance(spec, LeafSpec):
  155. return None
  156. node_type = _get_node_type(pytree)
  157. if node_type != spec.type:
  158. return None
  159. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  160. child_pytrees, ctx = flatten_fn(pytree)
  161. # Check if the Node is different from the spec
  162. if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context:
  163. return None
  164. # Recursively flatten the children
  165. result : List[Any] = []
  166. for child, child_spec in zip(child_pytrees, spec.children_specs):
  167. flat = _broadcast_to_and_flatten(child, child_spec)
  168. if flat is not None:
  169. result += flat
  170. else:
  171. return None
  172. return result