| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- ## @package session
- # Module caffe2.python.session
- from caffe2.python import core, workspace
- from caffe2.python.task import Cluster, Task, TaskGroup, WorkspaceType
- class CompiledRunnable(object):
- """ Wrapper for compiled runnable returned from session.compile() """
- def __init__(self, obj, session_class):
- self.obj = obj
- self.session_class = session_class
- class Session(object):
- """
- Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups.
- A session can potentially run in multiple nodes concurrently.
- Example:
- from core import Net
- from caffe2.python.task import Task, TaskGroup, WorkspaceType
- net = Net('test1')
- net.Add([net.Const(1), net.Const(2)])
- net2 = net.Clone()
- step = core.execution_step('step1', [net2])
- with TaskGroup(WorkspaceType.GLOBAL) as init_tg:
- with Node('node1'):
- n1setup = net.Net('n1setup')
- n1msg = n1setup.Const('Hello from node 1.')
- Task(step=n1setup)
- with TaskGroup() as private_tg:
- with Node('node1'):
- n1 = net.Net('n1')
- n1.Print(n1msg, 0)
- Task(step=n1)
- with Node('node2'):
- n2 = net.Net('n2')
- n2.Print(n2.Const('Hello from node 2.'), 0)
- Task(step=n2)
- session = LocalSession()
- session.run(net)
- session.run(step)
- session.run(init_tg)
- session.run(private_tg)
- Global Workspace:
- At the beginning of the session, a global workspace is created and kept
- alive for the duration of the session.
- Private Workspace:
- Tasks can be run either directly on the global workspace, or they can
- instantiate a private child workspace that is released after each run.
- Blob visibility:
- Tasks running in different nodes in parallel will always run under
- different workspaces, so it must be assumed that they won't be able to
- access each other's blobs. Tasks running on the same node will follow
- Workspace hierarchy rules: tasks running on separate private workspaces
- will only be able to share blobs defined on a common parent Workspace.
- """
- _compiled_cache = {}
- def __init__(self):
- self._open = True
- def is_open(self):
- return self._open
- @classmethod
- def compile(cls, runnable, workspace_type=None, setup_net_list=None):
- if isinstance(runnable, CompiledRunnable):
- assert cls == runnable.session_class, (
- 'Runnable was compiled for different session type. ' +
- 'Need: %s, got: %s' % (
- cls.__name__, runnable.session_class.__name__))
- return runnable
- if runnable in cls._compiled_cache:
- return cls._compiled_cache[runnable]
- if isinstance(runnable, TaskGroup):
- if workspace_type:
- if runnable.workspace_type():
- assert runnable.workspace_type() == workspace_type, \
- "Require {} but already have {}".format(
- workspace_type, runnable.workspace_type())
- else:
- runnable._workspace_type = workspace_type
- tg = runnable
- else:
- if workspace_type is None:
- workspace_type = WorkspaceType.GLOBAL
- tg = TaskGroup(workspace_type=workspace_type)
- if isinstance(runnable, Task):
- tg.add(runnable)
- elif isinstance(runnable, core.ExecutionStep):
- tg.add(Task(step=runnable))
- elif isinstance(runnable, core.Plan):
- # ExecutionSteps in Plan() object is supposed to run sequentially, while
- # tasks in TaskGroup run in parallel. So if we have multiple
- # ExecutionSteps in Plan() object, we choose to have a root
- # ExecutionStep to wrap all ExecutionSteps.
- assert len(runnable.Steps()) > 0
- if len(runnable.Steps()) == 1:
- tg.add(Task(step=runnable.Steps()[0]))
- else:
- # Task takes a list of ExecutionSteps and automatically wrap into
- # a root ExecutionStep
- tg.add(Task(step=runnable.Steps()))
- else:
- step = core.execution_step('runnable', runnable)
- tg.add(Task(step=step))
- compiled = CompiledRunnable(
- cls._compile_task_group(tg, setup_net_list), session_class=cls)
- cls._compiled_cache[runnable] = compiled
- return compiled
- def run(self, runnable, workspace_type=None, setup_net_list=None):
- """Run the given runnable.
- Args:
- runnable: Object recognized by the Session. Currently, we support
- TaskGroup, Task, Plan, ExecutionStep, and Net.
- workspace_type: A string defined in the WorkspaceType object.
- setup_net_list: A list of Net objects or a list of NetDef protos.
- So far this is only used by the DistributedSession, in which we
- need to pass a list of special nets to setup the master.
- """
- assert self.is_open(), 'Session is closed.'
- assert runnable is not None, 'Got a none runnable.'
- self._run_compiled(self.compile(runnable, workspace_type,
- setup_net_list).obj)
- def close(self):
- if self.is_open():
- self._do_close()
- self._open = False
- def fetch_output(self, output):
- raise NotImplementedError()
- def _run_compiled(self, task_group):
- raise NotImplementedError()
- @classmethod
- def _compile_task_group(cls, task_group, setup_net_list=None):
- return task_group
- def _do_close(self):
- pass
- def __enter__(self):
- assert self._open, 'Session already closed.'
- return self
- def __exit__(self, ex_type, value, traceback):
- if ex_type is None:
- self.close()
- class LocalSession(Session):
- """
- Session that runs in a single node.
- Tasks are all remapped to run in parallel in the 'local' node.
- Currently, LocalSession runs all parallel tasks in the same workspace,
- but this behavior may change in the future. Only tasks pointing to the
- same logical node are guaranteed to always run in the same workspace.
- """
- def __init__(self, ws=None):
- Session.__init__(self)
- self._ws = ws or workspace.C.Workspace.current
- @classmethod
- def _compile_task_group(cls, task_group, setup_net_list=None):
- with Cluster():
- task = task_group.to_task()
- plan = core.Plan('task_group_plan')
- plan.AddStep(task.get_step())
- return (plan, task.output_list(), task.workspace_type())
- def _run_compiled(self, compiled):
- plan, output_list, workspace_type = compiled
- # make sure the output blobs belong to the parent workspace
- outputs = []
- for name in output_list.names():
- self._ws.create_blob(str(name))
- outputs.append(core.BlobReference(str(name)))
- output_list.set_values(outputs, _fetch_func=self._fetch_output)
- task_ws = (
- workspace.C.Workspace(self._ws)
- if workspace_type == WorkspaceType.PRIVATE else self._ws)
- with workspace.WorkspaceGuard(task_ws):
- task_ws.run(plan)
- def _fetch_output(self, output):
- return self._ws.blobs[str(output)].fetch()
|