| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- ## @package recurrent
- # Module caffe2.python.recurrent
- from caffe2.python import core, workspace
- from future.utils import viewitems, viewkeys
- def recurrent_net(
- net, cell_net, inputs, initial_cell_inputs,
- links, timestep=None, scope=None, outputs_with_grads=(0,),
- recompute_blobs_on_backward=None, forward_only=False,
- ):
- '''
- net: the main net operator should be added to
- cell_net: cell_net which is executed in a recurrent fasion
- inputs: sequences to be fed into the recurrent net. Currently only one input
- is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
- of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions
- initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
- Format for each input is:
- (cell_net_input_name, external_blob_with_data)
- links: a dictionary from cell_net input names in moment t+1 and
- output names of moment t. Currently we assume that each output becomes
- an input for the next timestep.
- timestep: name of the timestep blob to be used. If not provided "timestep"
- is used.
- scope: Internal blobs are going to be scoped in a format
- <scope_name>/<blob_name>
- If not provided we generate a scope name automatically
- outputs_with_grads : position indices of output blobs which will receive
- error gradient (from outside recurrent network) during backpropagation
- recompute_blobs_on_backward: specify a list of blobs that will be
- recomputed for backward pass, and thus need not to be
- stored for each forward timestep.
- forward_only: if True, only forward steps are executed
- '''
- assert len(inputs) == 1, "Only one input blob is supported so far"
- input_blobs = [str(i[0]) for i in inputs]
- initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
- op_name = net.NextName('recurrent')
- def s(name):
- # We have to manually scope due to our internal/external blob
- # relationships.
- scope_name = op_name if scope is None else scope
- return "{}/{}".format(str(scope_name), str(name))
- # determine inputs that are considered to be references
- # it is those that are not referred to in inputs or initial_cell_inputs
- known_inputs = [str(b) for b in input_blobs + initial_input_blobs]
- known_inputs += [str(x[0]) for x in initial_cell_inputs]
- if timestep is not None:
- known_inputs.append(str(timestep))
- references = [
- core.BlobReference(b) for b in cell_net.Proto().external_input
- if b not in known_inputs]
- inner_outputs = list(cell_net.Proto().external_output)
- # These gradients are expected to be available during the backward pass
- inner_outputs_map = {o: o + '_grad' for o in inner_outputs}
- # compute the backward pass of the cell net
- if not forward_only:
- backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
- cell_net.Proto().op, inner_outputs_map)
- backward_mapping = {str(k): v for k, v in viewitems(backward_mapping)}
- backward_cell_net = core.Net("RecurrentBackwardStep")
- del backward_cell_net.Proto().op[:]
- if recompute_blobs_on_backward is not None:
- # Insert operators to re-compute the specified blobs.
- # They are added in the same order as for the forward pass, thus
- # the order is correct.
- recompute_blobs_on_backward = {str(b) for b in
- recompute_blobs_on_backward}
- for op in cell_net.Proto().op:
- if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
- backward_cell_net.Proto().op.extend([op])
- # This fires if other outputs than the declared
- # are computed by the ops that are recomputed
- assert set(op.output).issubset(recompute_blobs_on_backward)
- backward_cell_net.Proto().op.extend(backward_ops)
- # compute blobs used but not defined in the backward pass
- backward_ssa, backward_blob_versions = core.get_ssa(
- backward_cell_net.Proto())
- undefined = core.get_undefined_blobs(backward_ssa)
- # also add to the output list the intermediate outputs of fwd_step that
- # are used by backward.
- ssa, blob_versions = core.get_ssa(cell_net.Proto())
- scratches = [
- blob
- for blob, ver in viewitems(blob_versions)
- if (ver > 0 and
- blob in undefined and
- blob not in cell_net.Proto().external_output)
- ]
- backward_cell_net.Proto().external_input.extend(scratches)
- backward_cell_net.Proto().type = 'simple'
- else:
- backward_cell_net = None
- all_inputs = [i[1] for i in inputs] + [
- x[1] for x in initial_cell_inputs] + references
- all_outputs = []
- cell_net.Proto().type = 'simple'
- # Internal arguments used by RecurrentNetwork operator
- # Links are in the format blob_name, recurrent_states, offset.
- # In the moment t we know that corresponding data block is at
- # t + offset position in the recurrent_states tensor
- forward_links = []
- backward_links = []
- # Aliases are used to expose outputs to external world
- # Format (internal_blob, external_blob, offset)
- # Negative offset stands for going from the end,
- # positive - from the beginning
- aliases = []
- # States held inputs to the cell net
- recurrent_states = []
- for cell_input, _ in initial_cell_inputs:
- cell_input = str(cell_input)
- # Recurrent_states is going to be (T + 1) x ...
- # It stores all inputs and outputs of the cell net over time.
- # Or their gradients in the case of the backward pass.
- state = s(cell_input + "_states")
- states_grad = state + "_grad"
- cell_output = links[str(cell_input)]
- forward_links.append((cell_input, state, 0))
- forward_links.append((cell_output, state, 1))
- aliases.append((state, cell_output + "_all", 1))
- aliases.append((state, cell_output + "_last", -1))
- all_outputs.extend([cell_output + "_all", cell_output + "_last"])
- recurrent_states.append(state)
- if backward_cell_net is not None:
- backward_links.append((cell_output + "_grad", states_grad, 1))
- backward_cell_net.Proto().external_input.append(
- str(cell_output) + "_grad")
- recurrent_input_grad = cell_input + "_grad"
- if not backward_blob_versions.get(recurrent_input_grad, 0):
- # If nobody writes to this recurrent input gradient, we need
- # to make sure it gets to the states grad blob after all.
- # We do this by using backward_links which triggers an alias
- # This logic is being used for example in a SumOp case
- backward_links.append(
- (backward_mapping[cell_input], states_grad, 0))
- else:
- backward_links.append((recurrent_input_grad, states_grad, 0))
- for input_t, input_blob in inputs:
- forward_links.append((str(input_t), str(input_blob), 0))
- if backward_cell_net is not None:
- for input_t, input_blob in inputs:
- backward_links.append((
- backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
- ))
- backward_cell_net.Proto().external_input.extend(
- cell_net.Proto().external_input)
- backward_cell_net.Proto().external_input.extend(
- cell_net.Proto().external_output)
- def unpack_triple(x):
- if x:
- a, b, c = zip(*x)
- return a, b, c
- return [], [], []
- # Splitting to separate lists so we can pass them to c++
- # where we ensemle them back
- link_internal, link_external, link_offset = unpack_triple(forward_links)
- alias_src, alias_dst, alias_offset = unpack_triple(aliases)
- recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]
- # Make sure that recurrent gradients accumulate with internal gradients
- # (if a blob in the backward_cell_net receives gradient from both an
- # external connection as well as from within the backward_cell_net,
- # those gradients need to be added together, rather than one overwriting
- # the other)
- if backward_cell_net is not None:
- proto = backward_cell_net.Proto()
- operators = []
- while len(proto.op) > 0:
- op = proto.op[-1]
- proto.op.remove(op)
- operators.append(op)
- for op in operators[::-1]:
- proto.op.extend([op])
- for j, output_blob in enumerate(op.output):
- if output_blob in proto.external_input:
- # In place operation won't cause issues because it takes
- # existing value of a blob into account
- if output_blob in op.input:
- continue
- output_blob = core.BlobReference(output_blob)
- accum_blob = output_blob + "_accum"
- proto.op[-1].output[j] = str(accum_blob)
- backward_cell_net.Sum(
- [output_blob, accum_blob],
- [output_blob],
- )
- def map_to_dual_list(m):
- return [str(x) for x in list(m.keys())] + \
- [str(x) for x in list(m.values())]
- backward_args = {}
- if backward_cell_net is not None:
- backward_mapping_keys = set(viewkeys(backward_mapping))
- backward_link_internal, backward_link_external, backward_link_offset = \
- unpack_triple(backward_links)
- params = [x for x in references if x in backward_mapping_keys]
- param_grads = [
- str(backward_mapping[x])
- for x in references
- if x in backward_mapping_keys
- ]
- if recompute_blobs_on_backward is None:
- recompute_blobs_on_backward = set()
- backward_args = {
- 'param': [all_inputs.index(p) for p in params],
- 'backward_link_internal': [str(l) for l in backward_link_internal],
- 'backward_link_external': [str(l) for l in backward_link_external],
- 'backward_link_offset': backward_link_offset,
- 'outputs_with_grads': outputs_with_grads,
- 'recompute_blobs_on_backward': [
- str(b) for b in recompute_blobs_on_backward
- ],
- 'param_grads': param_grads,
- }
- if len(backward_cell_net.Proto().op) != 0:
- backward_args['backward_step_net'] = backward_cell_net.Proto()
- results = net.RecurrentNetwork(
- all_inputs,
- all_outputs + [s("step_workspaces")],
- alias_src=alias_src,
- alias_dst=[str(a) for a in alias_dst],
- alias_offset=alias_offset,
- recurrent_states=recurrent_states,
- initial_recurrent_state_ids=[
- all_inputs.index(i) for i in recurrent_inputs
- ],
- link_internal=[str(l) for l in link_internal],
- link_external=[str(l) for l in link_external],
- link_offset=link_offset,
- enable_rnn_executor=1,
- step_net=cell_net.Proto(),
- timestep="timestep" if timestep is None else str(timestep),
- **backward_args
- )
- # Restore net type since 'rnn' is not recognized outside RNNs
- cell_net.Proto().type = 'simple'
- # The last output is a list of step workspaces,
- # which is only needed internally for gradient propogation
- return results[:-1]
- def set_rnn_executor_config(rnn_op, num_threads=None, max_cuda_streams=None):
- from caffe2.proto import caffe2_pb2
- assert rnn_op.type in {'RecurrentNetwork', 'RecurrentNetworkGradient'}
- def add_arg(s, v):
- a = caffe2_pb2.Argument()
- a.name = "rnn_executor." + s
- a.i = v
- rnn_op.arg.extend([a])
- if num_threads is not None:
- add_arg('num_threads', num_threads)
- if max_cuda_streams is not None:
- add_arg('max_cuda_streams', max_cuda_streams)
- def retrieve_step_blobs(net, prefix='rnn'):
- '''
- Retrieves blobs from step workspaces (which contain intermediate recurrent
- network computation for each timestep) and puts them in the global
- workspace. This allows access to the contents of this intermediate
- computation in python. Returns the list of extracted blob names.
- net: the net from which the step workspace blobs should be extracted
- prefix: prefix to append to extracted blob names when placing them in the
- global workspace
- '''
- count = 1
- output_list = []
- for op in net.Proto().op:
- if op.type == "RecurrentNetwork":
- blob_name = prefix + "_" + str(count)
- count = count + 1
- scratch_workspaces_blob_name = op.output[-1]
- workspace.RunOperatorOnce(
- core.CreateOperator(
- "RecurrentNetworkBlobFetcher",
- [scratch_workspaces_blob_name],
- [blob_name],
- prefix=prefix
- )
- )
- output_list += workspace.FetchBlob(blob_name).tolist()
- return output_list
|