| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706 |
- ## @package control_ops_grad
- # Module caffe2.python.control_ops_grad
- from caffe2.proto import caffe2_pb2
- def gen_do_gradient(op, g_output):
- """
- Generates gradient Do operator, given forward Do op and a list
- of gradient blobs corresponding to forward op's outputs
- Returns a gradient op and a list of blobs corresponding to input gradients
- """
- from caffe2.python.core import BlobReference
- subnet, outer_to_inner_map, inner_to_outer_map, workspace_blob_name = \
- _do_op_sanity_check_and_process(op)
- assert len(g_output) == len(op.output), \
- "Different number of gradient blobs and Do op outputs"
- grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
- g_output = deduped_g_output
- # From the outer net point of view:
- # Do is an operator that has some number of inputs and outputs;
- # we have to generate a gradient operator that writes into
- # corresponding input gradient blobs and has access to inputs, outputs
- # and gradient output blobs
- # From the inner net point of view:
- # Do is an operator with a subnet and blob bindings,
- # we need to forward Do's output blob gradients into inner workspace,
- # use them to run backward pass generation and forward Do's input blob
- # gradients back into outer workspace
- op_output = [str(o) for o in op.output]
- op_output = op_output[:-1] # remove workspace pointer blob
- op_input = [str(i) for i in op.input]
- op_input = op_input[:-1] # remove workspace pointer blob
- ordered_inner_output_blob_names = [outer_to_inner_map[o] for o in op_output]
- backward_pass_initial_grad_map = {}
- initial_grad_map = {}
- for inner_output_name, outer_grad_output_name in \
- zip(ordered_inner_output_blob_names, g_output):
- # link inner_output_name to corresponding inner_grad_output_name for
- # backward pass generation;
- if outer_grad_output_name:
- inner_grad_output_name = inner_output_name + "/_DO_OPERATOR_INNER_GRAD_"
- backward_pass_initial_grad_map[BlobReference(inner_output_name)] = \
- BlobReference(inner_grad_output_name)
- initial_grad_map[inner_grad_output_name] = str(outer_grad_output_name)
- assert len(initial_grad_map) > 0, "Empty initial gradient map for Do op"
- inner_grad_ops, inner_grad_names_map = _gen_subgradient_pass(
- subnet, backward_pass_initial_grad_map)
- if len(inner_grad_ops) == 0:
- return [], []
- grad_copy_ops = []
- g_input = []
- new_op_outputs = []
- new_blob_bindings = {}
- for outer_input_name in op_input:
- inner_input_name = outer_to_inner_map[outer_input_name]
- if inner_input_name in inner_grad_names_map:
- inner_grad_input_name = inner_grad_names_map[inner_input_name]
- outer_grad_input_name = outer_input_name + "_grad"
- # It is possible that inner_grad_input_name will need to be
- # linked to another outer blob. For example:
- #
- # // y - param initialized in init_net
- # x = ...
- # z = ...
- # with ops.IfNet(...):
- # ops.Add([z, x], y) # inner Do block
- # loss = f(..., y, ...)
- #
- # In this case x, y and z are external for the inner Do block,
- # the inputs of the Do block are z and x and the output is y.
- # When computing the gradient of input x given the gradient
- # of output y it's easy to see that they are equal.
- # During the generation of gradient Do operator, we link
- # external gradient y (y_grad) to the internal name
- # (y/_DO_OPERATOR_INNER_GRAD_) and generate the backward pass
- # for the internal Do net. As a result we get gradient operators
- # for the gradient Do and gradient map that maps internal Do
- # blobs to their computed gradients.
- # In this example, gradient map may have blob x linked to
- # gradient blob y/_DO_OPERATOR_INNER_GRAD_.
- # We should export gradient for x outside of Do, so
- # we add a blob mapping from inner gradient blob
- # (y/_DO_OPERATOR_INNER_GRAD_) to a new outer name (x_grad).
- #
- # (Note: since we use transparent blob mapping between outer and
- # inner (Do's) workspace, these operations do not involve copying
- # but are merely using blobs in outer workspace in the Do's operator
- # workspace under (possibly) different names)
- #
- # At the same time, we need to add a blob mapping from inner name
- # y/_DO_OPERATOR_INNER_GRAD_ to the outer blob y_grad
- # Hence in this case, we cannot use existing blob mapping scheme
- # that requires a bijection between subset of inner blob names and
- # a set of all (Do's input and output) outer blob names
- # TODO(iliacher): Remove unnecessary blob copying
- new_inner_grad_input_name = \
- inner_input_name + "/_DO_OPERATOR_INNER_GRAD_COPY_"
- grad_copy_ops.append(_prepare_blob_copy_op(
- inner_grad_input_name, new_inner_grad_input_name))
- new_blob_bindings[new_inner_grad_input_name] = outer_grad_input_name
- new_op_outputs.append(outer_grad_input_name)
- g_input.append(outer_grad_input_name)
- else:
- g_input.append(None)
- new_op_inputs = []
- overwritten_names = set()
- saved_local_blob_names = set()
- for grad_op in inner_grad_ops:
- grad_op_input = [str(i) for i in grad_op.input]
- grad_op_output = [str(o) for o in grad_op.output]
- for grad_op_input_name in grad_op_input:
- if grad_op_input_name in overwritten_names:
- continue
- # check if this is an external blob
- outer_name = inner_to_outer_map.get(grad_op_input_name, None)
- if not outer_name:
- # check if this is an external gradient blob
- outer_name = initial_grad_map.get(grad_op_input_name, None)
- if outer_name:
- outer_name = str(outer_name)
- if outer_name not in new_op_inputs:
- new_op_inputs.append(outer_name)
- new_blob_bindings[grad_op_input_name] = outer_name
- else:
- # this is a local blob, we'll get it's value from
- # a saved forward op workspace
- saved_local_blob_names.add(grad_op_input_name)
- overwritten_names.update(grad_op_output)
- # add inner gradient copy ops
- inner_grad_ops += grad_copy_ops
- gradient_do_def = _prepare_gradient_do_op(
- fwd_op=op,
- fwd_net=subnet,
- grad_ops=inner_grad_ops,
- inputs=new_op_inputs,
- outputs=new_op_outputs,
- blob_bindings=new_blob_bindings,
- saved_fwd_blobs=saved_local_blob_names,
- workspace_blob_name=workspace_blob_name)
- grad_ops.append(gradient_do_def)
- _do_op_sanity_check_and_process(gradient_do_def)
- return grad_ops, g_input
- def dedupe_g_output(op, g_output):
- # When generation a gradient op it's possible to receive the same gradient
- # blob corresponding to different forward op output blobs, Do operator
- # requires a bijection between inner and outer names, make sure we do
- # deduplication
- grad_ops = []
- deduped_g_output = []
- init_grad_map = {}
- for output_name, grad_name in zip(op.output, g_output):
- if not grad_name:
- deduped_g_output.append(grad_name)
- continue
- if output_name in init_grad_map:
- deduped_g_output.append(init_grad_map[output_name])
- else:
- if grad_name not in init_grad_map.values():
- init_grad_map[output_name] = grad_name
- deduped_g_output.append(grad_name)
- else:
- deduped_grad_name = output_name + "_" + grad_name + "_DEDUP"
- assert deduped_grad_name not in init_grad_map.values()
- grad_copy_op = caffe2_pb2.OperatorDef()
- grad_copy_op.type = "Copy"
- grad_copy_op.input.extend([grad_name])
- grad_copy_op.output.extend([deduped_grad_name])
- grad_ops.append(grad_copy_op)
- deduped_g_output.append(deduped_grad_name)
- init_grad_map[output_name] = deduped_grad_name
- return grad_ops, deduped_g_output
- def gen_while_gradient(op, g_output):
- """
- Generates gradient While operator
- """
- from caffe2.python.core import BlobReference
- assert op.type == "While", "Expected While op"
- assert len(op.input) > 0, "Expected at least one input in While op"
- assert len(op.output) == len(g_output), \
- "Different number of gradient blobs and While op outputs"
- grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
- g_output = deduped_g_output
- init_grad_map = {}
- op_output = [str(o) for o in op.output]
- for output_name, grad_output_name in zip(op_output, g_output):
- if grad_output_name:
- init_grad_map[BlobReference(output_name)] = \
- BlobReference(grad_output_name)
- assert len(init_grad_map) > 0, "Empty initial gradient map for While op"
- loop_net = _get_net_argument(op, "loop_net")
- assert loop_net, "Expected loop subnet in While op"
- assert len(loop_net.op) == 1 and loop_net.op[0].type == "Do", \
- "Gradient While op requires single Do op as a loop body"
- do_op = loop_net.op[0]
- do_args = _get_do_arguments(do_op)
- assert "reuse_workspace" not in do_args or not do_args["reuse_workspace"], \
- "Gradient While op requires Do loop body op without reuse_workspace set"
- assert len(do_op.output) > 0, "Expected Do op with at least one output"
- workspace_blob = do_op.output[-1]
- loop_grad_net, loop_grad_map, loop_input_names, loop_output_names = \
- _gen_subnet_gradient(loop_net, init_grad_map)
- assert loop_grad_net, "Failed to get gradient net for loop body in While op"
- grad_ops += _prepare_gradient_while_ops(
- fwd_op=op,
- input_names=loop_input_names,
- output_names=loop_output_names,
- loop_grad_net=loop_grad_net,
- workspace_blob=workspace_blob,
- init_grad_map=init_grad_map,
- loop_grad_map=loop_grad_map)
- op_input = [str(i) for i in op.input]
- g_input = [loop_grad_map.get(i, None) for i in op_input]
- return grad_ops, g_input
- # Constructs gradient While op, arguments:
- # fwd_op - forward While op
- # input_names - input blob names for a gradient op
- # output_names - output blob names for a gradient op
- # loop_grad_net - gradient loop body net
- # workspace_blob - blob that holds forward workspaces stack
- # init_grad_map - initial gradient to forward blob map
- # loop_grad_map - gradient blob map for loop's body
- def _prepare_gradient_while_ops(
- fwd_op, input_names, output_names, loop_grad_net, workspace_blob,
- init_grad_map, loop_grad_map):
- gradient_while_def = caffe2_pb2.OperatorDef()
- gradient_while_def.CopyFrom(fwd_op)
- if gradient_while_def.name:
- gradient_while_def.name += "_grad"
- loop_net_arg = caffe2_pb2.Argument()
- loop_net_arg.name = "loop_net"
- loop_net_arg.n.CopyFrom(loop_grad_net)
- cond_net_arg = caffe2_pb2.Argument()
- cond_net_arg.name = "cond_net"
- from caffe2.python.core import Net, BlobReference
- # Construct condition net - check that there're still forward workspaces
- # left using HasScope op
- cond_net = Net('gradient_loop_cond_net')
- cond_init_net = Net('gradient_loop_cond_net_init')
- cond_blob = cond_net.NextScopedBlob(cond_net.Name() + '/cond')
- cond_init_net.HasScope(workspace_blob, cond_blob)
- cond_net.HasScope(workspace_blob, cond_blob)
- for blob, init_grad_blob in init_grad_map.items():
- blob_name = str(blob)
- init_grad_blob_name = str(init_grad_blob)
- if blob_name in loop_grad_map and \
- loop_grad_map[blob_name] != init_grad_blob_name:
- cond_net.Copy(
- BlobReference(loop_grad_map[blob_name]), init_grad_blob)
- cond_init_net.Copy(
- init_grad_blob, BlobReference(loop_grad_map[blob_name]))
- cond_net_arg.n.CopyFrom(cond_net.Proto())
- del gradient_while_def.arg[:]
- gradient_while_def.arg.extend([loop_net_arg, cond_net_arg])
- del gradient_while_def.control_input[:]
- del gradient_while_def.input[:]
- gradient_while_def.input.extend(
- [str(cond_blob).encode('utf-8')] + list(input_names))
- del gradient_while_def.output[:]
- gradient_while_def.output.extend(output_names)
- gradient_while_def.is_gradient_op = True
- return [o for o in cond_init_net.Proto().op] + [gradient_while_def]
- def _get_do_arguments(do_op):
- assert do_op.type == "Do", "Expected Do op"
- args = {}
- for arg in do_op.arg:
- if not arg.name:
- continue
- if arg.name == "net":
- assert arg.n, "Expected non empty net argument"
- args["net"] = arg.n
- elif arg.name == "reuse_workspace":
- assert arg.i, "Expected non empty reuse_workspace argument"
- args["reuse_workspace"] = bool(arg.i)
- elif arg.name == "inner_blobs":
- assert arg.strings, "Expected non empty inner_blobs argument"
- args["inner_blobs"] = arg.strings
- elif arg.name == "outer_blobs_idx":
- assert arg.ints, "Expected non empty outer_blobs_idx argument"
- args["outer_blobs_idx"] = arg.ints
- return args
- def gen_if_gradient(op, g_output):
- """
- Generates gradient If operator, given forward If op and a list
- of gradient blobs corresponding to forward op's outputs
- Returns a gradient op and a list of blobs corresponding to input gradients
- """
- from caffe2.python.core import BlobReference
- assert op.type == "If", "Expected If op"
- # first input is the condition blob
- assert len(op.input) > 0, "Expected at least one input in If op"
- assert len(op.output) == len(g_output), \
- "Different number of gradient blobs and If op outputs"
- grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
- g_output = deduped_g_output
- init_grad_map = {} # map from if's output blob to output gradient blob
- op_input = [str(i) for i in op.input]
- op_output = [str(o) for o in op.output]
- for output_name, grad_output_name in zip(op_output, g_output):
- if grad_output_name:
- init_grad_map[BlobReference(output_name)] = \
- BlobReference(grad_output_name)
- # shouldn't call without at least one output gradient available
- assert len(init_grad_map) > 0, "Empty initial gradient map for If op"
- grad_map = {} # map from blob to gradient blob
- then_net = _get_net_argument(op, "then_net")
- assert then_net, "Expected then subnet in If op"
- then_grad_net, then_grad_map, then_input_names, then_output_names = \
- _gen_subnet_gradient(then_net, init_grad_map)
- assert then_grad_net, "Failed to get gradient net for then in If op"
- grad_map.update(then_grad_map)
- else_input_names = set()
- else_output_names = set()
- else_grad_map = {}
- else_grad_net = None
- else_net = _get_net_argument(op, "else_net")
- if else_net:
- else_grad_net, else_grad_map, else_input_names, else_output_names = \
- _gen_subnet_gradient(else_net, init_grad_map)
- assert else_grad_net, "Failed to get gradient net for else in If op"
- # consider case: else doesn't update blob's gradient and keeps original
- # from init_grad_map, but then updates the gradient
- for else_blob, else_grad_blob in else_grad_map.items():
- if else_blob in then_grad_map:
- then_grad_blob = then_grad_map[else_blob]
- # if both then and else branches have grad blob name for the same
- # blob and grad names are different, then one of the branches
- # doesn't use blob and has original grad blob name in it's grad map,
- # and another branch uses blob and has <blob_name>_grad name
- # in it's grad map (might be different from original grad blob)
- if then_grad_blob != else_grad_blob:
- init_grad_name = init_grad_map[else_blob] \
- if else_blob in init_grad_map else None
- if then_grad_blob == init_grad_name:
- grad_map[else_blob] = else_grad_blob
- elif else_grad_blob == init_grad_name:
- grad_map[else_blob] = then_grad_blob
- else:
- raise "Unexpected grad blob name " + else_blob + ", " + \
- else_grad_blob + ", " + then_grad_blob
- else:
- grad_map[else_blob] = else_grad_blob
- # make sure gradients of blobs that were not computed
- # by the selected if's branch are initialized with zeros
- then_other_output_names = \
- then_output_names - (then_output_names & else_output_names)
- then_other_grad_output_names = set(
- [o for o in then_other_output_names if o in then_grad_map.values()])
- zero_then = _gen_grad_zero_init_ops(
- init_grad_map, then_grad_map, then_other_grad_output_names)
- if else_grad_net:
- else_grad_net.op.extend(zero_then)
- elif len(zero_then) > 0:
- else_grad_net = caffe2_pb2.NetDef()
- else_grad_net.CopyFrom(then_grad_net)
- if else_grad_net.name:
- else_grad_net.name += "_auto_else_zero_blobs_"
- del else_grad_net.op[:]
- else_grad_net.op.extend(zero_then)
- del else_grad_net.external_input[:]
- del else_grad_net.external_output[:]
- else_other_output_names = \
- else_output_names - (then_output_names & else_output_names)
- else_other_grad_output_names = set(
- [o for o in else_other_output_names if o in else_grad_map.values()])
- zero_else = _gen_grad_zero_init_ops(
- init_grad_map, else_grad_map, else_other_grad_output_names)
- then_grad_net.op.extend(zero_else)
- output_names = list(then_output_names | else_output_names)
- input_names = then_input_names | else_input_names
- # make sure condition blob is the first in the list
- input_names = [op_input[0]] + list(input_names - set(op_input[0]))
- gradient_if_def = _prepare_gradient_if_op(
- fwd_op=op,
- input_names=input_names,
- output_names=output_names,
- then_grad_net=then_grad_net,
- else_grad_net=else_grad_net)
- g_input = [grad_map.get(i, None) for i in op_input]
- return grad_ops + [gradient_if_def], g_input
- def _gen_subnet_gradient(subnet, init_grad):
- grad_ops, grad_names_map = _gen_subgradient_pass(
- subnet, init_grad)
- output_names = set()
- input_names = set()
- for grad_op in grad_ops:
- for grad_op_input in grad_op.input:
- if str(grad_op_input) not in output_names:
- input_names.add(str(grad_op_input))
- for grad_op_output in grad_op.output:
- output_names.add(str(grad_op_output))
- gradient_net_def = caffe2_pb2.NetDef()
- gradient_net_def.CopyFrom(subnet)
- if gradient_net_def.name:
- gradient_net_def.name += "_grad"
- del gradient_net_def.op[:]
- gradient_net_def.op.extend(grad_ops)
- del gradient_net_def.external_input[:]
- del gradient_net_def.external_output[:]
- return gradient_net_def, grad_names_map, input_names, output_names
- def _get_net_argument(op, net_name):
- for arg in op.arg:
- if arg.name and arg.name == net_name:
- assert arg.n, "Expected non empty net argument " + net_name
- return arg.n
- return None
- def getNetArgument(op, net_name):
- """A wrapper for external call"""
- return _get_net_argument(op, net_name)
- def _gen_subgradient_pass(subnet, init_grad):
- from caffe2.python.core import IR
- subnet_ir = IR(subnet.op)
- grad_ops, grad_blob_map = \
- subnet_ir.GetBackwardPass(init_grad)
- grad_names_map = {}
- for b, g in grad_blob_map.items():
- grad_names_map[str(b)] = str(g)
- return grad_ops, grad_names_map
- def _do_op_sanity_check_and_process(op):
- assert op.type == "Do", "Expected Do op"
- subnet = _get_net_argument(op, "net")
- assert subnet, "No net argument found in Do op"
- inner_blobs = None
- outer_blobs_idx = None
- for arg in op.arg:
- if arg.name and arg.name == "inner_blobs":
- assert not inner_blobs, "inner_blobs redefinition"
- assert arg.strings and len(arg.strings) > 0, \
- "Empty inner_blobs argument in Do op"
- inner_blobs = [s.decode('utf-8') for s in arg.strings]
- if arg.name and arg.name == "outer_blobs_idx":
- assert not outer_blobs_idx, "outer_blobs_idx redefinition"
- assert arg.ints and len(arg.ints) > 0, \
- "Empty outer_blobs_idx argument in Do op"
- outer_blobs_idx = arg.ints
- if inner_blobs and outer_blobs_idx:
- break
- assert inner_blobs, "No inner_blobs argument found in Do op"
- assert outer_blobs_idx, "No outer_blobs_idx argument found in Do op"
- assert len(inner_blobs) == len(outer_blobs_idx), \
- "Arguments inner_blobs and outer_blobs_idx of different length in Do op"
- all_inner_blobs = set(inner_blobs)
- assert len(all_inner_blobs) == len(inner_blobs), \
- "Found duplicates in inner_blobs in Do op"
- op_input = [str(i) for i in op.input]
- assert len(op_input) > 0, "Expected at least one input blob"
- # remove last input blob that holds pointer to workspace
- input_workspace_blob_name = op_input[-1]
- op_input = op_input[:-1]
- op_output = [str(o) for o in op.output]
- assert len(op_output) > 0, "Expected at least one output blob"
- # remove last output blob that holds pointer to workspace
- workspace_blob_name = op_output[-1]
- assert input_workspace_blob_name == workspace_blob_name, \
- "Expected same input/output workspace blob"
- op_output = op_output[:-1]
- all_op_input_blob_names = set(op_input)
- assert len(all_op_input_blob_names) == len(op_input), \
- "Found duplicates in Do op inputs"
- all_op_output_blob_names = set(op_output)
- assert len(all_op_output_blob_names) == len(op_output), \
- "Found duplicates in Do op outputs"
- ordered_outer_blob_names = op_input + op_output
- all_outer_blob_names = set(ordered_outer_blob_names)
- used_outer_blob_names = set()
- outer_to_inner_map = {}
- inner_to_outer_map = {}
- for inner_name, outer_blob_idx in zip(inner_blobs, outer_blobs_idx):
- assert outer_blob_idx >= 0 and \
- outer_blob_idx < len(ordered_outer_blob_names), \
- "Outer blob index is out of bounds in Do op"
- outer_name = ordered_outer_blob_names[outer_blob_idx]
- assert outer_name not in used_outer_blob_names, \
- "Reusage of outer blob name " + outer_name + " in Do op"
- used_outer_blob_names.add(outer_name)
- outer_to_inner_map[outer_name] = inner_name
- inner_to_outer_map[inner_name] = outer_name
- assert len(used_outer_blob_names) == len(all_outer_blob_names), \
- "Not all outer blob names are used in blob bindings in Do op"
- return subnet, outer_to_inner_map, inner_to_outer_map, workspace_blob_name
- def _prepare_blob_copy_op(from_name, to_name):
- copy_op_def = caffe2_pb2.OperatorDef()
- copy_op_def.type = "Copy"
- copy_op_def.input.extend([from_name])
- copy_op_def.output.extend([to_name])
- return copy_op_def
- def _prepare_gradient_do_op(
- fwd_op, fwd_net, grad_ops, inputs, outputs, blob_bindings, saved_fwd_blobs,
- workspace_blob_name):
- gradient_net_def = caffe2_pb2.NetDef()
- gradient_net_def.CopyFrom(fwd_net)
- if gradient_net_def.name:
- gradient_net_def.name += "_grad"
- del gradient_net_def.op[:]
- gradient_net_def.op.extend(grad_ops)
- del gradient_net_def.external_input[:]
- del gradient_net_def.external_output[:]
- gradient_do_def = caffe2_pb2.OperatorDef()
- gradient_do_def.CopyFrom(fwd_op)
- if gradient_do_def.name and len(gradient_do_def.name) > 0:
- gradient_do_def.name += "_grad"
- del gradient_do_def.input[:]
- gradient_do_def.input.extend(inputs)
- # workspace pointer blob
- gradient_do_def.input.append(workspace_blob_name)
- del gradient_do_def.output[:]
- gradient_do_def.output.extend(outputs)
- # workspace pointer blob
- gradient_do_def.output.append(workspace_blob_name)
- net_arg = caffe2_pb2.Argument()
- net_arg.name = "net"
- net_arg.n.CopyFrom(gradient_net_def)
- ordered_new_outer_names = inputs + outputs
- inner_blobs = blob_bindings.keys()
- new_outer_blobs_idx = [ordered_new_outer_names.index(blob_bindings[b])
- for b in inner_blobs]
- inner_blobs_arg = caffe2_pb2.Argument()
- inner_blobs_arg.name = "inner_blobs"
- inner_blobs_arg.strings.extend([b.encode('utf-8') for b in inner_blobs])
- outer_blobs_idx_arg = caffe2_pb2.Argument()
- outer_blobs_idx_arg.name = "outer_blobs_idx"
- outer_blobs_idx_arg.ints.extend(new_outer_blobs_idx)
- saved_blobs_arg = caffe2_pb2.Argument()
- saved_blobs_arg.name = "saved_fwd_blobs"
- saved_blobs_arg.strings.extend(
- [b.encode('utf-8') for b in saved_fwd_blobs])
- del gradient_do_def.arg[:]
- gradient_do_def.arg.extend([
- net_arg, inner_blobs_arg, outer_blobs_idx_arg, saved_blobs_arg])
- del gradient_do_def.control_input[:]
- gradient_do_def.is_gradient_op = True
- return gradient_do_def
- def _gen_grad_zero_init_ops(init_grad_map, grad_map, grad_output_names):
- grad_init_ops = []
- for grad_output in grad_output_names:
- # get the corresponding output name blob and use it in ConstantFill
- # so that grad_output has the same shape
- output_name = None
- for o, g in grad_map.items():
- if g == grad_output:
- output_name = o
- break
- assert output_name, "Unknown gradient output " + grad_output
- grad_init_op = None
- # make sure that we do not overwrite existing gradients with zeros
- if output_name in init_grad_map:
- init_grad_name = init_grad_map[output_name]
- # in case we use a different gradient blob name, copy gradient
- if init_grad_name != grad_output:
- grad_init_op = caffe2_pb2.OperatorDef()
- grad_init_op.type = "Copy"
- grad_init_op.input.extend([str(init_grad_name)])
- grad_init_op.output.extend([str(grad_output)])
- else:
- grad_init_op = caffe2_pb2.OperatorDef()
- grad_init_op.type = "ConstantFill"
- grad_init_op.input.extend([output_name])
- grad_init_op.output.extend([grad_output])
- value_arg = caffe2_pb2.Argument()
- value_arg.name = "value"
- value_arg.f = 0.0
- grad_init_op.arg.extend([value_arg])
- if grad_init_op:
- grad_init_ops.append(grad_init_op)
- return grad_init_ops
- def _prepare_gradient_if_op(
- fwd_op, input_names, output_names, then_grad_net, else_grad_net):
- gradient_if_def = caffe2_pb2.OperatorDef()
- gradient_if_def.CopyFrom(fwd_op)
- del gradient_if_def.input[:]
- gradient_if_def.input.extend(input_names)
- del gradient_if_def.output[:]
- gradient_if_def.output.extend(output_names)
- then_net_arg = caffe2_pb2.Argument()
- then_net_arg.name = "then_net"
- then_net_arg.n.CopyFrom(then_grad_net)
- gradient_args = [then_net_arg]
- if else_grad_net:
- else_net_arg = caffe2_pb2.Argument()
- else_net_arg.name = "else_net"
- else_net_arg.n.CopyFrom(else_grad_net)
- gradient_args.append(else_net_arg)
- del gradient_if_def.arg[:]
- gradient_if_def.arg.extend(gradient_args)
- if gradient_if_def.name:
- gradient_if_def.name += "_grad"
- del gradient_if_def.control_input[:]
- gradient_if_def.is_gradient_op = True
- return gradient_if_def
- def disambiguate_grad_if_op_output(grad_op, idx, new_grad_output):
- then_net = _get_net_argument(grad_op, "then_net")
- old_grad_out_match = grad_op.output[idx]
- for op in then_net.op:
- for i, out in enumerate(op.output):
- if out == old_grad_out_match:
- op.output[i] = new_grad_output
- else_net = _get_net_argument(grad_op, "else_net")
- if else_net:
- for op in else_net.op:
- for i, out in enumerate(op.output):
- if out == old_grad_out_match:
- op.output[i] = new_grad_output
- grad_op.output[idx] = new_grad_output
|