AnyExpOnTerm.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import argparse
  2. import json
  3. import os
  4. import caffe2.contrib.playground.AnyExp as AnyExp
  5. import caffe2.contrib.playground.checkpoint as checkpoint
  6. import logging
  7. logging.basicConfig()
  8. log = logging.getLogger("AnyExpOnTerm")
  9. log.setLevel(logging.DEBUG)
  10. def runShardedTrainLoop(opts, myTrainFun):
  11. start_epoch = 0
  12. pretrained_model = opts['model_param']['pretrained_model']
  13. if pretrained_model != '' and os.path.exists(pretrained_model):
  14. # Only want to get start_epoch.
  15. start_epoch, prev_checkpointed_lr, best_metric = \
  16. checkpoint.initialize_params_from_file(
  17. model=None,
  18. weights_file=pretrained_model,
  19. num_xpus=1,
  20. opts=opts,
  21. broadcast_computed_param=True,
  22. reset_epoch=opts['model_param']['reset_epoch'],
  23. )
  24. log.info('start epoch: {}'.format(start_epoch))
  25. pretrained_model = None if pretrained_model == '' else pretrained_model
  26. ret = None
  27. pretrained_model = ""
  28. shard_results = []
  29. for epoch in range(start_epoch,
  30. opts['epoch_iter']['num_epochs'],
  31. opts['epoch_iter']['num_epochs_per_flow_schedule']):
  32. # must support checkpoint or the multiple schedule will always
  33. # start from initial state
  34. checkpoint_model = None if epoch == start_epoch else ret['model']
  35. pretrained_model = None if epoch > start_epoch else pretrained_model
  36. shard_results = []
  37. # with LexicalContext('epoch{}_gang'.format(epoch),gang_schedule=False):
  38. for shard_id in range(opts['distributed']['num_shards']):
  39. opts['temp_var']['shard_id'] = shard_id
  40. opts['temp_var']['pretrained_model'] = pretrained_model
  41. opts['temp_var']['checkpoint_model'] = checkpoint_model
  42. opts['temp_var']['epoch'] = epoch
  43. opts['temp_var']['start_epoch'] = start_epoch
  44. shard_ret = myTrainFun(opts)
  45. shard_results.append(shard_ret)
  46. ret = None
  47. # always only take shard_0 return
  48. for shard_ret in shard_results:
  49. if shard_ret is not None:
  50. ret = shard_ret
  51. opts['temp_var']['metrics_output'] = ret['metrics']
  52. break
  53. log.info('ret is: {}'.format(str(ret)))
  54. return ret
  55. def trainFun():
  56. def simpleTrainFun(opts):
  57. trainerClass = AnyExp.createTrainerClass(opts)
  58. trainerClass = AnyExp.overrideAdditionalMethods(trainerClass, opts)
  59. trainer = trainerClass(opts)
  60. return trainer.buildModelAndTrain(opts)
  61. return simpleTrainFun
  62. if __name__ == '__main__':
  63. parser = argparse.ArgumentParser(description='Any Experiment training.')
  64. parser.add_argument("--parameters-json", type=json.loads,
  65. help='model options in json format', dest="params")
  66. args = parser.parse_args()
  67. opts = args.params['opts']
  68. opts = AnyExp.initOpts(opts)
  69. log.info('opts is: {}'.format(str(opts)))
  70. AnyExp.initDefaultModuleMap()
  71. opts['input']['datasets'] = AnyExp.aquireDatasets(opts)
  72. # defined this way so that AnyExp.trainFun(opts) can be replaced with
  73. # some other custermized training function.
  74. ret = runShardedTrainLoop(opts, trainFun())
  75. log.info('ret is: {}'.format(str(ret)))