attention.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. ## @package attention
  2. # Module caffe2.python.attention
  3. from caffe2.python import brew
  4. class AttentionType:
  5. Regular, Recurrent, Dot, SoftCoverage = tuple(range(4))
  6. def s(scope, name):
  7. # We have to manually scope due to our internal/external blob
  8. # relationships.
  9. return "{}/{}".format(str(scope), str(name))
  10. # c_i = \sum_j w_{ij}\textbf{s}_j
  11. def _calc_weighted_context(
  12. model,
  13. encoder_outputs_transposed,
  14. encoder_output_dim,
  15. attention_weights_3d,
  16. scope,
  17. ):
  18. # [batch_size, encoder_output_dim, 1]
  19. attention_weighted_encoder_context = brew.batch_mat_mul(
  20. model,
  21. [encoder_outputs_transposed, attention_weights_3d],
  22. s(scope, 'attention_weighted_encoder_context'),
  23. )
  24. # [batch_size, encoder_output_dim]
  25. attention_weighted_encoder_context, _ = model.net.Reshape(
  26. attention_weighted_encoder_context,
  27. [
  28. attention_weighted_encoder_context,
  29. s(scope, 'attention_weighted_encoder_context_old_shape'),
  30. ],
  31. shape=[1, -1, encoder_output_dim],
  32. )
  33. return attention_weighted_encoder_context
  34. # Calculate a softmax over the passed in attention energy logits
  35. def _calc_attention_weights(
  36. model,
  37. attention_logits_transposed,
  38. scope,
  39. encoder_lengths=None,
  40. ):
  41. if encoder_lengths is not None:
  42. attention_logits_transposed = model.net.SequenceMask(
  43. [attention_logits_transposed, encoder_lengths],
  44. ['masked_attention_logits'],
  45. mode='sequence',
  46. )
  47. # [batch_size, encoder_length, 1]
  48. attention_weights_3d = brew.softmax(
  49. model,
  50. attention_logits_transposed,
  51. s(scope, 'attention_weights_3d'),
  52. engine='CUDNN',
  53. axis=1,
  54. )
  55. return attention_weights_3d
  56. # e_{ij} = \textbf{v}^T tanh \alpha(\textbf{h}_{i-1}, \textbf{s}_j)
  57. def _calc_attention_logits_from_sum_match(
  58. model,
  59. decoder_hidden_encoder_outputs_sum,
  60. encoder_output_dim,
  61. scope,
  62. ):
  63. # [encoder_length, batch_size, encoder_output_dim]
  64. decoder_hidden_encoder_outputs_sum = model.net.Tanh(
  65. decoder_hidden_encoder_outputs_sum,
  66. decoder_hidden_encoder_outputs_sum,
  67. )
  68. # [encoder_length, batch_size, 1]
  69. attention_logits = brew.fc(
  70. model,
  71. decoder_hidden_encoder_outputs_sum,
  72. s(scope, 'attention_logits'),
  73. dim_in=encoder_output_dim,
  74. dim_out=1,
  75. axis=2,
  76. freeze_bias=True,
  77. )
  78. # [batch_size, encoder_length, 1]
  79. attention_logits_transposed = brew.transpose(
  80. model,
  81. attention_logits,
  82. s(scope, 'attention_logits_transposed'),
  83. axes=[1, 0, 2],
  84. )
  85. return attention_logits_transposed
  86. # \textbf{W}^\alpha used in the context of \alpha_{sum}(a,b)
  87. def _apply_fc_weight_for_sum_match(
  88. model,
  89. input,
  90. dim_in,
  91. dim_out,
  92. scope,
  93. name,
  94. ):
  95. output = brew.fc(
  96. model,
  97. input,
  98. s(scope, name),
  99. dim_in=dim_in,
  100. dim_out=dim_out,
  101. axis=2,
  102. )
  103. output = model.net.Squeeze(
  104. output,
  105. output,
  106. dims=[0],
  107. )
  108. return output
  109. # Implement RecAtt due to section 4.1 in http://arxiv.org/abs/1601.03317
  110. def apply_recurrent_attention(
  111. model,
  112. encoder_output_dim,
  113. encoder_outputs_transposed,
  114. weighted_encoder_outputs,
  115. decoder_hidden_state_t,
  116. decoder_hidden_state_dim,
  117. attention_weighted_encoder_context_t_prev,
  118. scope,
  119. encoder_lengths=None,
  120. ):
  121. weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
  122. model=model,
  123. input=attention_weighted_encoder_context_t_prev,
  124. dim_in=encoder_output_dim,
  125. dim_out=encoder_output_dim,
  126. scope=scope,
  127. name='weighted_prev_attention_context',
  128. )
  129. weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
  130. model=model,
  131. input=decoder_hidden_state_t,
  132. dim_in=decoder_hidden_state_dim,
  133. dim_out=encoder_output_dim,
  134. scope=scope,
  135. name='weighted_decoder_hidden_state',
  136. )
  137. # [1, batch_size, encoder_output_dim]
  138. decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
  139. [
  140. weighted_prev_attention_context,
  141. weighted_decoder_hidden_state,
  142. ],
  143. s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
  144. )
  145. # [encoder_length, batch_size, encoder_output_dim]
  146. decoder_hidden_encoder_outputs_sum = model.net.Add(
  147. [
  148. weighted_encoder_outputs,
  149. decoder_hidden_encoder_outputs_sum_tmp,
  150. ],
  151. s(scope, 'decoder_hidden_encoder_outputs_sum'),
  152. broadcast=1,
  153. )
  154. attention_logits_transposed = _calc_attention_logits_from_sum_match(
  155. model=model,
  156. decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
  157. encoder_output_dim=encoder_output_dim,
  158. scope=scope,
  159. )
  160. # [batch_size, encoder_length, 1]
  161. attention_weights_3d = _calc_attention_weights(
  162. model=model,
  163. attention_logits_transposed=attention_logits_transposed,
  164. scope=scope,
  165. encoder_lengths=encoder_lengths,
  166. )
  167. # [batch_size, encoder_output_dim, 1]
  168. attention_weighted_encoder_context = _calc_weighted_context(
  169. model=model,
  170. encoder_outputs_transposed=encoder_outputs_transposed,
  171. encoder_output_dim=encoder_output_dim,
  172. attention_weights_3d=attention_weights_3d,
  173. scope=scope,
  174. )
  175. return attention_weighted_encoder_context, attention_weights_3d, [
  176. decoder_hidden_encoder_outputs_sum,
  177. ]
  178. def apply_regular_attention(
  179. model,
  180. encoder_output_dim,
  181. encoder_outputs_transposed,
  182. weighted_encoder_outputs,
  183. decoder_hidden_state_t,
  184. decoder_hidden_state_dim,
  185. scope,
  186. encoder_lengths=None,
  187. ):
  188. weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
  189. model=model,
  190. input=decoder_hidden_state_t,
  191. dim_in=decoder_hidden_state_dim,
  192. dim_out=encoder_output_dim,
  193. scope=scope,
  194. name='weighted_decoder_hidden_state',
  195. )
  196. # [encoder_length, batch_size, encoder_output_dim]
  197. decoder_hidden_encoder_outputs_sum = model.net.Add(
  198. [weighted_encoder_outputs, weighted_decoder_hidden_state],
  199. s(scope, 'decoder_hidden_encoder_outputs_sum'),
  200. broadcast=1,
  201. use_grad_hack=1,
  202. )
  203. attention_logits_transposed = _calc_attention_logits_from_sum_match(
  204. model=model,
  205. decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
  206. encoder_output_dim=encoder_output_dim,
  207. scope=scope,
  208. )
  209. # [batch_size, encoder_length, 1]
  210. attention_weights_3d = _calc_attention_weights(
  211. model=model,
  212. attention_logits_transposed=attention_logits_transposed,
  213. scope=scope,
  214. encoder_lengths=encoder_lengths,
  215. )
  216. # [batch_size, encoder_output_dim, 1]
  217. attention_weighted_encoder_context = _calc_weighted_context(
  218. model=model,
  219. encoder_outputs_transposed=encoder_outputs_transposed,
  220. encoder_output_dim=encoder_output_dim,
  221. attention_weights_3d=attention_weights_3d,
  222. scope=scope,
  223. )
  224. return attention_weighted_encoder_context, attention_weights_3d, [
  225. decoder_hidden_encoder_outputs_sum,
  226. ]
  227. def apply_dot_attention(
  228. model,
  229. encoder_output_dim,
  230. # [batch_size, encoder_output_dim, encoder_length]
  231. encoder_outputs_transposed,
  232. # [1, batch_size, decoder_state_dim]
  233. decoder_hidden_state_t,
  234. decoder_hidden_state_dim,
  235. scope,
  236. encoder_lengths=None,
  237. ):
  238. if decoder_hidden_state_dim != encoder_output_dim:
  239. weighted_decoder_hidden_state = brew.fc(
  240. model,
  241. decoder_hidden_state_t,
  242. s(scope, 'weighted_decoder_hidden_state'),
  243. dim_in=decoder_hidden_state_dim,
  244. dim_out=encoder_output_dim,
  245. axis=2,
  246. )
  247. else:
  248. weighted_decoder_hidden_state = decoder_hidden_state_t
  249. # [batch_size, decoder_state_dim]
  250. squeezed_weighted_decoder_hidden_state = model.net.Squeeze(
  251. weighted_decoder_hidden_state,
  252. s(scope, 'squeezed_weighted_decoder_hidden_state'),
  253. dims=[0],
  254. )
  255. # [batch_size, decoder_state_dim, 1]
  256. expanddims_squeezed_weighted_decoder_hidden_state = model.net.ExpandDims(
  257. squeezed_weighted_decoder_hidden_state,
  258. squeezed_weighted_decoder_hidden_state,
  259. dims=[2],
  260. )
  261. # [batch_size, encoder_output_dim, 1]
  262. attention_logits_transposed = model.net.BatchMatMul(
  263. [
  264. encoder_outputs_transposed,
  265. expanddims_squeezed_weighted_decoder_hidden_state,
  266. ],
  267. s(scope, 'attention_logits'),
  268. trans_a=1,
  269. )
  270. # [batch_size, encoder_length, 1]
  271. attention_weights_3d = _calc_attention_weights(
  272. model=model,
  273. attention_logits_transposed=attention_logits_transposed,
  274. scope=scope,
  275. encoder_lengths=encoder_lengths,
  276. )
  277. # [batch_size, encoder_output_dim, 1]
  278. attention_weighted_encoder_context = _calc_weighted_context(
  279. model=model,
  280. encoder_outputs_transposed=encoder_outputs_transposed,
  281. encoder_output_dim=encoder_output_dim,
  282. attention_weights_3d=attention_weights_3d,
  283. scope=scope,
  284. )
  285. return attention_weighted_encoder_context, attention_weights_3d, []
  286. def apply_soft_coverage_attention(
  287. model,
  288. encoder_output_dim,
  289. encoder_outputs_transposed,
  290. weighted_encoder_outputs,
  291. decoder_hidden_state_t,
  292. decoder_hidden_state_dim,
  293. scope,
  294. encoder_lengths,
  295. coverage_t_prev,
  296. coverage_weights,
  297. ):
  298. weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
  299. model=model,
  300. input=decoder_hidden_state_t,
  301. dim_in=decoder_hidden_state_dim,
  302. dim_out=encoder_output_dim,
  303. scope=scope,
  304. name='weighted_decoder_hidden_state',
  305. )
  306. # [encoder_length, batch_size, encoder_output_dim]
  307. decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
  308. [weighted_encoder_outputs, weighted_decoder_hidden_state],
  309. s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
  310. broadcast=1,
  311. )
  312. # [batch_size, encoder_length]
  313. coverage_t_prev_2d = model.net.Squeeze(
  314. coverage_t_prev,
  315. s(scope, 'coverage_t_prev_2d'),
  316. dims=[0],
  317. )
  318. # [encoder_length, batch_size]
  319. coverage_t_prev_transposed = brew.transpose(
  320. model,
  321. coverage_t_prev_2d,
  322. s(scope, 'coverage_t_prev_transposed'),
  323. )
  324. # [encoder_length, batch_size, encoder_output_dim]
  325. scaled_coverage_weights = model.net.Mul(
  326. [coverage_weights, coverage_t_prev_transposed],
  327. s(scope, 'scaled_coverage_weights'),
  328. broadcast=1,
  329. axis=0,
  330. )
  331. # [encoder_length, batch_size, encoder_output_dim]
  332. decoder_hidden_encoder_outputs_sum = model.net.Add(
  333. [decoder_hidden_encoder_outputs_sum_tmp, scaled_coverage_weights],
  334. s(scope, 'decoder_hidden_encoder_outputs_sum'),
  335. )
  336. # [batch_size, encoder_length, 1]
  337. attention_logits_transposed = _calc_attention_logits_from_sum_match(
  338. model=model,
  339. decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
  340. encoder_output_dim=encoder_output_dim,
  341. scope=scope,
  342. )
  343. # [batch_size, encoder_length, 1]
  344. attention_weights_3d = _calc_attention_weights(
  345. model=model,
  346. attention_logits_transposed=attention_logits_transposed,
  347. scope=scope,
  348. encoder_lengths=encoder_lengths,
  349. )
  350. # [batch_size, encoder_output_dim, 1]
  351. attention_weighted_encoder_context = _calc_weighted_context(
  352. model=model,
  353. encoder_outputs_transposed=encoder_outputs_transposed,
  354. encoder_output_dim=encoder_output_dim,
  355. attention_weights_3d=attention_weights_3d,
  356. scope=scope,
  357. )
  358. # [batch_size, encoder_length]
  359. attention_weights_2d = model.net.Squeeze(
  360. attention_weights_3d,
  361. s(scope, 'attention_weights_2d'),
  362. dims=[2],
  363. )
  364. coverage_t = model.net.Add(
  365. [coverage_t_prev, attention_weights_2d],
  366. s(scope, 'coverage_t'),
  367. broadcast=1,
  368. )
  369. return (
  370. attention_weighted_encoder_context,
  371. attention_weights_3d,
  372. [decoder_hidden_encoder_outputs_sum],
  373. coverage_t,
  374. )