translate.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. ## @package translate
  2. # Module caffe2.python.models.seq2seq.translate
  3. from abc import ABCMeta, abstractmethod
  4. import argparse
  5. from future.utils import viewitems
  6. import logging
  7. import numpy as np
  8. import sys
  9. from caffe2.python import core, rnn_cell, workspace
  10. from caffe2.python.models.seq2seq.beam_search import BeamSearchForwardOnly
  11. from caffe2.python.models.seq2seq.seq2seq_model_helper import Seq2SeqModelHelper
  12. import caffe2.python.models.seq2seq.seq2seq_util as seq2seq_util
  13. logger = logging.getLogger(__name__)
  14. logger.setLevel(logging.INFO)
  15. logger.addHandler(logging.StreamHandler(sys.stderr))
  16. def _weighted_sum(model, values, weight, output_name):
  17. values_weights = zip(values, [weight] * len(values))
  18. values_weights_flattened = [x for v_w in values_weights for x in v_w]
  19. return model.net.WeightedSum(
  20. values_weights_flattened,
  21. output_name,
  22. )
  23. class Seq2SeqModelCaffe2EnsembleDecoderBase(metaclass=ABCMeta):
  24. @abstractmethod
  25. def get_model_file(self, model):
  26. pass
  27. @abstractmethod
  28. def get_db_type(self):
  29. pass
  30. def build_word_rewards(self, vocab_size, word_reward, unk_reward):
  31. word_rewards = np.full([vocab_size], word_reward, dtype=np.float32)
  32. word_rewards[seq2seq_util.PAD_ID] = 0
  33. word_rewards[seq2seq_util.GO_ID] = 0
  34. word_rewards[seq2seq_util.EOS_ID] = 0
  35. word_rewards[seq2seq_util.UNK_ID] = word_reward + unk_reward
  36. return word_rewards
  37. def load_models(self):
  38. db_reader = 'reader'
  39. for model, scope_name in zip(
  40. self.models,
  41. self.decoder_scope_names,
  42. ):
  43. params_for_current_model = [
  44. param
  45. for param in self.model.GetAllParams()
  46. if str(param).startswith(scope_name)
  47. ]
  48. assert workspace.RunOperatorOnce(core.CreateOperator(
  49. 'CreateDB',
  50. [], [db_reader],
  51. db=self.get_model_file(model),
  52. db_type=self.get_db_type())
  53. ), 'Failed to create db {}'.format(self.get_model_file(model))
  54. assert workspace.RunOperatorOnce(core.CreateOperator(
  55. 'Load',
  56. [db_reader],
  57. params_for_current_model,
  58. load_all=1,
  59. add_prefix=scope_name + '/',
  60. strip_prefix='gpu_0/',
  61. ))
  62. logger.info('Model {} is loaded from a checkpoint {}'.format(
  63. scope_name, self.get_model_file(model)))
  64. class Seq2SeqModelCaffe2EnsembleDecoder(Seq2SeqModelCaffe2EnsembleDecoderBase):
  65. def get_model_file(self, model):
  66. return model['model_file']
  67. def get_db_type(self):
  68. return 'minidb'
  69. def scope(self, scope_name, blob_name):
  70. return (
  71. scope_name + '/' + blob_name
  72. if scope_name is not None
  73. else blob_name
  74. )
  75. def _build_decoder(
  76. self,
  77. model,
  78. step_model,
  79. model_params,
  80. scope,
  81. previous_tokens,
  82. timestep,
  83. fake_seq_lengths,
  84. ):
  85. attention_type = model_params['attention']
  86. assert attention_type in ['none', 'regular']
  87. use_attention = (attention_type != 'none')
  88. with core.NameScope(scope):
  89. encoder_embeddings = seq2seq_util.build_embeddings(
  90. model=model,
  91. vocab_size=self.source_vocab_size,
  92. embedding_size=model_params['encoder_embedding_size'],
  93. name='encoder_embeddings',
  94. freeze_embeddings=False,
  95. )
  96. (
  97. encoder_outputs,
  98. weighted_encoder_outputs,
  99. final_encoder_hidden_states,
  100. final_encoder_cell_states,
  101. encoder_units_per_layer,
  102. ) = seq2seq_util.build_embedding_encoder(
  103. model=model,
  104. encoder_params=model_params['encoder_type'],
  105. num_decoder_layers=len(model_params['decoder_layer_configs']),
  106. inputs=self.encoder_inputs,
  107. input_lengths=self.encoder_lengths,
  108. vocab_size=self.source_vocab_size,
  109. embeddings=encoder_embeddings,
  110. embedding_size=model_params['encoder_embedding_size'],
  111. use_attention=use_attention,
  112. num_gpus=0,
  113. forward_only=True,
  114. scope=scope,
  115. )
  116. with core.NameScope(scope):
  117. if use_attention:
  118. # [max_source_length, beam_size, encoder_output_dim]
  119. encoder_outputs = model.net.Tile(
  120. encoder_outputs,
  121. 'encoder_outputs_tiled',
  122. tiles=self.beam_size,
  123. axis=1,
  124. )
  125. if weighted_encoder_outputs is not None:
  126. weighted_encoder_outputs = model.net.Tile(
  127. weighted_encoder_outputs,
  128. 'weighted_encoder_outputs_tiled',
  129. tiles=self.beam_size,
  130. axis=1,
  131. )
  132. decoder_embeddings = seq2seq_util.build_embeddings(
  133. model=model,
  134. vocab_size=self.target_vocab_size,
  135. embedding_size=model_params['decoder_embedding_size'],
  136. name='decoder_embeddings',
  137. freeze_embeddings=False,
  138. )
  139. embedded_tokens_t_prev = step_model.net.Gather(
  140. [decoder_embeddings, previous_tokens],
  141. 'embedded_tokens_t_prev',
  142. )
  143. decoder_cells = []
  144. decoder_units_per_layer = []
  145. for i, layer_config in enumerate(model_params['decoder_layer_configs']):
  146. num_units = layer_config['num_units']
  147. decoder_units_per_layer.append(num_units)
  148. if i == 0:
  149. input_size = model_params['decoder_embedding_size']
  150. else:
  151. input_size = (
  152. model_params['decoder_layer_configs'][i - 1]['num_units']
  153. )
  154. cell = rnn_cell.LSTMCell(
  155. forward_only=True,
  156. input_size=input_size,
  157. hidden_size=num_units,
  158. forget_bias=0.0,
  159. memory_optimization=False,
  160. )
  161. decoder_cells.append(cell)
  162. with core.NameScope(scope):
  163. if final_encoder_hidden_states is not None:
  164. for i in range(len(final_encoder_hidden_states)):
  165. if final_encoder_hidden_states[i] is not None:
  166. final_encoder_hidden_states[i] = model.net.Tile(
  167. final_encoder_hidden_states[i],
  168. 'final_encoder_hidden_tiled_{}'.format(i),
  169. tiles=self.beam_size,
  170. axis=1,
  171. )
  172. if final_encoder_cell_states is not None:
  173. for i in range(len(final_encoder_cell_states)):
  174. if final_encoder_cell_states[i] is not None:
  175. final_encoder_cell_states[i] = model.net.Tile(
  176. final_encoder_cell_states[i],
  177. 'final_encoder_cell_tiled_{}'.format(i),
  178. tiles=self.beam_size,
  179. axis=1,
  180. )
  181. initial_states = \
  182. seq2seq_util.build_initial_rnn_decoder_states(
  183. model=model,
  184. encoder_units_per_layer=encoder_units_per_layer,
  185. decoder_units_per_layer=decoder_units_per_layer,
  186. final_encoder_hidden_states=final_encoder_hidden_states,
  187. final_encoder_cell_states=final_encoder_cell_states,
  188. use_attention=use_attention,
  189. )
  190. attention_decoder = seq2seq_util.LSTMWithAttentionDecoder(
  191. encoder_outputs=encoder_outputs,
  192. encoder_output_dim=encoder_units_per_layer[-1],
  193. encoder_lengths=None,
  194. vocab_size=self.target_vocab_size,
  195. attention_type=attention_type,
  196. embedding_size=model_params['decoder_embedding_size'],
  197. decoder_num_units=decoder_units_per_layer[-1],
  198. decoder_cells=decoder_cells,
  199. weighted_encoder_outputs=weighted_encoder_outputs,
  200. name=scope,
  201. )
  202. states_prev = step_model.net.AddExternalInputs(*[
  203. '{}/{}_prev'.format(scope, s)
  204. for s in attention_decoder.get_state_names()
  205. ])
  206. decoder_outputs, states = attention_decoder.apply(
  207. model=step_model,
  208. input_t=embedded_tokens_t_prev,
  209. seq_lengths=fake_seq_lengths,
  210. states=states_prev,
  211. timestep=timestep,
  212. )
  213. state_configs = [
  214. BeamSearchForwardOnly.StateConfig(
  215. initial_value=initial_state,
  216. state_prev_link=BeamSearchForwardOnly.LinkConfig(
  217. blob=state_prev,
  218. offset=0,
  219. window=1,
  220. ),
  221. state_link=BeamSearchForwardOnly.LinkConfig(
  222. blob=state,
  223. offset=1,
  224. window=1,
  225. ),
  226. )
  227. for initial_state, state_prev, state in zip(
  228. initial_states,
  229. states_prev,
  230. states,
  231. )
  232. ]
  233. with core.NameScope(scope):
  234. decoder_outputs_flattened, _ = step_model.net.Reshape(
  235. [decoder_outputs],
  236. [
  237. 'decoder_outputs_flattened',
  238. 'decoder_outputs_and_contexts_combination_old_shape',
  239. ],
  240. shape=[-1, attention_decoder.get_output_dim()],
  241. )
  242. output_logits = seq2seq_util.output_projection(
  243. model=step_model,
  244. decoder_outputs=decoder_outputs_flattened,
  245. decoder_output_size=attention_decoder.get_output_dim(),
  246. target_vocab_size=self.target_vocab_size,
  247. decoder_softmax_size=model_params['decoder_softmax_size'],
  248. )
  249. # [1, beam_size, target_vocab_size]
  250. output_probs = step_model.net.Softmax(
  251. output_logits,
  252. 'output_probs',
  253. )
  254. output_log_probs = step_model.net.Log(
  255. output_probs,
  256. 'output_log_probs',
  257. )
  258. if use_attention:
  259. attention_weights = attention_decoder.get_attention_weights()
  260. else:
  261. attention_weights = step_model.net.ConstantFill(
  262. [self.encoder_inputs],
  263. 'zero_attention_weights_tmp_1',
  264. value=0.0,
  265. )
  266. attention_weights = step_model.net.Transpose(
  267. attention_weights,
  268. 'zero_attention_weights_tmp_2',
  269. )
  270. attention_weights = step_model.net.Tile(
  271. attention_weights,
  272. 'zero_attention_weights_tmp',
  273. tiles=self.beam_size,
  274. axis=0,
  275. )
  276. return (
  277. state_configs,
  278. output_log_probs,
  279. attention_weights,
  280. )
  281. def __init__(
  282. self,
  283. translate_params,
  284. ):
  285. self.models = translate_params['ensemble_models']
  286. decoding_params = translate_params['decoding_params']
  287. self.beam_size = decoding_params['beam_size']
  288. assert len(self.models) > 0
  289. source_vocab = self.models[0]['source_vocab']
  290. target_vocab = self.models[0]['target_vocab']
  291. for model in self.models:
  292. assert model['source_vocab'] == source_vocab
  293. assert model['target_vocab'] == target_vocab
  294. self.source_vocab_size = len(source_vocab)
  295. self.target_vocab_size = len(target_vocab)
  296. self.decoder_scope_names = [
  297. 'model{}'.format(i) for i in range(len(self.models))
  298. ]
  299. self.model = Seq2SeqModelHelper(init_params=True)
  300. self.encoder_inputs = self.model.net.AddExternalInput('encoder_inputs')
  301. self.encoder_lengths = self.model.net.AddExternalInput(
  302. 'encoder_lengths'
  303. )
  304. self.max_output_seq_len = self.model.net.AddExternalInput(
  305. 'max_output_seq_len'
  306. )
  307. fake_seq_lengths = self.model.param_init_net.ConstantFill(
  308. [],
  309. 'fake_seq_lengths',
  310. shape=[self.beam_size],
  311. value=100000,
  312. dtype=core.DataType.INT32,
  313. )
  314. beam_decoder = BeamSearchForwardOnly(
  315. beam_size=self.beam_size,
  316. model=self.model,
  317. go_token_id=seq2seq_util.GO_ID,
  318. eos_token_id=seq2seq_util.EOS_ID,
  319. )
  320. step_model = beam_decoder.get_step_model()
  321. state_configs = []
  322. output_log_probs = []
  323. attention_weights = []
  324. for model, scope_name in zip(
  325. self.models,
  326. self.decoder_scope_names,
  327. ):
  328. (
  329. state_configs_per_decoder,
  330. output_log_probs_per_decoder,
  331. attention_weights_per_decoder,
  332. ) = self._build_decoder(
  333. model=self.model,
  334. step_model=step_model,
  335. model_params=model['model_params'],
  336. scope=scope_name,
  337. previous_tokens=beam_decoder.get_previous_tokens(),
  338. timestep=beam_decoder.get_timestep(),
  339. fake_seq_lengths=fake_seq_lengths,
  340. )
  341. state_configs.extend(state_configs_per_decoder)
  342. output_log_probs.append(output_log_probs_per_decoder)
  343. if attention_weights_per_decoder is not None:
  344. attention_weights.append(attention_weights_per_decoder)
  345. assert len(attention_weights) > 0
  346. num_decoders_with_attention_blob = (
  347. self.model.param_init_net.ConstantFill(
  348. [],
  349. 'num_decoders_with_attention_blob',
  350. value=1 / float(len(attention_weights)),
  351. shape=[1],
  352. )
  353. )
  354. # [beam_size, encoder_length, 1]
  355. attention_weights_average = _weighted_sum(
  356. model=step_model,
  357. values=attention_weights,
  358. weight=num_decoders_with_attention_blob,
  359. output_name='attention_weights_average',
  360. )
  361. num_decoders_blob = self.model.param_init_net.ConstantFill(
  362. [],
  363. 'num_decoders_blob',
  364. value=1 / float(len(output_log_probs)),
  365. shape=[1],
  366. )
  367. # [beam_size, target_vocab_size]
  368. output_log_probs_average = _weighted_sum(
  369. model=step_model,
  370. values=output_log_probs,
  371. weight=num_decoders_blob,
  372. output_name='output_log_probs_average',
  373. )
  374. word_rewards = self.model.param_init_net.ConstantFill(
  375. [],
  376. 'word_rewards',
  377. shape=[self.target_vocab_size],
  378. value=0.0,
  379. dtype=core.DataType.FLOAT,
  380. )
  381. (
  382. self.output_token_beam_list,
  383. self.output_prev_index_beam_list,
  384. self.output_score_beam_list,
  385. self.output_attention_weights_beam_list,
  386. ) = beam_decoder.apply(
  387. inputs=self.encoder_inputs,
  388. length=self.max_output_seq_len,
  389. log_probs=output_log_probs_average,
  390. attentions=attention_weights_average,
  391. state_configs=state_configs,
  392. data_dependencies=[],
  393. word_rewards=word_rewards,
  394. )
  395. workspace.RunNetOnce(self.model.param_init_net)
  396. workspace.FeedBlob(
  397. 'word_rewards',
  398. self.build_word_rewards(
  399. vocab_size=self.target_vocab_size,
  400. word_reward=translate_params['decoding_params']['word_reward'],
  401. unk_reward=translate_params['decoding_params']['unk_reward'],
  402. )
  403. )
  404. workspace.CreateNet(
  405. self.model.net,
  406. input_blobs=[
  407. str(self.encoder_inputs),
  408. str(self.encoder_lengths),
  409. str(self.max_output_seq_len),
  410. ],
  411. )
  412. logger.info('Params created: ')
  413. for param in self.model.params:
  414. logger.info(param)
  415. def decode(self, numberized_input, max_output_seq_len):
  416. workspace.FeedBlob(
  417. self.encoder_inputs,
  418. np.array([
  419. [token_id] for token_id in reversed(numberized_input)
  420. ]).astype(dtype=np.int32),
  421. )
  422. workspace.FeedBlob(
  423. self.encoder_lengths,
  424. np.array([len(numberized_input)]).astype(dtype=np.int32),
  425. )
  426. workspace.FeedBlob(
  427. self.max_output_seq_len,
  428. np.array([max_output_seq_len]).astype(dtype=np.int64),
  429. )
  430. workspace.RunNet(self.model.net)
  431. num_steps = max_output_seq_len
  432. score_beam_list = workspace.FetchBlob(self.output_score_beam_list)
  433. token_beam_list = (
  434. workspace.FetchBlob(self.output_token_beam_list)
  435. )
  436. prev_index_beam_list = (
  437. workspace.FetchBlob(self.output_prev_index_beam_list)
  438. )
  439. attention_weights_beam_list = (
  440. workspace.FetchBlob(self.output_attention_weights_beam_list)
  441. )
  442. best_indices = (num_steps, 0)
  443. for i in range(num_steps + 1):
  444. for hyp_index in range(self.beam_size):
  445. if (
  446. (
  447. token_beam_list[i][hyp_index][0] ==
  448. seq2seq_util.EOS_ID or
  449. i == num_steps
  450. ) and
  451. (
  452. score_beam_list[i][hyp_index][0] >
  453. score_beam_list[best_indices[0]][best_indices[1]][0]
  454. )
  455. ):
  456. best_indices = (i, hyp_index)
  457. i, hyp_index = best_indices
  458. output = []
  459. attention_weights_per_token = []
  460. best_score = -score_beam_list[i][hyp_index][0]
  461. while i > 0:
  462. output.append(token_beam_list[i][hyp_index][0])
  463. attention_weights_per_token.append(
  464. attention_weights_beam_list[i][hyp_index]
  465. )
  466. hyp_index = prev_index_beam_list[i][hyp_index][0]
  467. i -= 1
  468. attention_weights_per_token = reversed(attention_weights_per_token)
  469. # encoder_inputs are reversed, see get_batch func
  470. attention_weights_per_token = [
  471. list(reversed(attention_weights))[:len(numberized_input)]
  472. for attention_weights in attention_weights_per_token
  473. ]
  474. output = list(reversed(output))
  475. return output, attention_weights_per_token, best_score
  476. def run_seq2seq_beam_decoder(args, model_params, decoding_params):
  477. source_vocab = seq2seq_util.gen_vocab(
  478. args.source_corpus,
  479. args.unk_threshold,
  480. )
  481. logger.info('Source vocab size {}'.format(len(source_vocab)))
  482. target_vocab = seq2seq_util.gen_vocab(
  483. args.target_corpus,
  484. args.unk_threshold,
  485. )
  486. inversed_target_vocab = {v: k for (k, v) in viewitems(target_vocab)}
  487. logger.info('Target vocab size {}'.format(len(target_vocab)))
  488. decoder = Seq2SeqModelCaffe2EnsembleDecoder(
  489. translate_params=dict(
  490. ensemble_models=[dict(
  491. source_vocab=source_vocab,
  492. target_vocab=target_vocab,
  493. model_params=model_params,
  494. model_file=args.checkpoint,
  495. )],
  496. decoding_params=decoding_params,
  497. ),
  498. )
  499. decoder.load_models()
  500. for line in sys.stdin:
  501. numerized_source_sentence = seq2seq_util.get_numberized_sentence(
  502. line,
  503. source_vocab,
  504. )
  505. translation, alignment, _ = decoder.decode(
  506. numerized_source_sentence,
  507. 2 * len(numerized_source_sentence) + 5,
  508. )
  509. print(' '.join([inversed_target_vocab[tid] for tid in translation]))
  510. def main():
  511. parser = argparse.ArgumentParser(
  512. description='Caffe2: Seq2Seq Translation',
  513. )
  514. parser.add_argument('--source-corpus', type=str, default=None,
  515. help='Path to source corpus in a text file format. Each '
  516. 'line in the file should contain a single sentence',
  517. required=True)
  518. parser.add_argument('--target-corpus', type=str, default=None,
  519. help='Path to target corpus in a text file format',
  520. required=True)
  521. parser.add_argument('--unk-threshold', type=int, default=50,
  522. help='Threshold frequency under which token becomes '
  523. 'labeled unknown token')
  524. parser.add_argument('--use-bidirectional-encoder', action='store_true',
  525. help='Set flag to use bidirectional recurrent network '
  526. 'in encoder')
  527. parser.add_argument('--use-attention', action='store_true',
  528. help='Set flag to use seq2seq with attention model')
  529. parser.add_argument('--encoder-cell-num-units', type=int, default=512,
  530. help='Number of cell units per encoder layer')
  531. parser.add_argument('--encoder-num-layers', type=int, default=2,
  532. help='Number encoder layers')
  533. parser.add_argument('--decoder-cell-num-units', type=int, default=512,
  534. help='Number of cell units in the decoder layer')
  535. parser.add_argument('--decoder-num-layers', type=int, default=2,
  536. help='Number decoder layers')
  537. parser.add_argument('--encoder-embedding-size', type=int, default=256,
  538. help='Size of embedding in the encoder layer')
  539. parser.add_argument('--decoder-embedding-size', type=int, default=512,
  540. help='Size of embedding in the decoder layer')
  541. parser.add_argument('--decoder-softmax-size', type=int, default=None,
  542. help='Size of softmax layer in the decoder')
  543. parser.add_argument('--beam-size', type=int, default=6,
  544. help='Size of beam for the decoder')
  545. parser.add_argument('--word-reward', type=float, default=0.0,
  546. help='Reward per each word generated.')
  547. parser.add_argument('--unk-reward', type=float, default=0.0,
  548. help='Reward per each UNK token generated. '
  549. 'Typically should be negative.')
  550. parser.add_argument('--checkpoint', type=str, default=None,
  551. help='Path to checkpoint', required=True)
  552. args = parser.parse_args()
  553. encoder_layer_configs = [
  554. dict(
  555. num_units=args.encoder_cell_num_units,
  556. ),
  557. ] * args.encoder_num_layers
  558. if args.use_bidirectional_encoder:
  559. assert args.encoder_cell_num_units % 2 == 0
  560. encoder_layer_configs[0]['num_units'] /= 2
  561. decoder_layer_configs = [
  562. dict(
  563. num_units=args.decoder_cell_num_units,
  564. ),
  565. ] * args.decoder_num_layers
  566. run_seq2seq_beam_decoder(
  567. args,
  568. model_params=dict(
  569. attention=('regular' if args.use_attention else 'none'),
  570. decoder_layer_configs=decoder_layer_configs,
  571. encoder_type=dict(
  572. encoder_layer_configs=encoder_layer_configs,
  573. use_bidirectional_encoder=args.use_bidirectional_encoder,
  574. ),
  575. encoder_embedding_size=args.encoder_embedding_size,
  576. decoder_embedding_size=args.decoder_embedding_size,
  577. decoder_softmax_size=args.decoder_softmax_size,
  578. ),
  579. decoding_params=dict(
  580. beam_size=args.beam_size,
  581. word_reward=args.word_reward,
  582. unk_reward=args.unk_reward,
  583. ),
  584. )
  585. if __name__ == '__main__':
  586. main()