session.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. ## @package session
  2. # Module caffe2.python.session
  3. from caffe2.python import core, workspace
  4. from caffe2.python.task import Cluster, Task, TaskGroup, WorkspaceType
  5. class CompiledRunnable(object):
  6. """ Wrapper for compiled runnable returned from session.compile() """
  7. def __init__(self, obj, session_class):
  8. self.obj = obj
  9. self.session_class = session_class
  10. class Session(object):
  11. """
  12. Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups.
  13. A session can potentially run in multiple nodes concurrently.
  14. Example:
  15. from core import Net
  16. from caffe2.python.task import Task, TaskGroup, WorkspaceType
  17. net = Net('test1')
  18. net.Add([net.Const(1), net.Const(2)])
  19. net2 = net.Clone()
  20. step = core.execution_step('step1', [net2])
  21. with TaskGroup(WorkspaceType.GLOBAL) as init_tg:
  22. with Node('node1'):
  23. n1setup = net.Net('n1setup')
  24. n1msg = n1setup.Const('Hello from node 1.')
  25. Task(step=n1setup)
  26. with TaskGroup() as private_tg:
  27. with Node('node1'):
  28. n1 = net.Net('n1')
  29. n1.Print(n1msg, 0)
  30. Task(step=n1)
  31. with Node('node2'):
  32. n2 = net.Net('n2')
  33. n2.Print(n2.Const('Hello from node 2.'), 0)
  34. Task(step=n2)
  35. session = LocalSession()
  36. session.run(net)
  37. session.run(step)
  38. session.run(init_tg)
  39. session.run(private_tg)
  40. Global Workspace:
  41. At the beginning of the session, a global workspace is created and kept
  42. alive for the duration of the session.
  43. Private Workspace:
  44. Tasks can be run either directly on the global workspace, or they can
  45. instantiate a private child workspace that is released after each run.
  46. Blob visibility:
  47. Tasks running in different nodes in parallel will always run under
  48. different workspaces, so it must be assumed that they won't be able to
  49. access each other's blobs. Tasks running on the same node will follow
  50. Workspace hierarchy rules: tasks running on separate private workspaces
  51. will only be able to share blobs defined on a common parent Workspace.
  52. """
  53. _compiled_cache = {}
  54. def __init__(self):
  55. self._open = True
  56. def is_open(self):
  57. return self._open
  58. @classmethod
  59. def compile(cls, runnable, workspace_type=None, setup_net_list=None):
  60. if isinstance(runnable, CompiledRunnable):
  61. assert cls == runnable.session_class, (
  62. 'Runnable was compiled for different session type. ' +
  63. 'Need: %s, got: %s' % (
  64. cls.__name__, runnable.session_class.__name__))
  65. return runnable
  66. if runnable in cls._compiled_cache:
  67. return cls._compiled_cache[runnable]
  68. if isinstance(runnable, TaskGroup):
  69. if workspace_type:
  70. if runnable.workspace_type():
  71. assert runnable.workspace_type() == workspace_type, \
  72. "Require {} but already have {}".format(
  73. workspace_type, runnable.workspace_type())
  74. else:
  75. runnable._workspace_type = workspace_type
  76. tg = runnable
  77. else:
  78. if workspace_type is None:
  79. workspace_type = WorkspaceType.GLOBAL
  80. tg = TaskGroup(workspace_type=workspace_type)
  81. if isinstance(runnable, Task):
  82. tg.add(runnable)
  83. elif isinstance(runnable, core.ExecutionStep):
  84. tg.add(Task(step=runnable))
  85. elif isinstance(runnable, core.Plan):
  86. # ExecutionSteps in Plan() object is supposed to run sequentially, while
  87. # tasks in TaskGroup run in parallel. So if we have multiple
  88. # ExecutionSteps in Plan() object, we choose to have a root
  89. # ExecutionStep to wrap all ExecutionSteps.
  90. assert len(runnable.Steps()) > 0
  91. if len(runnable.Steps()) == 1:
  92. tg.add(Task(step=runnable.Steps()[0]))
  93. else:
  94. # Task takes a list of ExecutionSteps and automatically wrap into
  95. # a root ExecutionStep
  96. tg.add(Task(step=runnable.Steps()))
  97. else:
  98. step = core.execution_step('runnable', runnable)
  99. tg.add(Task(step=step))
  100. compiled = CompiledRunnable(
  101. cls._compile_task_group(tg, setup_net_list), session_class=cls)
  102. cls._compiled_cache[runnable] = compiled
  103. return compiled
  104. def run(self, runnable, workspace_type=None, setup_net_list=None):
  105. """Run the given runnable.
  106. Args:
  107. runnable: Object recognized by the Session. Currently, we support
  108. TaskGroup, Task, Plan, ExecutionStep, and Net.
  109. workspace_type: A string defined in the WorkspaceType object.
  110. setup_net_list: A list of Net objects or a list of NetDef protos.
  111. So far this is only used by the DistributedSession, in which we
  112. need to pass a list of special nets to setup the master.
  113. """
  114. assert self.is_open(), 'Session is closed.'
  115. assert runnable is not None, 'Got a none runnable.'
  116. self._run_compiled(self.compile(runnable, workspace_type,
  117. setup_net_list).obj)
  118. def close(self):
  119. if self.is_open():
  120. self._do_close()
  121. self._open = False
  122. def fetch_output(self, output):
  123. raise NotImplementedError()
  124. def _run_compiled(self, task_group):
  125. raise NotImplementedError()
  126. @classmethod
  127. def _compile_task_group(cls, task_group, setup_net_list=None):
  128. return task_group
  129. def _do_close(self):
  130. pass
  131. def __enter__(self):
  132. assert self._open, 'Session already closed.'
  133. return self
  134. def __exit__(self, ex_type, value, traceback):
  135. if ex_type is None:
  136. self.close()
  137. class LocalSession(Session):
  138. """
  139. Session that runs in a single node.
  140. Tasks are all remapped to run in parallel in the 'local' node.
  141. Currently, LocalSession runs all parallel tasks in the same workspace,
  142. but this behavior may change in the future. Only tasks pointing to the
  143. same logical node are guaranteed to always run in the same workspace.
  144. """
  145. def __init__(self, ws=None):
  146. Session.__init__(self)
  147. self._ws = ws or workspace.C.Workspace.current
  148. @classmethod
  149. def _compile_task_group(cls, task_group, setup_net_list=None):
  150. with Cluster():
  151. task = task_group.to_task()
  152. plan = core.Plan('task_group_plan')
  153. plan.AddStep(task.get_step())
  154. return (plan, task.output_list(), task.workspace_type())
  155. def _run_compiled(self, compiled):
  156. plan, output_list, workspace_type = compiled
  157. # make sure the output blobs belong to the parent workspace
  158. outputs = []
  159. for name in output_list.names():
  160. self._ws.create_blob(str(name))
  161. outputs.append(core.BlobReference(str(name)))
  162. output_list.set_values(outputs, _fetch_func=self._fetch_output)
  163. task_ws = (
  164. workspace.C.Workspace(self._ws)
  165. if workspace_type == WorkspaceType.PRIVATE else self._ws)
  166. with workspace.WorkspaceGuard(task_ws):
  167. task_ws.run(plan)
  168. def _fetch_output(self, output):
  169. return self._ws.blobs[str(output)].fetch()