recurrent.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. ## @package recurrent
  2. # Module caffe2.python.recurrent
  3. from caffe2.python import core, workspace
  4. from future.utils import viewitems, viewkeys
  5. def recurrent_net(
  6. net, cell_net, inputs, initial_cell_inputs,
  7. links, timestep=None, scope=None, outputs_with_grads=(0,),
  8. recompute_blobs_on_backward=None, forward_only=False,
  9. ):
  10. '''
  11. net: the main net operator should be added to
  12. cell_net: cell_net which is executed in a recurrent fasion
  13. inputs: sequences to be fed into the recurrent net. Currently only one input
  14. is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
  15. of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions
  16. initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
  17. Format for each input is:
  18. (cell_net_input_name, external_blob_with_data)
  19. links: a dictionary from cell_net input names in moment t+1 and
  20. output names of moment t. Currently we assume that each output becomes
  21. an input for the next timestep.
  22. timestep: name of the timestep blob to be used. If not provided "timestep"
  23. is used.
  24. scope: Internal blobs are going to be scoped in a format
  25. <scope_name>/<blob_name>
  26. If not provided we generate a scope name automatically
  27. outputs_with_grads : position indices of output blobs which will receive
  28. error gradient (from outside recurrent network) during backpropagation
  29. recompute_blobs_on_backward: specify a list of blobs that will be
  30. recomputed for backward pass, and thus need not to be
  31. stored for each forward timestep.
  32. forward_only: if True, only forward steps are executed
  33. '''
  34. assert len(inputs) == 1, "Only one input blob is supported so far"
  35. input_blobs = [str(i[0]) for i in inputs]
  36. initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
  37. op_name = net.NextName('recurrent')
  38. def s(name):
  39. # We have to manually scope due to our internal/external blob
  40. # relationships.
  41. scope_name = op_name if scope is None else scope
  42. return "{}/{}".format(str(scope_name), str(name))
  43. # determine inputs that are considered to be references
  44. # it is those that are not referred to in inputs or initial_cell_inputs
  45. known_inputs = [str(b) for b in input_blobs + initial_input_blobs]
  46. known_inputs += [str(x[0]) for x in initial_cell_inputs]
  47. if timestep is not None:
  48. known_inputs.append(str(timestep))
  49. references = [
  50. core.BlobReference(b) for b in cell_net.Proto().external_input
  51. if b not in known_inputs]
  52. inner_outputs = list(cell_net.Proto().external_output)
  53. # These gradients are expected to be available during the backward pass
  54. inner_outputs_map = {o: o + '_grad' for o in inner_outputs}
  55. # compute the backward pass of the cell net
  56. if not forward_only:
  57. backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
  58. cell_net.Proto().op, inner_outputs_map)
  59. backward_mapping = {str(k): v for k, v in viewitems(backward_mapping)}
  60. backward_cell_net = core.Net("RecurrentBackwardStep")
  61. del backward_cell_net.Proto().op[:]
  62. if recompute_blobs_on_backward is not None:
  63. # Insert operators to re-compute the specified blobs.
  64. # They are added in the same order as for the forward pass, thus
  65. # the order is correct.
  66. recompute_blobs_on_backward = {str(b) for b in
  67. recompute_blobs_on_backward}
  68. for op in cell_net.Proto().op:
  69. if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
  70. backward_cell_net.Proto().op.extend([op])
  71. # This fires if other outputs than the declared
  72. # are computed by the ops that are recomputed
  73. assert set(op.output).issubset(recompute_blobs_on_backward)
  74. backward_cell_net.Proto().op.extend(backward_ops)
  75. # compute blobs used but not defined in the backward pass
  76. backward_ssa, backward_blob_versions = core.get_ssa(
  77. backward_cell_net.Proto())
  78. undefined = core.get_undefined_blobs(backward_ssa)
  79. # also add to the output list the intermediate outputs of fwd_step that
  80. # are used by backward.
  81. ssa, blob_versions = core.get_ssa(cell_net.Proto())
  82. scratches = [
  83. blob
  84. for blob, ver in viewitems(blob_versions)
  85. if (ver > 0 and
  86. blob in undefined and
  87. blob not in cell_net.Proto().external_output)
  88. ]
  89. backward_cell_net.Proto().external_input.extend(scratches)
  90. backward_cell_net.Proto().type = 'simple'
  91. else:
  92. backward_cell_net = None
  93. all_inputs = [i[1] for i in inputs] + [
  94. x[1] for x in initial_cell_inputs] + references
  95. all_outputs = []
  96. cell_net.Proto().type = 'simple'
  97. # Internal arguments used by RecurrentNetwork operator
  98. # Links are in the format blob_name, recurrent_states, offset.
  99. # In the moment t we know that corresponding data block is at
  100. # t + offset position in the recurrent_states tensor
  101. forward_links = []
  102. backward_links = []
  103. # Aliases are used to expose outputs to external world
  104. # Format (internal_blob, external_blob, offset)
  105. # Negative offset stands for going from the end,
  106. # positive - from the beginning
  107. aliases = []
  108. # States held inputs to the cell net
  109. recurrent_states = []
  110. for cell_input, _ in initial_cell_inputs:
  111. cell_input = str(cell_input)
  112. # Recurrent_states is going to be (T + 1) x ...
  113. # It stores all inputs and outputs of the cell net over time.
  114. # Or their gradients in the case of the backward pass.
  115. state = s(cell_input + "_states")
  116. states_grad = state + "_grad"
  117. cell_output = links[str(cell_input)]
  118. forward_links.append((cell_input, state, 0))
  119. forward_links.append((cell_output, state, 1))
  120. aliases.append((state, cell_output + "_all", 1))
  121. aliases.append((state, cell_output + "_last", -1))
  122. all_outputs.extend([cell_output + "_all", cell_output + "_last"])
  123. recurrent_states.append(state)
  124. if backward_cell_net is not None:
  125. backward_links.append((cell_output + "_grad", states_grad, 1))
  126. backward_cell_net.Proto().external_input.append(
  127. str(cell_output) + "_grad")
  128. recurrent_input_grad = cell_input + "_grad"
  129. if not backward_blob_versions.get(recurrent_input_grad, 0):
  130. # If nobody writes to this recurrent input gradient, we need
  131. # to make sure it gets to the states grad blob after all.
  132. # We do this by using backward_links which triggers an alias
  133. # This logic is being used for example in a SumOp case
  134. backward_links.append(
  135. (backward_mapping[cell_input], states_grad, 0))
  136. else:
  137. backward_links.append((recurrent_input_grad, states_grad, 0))
  138. for input_t, input_blob in inputs:
  139. forward_links.append((str(input_t), str(input_blob), 0))
  140. if backward_cell_net is not None:
  141. for input_t, input_blob in inputs:
  142. backward_links.append((
  143. backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
  144. ))
  145. backward_cell_net.Proto().external_input.extend(
  146. cell_net.Proto().external_input)
  147. backward_cell_net.Proto().external_input.extend(
  148. cell_net.Proto().external_output)
  149. def unpack_triple(x):
  150. if x:
  151. a, b, c = zip(*x)
  152. return a, b, c
  153. return [], [], []
  154. # Splitting to separate lists so we can pass them to c++
  155. # where we ensemle them back
  156. link_internal, link_external, link_offset = unpack_triple(forward_links)
  157. alias_src, alias_dst, alias_offset = unpack_triple(aliases)
  158. recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]
  159. # Make sure that recurrent gradients accumulate with internal gradients
  160. # (if a blob in the backward_cell_net receives gradient from both an
  161. # external connection as well as from within the backward_cell_net,
  162. # those gradients need to be added together, rather than one overwriting
  163. # the other)
  164. if backward_cell_net is not None:
  165. proto = backward_cell_net.Proto()
  166. operators = []
  167. while len(proto.op) > 0:
  168. op = proto.op[-1]
  169. proto.op.remove(op)
  170. operators.append(op)
  171. for op in operators[::-1]:
  172. proto.op.extend([op])
  173. for j, output_blob in enumerate(op.output):
  174. if output_blob in proto.external_input:
  175. # In place operation won't cause issues because it takes
  176. # existing value of a blob into account
  177. if output_blob in op.input:
  178. continue
  179. output_blob = core.BlobReference(output_blob)
  180. accum_blob = output_blob + "_accum"
  181. proto.op[-1].output[j] = str(accum_blob)
  182. backward_cell_net.Sum(
  183. [output_blob, accum_blob],
  184. [output_blob],
  185. )
  186. def map_to_dual_list(m):
  187. return [str(x) for x in list(m.keys())] + \
  188. [str(x) for x in list(m.values())]
  189. backward_args = {}
  190. if backward_cell_net is not None:
  191. backward_mapping_keys = set(viewkeys(backward_mapping))
  192. backward_link_internal, backward_link_external, backward_link_offset = \
  193. unpack_triple(backward_links)
  194. params = [x for x in references if x in backward_mapping_keys]
  195. param_grads = [
  196. str(backward_mapping[x])
  197. for x in references
  198. if x in backward_mapping_keys
  199. ]
  200. if recompute_blobs_on_backward is None:
  201. recompute_blobs_on_backward = set()
  202. backward_args = {
  203. 'param': [all_inputs.index(p) for p in params],
  204. 'backward_link_internal': [str(l) for l in backward_link_internal],
  205. 'backward_link_external': [str(l) for l in backward_link_external],
  206. 'backward_link_offset': backward_link_offset,
  207. 'outputs_with_grads': outputs_with_grads,
  208. 'recompute_blobs_on_backward': [
  209. str(b) for b in recompute_blobs_on_backward
  210. ],
  211. 'param_grads': param_grads,
  212. }
  213. if len(backward_cell_net.Proto().op) != 0:
  214. backward_args['backward_step_net'] = backward_cell_net.Proto()
  215. results = net.RecurrentNetwork(
  216. all_inputs,
  217. all_outputs + [s("step_workspaces")],
  218. alias_src=alias_src,
  219. alias_dst=[str(a) for a in alias_dst],
  220. alias_offset=alias_offset,
  221. recurrent_states=recurrent_states,
  222. initial_recurrent_state_ids=[
  223. all_inputs.index(i) for i in recurrent_inputs
  224. ],
  225. link_internal=[str(l) for l in link_internal],
  226. link_external=[str(l) for l in link_external],
  227. link_offset=link_offset,
  228. enable_rnn_executor=1,
  229. step_net=cell_net.Proto(),
  230. timestep="timestep" if timestep is None else str(timestep),
  231. **backward_args
  232. )
  233. # Restore net type since 'rnn' is not recognized outside RNNs
  234. cell_net.Proto().type = 'simple'
  235. # The last output is a list of step workspaces,
  236. # which is only needed internally for gradient propogation
  237. return results[:-1]
  238. def set_rnn_executor_config(rnn_op, num_threads=None, max_cuda_streams=None):
  239. from caffe2.proto import caffe2_pb2
  240. assert rnn_op.type in {'RecurrentNetwork', 'RecurrentNetworkGradient'}
  241. def add_arg(s, v):
  242. a = caffe2_pb2.Argument()
  243. a.name = "rnn_executor." + s
  244. a.i = v
  245. rnn_op.arg.extend([a])
  246. if num_threads is not None:
  247. add_arg('num_threads', num_threads)
  248. if max_cuda_streams is not None:
  249. add_arg('max_cuda_streams', max_cuda_streams)
  250. def retrieve_step_blobs(net, prefix='rnn'):
  251. '''
  252. Retrieves blobs from step workspaces (which contain intermediate recurrent
  253. network computation for each timestep) and puts them in the global
  254. workspace. This allows access to the contents of this intermediate
  255. computation in python. Returns the list of extracted blob names.
  256. net: the net from which the step workspace blobs should be extracted
  257. prefix: prefix to append to extracted blob names when placing them in the
  258. global workspace
  259. '''
  260. count = 1
  261. output_list = []
  262. for op in net.Proto().op:
  263. if op.type == "RecurrentNetwork":
  264. blob_name = prefix + "_" + str(count)
  265. count = count + 1
  266. scratch_workspaces_blob_name = op.output[-1]
  267. workspace.RunOperatorOnce(
  268. core.CreateOperator(
  269. "RecurrentNetworkBlobFetcher",
  270. [scratch_workspaces_blob_name],
  271. [blob_name],
  272. prefix=prefix
  273. )
  274. )
  275. output_list += workspace.FetchBlob(blob_name).tolist()
  276. return output_list