train.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
  1. ## @package train
  2. # Module caffe2.python.models.seq2seq.train
  3. import argparse
  4. import collections
  5. import logging
  6. import math
  7. import numpy as np
  8. import random
  9. import time
  10. import sys
  11. import os
  12. import caffe2.proto.caffe2_pb2 as caffe2_pb2
  13. from caffe2.python import core, workspace, data_parallel_model
  14. import caffe2.python.models.seq2seq.seq2seq_util as seq2seq_util
  15. from caffe2.python.models.seq2seq.seq2seq_model_helper import Seq2SeqModelHelper
  16. logger = logging.getLogger(__name__)
  17. logger.setLevel(logging.INFO)
  18. logger.addHandler(logging.StreamHandler(sys.stderr))
  19. Batch = collections.namedtuple('Batch', [
  20. 'encoder_inputs',
  21. 'encoder_lengths',
  22. 'decoder_inputs',
  23. 'decoder_lengths',
  24. 'targets',
  25. 'target_weights',
  26. ])
  27. def prepare_batch(batch):
  28. encoder_lengths = [len(entry[0]) for entry in batch]
  29. max_encoder_length = max(encoder_lengths)
  30. decoder_lengths = []
  31. max_decoder_length = max([len(entry[1]) for entry in batch])
  32. batch_encoder_inputs = []
  33. batch_decoder_inputs = []
  34. batch_targets = []
  35. batch_target_weights = []
  36. for source_seq, target_seq in batch:
  37. encoder_pads = (
  38. [seq2seq_util.PAD_ID] * (max_encoder_length - len(source_seq))
  39. )
  40. batch_encoder_inputs.append(
  41. list(reversed(source_seq)) + encoder_pads
  42. )
  43. decoder_pads = (
  44. [seq2seq_util.PAD_ID] * (max_decoder_length - len(target_seq))
  45. )
  46. target_seq_with_go_token = [seq2seq_util.GO_ID] + target_seq
  47. decoder_lengths.append(len(target_seq_with_go_token))
  48. batch_decoder_inputs.append(target_seq_with_go_token + decoder_pads)
  49. target_seq_with_eos = target_seq + [seq2seq_util.EOS_ID]
  50. targets = target_seq_with_eos + decoder_pads
  51. batch_targets.append(targets)
  52. if len(source_seq) + len(target_seq) == 0:
  53. target_weights = [0] * len(targets)
  54. else:
  55. target_weights = [
  56. 1 if target != seq2seq_util.PAD_ID else 0
  57. for target in targets
  58. ]
  59. batch_target_weights.append(target_weights)
  60. return Batch(
  61. encoder_inputs=np.array(
  62. batch_encoder_inputs,
  63. dtype=np.int32,
  64. ).transpose(),
  65. encoder_lengths=np.array(encoder_lengths, dtype=np.int32),
  66. decoder_inputs=np.array(
  67. batch_decoder_inputs,
  68. dtype=np.int32,
  69. ).transpose(),
  70. decoder_lengths=np.array(decoder_lengths, dtype=np.int32),
  71. targets=np.array(
  72. batch_targets,
  73. dtype=np.int32,
  74. ).transpose(),
  75. target_weights=np.array(
  76. batch_target_weights,
  77. dtype=np.float32,
  78. ).transpose(),
  79. )
  80. class Seq2SeqModelCaffe2(object):
  81. def _build_model(
  82. self,
  83. init_params,
  84. ):
  85. model = Seq2SeqModelHelper(init_params=init_params)
  86. self._build_shared(model)
  87. self._build_embeddings(model)
  88. forward_model = Seq2SeqModelHelper(init_params=init_params)
  89. self._build_shared(forward_model)
  90. self._build_embeddings(forward_model)
  91. if self.num_gpus == 0:
  92. loss_blobs = self.model_build_fun(model)
  93. model.AddGradientOperators(loss_blobs)
  94. self.norm_clipped_grad_update(
  95. model,
  96. scope='norm_clipped_grad_update'
  97. )
  98. self.forward_model_build_fun(forward_model)
  99. else:
  100. assert (self.batch_size % self.num_gpus) == 0
  101. data_parallel_model.Parallelize_GPU(
  102. forward_model,
  103. input_builder_fun=lambda m: None,
  104. forward_pass_builder_fun=self.forward_model_build_fun,
  105. param_update_builder_fun=None,
  106. devices=list(range(self.num_gpus)),
  107. )
  108. def clipped_grad_update_bound(model):
  109. self.norm_clipped_grad_update(
  110. model,
  111. scope='norm_clipped_grad_update',
  112. )
  113. data_parallel_model.Parallelize_GPU(
  114. model,
  115. input_builder_fun=lambda m: None,
  116. forward_pass_builder_fun=self.model_build_fun,
  117. param_update_builder_fun=clipped_grad_update_bound,
  118. devices=list(range(self.num_gpus)),
  119. )
  120. self.norm_clipped_sparse_grad_update(
  121. model,
  122. scope='norm_clipped_sparse_grad_update',
  123. )
  124. self.model = model
  125. self.forward_net = forward_model.net
  126. def _build_shared(self, model):
  127. optimizer_params = self.model_params['optimizer_params']
  128. with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
  129. self.learning_rate = model.AddParam(
  130. name='learning_rate',
  131. init_value=float(optimizer_params['learning_rate']),
  132. trainable=False,
  133. )
  134. self.global_step = model.AddParam(
  135. name='global_step',
  136. init_value=0,
  137. trainable=False,
  138. )
  139. self.start_time = model.AddParam(
  140. name='start_time',
  141. init_value=time.time(),
  142. trainable=False,
  143. )
  144. def _build_embeddings(self, model):
  145. with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
  146. sqrt3 = math.sqrt(3)
  147. self.encoder_embeddings = model.param_init_net.UniformFill(
  148. [],
  149. 'encoder_embeddings',
  150. shape=[
  151. self.source_vocab_size,
  152. self.model_params['encoder_embedding_size'],
  153. ],
  154. min=-sqrt3,
  155. max=sqrt3,
  156. )
  157. model.params.append(self.encoder_embeddings)
  158. self.decoder_embeddings = model.param_init_net.UniformFill(
  159. [],
  160. 'decoder_embeddings',
  161. shape=[
  162. self.target_vocab_size,
  163. self.model_params['decoder_embedding_size'],
  164. ],
  165. min=-sqrt3,
  166. max=sqrt3,
  167. )
  168. model.params.append(self.decoder_embeddings)
  169. def model_build_fun(self, model, forward_only=False, loss_scale=None):
  170. encoder_inputs = model.net.AddExternalInput(
  171. workspace.GetNameScope() + 'encoder_inputs',
  172. )
  173. encoder_lengths = model.net.AddExternalInput(
  174. workspace.GetNameScope() + 'encoder_lengths',
  175. )
  176. decoder_inputs = model.net.AddExternalInput(
  177. workspace.GetNameScope() + 'decoder_inputs',
  178. )
  179. decoder_lengths = model.net.AddExternalInput(
  180. workspace.GetNameScope() + 'decoder_lengths',
  181. )
  182. targets = model.net.AddExternalInput(
  183. workspace.GetNameScope() + 'targets',
  184. )
  185. target_weights = model.net.AddExternalInput(
  186. workspace.GetNameScope() + 'target_weights',
  187. )
  188. attention_type = self.model_params['attention']
  189. assert attention_type in ['none', 'regular', 'dot']
  190. (
  191. encoder_outputs,
  192. weighted_encoder_outputs,
  193. final_encoder_hidden_states,
  194. final_encoder_cell_states,
  195. encoder_units_per_layer,
  196. ) = seq2seq_util.build_embedding_encoder(
  197. model=model,
  198. encoder_params=self.encoder_params,
  199. num_decoder_layers=len(self.model_params['decoder_layer_configs']),
  200. inputs=encoder_inputs,
  201. input_lengths=encoder_lengths,
  202. vocab_size=self.source_vocab_size,
  203. embeddings=self.encoder_embeddings,
  204. embedding_size=self.model_params['encoder_embedding_size'],
  205. use_attention=(attention_type != 'none'),
  206. num_gpus=self.num_gpus,
  207. )
  208. (
  209. decoder_outputs,
  210. decoder_output_size,
  211. ) = seq2seq_util.build_embedding_decoder(
  212. model,
  213. decoder_layer_configs=self.model_params['decoder_layer_configs'],
  214. inputs=decoder_inputs,
  215. input_lengths=decoder_lengths,
  216. encoder_lengths=encoder_lengths,
  217. encoder_outputs=encoder_outputs,
  218. weighted_encoder_outputs=weighted_encoder_outputs,
  219. final_encoder_hidden_states=final_encoder_hidden_states,
  220. final_encoder_cell_states=final_encoder_cell_states,
  221. encoder_units_per_layer=encoder_units_per_layer,
  222. vocab_size=self.target_vocab_size,
  223. embeddings=self.decoder_embeddings,
  224. embedding_size=self.model_params['decoder_embedding_size'],
  225. attention_type=attention_type,
  226. forward_only=False,
  227. num_gpus=self.num_gpus,
  228. )
  229. output_logits = seq2seq_util.output_projection(
  230. model=model,
  231. decoder_outputs=decoder_outputs,
  232. decoder_output_size=decoder_output_size,
  233. target_vocab_size=self.target_vocab_size,
  234. decoder_softmax_size=self.model_params['decoder_softmax_size'],
  235. )
  236. targets, _ = model.net.Reshape(
  237. [targets],
  238. ['targets', 'targets_old_shape'],
  239. shape=[-1],
  240. )
  241. target_weights, _ = model.net.Reshape(
  242. [target_weights],
  243. ['target_weights', 'target_weights_old_shape'],
  244. shape=[-1],
  245. )
  246. _, loss_per_word = model.net.SoftmaxWithLoss(
  247. [output_logits, targets, target_weights],
  248. ['OutputProbs_INVALID', 'loss_per_word'],
  249. only_loss=True,
  250. )
  251. num_words = model.net.SumElements(
  252. [target_weights],
  253. 'num_words',
  254. )
  255. total_loss_scalar = model.net.Mul(
  256. [loss_per_word, num_words],
  257. 'total_loss_scalar',
  258. )
  259. total_loss_scalar_weighted = model.net.Scale(
  260. [total_loss_scalar],
  261. 'total_loss_scalar_weighted',
  262. scale=1.0 / self.batch_size,
  263. )
  264. return [total_loss_scalar_weighted]
  265. def forward_model_build_fun(self, model, loss_scale=None):
  266. return self.model_build_fun(
  267. model=model,
  268. forward_only=True,
  269. loss_scale=loss_scale
  270. )
  271. def _calc_norm_ratio(self, model, params, scope, ONE):
  272. with core.NameScope(scope):
  273. grad_squared_sums = []
  274. for i, param in enumerate(params):
  275. logger.info(param)
  276. grad = (
  277. model.param_to_grad[param]
  278. if not isinstance(
  279. model.param_to_grad[param],
  280. core.GradientSlice,
  281. ) else model.param_to_grad[param].values
  282. )
  283. grad_squared = model.net.Sqr(
  284. [grad],
  285. 'grad_{}_squared'.format(i),
  286. )
  287. grad_squared_sum = model.net.SumElements(
  288. grad_squared,
  289. 'grad_{}_squared_sum'.format(i),
  290. )
  291. grad_squared_sums.append(grad_squared_sum)
  292. grad_squared_full_sum = model.net.Sum(
  293. grad_squared_sums,
  294. 'grad_squared_full_sum',
  295. )
  296. global_norm = model.net.Pow(
  297. grad_squared_full_sum,
  298. 'global_norm',
  299. exponent=0.5,
  300. )
  301. clip_norm = model.param_init_net.ConstantFill(
  302. [],
  303. 'clip_norm',
  304. shape=[],
  305. value=float(self.model_params['max_gradient_norm']),
  306. )
  307. max_norm = model.net.Max(
  308. [global_norm, clip_norm],
  309. 'max_norm',
  310. )
  311. norm_ratio = model.net.Div(
  312. [clip_norm, max_norm],
  313. 'norm_ratio',
  314. )
  315. return norm_ratio
  316. def _apply_norm_ratio(
  317. self, norm_ratio, model, params, learning_rate, scope, ONE
  318. ):
  319. for param in params:
  320. param_grad = model.param_to_grad[param]
  321. nlr = model.net.Negative(
  322. [learning_rate],
  323. 'negative_learning_rate',
  324. )
  325. with core.NameScope(scope):
  326. update_coeff = model.net.Mul(
  327. [nlr, norm_ratio],
  328. 'update_coeff',
  329. broadcast=1,
  330. )
  331. if isinstance(param_grad, core.GradientSlice):
  332. param_grad_values = param_grad.values
  333. model.net.ScatterWeightedSum(
  334. [
  335. param,
  336. ONE,
  337. param_grad.indices,
  338. param_grad_values,
  339. update_coeff,
  340. ],
  341. param,
  342. )
  343. else:
  344. model.net.WeightedSum(
  345. [
  346. param,
  347. ONE,
  348. param_grad,
  349. update_coeff,
  350. ],
  351. param,
  352. )
  353. def norm_clipped_grad_update(self, model, scope):
  354. if self.num_gpus == 0:
  355. learning_rate = self.learning_rate
  356. else:
  357. learning_rate = model.CopyCPUToGPU(self.learning_rate, 'LR')
  358. params = []
  359. for param in model.GetParams(top_scope=True):
  360. if param in model.param_to_grad:
  361. if not isinstance(
  362. model.param_to_grad[param],
  363. core.GradientSlice,
  364. ):
  365. params.append(param)
  366. ONE = model.param_init_net.ConstantFill(
  367. [],
  368. 'ONE',
  369. shape=[1],
  370. value=1.0,
  371. )
  372. logger.info('Dense trainable variables: ')
  373. norm_ratio = self._calc_norm_ratio(model, params, scope, ONE)
  374. self._apply_norm_ratio(
  375. norm_ratio, model, params, learning_rate, scope, ONE
  376. )
  377. def norm_clipped_sparse_grad_update(self, model, scope):
  378. learning_rate = self.learning_rate
  379. params = []
  380. for param in model.GetParams(top_scope=True):
  381. if param in model.param_to_grad:
  382. if isinstance(
  383. model.param_to_grad[param],
  384. core.GradientSlice,
  385. ):
  386. params.append(param)
  387. ONE = model.param_init_net.ConstantFill(
  388. [],
  389. 'ONE',
  390. shape=[1],
  391. value=1.0,
  392. )
  393. logger.info('Sparse trainable variables: ')
  394. norm_ratio = self._calc_norm_ratio(model, params, scope, ONE)
  395. self._apply_norm_ratio(
  396. norm_ratio, model, params, learning_rate, scope, ONE
  397. )
  398. def total_loss_scalar(self):
  399. if self.num_gpus == 0:
  400. return workspace.FetchBlob('total_loss_scalar')
  401. else:
  402. total_loss = 0
  403. for i in range(self.num_gpus):
  404. name = 'gpu_{}/total_loss_scalar'.format(i)
  405. gpu_loss = workspace.FetchBlob(name)
  406. total_loss += gpu_loss
  407. return total_loss
  408. def _init_model(self):
  409. workspace.RunNetOnce(self.model.param_init_net)
  410. def create_net(net):
  411. workspace.CreateNet(
  412. net,
  413. input_blobs=[str(i) for i in net.external_inputs],
  414. )
  415. create_net(self.model.net)
  416. create_net(self.forward_net)
  417. def __init__(
  418. self,
  419. model_params,
  420. source_vocab_size,
  421. target_vocab_size,
  422. num_gpus=1,
  423. num_cpus=1,
  424. ):
  425. self.model_params = model_params
  426. self.encoder_type = 'rnn'
  427. self.encoder_params = model_params['encoder_type']
  428. self.source_vocab_size = source_vocab_size
  429. self.target_vocab_size = target_vocab_size
  430. self.num_gpus = num_gpus
  431. self.num_cpus = num_cpus
  432. self.batch_size = model_params['batch_size']
  433. workspace.GlobalInit([
  434. 'caffe2',
  435. # NOTE: modify log level for debugging purposes
  436. '--caffe2_log_level=0',
  437. # NOTE: modify log level for debugging purposes
  438. '--v=0',
  439. # Fail gracefully if one of the threads fails
  440. '--caffe2_handle_executor_threads_exceptions=1',
  441. '--caffe2_mkl_num_threads=' + str(self.num_cpus),
  442. ])
  443. def __enter__(self):
  444. return self
  445. def __exit__(self, exc_type, exc_value, traceback):
  446. workspace.ResetWorkspace()
  447. def initialize_from_scratch(self):
  448. logger.info('Initializing Seq2SeqModelCaffe2 from scratch: Start')
  449. self._build_model(init_params=True)
  450. self._init_model()
  451. logger.info('Initializing Seq2SeqModelCaffe2 from scratch: Finish')
  452. def get_current_step(self):
  453. return workspace.FetchBlob(self.global_step)[0]
  454. def inc_current_step(self):
  455. workspace.FeedBlob(
  456. self.global_step,
  457. np.array([self.get_current_step() + 1]),
  458. )
  459. def step(
  460. self,
  461. batch,
  462. forward_only
  463. ):
  464. if self.num_gpus < 1:
  465. batch_obj = prepare_batch(batch)
  466. for batch_obj_name, batch_obj_value in zip(
  467. Batch._fields,
  468. batch_obj,
  469. ):
  470. workspace.FeedBlob(batch_obj_name, batch_obj_value)
  471. else:
  472. for i in range(self.num_gpus):
  473. gpu_batch = batch[i::self.num_gpus]
  474. batch_obj = prepare_batch(gpu_batch)
  475. for batch_obj_name, batch_obj_value in zip(
  476. Batch._fields,
  477. batch_obj,
  478. ):
  479. name = 'gpu_{}/{}'.format(i, batch_obj_name)
  480. if batch_obj_name in ['encoder_inputs', 'decoder_inputs']:
  481. dev = core.DeviceOption(caffe2_pb2.CPU)
  482. else:
  483. dev = core.DeviceOption(workspace.GpuDeviceType, i)
  484. workspace.FeedBlob(name, batch_obj_value, device_option=dev)
  485. if forward_only:
  486. workspace.RunNet(self.forward_net)
  487. else:
  488. workspace.RunNet(self.model.net)
  489. self.inc_current_step()
  490. return self.total_loss_scalar()
  491. def save(self, checkpoint_path_prefix, current_step):
  492. checkpoint_path = '{0}-{1}'.format(
  493. checkpoint_path_prefix,
  494. current_step,
  495. )
  496. assert workspace.RunOperatorOnce(core.CreateOperator(
  497. 'Save',
  498. self.model.GetAllParams(),
  499. [],
  500. absolute_path=True,
  501. db=checkpoint_path,
  502. db_type='minidb',
  503. ))
  504. checkpoint_config_path = os.path.join(
  505. os.path.dirname(checkpoint_path_prefix),
  506. 'checkpoint',
  507. )
  508. with open(checkpoint_config_path, 'w') as checkpoint_config_file:
  509. checkpoint_config_file.write(
  510. 'model_checkpoint_path: "' + checkpoint_path + '"\n'
  511. 'all_model_checkpoint_paths: "' + checkpoint_path + '"\n'
  512. )
  513. logger.info('Saved checkpoint file to ' + checkpoint_path)
  514. return checkpoint_path
  515. def gen_batches(source_corpus, target_corpus, source_vocab, target_vocab,
  516. batch_size, max_length):
  517. with open(source_corpus) as source, open(target_corpus) as target:
  518. parallel_sentences = []
  519. for source_sentence, target_sentence in zip(source, target):
  520. numerized_source_sentence = seq2seq_util.get_numberized_sentence(
  521. source_sentence,
  522. source_vocab,
  523. )
  524. numerized_target_sentence = seq2seq_util.get_numberized_sentence(
  525. target_sentence,
  526. target_vocab,
  527. )
  528. if (
  529. len(numerized_source_sentence) > 0 and
  530. len(numerized_target_sentence) > 0 and
  531. (
  532. max_length is None or (
  533. len(numerized_source_sentence) <= max_length and
  534. len(numerized_target_sentence) <= max_length
  535. )
  536. )
  537. ):
  538. parallel_sentences.append((
  539. numerized_source_sentence,
  540. numerized_target_sentence,
  541. ))
  542. parallel_sentences.sort(key=lambda s_t: (len(s_t[0]), len(s_t[1])))
  543. batches, batch = [], []
  544. for sentence_pair in parallel_sentences:
  545. batch.append(sentence_pair)
  546. if len(batch) >= batch_size:
  547. batches.append(batch)
  548. batch = []
  549. if len(batch) > 0:
  550. while len(batch) < batch_size:
  551. batch.append(batch[-1])
  552. assert len(batch) == batch_size
  553. batches.append(batch)
  554. random.shuffle(batches)
  555. return batches
  556. def run_seq2seq_model(args, model_params=None):
  557. source_vocab = seq2seq_util.gen_vocab(
  558. args.source_corpus,
  559. args.unk_threshold,
  560. )
  561. target_vocab = seq2seq_util.gen_vocab(
  562. args.target_corpus,
  563. args.unk_threshold,
  564. )
  565. logger.info('Source vocab size {}'.format(len(source_vocab)))
  566. logger.info('Target vocab size {}'.format(len(target_vocab)))
  567. batches = gen_batches(args.source_corpus, args.target_corpus, source_vocab,
  568. target_vocab, model_params['batch_size'],
  569. args.max_length)
  570. logger.info('Number of training batches {}'.format(len(batches)))
  571. batches_eval = gen_batches(args.source_corpus_eval, args.target_corpus_eval,
  572. source_vocab, target_vocab,
  573. model_params['batch_size'], args.max_length)
  574. logger.info('Number of eval batches {}'.format(len(batches_eval)))
  575. with Seq2SeqModelCaffe2(
  576. model_params=model_params,
  577. source_vocab_size=len(source_vocab),
  578. target_vocab_size=len(target_vocab),
  579. num_gpus=args.num_gpus,
  580. num_cpus=20,
  581. ) as model_obj:
  582. model_obj.initialize_from_scratch()
  583. for i in range(args.epochs):
  584. logger.info('Epoch {}'.format(i))
  585. total_loss = 0
  586. for batch in batches:
  587. total_loss += model_obj.step(
  588. batch=batch,
  589. forward_only=False,
  590. )
  591. logger.info('\ttraining loss {}'.format(total_loss))
  592. total_loss = 0
  593. for batch in batches_eval:
  594. total_loss += model_obj.step(
  595. batch=batch,
  596. forward_only=True,
  597. )
  598. logger.info('\teval loss {}'.format(total_loss))
  599. if args.checkpoint is not None:
  600. model_obj.save(args.checkpoint, i)
  601. def main():
  602. random.seed(31415)
  603. parser = argparse.ArgumentParser(
  604. description='Caffe2: Seq2Seq Training'
  605. )
  606. parser.add_argument('--source-corpus', type=str, default=None,
  607. help='Path to source corpus in a text file format. Each '
  608. 'line in the file should contain a single sentence',
  609. required=True)
  610. parser.add_argument('--target-corpus', type=str, default=None,
  611. help='Path to target corpus in a text file format',
  612. required=True)
  613. parser.add_argument('--max-length', type=int, default=None,
  614. help='Maximal lengths of train and eval sentences')
  615. parser.add_argument('--unk-threshold', type=int, default=50,
  616. help='Threshold frequency under which token becomes '
  617. 'labeled unknown token')
  618. parser.add_argument('--batch-size', type=int, default=32,
  619. help='Training batch size')
  620. parser.add_argument('--epochs', type=int, default=10,
  621. help='Number of iterations over training data')
  622. parser.add_argument('--learning-rate', type=float, default=0.5,
  623. help='Learning rate')
  624. parser.add_argument('--max-gradient-norm', type=float, default=1.0,
  625. help='Max global norm of gradients at the end of each '
  626. 'backward pass. We do clipping to match the number.')
  627. parser.add_argument('--num-gpus', type=int, default=0,
  628. help='Number of GPUs for data parallel model')
  629. parser.add_argument('--use-bidirectional-encoder', action='store_true',
  630. help='Set flag to use bidirectional recurrent network '
  631. 'for first layer of encoder')
  632. parser.add_argument('--use-attention', action='store_true',
  633. help='Set flag to use seq2seq with attention model')
  634. parser.add_argument('--source-corpus-eval', type=str, default=None,
  635. help='Path to source corpus for evaluation in a text '
  636. 'file format', required=True)
  637. parser.add_argument('--target-corpus-eval', type=str, default=None,
  638. help='Path to target corpus for evaluation in a text '
  639. 'file format', required=True)
  640. parser.add_argument('--encoder-cell-num-units', type=int, default=512,
  641. help='Number of cell units per encoder layer')
  642. parser.add_argument('--encoder-num-layers', type=int, default=2,
  643. help='Number encoder layers')
  644. parser.add_argument('--decoder-cell-num-units', type=int, default=512,
  645. help='Number of cell units in the decoder layer')
  646. parser.add_argument('--decoder-num-layers', type=int, default=2,
  647. help='Number decoder layers')
  648. parser.add_argument('--encoder-embedding-size', type=int, default=256,
  649. help='Size of embedding in the encoder layer')
  650. parser.add_argument('--decoder-embedding-size', type=int, default=512,
  651. help='Size of embedding in the decoder layer')
  652. parser.add_argument('--decoder-softmax-size', type=int, default=None,
  653. help='Size of softmax layer in the decoder')
  654. parser.add_argument('--checkpoint', type=str, default=None,
  655. help='Path to checkpoint')
  656. args = parser.parse_args()
  657. encoder_layer_configs = [
  658. dict(
  659. num_units=args.encoder_cell_num_units,
  660. ),
  661. ] * args.encoder_num_layers
  662. if args.use_bidirectional_encoder:
  663. assert args.encoder_cell_num_units % 2 == 0
  664. encoder_layer_configs[0]['num_units'] /= 2
  665. decoder_layer_configs = [
  666. dict(
  667. num_units=args.decoder_cell_num_units,
  668. ),
  669. ] * args.decoder_num_layers
  670. run_seq2seq_model(args, model_params=dict(
  671. attention=('regular' if args.use_attention else 'none'),
  672. decoder_layer_configs=decoder_layer_configs,
  673. encoder_type=dict(
  674. encoder_layer_configs=encoder_layer_configs,
  675. use_bidirectional_encoder=args.use_bidirectional_encoder,
  676. ),
  677. batch_size=args.batch_size,
  678. optimizer_params=dict(
  679. learning_rate=args.learning_rate,
  680. ),
  681. encoder_embedding_size=args.encoder_embedding_size,
  682. decoder_embedding_size=args.decoder_embedding_size,
  683. decoder_softmax_size=args.decoder_softmax_size,
  684. max_gradient_norm=args.max_gradient_norm,
  685. ))
  686. if __name__ == '__main__':
  687. main()