| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575 |
- ## @package control
- # Module caffe2.python.control
- """
- Implement functions for controlling execution of nets and steps, including
- Do
- DoParallel
- For-loop
- While-loop
- Do-While-loop
- Switch
- If
- """
- from caffe2.python import core
- from future.utils import viewitems
- # Used to generate names of the steps created by the control functions.
- # It is actually the internal index of these steps.
- _current_idx = 1
- _used_step_names = set()
- def _get_next_step_name(control_name, base_name):
- global _current_idx, _used_step_names
- concat_name = '%s/%s' % (base_name, control_name)
- next_name = concat_name
- while next_name in _used_step_names:
- next_name = '%s_%d' % (concat_name, _current_idx)
- _current_idx += 1
- _used_step_names.add(next_name)
- return next_name
- def _MakeList(input):
- """ input is a tuple.
- Example:
- (a, b, c) --> [a, b, c]
- (a) --> [a]
- ([a, b, c]) --> [a, b, c]
- """
- if len(input) == 0:
- raise ValueError(
- 'input cannot be empty.')
- elif len(input) == 1:
- output = input[0]
- if not isinstance(output, list):
- output = [output]
- else:
- output = list(input)
- return output
- def _IsNets(nets_or_steps):
- if isinstance(nets_or_steps, list):
- return all(isinstance(n, core.Net) for n in nets_or_steps)
- else:
- return isinstance(nets_or_steps, core.Net)
- def _PrependNets(nets_or_steps, *nets):
- nets_or_steps = _MakeList((nets_or_steps,))
- nets = _MakeList(nets)
- if _IsNets(nets_or_steps):
- return nets + nets_or_steps
- else:
- return [Do('prepend', nets)] + nets_or_steps
- def _AppendNets(nets_or_steps, *nets):
- nets_or_steps = _MakeList((nets_or_steps,))
- nets = _MakeList(nets)
- if _IsNets(nets_or_steps):
- return nets_or_steps + nets
- else:
- return nets_or_steps + [Do('append', nets)]
- def GetConditionBlobFromNet(condition_net):
- """
- The condition blob is the last external_output that must
- be a single bool
- """
- assert len(condition_net.Proto().external_output) > 0, (
- "Condition net %s must has at least one external output" %
- condition_net.Proto.name)
- # we need to use a blob reference here instead of a string
- # otherwise, it will add another name_scope to the input later
- # when we create new ops (such as OR of two inputs)
- return core.BlobReference(condition_net.Proto().external_output[-1])
- def BoolNet(*blobs_with_bool_value):
- """A net assigning constant bool values to blobs. It is mainly used for
- initializing condition blobs, for example, in multi-task learning, we
- need to access reader_done blobs before reader_net run. In that case,
- the reader_done blobs must be initialized.
- Args:
- blobs_with_bool_value: one or more (blob, bool_value) pairs. The net will
- assign each bool_value to the corresponding blob.
- returns
- bool_net: A net assigning constant bool values to blobs.
- Examples:
- - BoolNet((blob_1, bool_value_1), ..., (blob_n, bool_value_n))
- - BoolNet([(blob_1, net1), ..., (blob_n, bool_value_n)])
- - BoolNet((cond_1, bool_value_1))
- """
- blobs_with_bool_value = _MakeList(blobs_with_bool_value)
- bool_net = core.Net('bool_net')
- for blob, bool_value in blobs_with_bool_value:
- out_blob = bool_net.ConstantFill(
- [],
- [blob],
- shape=[],
- value=bool_value,
- dtype=core.DataType.BOOL)
- bool_net.AddExternalOutput(out_blob)
- return bool_net
- def NotNet(condition_blob_or_net):
- """Not of a condition blob or net
- Args:
- condition_blob_or_net can be either blob or net. If condition_blob_or_net
- is Net, the condition is its last external_output
- that must be a single bool.
- returns
- not_net: the net NOT the input
- out_blob: the output blob of the not_net
- """
- if isinstance(condition_blob_or_net, core.Net):
- condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
- else:
- condition_blob = condition_blob_or_net
- not_net = core.Net('not_net')
- out_blob = not_net.Not(condition_blob)
- not_net.AddExternalOutput(out_blob)
- return not_net, out_blob
- def _CopyConditionBlobNet(condition_blob):
- """Make a condition net that copies the condition_blob
- Args:
- condition_blob is a single bool.
- returns
- not_net: the net NOT the input
- out_blob: the output blob of the not_net
- """
- condition_net = core.Net('copy_condition_blob_net')
- out_blob = condition_net.Copy(condition_blob)
- condition_net.AddExternalOutput(out_blob)
- return condition_net, out_blob
- def MergeConditionNets(name, condition_nets, relation):
- """
- Merge multi condition nets into a single condition nets.
- Args:
- name: name of the new condition net.
- condition_nets: a list of condition nets. The last external_output
- of each condition net must be single bool value.
- relation: can be 'And' or 'Or'.
- Returns:
- - A new condition net. Its last external output is relation of all
- condition_nets.
- """
- if not isinstance(condition_nets, list):
- return condition_nets
- if len(condition_nets) <= 1:
- return condition_nets[0] if condition_nets else None
- merged_net = core.Net(name)
- for i in range(len(condition_nets)):
- net_proto = condition_nets[i].Proto()
- assert net_proto.device_option == merged_net.Proto().device_option
- assert net_proto.type == merged_net.Proto().type
- merged_net.Proto().op.extend(net_proto.op)
- merged_net.Proto().external_input.extend(net_proto.external_input)
- # discard external outputs as we're combining them together
- curr_cond = GetConditionBlobFromNet(condition_nets[i])
- if i == 0:
- last_cond = curr_cond
- else:
- last_cond = merged_net.__getattr__(relation)([last_cond, curr_cond])
- # merge attributes
- for k, v in viewitems(condition_nets[i]._attr_dict):
- merged_net._attr_dict[k] += v
- merged_net.AddExternalOutput(last_cond)
- return merged_net
- def CombineConditions(name, condition_nets, relation):
- """
- Combine conditions of multi nets into a single condition nets. Unlike
- MergeConditionNets, the actual body of condition_nets is not copied into
- the combine condition net.
- One example is about multi readers. Each reader net has a reader_done
- condition. When we want to check whether all readers are done, we can
- use this function to build a new net.
- Args:
- name: name of the new condition net.
- condition_nets: a list of condition nets. The last external_output
- of each condition net must be single bool value.
- relation: can be 'And' or 'Or'.
- Returns:
- - A new condition net. Its last external output is relation of all
- condition_nets.
- """
- if not condition_nets:
- return None
- if not isinstance(condition_nets, list):
- raise ValueError('condition_nets must be a list of nets.')
- if len(condition_nets) == 1:
- condition_blob = GetConditionBlobFromNet(condition_nets[0])
- condition_net, _ = _CopyConditionBlobNet(condition_blob)
- return condition_net
- combined_net = core.Net(name)
- for i in range(len(condition_nets)):
- curr_cond = GetConditionBlobFromNet(condition_nets[i])
- if i == 0:
- last_cond = curr_cond
- else:
- last_cond = combined_net.__getattr__(relation)(
- [last_cond, curr_cond])
- combined_net.AddExternalOutput(last_cond)
- return combined_net
- def Do(name, *nets_or_steps):
- """
- Execute the sequence of nets or steps once.
- Examples:
- - Do('myDo', net1, net2, ..., net_n)
- - Do('myDo', list_of_nets)
- - Do('myDo', step1, step2, ..., step_n)
- - Do('myDo', list_of_steps)
- """
- nets_or_steps = _MakeList(nets_or_steps)
- if (len(nets_or_steps) == 1 and isinstance(
- nets_or_steps[0], core.ExecutionStep)):
- return nets_or_steps[0]
- else:
- return core.scoped_execution_step(
- _get_next_step_name('Do', name), nets_or_steps)
- def DoParallel(name, *nets_or_steps):
- """
- Execute the nets or steps in parallel, waiting for all of them to finish
- Examples:
- - DoParallel('pDo', net1, net2, ..., net_n)
- - DoParallel('pDo', list_of_nets)
- - DoParallel('pDo', step1, step2, ..., step_n)
- - DoParallel('pDo', list_of_steps)
- """
- nets_or_steps = _MakeList(nets_or_steps)
- if (len(nets_or_steps) == 1 and isinstance(
- nets_or_steps[0], core.ExecutionStep)):
- return nets_or_steps[0]
- else:
- return core.scoped_execution_step(
- _get_next_step_name('DoParallel', name),
- nets_or_steps,
- concurrent_substeps=True)
- def _RunOnceIf(name, condition_blob_or_net, nets_or_steps):
- """
- Execute nets_or_steps once if condition_blob_or_net evaluates as true.
- If condition_blob_or_net is Net, the condition is its last external_output
- that must be a single bool. And this net will be executed before
- nets_or_steps so as to get the condition.
- """
- condition_not_net, stop_blob = NotNet(condition_blob_or_net)
- if isinstance(condition_blob_or_net, core.Net):
- nets_or_steps = _PrependNets(
- nets_or_steps, condition_blob_or_net, condition_not_net)
- else:
- nets_or_steps = _PrependNets(nets_or_steps, condition_not_net)
- def if_step(control_name):
- return core.scoped_execution_step(
- _get_next_step_name(control_name, name),
- nets_or_steps,
- should_stop_blob=stop_blob,
- only_once=True,
- )
- if _IsNets(nets_or_steps):
- bool_net = BoolNet((stop_blob, False))
- return Do(name + '/_RunOnceIf',
- bool_net, if_step('_RunOnceIf-inner'))
- else:
- return if_step('_RunOnceIf')
- def _RunOnceIfNot(name, condition_blob_or_net, nets_or_steps):
- """
- Similar to _RunOnceIf() but Execute nets_or_steps once if
- condition_blob_or_net evaluates as false.
- """
- if isinstance(condition_blob_or_net, core.Net):
- condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
- nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net)
- else:
- copy_net, condition_blob = _CopyConditionBlobNet(condition_blob_or_net)
- nets_or_steps = _PrependNets(nets_or_steps, copy_net)
- return core.scoped_execution_step(
- _get_next_step_name('_RunOnceIfNot', name),
- nets_or_steps,
- should_stop_blob=condition_blob,
- only_once=True,
- )
- def For(name, nets_or_steps, iter_num):
- """
- Execute nets_or_steps iter_num times.
- Args:
- nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or
- a list nets.
- iter_num: the number times to execute the nets_or_steps.
- Returns:
- A ExecutionStep instance.
- """
- init_net = core.Net('init-net')
- iter_cnt = init_net.CreateCounter([], init_count=iter_num)
- iter_net = core.Net('For-iter')
- iter_done = iter_net.CountDown([iter_cnt])
- for_step = core.scoped_execution_step(
- _get_next_step_name('For-inner', name),
- _PrependNets(nets_or_steps, iter_net),
- should_stop_blob=iter_done)
- return Do(name + '/For',
- Do(name + '/For-init-net', init_net),
- for_step)
- def While(name, condition_blob_or_net, nets_or_steps):
- """
- Execute nets_or_steps when condition_blob_or_net returns true.
- Args:
- condition_blob_or_net: If it is an instance of Net, its last
- external_output must be a single bool.
- nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or
- a list nets.
- Returns:
- A ExecutionStep instance.
- """
- condition_not_net, stop_blob = NotNet(condition_blob_or_net)
- if isinstance(condition_blob_or_net, core.Net):
- nets_or_steps = _PrependNets(
- nets_or_steps, condition_blob_or_net, condition_not_net)
- else:
- nets_or_steps = _PrependNets(nets_or_steps, condition_not_net)
- def while_step(control_name):
- return core.scoped_execution_step(
- _get_next_step_name(control_name, name),
- nets_or_steps,
- should_stop_blob=stop_blob,
- )
- if _IsNets(nets_or_steps):
- # In this case, while_step has sub-nets:
- # [condition_blob_or_net, condition_not_net, nets_or_steps]
- # If stop_blob is pre-set to True (this may happen when While() is
- # called twice), the loop will exit after executing
- # condition_blob_or_net. So we use BootNet to set stop_blob to
- # False.
- bool_net = BoolNet((stop_blob, False))
- return Do(name + '/While', bool_net, while_step('While-inner'))
- else:
- return while_step('While')
- def Until(name, condition_blob_or_net, nets_or_steps):
- """
- Similar to While() but execute nets_or_steps when
- condition_blob_or_net returns false
- """
- if isinstance(condition_blob_or_net, core.Net):
- stop_blob = GetConditionBlobFromNet(condition_blob_or_net)
- nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net)
- else:
- stop_blob = core.BlobReference(str(condition_blob_or_net))
- return core.scoped_execution_step(
- _get_next_step_name('Until', name),
- nets_or_steps,
- should_stop_blob=stop_blob)
- def DoWhile(name, condition_blob_or_net, nets_or_steps):
- """
- Execute nets_or_steps when condition_blob_or_net returns true. It will
- execute nets_or_steps before evaluating condition_blob_or_net.
- Args:
- condition_blob_or_net: if it is an instance of Net, tts last external_output
- must be a single bool.
- nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or
- a list nets.
- Returns:
- A ExecutionStep instance.
- """
- condition_not_net, stop_blob = NotNet(condition_blob_or_net)
- if isinstance(condition_blob_or_net, core.Net):
- nets_or_steps = _AppendNets(
- nets_or_steps, condition_blob_or_net, condition_not_net)
- else:
- nets_or_steps = _AppendNets(nets_or_steps, condition_not_net)
- # If stop_blob is pre-set to True (this may happen when DoWhile() is
- # called twice), the loop will exit after executing the first net/step
- # in nets_or_steps. This is not what we want. So we use BootNet to
- # set stop_blob to False.
- bool_net = BoolNet((stop_blob, False))
- return Do(name + '/DoWhile', bool_net, core.scoped_execution_step(
- _get_next_step_name('DoWhile-inner', name),
- nets_or_steps,
- should_stop_blob=stop_blob,
- ))
- def DoUntil(name, condition_blob_or_net, nets_or_steps):
- """
- Similar to DoWhile() but execute nets_or_steps when
- condition_blob_or_net returns false. It will execute
- nets_or_steps before evaluating condition_blob_or_net.
- Special case: if condition_blob_or_net is a blob and is pre-set to
- true, then only the first net/step of nets_or_steps will be executed and
- loop is exited. So you need to be careful about the initial value the
- condition blob when using DoUntil(), esp when DoUntil() is called twice.
- """
- if not isinstance(condition_blob_or_net, core.Net):
- stop_blob = core.BlobReference(condition_blob_or_net)
- return core.scoped_execution_step(
- _get_next_step_name('DoUntil', name),
- nets_or_steps,
- should_stop_blob=stop_blob)
- nets_or_steps = _AppendNets(nets_or_steps, condition_blob_or_net)
- stop_blob = GetConditionBlobFromNet(condition_blob_or_net)
- # If stop_blob is pre-set to True (this may happen when DoWhile() is
- # called twice), the loop will exit after executing the first net/step
- # in nets_or_steps. This is not what we want. So we use BootNet to
- # set stop_blob to False.
- bool_net = BoolNet((stop_blob, False))
- return Do(name + '/DoUntil', bool_net, core.scoped_execution_step(
- _get_next_step_name('DoUntil-inner', name),
- nets_or_steps,
- should_stop_blob=stop_blob,
- ))
- def Switch(name, *conditions):
- """
- Execute the steps for which the condition is true.
- Each condition is a tuple (condition_blob_or_net, nets_or_steps).
- Note:
- 1. Multi steps can be executed if their conditions are true.
- 2. The conditions_blob_or_net (if it is Net) of all steps will be
- executed once.
- Examples:
- - Switch('name', (cond_1, net_1), (cond_2, net_2), ..., (cond_n, net_n))
- - Switch('name', [(cond_1, net1), (cond_2, net_2), ..., (cond_n, net_n)])
- - Switch('name', (cond_1, net_1))
- """
- conditions = _MakeList(conditions)
- return core.scoped_execution_step(
- _get_next_step_name('Switch', name),
- [_RunOnceIf(name + '/Switch', cond, step) for cond, step in conditions])
- def SwitchNot(name, *conditions):
- """
- Similar to Switch() but execute the steps for which the condition is False.
- """
- conditions = _MakeList(conditions)
- return core.scoped_execution_step(
- _get_next_step_name('SwitchNot', name),
- [_RunOnceIfNot(name + '/SwitchNot', cond, step)
- for cond, step in conditions])
- def If(name, condition_blob_or_net,
- true_nets_or_steps, false_nets_or_steps=None):
- """
- condition_blob_or_net is first evaluated or executed. If the condition is
- true, true_nets_or_steps is then executed, otherwise, false_nets_or_steps
- is executed.
- If condition_blob_or_net is Net, the condition is its last external_output
- that must be a single bool. And this Net will be executred before both
- true/false_nets_or_steps so as to get the condition.
- """
- if not false_nets_or_steps:
- return _RunOnceIf(name + '/If',
- condition_blob_or_net, true_nets_or_steps)
- if isinstance(condition_blob_or_net, core.Net):
- condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
- else:
- condition_blob = condition_blob_or_net
- return Do(
- name + '/If',
- _RunOnceIf(name + '/If-true',
- condition_blob_or_net, true_nets_or_steps),
- _RunOnceIfNot(name + '/If-false', condition_blob, false_nets_or_steps)
- )
- def IfNot(name, condition_blob_or_net,
- true_nets_or_steps, false_nets_or_steps=None):
- """
- If condition_blob_or_net returns false, executes true_nets_or_steps,
- otherwise executes false_nets_or_steps
- """
- if not false_nets_or_steps:
- return _RunOnceIfNot(name + '/IfNot',
- condition_blob_or_net, true_nets_or_steps)
- if isinstance(condition_blob_or_net, core.Net):
- condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
- else:
- condition_blob = condition_blob_or_net
- return Do(
- name + '/IfNot',
- _RunOnceIfNot(name + '/IfNot-true',
- condition_blob_or_net, true_nets_or_steps),
- _RunOnceIf(name + '/IfNot-false', condition_blob, false_nets_or_steps)
- )
|