crf.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. ## @package crf
  2. # Module caffe2.python.crf
  3. import numpy as np
  4. from caffe2.python import brew, core, model_helper, recurrent
  5. """
  6. Due to a limitation in ReccurentNetworkOp, this layer only supports batch_size=1
  7. In order to support batch_size > 1, we will have to implement the CRFUnit
  8. and its gradient in C++ and handle the different batches there.
  9. """
  10. class CRFWithLoss(object):
  11. def __init__(self, model, num_classes, transitions_blob=None):
  12. self.model = model
  13. self.num_classes = num_classes
  14. self.num_classes_padded = num_classes + 2 # After adding BOS and EOS
  15. if not transitions_blob:
  16. transitions_blob = self.model.param_init_net.UniformFill(
  17. [],
  18. [core.ScopedBlobReference("crf_transitions")],
  19. shape=[self.num_classes_padded, self.num_classes_padded],
  20. min=-1.0,
  21. max=1.0,
  22. )
  23. self.transitions = transitions_blob
  24. self.model.params.append(self.transitions)
  25. def crf_loss(self, predictions, labels, seq_lengths=None):
  26. # Since the transitions matrix is a shared parameter, need to
  27. # take a snapshot of it at the beginning since it can be updated
  28. # in between the operators that uses it when doing parallel updates
  29. transitions_snapshot = self.model.net.Copy(
  30. self.transitions, core.ScopedBlobReference("transitions_snapshot")
  31. )
  32. # Compute best path unary score from the logits
  33. path_unary_score = self._gather_entries_sum(
  34. predictions, labels, self.num_classes
  35. )
  36. # Append BOS and EOS entries to the predictions and labels
  37. predictions = CRFWithLoss.pad_predictions(
  38. predictions, self.model.param_init_net, self.model.net, self.num_classes
  39. )
  40. labels = CRFWithLoss.pad_labels(
  41. labels, self.model.param_init_net, self.model.net, self.num_classes
  42. )
  43. # Compute best path binary scores from the transitions matrix
  44. path_binary_score = self._path_binary_scores(
  45. labels, transitions_snapshot, seq_lengths
  46. )
  47. path_total_score = self.model.net.Add(
  48. [path_binary_score, path_unary_score],
  49. core.ScopedBlobReference("path_total"),
  50. )
  51. # Compute all paths score
  52. zero_index = self.model.param_init_net.ConstantFill([], shape=[1], value=0)
  53. initial_state = self.model.net.Gather(
  54. [predictions, zero_index],
  55. core.ScopedBlobReference("rnn_initial"),
  56. dense_gradient=True,
  57. )
  58. input_data, _ = self.model.net.RemovePadding(
  59. [predictions], padding_width=1, end_padding_width=0, outputs=2
  60. )
  61. input_data = self.model.net.ExpandDims(
  62. [input_data], core.ScopedBlobReference("rnn_input_data"), dims=[1]
  63. )
  64. # Due to a bug in RecurrentNetworkGradientOp, we need to copy the
  65. # transitions blob before sending it to the recurrent network
  66. transitions_copy = self.model.net.Copy(
  67. transitions_snapshot, core.ScopedBlobReference("transitions_copy")
  68. )
  69. all_paths_scores = self._crf_forward(
  70. input_data, initial_state, transitions_copy
  71. )
  72. loss = self.model.net.Sub(
  73. [all_paths_scores, path_total_score], core.ScopedBlobReference("crf_loss")
  74. )
  75. return loss
  76. def _path_binary_scores(self, labels, transitions, seq_lengths=None):
  77. column_ids, _ = self.model.net.RemovePadding(
  78. [labels], outputs=2, padding_width=1, end_padding_width=0
  79. )
  80. row_ids, _ = self.model.net.RemovePadding(
  81. [labels], outputs=2, padding_width=0, end_padding_width=1
  82. )
  83. # Since there is no multi-dimensional gather, I flatten the matrix to
  84. # a 1-d vector and transform the ids to (row_ids * num_columns +
  85. # column_ids) and do gather in 1-d
  86. num_columns_blob = self.model.net.ConstantFill(
  87. [row_ids], value=self.num_classes_padded
  88. )
  89. flattened_ids = self.model.net.Mul([row_ids, num_columns_blob])
  90. flattened_ids = self.model.net.Add([flattened_ids, column_ids])
  91. flattened_transitions = self.model.net.FlattenToVec([transitions])
  92. entries = self.model.net.Gather(
  93. [flattened_transitions, flattened_ids], dense_gradient=True
  94. )
  95. return self.model.ReduceFrontSum(entries)
  96. def _gather_entries_sum(self, in_data, indices, index_size):
  97. indices = self.model.net.Cast([indices], to="int64")
  98. index_size_blob = self.model.param_init_net.ConstantFill(
  99. [], shape=[1], value=index_size
  100. )
  101. query_one_hot = self.model.net.OneHot([indices, index_size_blob])
  102. flattend_query = self.model.net.FlattenToVec(query_one_hot)
  103. flattend_data = self.model.net.FlattenToVec(in_data)
  104. query_scores = self.model.net.DotProduct([flattend_query, flattend_data])
  105. final_sum = self.model.net.ReduceFrontSum([query_scores])
  106. return final_sum
  107. def _crf_forward(
  108. self, input_blob, initial_state, transitions_copy, seq_lengths=None
  109. ):
  110. # Build the RNN net and get the last timestep output
  111. out_last = self.build_crf_net(input_blob, initial_state, transitions_copy)
  112. out_last, _ = self.model.net.Reshape(
  113. [out_last], outputs=2, shape=(self.num_classes_padded,)
  114. )
  115. zero_segment_id = self.model.param_init_net.ConstantFill(
  116. [], value=0, shape=[self.num_classes_padded], dtype=core.DataType.INT32
  117. )
  118. # Compute the accumulated total score of all the paths
  119. accum_score = self.model.net.SortedSegmentRangeLogSumExp(
  120. [out_last, zero_segment_id]
  121. )
  122. accum_score, _ = self.model.net.Reshape(accum_score, outputs=2, shape=())
  123. return accum_score
  124. def build_crf_net(self, input_blob, initial_state, transitions):
  125. """
  126. Adds the crf_net recurrent operator to the model.
  127. model: model_helper.ModelHelper object new operators would be added
  128. to
  129. input_blob: the input sequence in a format T x N x D
  130. where T is sequence size, N - batch size and D - input dimension
  131. ##Only supports batch-size 1##
  132. seq_lengths: blob containing sequence lengths (unused)
  133. """
  134. scope = "crf_net"
  135. def s(name):
  136. ""
  137. # We have to manually scope due to our internal/external blob
  138. # relationships.
  139. return "{}/{}".format(str(scope), str(name))
  140. step_model = model_helper.ModelHelper(name="crf_step", param_model=self.model)
  141. input_t, cell_t_prev, _ = step_model.net.AddExternalInputs(
  142. core.ScopedBlobReference("input_t"),
  143. core.ScopedBlobReference("cell_t_prev"),
  144. transitions,
  145. )
  146. zero_segment_id = step_model.param_init_net.ConstantFill(
  147. [],
  148. [s("zero_segment_id")],
  149. value=0,
  150. shape=[self.num_classes_padded],
  151. dtype=core.DataType.INT32,
  152. )
  153. # A hack to bypass model cloning for test
  154. step_model.param_init_net.AddExternalOutput(zero_segment_id)
  155. """ the CRF step """
  156. # Do tile
  157. prev_transpose = brew.transpose(
  158. step_model, cell_t_prev, [s("prev_transpose")], axes=(0, 2, 1)
  159. )
  160. prev_tiled = step_model.net.Tile(
  161. prev_transpose, [s("prev_tiled")], tiles=self.num_classes_padded, axis=2
  162. )
  163. input_t_tiled = step_model.net.Tile(
  164. input_t, [s("input_t_tiled")], tiles=self.num_classes_padded, axis=1
  165. )
  166. input_with_prev = step_model.net.Add(
  167. [prev_tiled, input_t_tiled], [s("input_with_prev")]
  168. )
  169. all_with_transitions = step_model.net.Add(
  170. [input_with_prev, transitions],
  171. [s("prev_with_transitions")],
  172. broadcast=1,
  173. use_grad_hack=1,
  174. )
  175. all_with_transitions_reshaped, _ = step_model.net.Reshape(
  176. all_with_transitions,
  177. [s("all_with_transitions_reshaped"), s("all_with_transitions_orig")],
  178. shape=(self.num_classes_padded, self.num_classes_padded),
  179. )
  180. cell_t = step_model.net.SortedSegmentRangeLogSumExp(
  181. [all_with_transitions_reshaped, zero_segment_id], [s("cell_t")]
  182. )
  183. step_model.net.AddExternalOutputs(cell_t)
  184. """ recurrent network """
  185. cell_input_blob = initial_state
  186. out_all, out_last = recurrent.recurrent_net(
  187. net=self.model.net,
  188. cell_net=step_model.net,
  189. inputs=[(input_t, input_blob)],
  190. initial_cell_inputs=[(cell_t_prev, cell_input_blob)],
  191. links={cell_t_prev: cell_t},
  192. scope=scope,
  193. outputs_with_grads=(1,),
  194. )
  195. return out_last
  196. def update_predictions(self, classes):
  197. def crf_update_predictions_op(inputs, outputs):
  198. # This operator will compute the best path of classes by performing
  199. # Viterbi decoding and then updates the predictions to make the tag
  200. # On the best path has the highest score among the others
  201. predictions = inputs[0].data
  202. transitions = inputs[1].data
  203. predictions = inputs[0].data
  204. predictions_shape = inputs[0].shape
  205. outputs[0].reshape(predictions_shape)
  206. trellis = np.zeros(predictions_shape)
  207. backpointers = np.zeros(predictions_shape, dtype=np.int32)
  208. trellis[0] = predictions[0]
  209. for t in range(1, predictions_shape[0]):
  210. v = np.expand_dims(trellis[t - 1], 1) + transitions
  211. trellis[t] = predictions[t] + np.max(v, 0)
  212. backpointers[t] = np.argmax(v, 0)
  213. viterbi = [np.argmax(trellis[-1])]
  214. for bp in reversed(backpointers[1:]):
  215. viterbi.append(bp[viterbi[-1]])
  216. viterbi.reverse()
  217. new_predictions = np.zeros(predictions_shape)
  218. old_bests = []
  219. for i, w_predictions in enumerate(predictions):
  220. # Get the current tag with the maximum score
  221. new_predictions[i] = predictions[i]
  222. old_best = np.argmax(w_predictions)
  223. old_bests.append(old_best)
  224. # Swap the scores of the current best tag and the tag on the
  225. # Viterbi path
  226. w_predictions[viterbi[i]], w_predictions[old_best] = (
  227. w_predictions[old_best],
  228. w_predictions[viterbi[i]],
  229. )
  230. new_predictions[i] = w_predictions
  231. # Remove the BOS and EOS entries from the predictions matrix
  232. orig_predictions = new_predictions[1:-1, 0:-2]
  233. outputs[0].reshape(orig_predictions.shape)
  234. outputs[0].data[...] = orig_predictions
  235. padded_classes = CRFWithLoss.pad_predictions(
  236. classes, self.model.param_init_net, self.model.net, self.num_classes
  237. )
  238. new_classes = self.model.net.Python(crf_update_predictions_op)(
  239. [padded_classes, self.transitions],
  240. core.ScopedBlobReference("post_crf_classes"),
  241. )
  242. return new_classes
  243. @staticmethod
  244. def pad_labels(labels, init_net, net, num_classes):
  245. bos_i = num_classes
  246. eos_i = num_classes + 1
  247. bos_i_b = init_net.ConstantFill([], shape=[1], value=bos_i)
  248. eos_i_b = init_net.ConstantFill([], shape=[1], value=eos_i)
  249. labels = net.Cast([labels], to="int64")
  250. padded_labels, _ = net.Concat([bos_i_b, labels, eos_i_b], axis=0, outputs=2)
  251. return padded_labels
  252. @staticmethod
  253. def pad_predictions(predictions, init_net, net, num_classes):
  254. # This function will introduce two labels for beginning of sequence
  255. # And end of sequence, it will make the necessary udpates to the
  256. # the predictions blob
  257. low_score = -1000.0 # An arbitray very low number
  258. b_scores = np.array([[low_score] * num_classes + [0, low_score]]).astype(
  259. np.float32
  260. )
  261. e_scores = np.array([[low_score] * num_classes + [low_score, 0]]).astype(
  262. np.float32
  263. )
  264. b_scores = init_net.GivenTensorFill(
  265. [], "b_scores", shape=[1, num_classes + 2], values=b_scores
  266. )
  267. e_scores = init_net.GivenTensorFill(
  268. [], "e_scores", shape=[1, num_classes + 2], values=e_scores
  269. )
  270. zero_index = net.ConstantFill([], shape=[1], value=0)
  271. length = net.Gather([net.Shape([predictions]), zero_index])
  272. length = net.Cast(length, to="int32")
  273. t_range = net.LengthsRangeFill(length)
  274. padding = net.ConstantFill([t_range], value=low_score)
  275. padding = net.ExpandDims(padding, dims=[1])
  276. padded_predictions, _ = net.Concat(
  277. [predictions, padding, padding], outputs=2, axis=1
  278. )
  279. padded_predictions_concat, _ = net.Concat(
  280. [b_scores, padded_predictions, e_scores], outputs=2, axis=0
  281. )
  282. return padded_predictions_concat