control_ops_grad.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. ## @package control_ops_grad
  2. # Module caffe2.python.control_ops_grad
  3. from caffe2.proto import caffe2_pb2
  4. def gen_do_gradient(op, g_output):
  5. """
  6. Generates gradient Do operator, given forward Do op and a list
  7. of gradient blobs corresponding to forward op's outputs
  8. Returns a gradient op and a list of blobs corresponding to input gradients
  9. """
  10. from caffe2.python.core import BlobReference
  11. subnet, outer_to_inner_map, inner_to_outer_map, workspace_blob_name = \
  12. _do_op_sanity_check_and_process(op)
  13. assert len(g_output) == len(op.output), \
  14. "Different number of gradient blobs and Do op outputs"
  15. grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
  16. g_output = deduped_g_output
  17. # From the outer net point of view:
  18. # Do is an operator that has some number of inputs and outputs;
  19. # we have to generate a gradient operator that writes into
  20. # corresponding input gradient blobs and has access to inputs, outputs
  21. # and gradient output blobs
  22. # From the inner net point of view:
  23. # Do is an operator with a subnet and blob bindings,
  24. # we need to forward Do's output blob gradients into inner workspace,
  25. # use them to run backward pass generation and forward Do's input blob
  26. # gradients back into outer workspace
  27. op_output = [str(o) for o in op.output]
  28. op_output = op_output[:-1] # remove workspace pointer blob
  29. op_input = [str(i) for i in op.input]
  30. op_input = op_input[:-1] # remove workspace pointer blob
  31. ordered_inner_output_blob_names = [outer_to_inner_map[o] for o in op_output]
  32. backward_pass_initial_grad_map = {}
  33. initial_grad_map = {}
  34. for inner_output_name, outer_grad_output_name in \
  35. zip(ordered_inner_output_blob_names, g_output):
  36. # link inner_output_name to corresponding inner_grad_output_name for
  37. # backward pass generation;
  38. if outer_grad_output_name:
  39. inner_grad_output_name = inner_output_name + "/_DO_OPERATOR_INNER_GRAD_"
  40. backward_pass_initial_grad_map[BlobReference(inner_output_name)] = \
  41. BlobReference(inner_grad_output_name)
  42. initial_grad_map[inner_grad_output_name] = str(outer_grad_output_name)
  43. assert len(initial_grad_map) > 0, "Empty initial gradient map for Do op"
  44. inner_grad_ops, inner_grad_names_map = _gen_subgradient_pass(
  45. subnet, backward_pass_initial_grad_map)
  46. if len(inner_grad_ops) == 0:
  47. return [], []
  48. grad_copy_ops = []
  49. g_input = []
  50. new_op_outputs = []
  51. new_blob_bindings = {}
  52. for outer_input_name in op_input:
  53. inner_input_name = outer_to_inner_map[outer_input_name]
  54. if inner_input_name in inner_grad_names_map:
  55. inner_grad_input_name = inner_grad_names_map[inner_input_name]
  56. outer_grad_input_name = outer_input_name + "_grad"
  57. # It is possible that inner_grad_input_name will need to be
  58. # linked to another outer blob. For example:
  59. #
  60. # // y - param initialized in init_net
  61. # x = ...
  62. # z = ...
  63. # with ops.IfNet(...):
  64. # ops.Add([z, x], y) # inner Do block
  65. # loss = f(..., y, ...)
  66. #
  67. # In this case x, y and z are external for the inner Do block,
  68. # the inputs of the Do block are z and x and the output is y.
  69. # When computing the gradient of input x given the gradient
  70. # of output y it's easy to see that they are equal.
  71. # During the generation of gradient Do operator, we link
  72. # external gradient y (y_grad) to the internal name
  73. # (y/_DO_OPERATOR_INNER_GRAD_) and generate the backward pass
  74. # for the internal Do net. As a result we get gradient operators
  75. # for the gradient Do and gradient map that maps internal Do
  76. # blobs to their computed gradients.
  77. # In this example, gradient map may have blob x linked to
  78. # gradient blob y/_DO_OPERATOR_INNER_GRAD_.
  79. # We should export gradient for x outside of Do, so
  80. # we add a blob mapping from inner gradient blob
  81. # (y/_DO_OPERATOR_INNER_GRAD_) to a new outer name (x_grad).
  82. #
  83. # (Note: since we use transparent blob mapping between outer and
  84. # inner (Do's) workspace, these operations do not involve copying
  85. # but are merely using blobs in outer workspace in the Do's operator
  86. # workspace under (possibly) different names)
  87. #
  88. # At the same time, we need to add a blob mapping from inner name
  89. # y/_DO_OPERATOR_INNER_GRAD_ to the outer blob y_grad
  90. # Hence in this case, we cannot use existing blob mapping scheme
  91. # that requires a bijection between subset of inner blob names and
  92. # a set of all (Do's input and output) outer blob names
  93. # TODO(iliacher): Remove unnecessary blob copying
  94. new_inner_grad_input_name = \
  95. inner_input_name + "/_DO_OPERATOR_INNER_GRAD_COPY_"
  96. grad_copy_ops.append(_prepare_blob_copy_op(
  97. inner_grad_input_name, new_inner_grad_input_name))
  98. new_blob_bindings[new_inner_grad_input_name] = outer_grad_input_name
  99. new_op_outputs.append(outer_grad_input_name)
  100. g_input.append(outer_grad_input_name)
  101. else:
  102. g_input.append(None)
  103. new_op_inputs = []
  104. overwritten_names = set()
  105. saved_local_blob_names = set()
  106. for grad_op in inner_grad_ops:
  107. grad_op_input = [str(i) for i in grad_op.input]
  108. grad_op_output = [str(o) for o in grad_op.output]
  109. for grad_op_input_name in grad_op_input:
  110. if grad_op_input_name in overwritten_names:
  111. continue
  112. # check if this is an external blob
  113. outer_name = inner_to_outer_map.get(grad_op_input_name, None)
  114. if not outer_name:
  115. # check if this is an external gradient blob
  116. outer_name = initial_grad_map.get(grad_op_input_name, None)
  117. if outer_name:
  118. outer_name = str(outer_name)
  119. if outer_name not in new_op_inputs:
  120. new_op_inputs.append(outer_name)
  121. new_blob_bindings[grad_op_input_name] = outer_name
  122. else:
  123. # this is a local blob, we'll get it's value from
  124. # a saved forward op workspace
  125. saved_local_blob_names.add(grad_op_input_name)
  126. overwritten_names.update(grad_op_output)
  127. # add inner gradient copy ops
  128. inner_grad_ops += grad_copy_ops
  129. gradient_do_def = _prepare_gradient_do_op(
  130. fwd_op=op,
  131. fwd_net=subnet,
  132. grad_ops=inner_grad_ops,
  133. inputs=new_op_inputs,
  134. outputs=new_op_outputs,
  135. blob_bindings=new_blob_bindings,
  136. saved_fwd_blobs=saved_local_blob_names,
  137. workspace_blob_name=workspace_blob_name)
  138. grad_ops.append(gradient_do_def)
  139. _do_op_sanity_check_and_process(gradient_do_def)
  140. return grad_ops, g_input
  141. def dedupe_g_output(op, g_output):
  142. # When generation a gradient op it's possible to receive the same gradient
  143. # blob corresponding to different forward op output blobs, Do operator
  144. # requires a bijection between inner and outer names, make sure we do
  145. # deduplication
  146. grad_ops = []
  147. deduped_g_output = []
  148. init_grad_map = {}
  149. for output_name, grad_name in zip(op.output, g_output):
  150. if not grad_name:
  151. deduped_g_output.append(grad_name)
  152. continue
  153. if output_name in init_grad_map:
  154. deduped_g_output.append(init_grad_map[output_name])
  155. else:
  156. if grad_name not in init_grad_map.values():
  157. init_grad_map[output_name] = grad_name
  158. deduped_g_output.append(grad_name)
  159. else:
  160. deduped_grad_name = output_name + "_" + grad_name + "_DEDUP"
  161. assert deduped_grad_name not in init_grad_map.values()
  162. grad_copy_op = caffe2_pb2.OperatorDef()
  163. grad_copy_op.type = "Copy"
  164. grad_copy_op.input.extend([grad_name])
  165. grad_copy_op.output.extend([deduped_grad_name])
  166. grad_ops.append(grad_copy_op)
  167. deduped_g_output.append(deduped_grad_name)
  168. init_grad_map[output_name] = deduped_grad_name
  169. return grad_ops, deduped_g_output
  170. def gen_while_gradient(op, g_output):
  171. """
  172. Generates gradient While operator
  173. """
  174. from caffe2.python.core import BlobReference
  175. assert op.type == "While", "Expected While op"
  176. assert len(op.input) > 0, "Expected at least one input in While op"
  177. assert len(op.output) == len(g_output), \
  178. "Different number of gradient blobs and While op outputs"
  179. grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
  180. g_output = deduped_g_output
  181. init_grad_map = {}
  182. op_output = [str(o) for o in op.output]
  183. for output_name, grad_output_name in zip(op_output, g_output):
  184. if grad_output_name:
  185. init_grad_map[BlobReference(output_name)] = \
  186. BlobReference(grad_output_name)
  187. assert len(init_grad_map) > 0, "Empty initial gradient map for While op"
  188. loop_net = _get_net_argument(op, "loop_net")
  189. assert loop_net, "Expected loop subnet in While op"
  190. assert len(loop_net.op) == 1 and loop_net.op[0].type == "Do", \
  191. "Gradient While op requires single Do op as a loop body"
  192. do_op = loop_net.op[0]
  193. do_args = _get_do_arguments(do_op)
  194. assert "reuse_workspace" not in do_args or not do_args["reuse_workspace"], \
  195. "Gradient While op requires Do loop body op without reuse_workspace set"
  196. assert len(do_op.output) > 0, "Expected Do op with at least one output"
  197. workspace_blob = do_op.output[-1]
  198. loop_grad_net, loop_grad_map, loop_input_names, loop_output_names = \
  199. _gen_subnet_gradient(loop_net, init_grad_map)
  200. assert loop_grad_net, "Failed to get gradient net for loop body in While op"
  201. grad_ops += _prepare_gradient_while_ops(
  202. fwd_op=op,
  203. input_names=loop_input_names,
  204. output_names=loop_output_names,
  205. loop_grad_net=loop_grad_net,
  206. workspace_blob=workspace_blob,
  207. init_grad_map=init_grad_map,
  208. loop_grad_map=loop_grad_map)
  209. op_input = [str(i) for i in op.input]
  210. g_input = [loop_grad_map.get(i, None) for i in op_input]
  211. return grad_ops, g_input
  212. # Constructs gradient While op, arguments:
  213. # fwd_op - forward While op
  214. # input_names - input blob names for a gradient op
  215. # output_names - output blob names for a gradient op
  216. # loop_grad_net - gradient loop body net
  217. # workspace_blob - blob that holds forward workspaces stack
  218. # init_grad_map - initial gradient to forward blob map
  219. # loop_grad_map - gradient blob map for loop's body
  220. def _prepare_gradient_while_ops(
  221. fwd_op, input_names, output_names, loop_grad_net, workspace_blob,
  222. init_grad_map, loop_grad_map):
  223. gradient_while_def = caffe2_pb2.OperatorDef()
  224. gradient_while_def.CopyFrom(fwd_op)
  225. if gradient_while_def.name:
  226. gradient_while_def.name += "_grad"
  227. loop_net_arg = caffe2_pb2.Argument()
  228. loop_net_arg.name = "loop_net"
  229. loop_net_arg.n.CopyFrom(loop_grad_net)
  230. cond_net_arg = caffe2_pb2.Argument()
  231. cond_net_arg.name = "cond_net"
  232. from caffe2.python.core import Net, BlobReference
  233. # Construct condition net - check that there're still forward workspaces
  234. # left using HasScope op
  235. cond_net = Net('gradient_loop_cond_net')
  236. cond_init_net = Net('gradient_loop_cond_net_init')
  237. cond_blob = cond_net.NextScopedBlob(cond_net.Name() + '/cond')
  238. cond_init_net.HasScope(workspace_blob, cond_blob)
  239. cond_net.HasScope(workspace_blob, cond_blob)
  240. for blob, init_grad_blob in init_grad_map.items():
  241. blob_name = str(blob)
  242. init_grad_blob_name = str(init_grad_blob)
  243. if blob_name in loop_grad_map and \
  244. loop_grad_map[blob_name] != init_grad_blob_name:
  245. cond_net.Copy(
  246. BlobReference(loop_grad_map[blob_name]), init_grad_blob)
  247. cond_init_net.Copy(
  248. init_grad_blob, BlobReference(loop_grad_map[blob_name]))
  249. cond_net_arg.n.CopyFrom(cond_net.Proto())
  250. del gradient_while_def.arg[:]
  251. gradient_while_def.arg.extend([loop_net_arg, cond_net_arg])
  252. del gradient_while_def.control_input[:]
  253. del gradient_while_def.input[:]
  254. gradient_while_def.input.extend(
  255. [str(cond_blob).encode('utf-8')] + list(input_names))
  256. del gradient_while_def.output[:]
  257. gradient_while_def.output.extend(output_names)
  258. gradient_while_def.is_gradient_op = True
  259. return [o for o in cond_init_net.Proto().op] + [gradient_while_def]
  260. def _get_do_arguments(do_op):
  261. assert do_op.type == "Do", "Expected Do op"
  262. args = {}
  263. for arg in do_op.arg:
  264. if not arg.name:
  265. continue
  266. if arg.name == "net":
  267. assert arg.n, "Expected non empty net argument"
  268. args["net"] = arg.n
  269. elif arg.name == "reuse_workspace":
  270. assert arg.i, "Expected non empty reuse_workspace argument"
  271. args["reuse_workspace"] = bool(arg.i)
  272. elif arg.name == "inner_blobs":
  273. assert arg.strings, "Expected non empty inner_blobs argument"
  274. args["inner_blobs"] = arg.strings
  275. elif arg.name == "outer_blobs_idx":
  276. assert arg.ints, "Expected non empty outer_blobs_idx argument"
  277. args["outer_blobs_idx"] = arg.ints
  278. return args
  279. def gen_if_gradient(op, g_output):
  280. """
  281. Generates gradient If operator, given forward If op and a list
  282. of gradient blobs corresponding to forward op's outputs
  283. Returns a gradient op and a list of blobs corresponding to input gradients
  284. """
  285. from caffe2.python.core import BlobReference
  286. assert op.type == "If", "Expected If op"
  287. # first input is the condition blob
  288. assert len(op.input) > 0, "Expected at least one input in If op"
  289. assert len(op.output) == len(g_output), \
  290. "Different number of gradient blobs and If op outputs"
  291. grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
  292. g_output = deduped_g_output
  293. init_grad_map = {} # map from if's output blob to output gradient blob
  294. op_input = [str(i) for i in op.input]
  295. op_output = [str(o) for o in op.output]
  296. for output_name, grad_output_name in zip(op_output, g_output):
  297. if grad_output_name:
  298. init_grad_map[BlobReference(output_name)] = \
  299. BlobReference(grad_output_name)
  300. # shouldn't call without at least one output gradient available
  301. assert len(init_grad_map) > 0, "Empty initial gradient map for If op"
  302. grad_map = {} # map from blob to gradient blob
  303. then_net = _get_net_argument(op, "then_net")
  304. assert then_net, "Expected then subnet in If op"
  305. then_grad_net, then_grad_map, then_input_names, then_output_names = \
  306. _gen_subnet_gradient(then_net, init_grad_map)
  307. assert then_grad_net, "Failed to get gradient net for then in If op"
  308. grad_map.update(then_grad_map)
  309. else_input_names = set()
  310. else_output_names = set()
  311. else_grad_map = {}
  312. else_grad_net = None
  313. else_net = _get_net_argument(op, "else_net")
  314. if else_net:
  315. else_grad_net, else_grad_map, else_input_names, else_output_names = \
  316. _gen_subnet_gradient(else_net, init_grad_map)
  317. assert else_grad_net, "Failed to get gradient net for else in If op"
  318. # consider case: else doesn't update blob's gradient and keeps original
  319. # from init_grad_map, but then updates the gradient
  320. for else_blob, else_grad_blob in else_grad_map.items():
  321. if else_blob in then_grad_map:
  322. then_grad_blob = then_grad_map[else_blob]
  323. # if both then and else branches have grad blob name for the same
  324. # blob and grad names are different, then one of the branches
  325. # doesn't use blob and has original grad blob name in it's grad map,
  326. # and another branch uses blob and has <blob_name>_grad name
  327. # in it's grad map (might be different from original grad blob)
  328. if then_grad_blob != else_grad_blob:
  329. init_grad_name = init_grad_map[else_blob] \
  330. if else_blob in init_grad_map else None
  331. if then_grad_blob == init_grad_name:
  332. grad_map[else_blob] = else_grad_blob
  333. elif else_grad_blob == init_grad_name:
  334. grad_map[else_blob] = then_grad_blob
  335. else:
  336. raise "Unexpected grad blob name " + else_blob + ", " + \
  337. else_grad_blob + ", " + then_grad_blob
  338. else:
  339. grad_map[else_blob] = else_grad_blob
  340. # make sure gradients of blobs that were not computed
  341. # by the selected if's branch are initialized with zeros
  342. then_other_output_names = \
  343. then_output_names - (then_output_names & else_output_names)
  344. then_other_grad_output_names = set(
  345. [o for o in then_other_output_names if o in then_grad_map.values()])
  346. zero_then = _gen_grad_zero_init_ops(
  347. init_grad_map, then_grad_map, then_other_grad_output_names)
  348. if else_grad_net:
  349. else_grad_net.op.extend(zero_then)
  350. elif len(zero_then) > 0:
  351. else_grad_net = caffe2_pb2.NetDef()
  352. else_grad_net.CopyFrom(then_grad_net)
  353. if else_grad_net.name:
  354. else_grad_net.name += "_auto_else_zero_blobs_"
  355. del else_grad_net.op[:]
  356. else_grad_net.op.extend(zero_then)
  357. del else_grad_net.external_input[:]
  358. del else_grad_net.external_output[:]
  359. else_other_output_names = \
  360. else_output_names - (then_output_names & else_output_names)
  361. else_other_grad_output_names = set(
  362. [o for o in else_other_output_names if o in else_grad_map.values()])
  363. zero_else = _gen_grad_zero_init_ops(
  364. init_grad_map, else_grad_map, else_other_grad_output_names)
  365. then_grad_net.op.extend(zero_else)
  366. output_names = list(then_output_names | else_output_names)
  367. input_names = then_input_names | else_input_names
  368. # make sure condition blob is the first in the list
  369. input_names = [op_input[0]] + list(input_names - set(op_input[0]))
  370. gradient_if_def = _prepare_gradient_if_op(
  371. fwd_op=op,
  372. input_names=input_names,
  373. output_names=output_names,
  374. then_grad_net=then_grad_net,
  375. else_grad_net=else_grad_net)
  376. g_input = [grad_map.get(i, None) for i in op_input]
  377. return grad_ops + [gradient_if_def], g_input
  378. def _gen_subnet_gradient(subnet, init_grad):
  379. grad_ops, grad_names_map = _gen_subgradient_pass(
  380. subnet, init_grad)
  381. output_names = set()
  382. input_names = set()
  383. for grad_op in grad_ops:
  384. for grad_op_input in grad_op.input:
  385. if str(grad_op_input) not in output_names:
  386. input_names.add(str(grad_op_input))
  387. for grad_op_output in grad_op.output:
  388. output_names.add(str(grad_op_output))
  389. gradient_net_def = caffe2_pb2.NetDef()
  390. gradient_net_def.CopyFrom(subnet)
  391. if gradient_net_def.name:
  392. gradient_net_def.name += "_grad"
  393. del gradient_net_def.op[:]
  394. gradient_net_def.op.extend(grad_ops)
  395. del gradient_net_def.external_input[:]
  396. del gradient_net_def.external_output[:]
  397. return gradient_net_def, grad_names_map, input_names, output_names
  398. def _get_net_argument(op, net_name):
  399. for arg in op.arg:
  400. if arg.name and arg.name == net_name:
  401. assert arg.n, "Expected non empty net argument " + net_name
  402. return arg.n
  403. return None
  404. def getNetArgument(op, net_name):
  405. """A wrapper for external call"""
  406. return _get_net_argument(op, net_name)
  407. def _gen_subgradient_pass(subnet, init_grad):
  408. from caffe2.python.core import IR
  409. subnet_ir = IR(subnet.op)
  410. grad_ops, grad_blob_map = \
  411. subnet_ir.GetBackwardPass(init_grad)
  412. grad_names_map = {}
  413. for b, g in grad_blob_map.items():
  414. grad_names_map[str(b)] = str(g)
  415. return grad_ops, grad_names_map
  416. def _do_op_sanity_check_and_process(op):
  417. assert op.type == "Do", "Expected Do op"
  418. subnet = _get_net_argument(op, "net")
  419. assert subnet, "No net argument found in Do op"
  420. inner_blobs = None
  421. outer_blobs_idx = None
  422. for arg in op.arg:
  423. if arg.name and arg.name == "inner_blobs":
  424. assert not inner_blobs, "inner_blobs redefinition"
  425. assert arg.strings and len(arg.strings) > 0, \
  426. "Empty inner_blobs argument in Do op"
  427. inner_blobs = [s.decode('utf-8') for s in arg.strings]
  428. if arg.name and arg.name == "outer_blobs_idx":
  429. assert not outer_blobs_idx, "outer_blobs_idx redefinition"
  430. assert arg.ints and len(arg.ints) > 0, \
  431. "Empty outer_blobs_idx argument in Do op"
  432. outer_blobs_idx = arg.ints
  433. if inner_blobs and outer_blobs_idx:
  434. break
  435. assert inner_blobs, "No inner_blobs argument found in Do op"
  436. assert outer_blobs_idx, "No outer_blobs_idx argument found in Do op"
  437. assert len(inner_blobs) == len(outer_blobs_idx), \
  438. "Arguments inner_blobs and outer_blobs_idx of different length in Do op"
  439. all_inner_blobs = set(inner_blobs)
  440. assert len(all_inner_blobs) == len(inner_blobs), \
  441. "Found duplicates in inner_blobs in Do op"
  442. op_input = [str(i) for i in op.input]
  443. assert len(op_input) > 0, "Expected at least one input blob"
  444. # remove last input blob that holds pointer to workspace
  445. input_workspace_blob_name = op_input[-1]
  446. op_input = op_input[:-1]
  447. op_output = [str(o) for o in op.output]
  448. assert len(op_output) > 0, "Expected at least one output blob"
  449. # remove last output blob that holds pointer to workspace
  450. workspace_blob_name = op_output[-1]
  451. assert input_workspace_blob_name == workspace_blob_name, \
  452. "Expected same input/output workspace blob"
  453. op_output = op_output[:-1]
  454. all_op_input_blob_names = set(op_input)
  455. assert len(all_op_input_blob_names) == len(op_input), \
  456. "Found duplicates in Do op inputs"
  457. all_op_output_blob_names = set(op_output)
  458. assert len(all_op_output_blob_names) == len(op_output), \
  459. "Found duplicates in Do op outputs"
  460. ordered_outer_blob_names = op_input + op_output
  461. all_outer_blob_names = set(ordered_outer_blob_names)
  462. used_outer_blob_names = set()
  463. outer_to_inner_map = {}
  464. inner_to_outer_map = {}
  465. for inner_name, outer_blob_idx in zip(inner_blobs, outer_blobs_idx):
  466. assert outer_blob_idx >= 0 and \
  467. outer_blob_idx < len(ordered_outer_blob_names), \
  468. "Outer blob index is out of bounds in Do op"
  469. outer_name = ordered_outer_blob_names[outer_blob_idx]
  470. assert outer_name not in used_outer_blob_names, \
  471. "Reusage of outer blob name " + outer_name + " in Do op"
  472. used_outer_blob_names.add(outer_name)
  473. outer_to_inner_map[outer_name] = inner_name
  474. inner_to_outer_map[inner_name] = outer_name
  475. assert len(used_outer_blob_names) == len(all_outer_blob_names), \
  476. "Not all outer blob names are used in blob bindings in Do op"
  477. return subnet, outer_to_inner_map, inner_to_outer_map, workspace_blob_name
  478. def _prepare_blob_copy_op(from_name, to_name):
  479. copy_op_def = caffe2_pb2.OperatorDef()
  480. copy_op_def.type = "Copy"
  481. copy_op_def.input.extend([from_name])
  482. copy_op_def.output.extend([to_name])
  483. return copy_op_def
  484. def _prepare_gradient_do_op(
  485. fwd_op, fwd_net, grad_ops, inputs, outputs, blob_bindings, saved_fwd_blobs,
  486. workspace_blob_name):
  487. gradient_net_def = caffe2_pb2.NetDef()
  488. gradient_net_def.CopyFrom(fwd_net)
  489. if gradient_net_def.name:
  490. gradient_net_def.name += "_grad"
  491. del gradient_net_def.op[:]
  492. gradient_net_def.op.extend(grad_ops)
  493. del gradient_net_def.external_input[:]
  494. del gradient_net_def.external_output[:]
  495. gradient_do_def = caffe2_pb2.OperatorDef()
  496. gradient_do_def.CopyFrom(fwd_op)
  497. if gradient_do_def.name and len(gradient_do_def.name) > 0:
  498. gradient_do_def.name += "_grad"
  499. del gradient_do_def.input[:]
  500. gradient_do_def.input.extend(inputs)
  501. # workspace pointer blob
  502. gradient_do_def.input.append(workspace_blob_name)
  503. del gradient_do_def.output[:]
  504. gradient_do_def.output.extend(outputs)
  505. # workspace pointer blob
  506. gradient_do_def.output.append(workspace_blob_name)
  507. net_arg = caffe2_pb2.Argument()
  508. net_arg.name = "net"
  509. net_arg.n.CopyFrom(gradient_net_def)
  510. ordered_new_outer_names = inputs + outputs
  511. inner_blobs = blob_bindings.keys()
  512. new_outer_blobs_idx = [ordered_new_outer_names.index(blob_bindings[b])
  513. for b in inner_blobs]
  514. inner_blobs_arg = caffe2_pb2.Argument()
  515. inner_blobs_arg.name = "inner_blobs"
  516. inner_blobs_arg.strings.extend([b.encode('utf-8') for b in inner_blobs])
  517. outer_blobs_idx_arg = caffe2_pb2.Argument()
  518. outer_blobs_idx_arg.name = "outer_blobs_idx"
  519. outer_blobs_idx_arg.ints.extend(new_outer_blobs_idx)
  520. saved_blobs_arg = caffe2_pb2.Argument()
  521. saved_blobs_arg.name = "saved_fwd_blobs"
  522. saved_blobs_arg.strings.extend(
  523. [b.encode('utf-8') for b in saved_fwd_blobs])
  524. del gradient_do_def.arg[:]
  525. gradient_do_def.arg.extend([
  526. net_arg, inner_blobs_arg, outer_blobs_idx_arg, saved_blobs_arg])
  527. del gradient_do_def.control_input[:]
  528. gradient_do_def.is_gradient_op = True
  529. return gradient_do_def
  530. def _gen_grad_zero_init_ops(init_grad_map, grad_map, grad_output_names):
  531. grad_init_ops = []
  532. for grad_output in grad_output_names:
  533. # get the corresponding output name blob and use it in ConstantFill
  534. # so that grad_output has the same shape
  535. output_name = None
  536. for o, g in grad_map.items():
  537. if g == grad_output:
  538. output_name = o
  539. break
  540. assert output_name, "Unknown gradient output " + grad_output
  541. grad_init_op = None
  542. # make sure that we do not overwrite existing gradients with zeros
  543. if output_name in init_grad_map:
  544. init_grad_name = init_grad_map[output_name]
  545. # in case we use a different gradient blob name, copy gradient
  546. if init_grad_name != grad_output:
  547. grad_init_op = caffe2_pb2.OperatorDef()
  548. grad_init_op.type = "Copy"
  549. grad_init_op.input.extend([str(init_grad_name)])
  550. grad_init_op.output.extend([str(grad_output)])
  551. else:
  552. grad_init_op = caffe2_pb2.OperatorDef()
  553. grad_init_op.type = "ConstantFill"
  554. grad_init_op.input.extend([output_name])
  555. grad_init_op.output.extend([grad_output])
  556. value_arg = caffe2_pb2.Argument()
  557. value_arg.name = "value"
  558. value_arg.f = 0.0
  559. grad_init_op.arg.extend([value_arg])
  560. if grad_init_op:
  561. grad_init_ops.append(grad_init_op)
  562. return grad_init_ops
  563. def _prepare_gradient_if_op(
  564. fwd_op, input_names, output_names, then_grad_net, else_grad_net):
  565. gradient_if_def = caffe2_pb2.OperatorDef()
  566. gradient_if_def.CopyFrom(fwd_op)
  567. del gradient_if_def.input[:]
  568. gradient_if_def.input.extend(input_names)
  569. del gradient_if_def.output[:]
  570. gradient_if_def.output.extend(output_names)
  571. then_net_arg = caffe2_pb2.Argument()
  572. then_net_arg.name = "then_net"
  573. then_net_arg.n.CopyFrom(then_grad_net)
  574. gradient_args = [then_net_arg]
  575. if else_grad_net:
  576. else_net_arg = caffe2_pb2.Argument()
  577. else_net_arg.name = "else_net"
  578. else_net_arg.n.CopyFrom(else_grad_net)
  579. gradient_args.append(else_net_arg)
  580. del gradient_if_def.arg[:]
  581. gradient_if_def.arg.extend(gradient_args)
  582. if gradient_if_def.name:
  583. gradient_if_def.name += "_grad"
  584. del gradient_if_def.control_input[:]
  585. gradient_if_def.is_gradient_op = True
  586. return gradient_if_def
  587. def disambiguate_grad_if_op_output(grad_op, idx, new_grad_output):
  588. then_net = _get_net_argument(grad_op, "then_net")
  589. old_grad_out_match = grad_op.output[idx]
  590. for op in then_net.op:
  591. for i, out in enumerate(op.output):
  592. if out == old_grad_out_match:
  593. op.output[i] = new_grad_output
  594. else_net = _get_net_argument(grad_op, "else_net")
  595. if else_net:
  596. for op in else_net.op:
  597. for i, out in enumerate(op.output):
  598. if out == old_grad_out_match:
  599. op.output[i] = new_grad_output
  600. grad_op.output[idx] = new_grad_output