AnyExp.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. from abc import abstractmethod
  2. from caffe2.python import workspace
  3. from caffe2.python import timeout_guard
  4. from caffe2.python import data_parallel_model
  5. from . import checkpoint as checkpoint
  6. from . import ModuleRegister as ModuleRegister
  7. from . import module_map as module_map
  8. # instantiate logger outside of distributed operators may trigger error
  9. # logger need to be created in each idividual operator instead.
  10. import os
  11. import inspect
  12. import time
  13. import logging
  14. logging.basicConfig()
  15. log = logging.getLogger("AnyExp")
  16. log.setLevel(logging.DEBUG)
  17. def initOpts(opts):
  18. workspace.GlobalInit(
  19. ['caffe2', '--caffe2_log_level=2', '--caffe2_gpu_memory_tracking=0'])
  20. assert (opts['distributed']['num_gpus'] > 0 or
  21. opts['distributed']['num_cpus'] > 0),\
  22. "Need to specify num_gpus or num_cpus to decide which device to use."
  23. trainWithCPU = (opts['distributed']['num_gpus'] == 0)
  24. num_xpus = opts['distributed']['num_cpus'] if \
  25. trainWithCPU else opts['distributed']['num_gpus']
  26. first_xpu = opts['distributed']['first_cpu_id'] if \
  27. trainWithCPU else opts['distributed']['first_gpu_id']
  28. opts['distributed']['device'] = 'cpu' if trainWithCPU else 'gpu'
  29. opts['model_param']['combine_spatial_bn'] =\
  30. trainWithCPU and opts['model_param']['combine_spatial_bn']
  31. opts['distributed']['num_xpus'] = num_xpus
  32. opts['distributed']['first_xpu_id'] = first_xpu
  33. opts['temp_var'] = {}
  34. opts['temp_var']['metrics_output'] = {}
  35. return opts
  36. def initDefaultModuleMap():
  37. registerModuleMap(module_map)
  38. def registerModuleMap(module_map):
  39. ModuleRegister.registerModuleMap(module_map)
  40. def aquireDatasets(opts):
  41. myAquireDataModule = ModuleRegister.getModule(opts['input']['input_name_py'])
  42. return myAquireDataModule.get_input_dataset(opts)
  43. def createTrainerClass(opts):
  44. return ModuleRegister.constructTrainerClass(AnyExpTrainer, opts)
  45. def overrideAdditionalMethods(myTrainerClass, opts):
  46. return ModuleRegister.overrideAdditionalMethods(myTrainerClass, opts)
  47. def initialize_params_from_file(*args, **kwargs):
  48. return checkpoint.initialize_params_from_file(*args, **kwargs)
  49. class AnyExpTrainer(object):
  50. def __init__(self, opts):
  51. import logging
  52. logging.basicConfig()
  53. log = logging.getLogger("AnyExp")
  54. log.setLevel(logging.DEBUG)
  55. self.log = log
  56. self.opts = opts
  57. self.train_dataset = None
  58. self.test_dataset = None
  59. self.train_df = None
  60. self.test_df = None
  61. self.metrics = {}
  62. self.plotsIngredients = []
  63. self.record_epochs = []
  64. self.samples_per_sec = []
  65. self.secs_per_train = []
  66. self.metrics_output = opts['temp_var']['metrics_output']
  67. first_xpu = opts['distributed']['first_xpu_id']
  68. num_xpus = opts['distributed']['num_xpus']
  69. self.xpus = range(first_xpu, first_xpu + num_xpus)
  70. self.total_batch_size = \
  71. self.opts['epoch_iter']['batch_per_device'] * \
  72. self.opts['distributed']['num_xpus'] * \
  73. self.opts['distributed']['num_shards']
  74. self.epoch_iterations = \
  75. self.opts['epoch_iter']['num_train_sample_per_epoch'] // \
  76. self.total_batch_size
  77. if len(opts['input']['datasets']) > 0:
  78. self.train_df = opts['input']['datasets'][0]
  79. if len(opts['input']['datasets']) == 2:
  80. self.test_df = opts['input']['datasets'][1]
  81. # at this point, the intance of this class becomes many instances
  82. # running on different machines. Most of their attributes are same,
  83. # but the shard_ids are different.
  84. self.shard_id = opts['temp_var']['shard_id']
  85. self.start_epoch = opts['temp_var']['start_epoch']
  86. self.epoch = opts['temp_var']['epoch']
  87. self.epochs_to_run = opts['epoch_iter']['num_epochs_per_flow_schedule']
  88. log.info('opts: {}'.format(str(opts)))
  89. @abstractmethod
  90. def get_input_dataset(self, opts):
  91. pass
  92. @abstractmethod
  93. def get_model_input_fun(self):
  94. pass
  95. @abstractmethod
  96. def init_model(self):
  97. pass
  98. def init_metrics(self):
  99. metrics = self.opts['output']['metrics']
  100. for metric in metrics:
  101. meterClass = self.getMeterClass(metric['meter_py'])
  102. # log.info('metric.meter_kargs {}'.format(metric.meter_kargs))
  103. # log.info('type meter_kargs {}'.format(type(metric.meter_kargs)))
  104. meterInstance = meterClass(opts=self.opts, **metric['meter_kargs'])
  105. self.add_metric(metric['name'], meterInstance, metric['is_train'])
  106. def getMeterClass(self, meterName):
  107. return ModuleRegister.getClassFromModule(meterName, meterName)
  108. def add_metric(self, name, calculator, is_train):
  109. metrics = self.metrics
  110. metrics[name] = {}
  111. metrics[name]['calculator'] = calculator
  112. metrics[name]['is_train'] = is_train
  113. metrics[name]['output'] = []
  114. def extendMetricsOutput(self):
  115. metrics_output = self.metrics_output
  116. if not metrics_output:
  117. metrics_output['epochs'] = self.record_epochs
  118. metrics_output['samples_per_sec'] = self.samples_per_sec
  119. metrics_output['secs_per_train'] = self.secs_per_train
  120. for metric, value in self.metrics.items():
  121. metrics_output[metric] = value['output']
  122. else:
  123. metrics_output['epochs'].extend(self.record_epochs)
  124. metrics_output['samples_per_sec'].extend(self.samples_per_sec)
  125. metrics_output['secs_per_train'].extend(self.secs_per_train)
  126. for metric, value in self.metrics.items():
  127. metrics_output[metric].extend(value['output'])
  128. @abstractmethod
  129. def init_plots(self):
  130. pass
  131. def add_plot(self, x, x_title, ys, y_title):
  132. plotsIngredients = self.plotsIngredients
  133. aPlotIngredients = {}
  134. aPlotIngredients['x'] = x
  135. aPlotIngredients['x_title'] = x_title
  136. aPlotIngredients['ys'] = ys
  137. aPlotIngredients['y_title'] = y_title
  138. plotsIngredients.append(aPlotIngredients)
  139. @abstractmethod
  140. def init_logs(self):
  141. pass
  142. def list_of_epochs(self):
  143. iter_end_point = min(self.opts['epoch_iter']['num_epochs'],
  144. self.epoch +
  145. self.opts['epoch_iter']['num_epochs_per_flow_schedule'])
  146. return range(self.epoch, iter_end_point)
  147. def list_of_epoch_iters(self):
  148. return range(0, self.epoch_iterations)
  149. @abstractmethod
  150. def fun_per_epoch_b4RunNet(self, epoch):
  151. pass
  152. @abstractmethod
  153. def fun_per_epoch_aftRunNet(self, epoch):
  154. pass
  155. def checkpoint(self, epoch):
  156. self.model_path = checkpoint.save_model_params(
  157. True, self.train_model, self.gen_checkpoint_path(True, epoch + 1),
  158. epoch + 1, self.opts, float('-inf'))
  159. def gen_checkpoint_path(self, is_checkpoint, epoch):
  160. if (is_checkpoint):
  161. filename = "model_checkpoint_epoch{}.pkl".format(epoch)
  162. else:
  163. filename = "model_final.pkl"
  164. return self.opts['output']['checkpoint_folder'] + filename
  165. # @abstractmethod
  166. # def gen_checkpoint_path(self, is_checkpoint, epoch):
  167. # pass
  168. @abstractmethod
  169. def fun_per_iter_b4RunNet(self, epoch, epoch_iter):
  170. pass
  171. @abstractmethod
  172. def fun_per_iter_aftRunNetB4Test(self, epoch, epoch_iter):
  173. pass
  174. @abstractmethod
  175. def fun_per_iter_aftRunNetAftTest(self, epoch, epoch_iter):
  176. pass
  177. @abstractmethod
  178. def fun_conclude_operator(self, opts):
  179. pass
  180. def createMetricsPlotsModelsOutputs(self):
  181. self.extendMetricsOutput()
  182. self.model_output = self.model_path
  183. @abstractmethod
  184. def assembleAllOutputs(self):
  185. pass
  186. @abstractmethod
  187. def gen_input_builder_fun(self, model, dataset, is_train):
  188. pass
  189. @abstractmethod
  190. def gen_forward_pass_builder_fun(self, model, dataset, is_train):
  191. pass
  192. @abstractmethod
  193. def gen_param_update_builder_fun(self, model, dataset, is_train):
  194. pass
  195. @abstractmethod
  196. def gen_optimizer_fun(self, model, dataset, is_train):
  197. pass
  198. @abstractmethod
  199. def gen_rendezvous_ctx(self, model, dataset, is_train):
  200. pass
  201. @abstractmethod
  202. def run_training_net(self):
  203. pass
  204. @abstractmethod
  205. def run_testing_net(self):
  206. if self.test_model is None:
  207. return
  208. timeout = 2000.0
  209. with timeout_guard.CompleteInTimeOrDie(timeout):
  210. workspace.RunNet(self.test_model.net.Proto().name)
  211. # @abstractmethod
  212. def planning_output(self):
  213. self.init_metrics()
  214. self.init_plots()
  215. self.init_logs()
  216. def prep_data_parallel_models(self):
  217. self.prep_a_data_parallel_model(self.train_model,
  218. self.train_dataset, True)
  219. self.prep_a_data_parallel_model(self.test_model,
  220. self.test_dataset, False)
  221. def prep_a_data_parallel_model(self, model, dataset, is_train):
  222. if model is None:
  223. return
  224. log.info('in prep_a_data_parallel_model')
  225. param_update = \
  226. self.gen_param_update_builder_fun(model, dataset, is_train) \
  227. if self.gen_param_update_builder_fun is not None else None
  228. log.info('in prep_a_data_parallel_model param_update done ')
  229. optimizer = \
  230. self.gen_optimizer_fun(model, dataset, is_train) \
  231. if self.gen_optimizer_fun is not None else None
  232. log.info('in prep_a_data_parallel_model optimizer done ')
  233. max_ops = self.opts['model_param']['max_concurrent_distributed_ops']
  234. data_parallel_model.Parallelize(
  235. model,
  236. input_builder_fun=self.gen_input_builder_fun(model, dataset, is_train),
  237. forward_pass_builder_fun=self.gen_forward_pass_builder_fun(
  238. model, dataset, is_train),
  239. param_update_builder_fun=param_update,
  240. optimizer_builder_fun=optimizer,
  241. devices=self.xpus,
  242. rendezvous=self.gen_rendezvous_ctx(model, dataset, is_train),
  243. broadcast_computed_params=False,
  244. optimize_gradient_memory=self.opts['model_param']['memonger'],
  245. use_nccl=self.opts['model_param']['cuda_nccl'],
  246. max_concurrent_distributed_ops=max_ops,
  247. cpu_device=(self.opts['distributed']['device'] == 'cpu'),
  248. # "shared model" will only keep model parameters for cpu_0 or gpu_0
  249. # will cause issue when initialize each gpu_0, gpu_1, gpu_2 ...
  250. # shared_model=(self.opts['distributed']['device'] == 'cpu'),
  251. combine_spatial_bn=self.opts['model_param']['combine_spatial_bn'],
  252. )
  253. log.info('in prep_a_data_parallel_model Parallelize done ')
  254. # log.info("Current blobs in workspace: {}".format(workspace.Blobs()))
  255. workspace.RunNetOnce(model.param_init_net)
  256. log.info('in prep_a_data_parallel_model RunNetOnce done ')
  257. # for op in model.net.Proto().op:
  258. # log.info('op type engine {} {}'.format(op.type, op.engine))
  259. log.info('model.net.Proto() {}'.format(model.net.Proto()))
  260. workspace.CreateNet(model.net)
  261. # for op in model.net.Proto().op:
  262. # log.info('after CreateNet op type engine {} {}'.
  263. # format(op.type, op.engine))
  264. log.info('in prep_a_data_parallel_model CreateNet done ')
  265. def loadCheckpoint(self):
  266. opts = self.opts
  267. previous_checkpoint = opts['temp_var']['checkpoint_model']
  268. pretrained_model = opts['temp_var']['pretrained_model']
  269. num_xpus = opts['distributed']['num_xpus']
  270. if (previous_checkpoint is not None):
  271. if os.path.exists(previous_checkpoint):
  272. log.info('Load previous checkpoint:{}'.format(
  273. previous_checkpoint
  274. ))
  275. start_epoch, prev_checkpointed_lr, _best_metric = \
  276. checkpoint.initialize_params_from_file(
  277. model=self.train_model,
  278. weights_file=previous_checkpoint,
  279. num_xpus=num_xpus,
  280. opts=opts,
  281. broadcast_computed_param=True,
  282. reset_epoch=False,
  283. )
  284. elif pretrained_model is not None and os.path.exists(pretrained_model):
  285. log.info("Load pretrained model: {}".format(pretrained_model))
  286. start_epoch, prev_checkpointed_lr, best_metric = \
  287. checkpoint.initialize_params_from_file(
  288. model=self.train_model,
  289. weights_file=pretrained_model,
  290. num_xpus=num_xpus,
  291. opts=opts,
  292. broadcast_computed_param=True,
  293. reset_epoch=opts['model_param']['reset_epoch'],
  294. )
  295. data_parallel_model.FinalizeAfterCheckpoint(self.train_model)
  296. def buildModelAndTrain(self, opts):
  297. log.info('in buildModelAndTrain, trainer_input: {}'.format(str(opts)))
  298. log.info("check type self: {}".format(type(self)))
  299. log.info("check self dir: {}".format(dir(self)))
  300. log.info("check self source: {}".format(self.__dict__))
  301. log.info("check self get_input_dataset methods: {}".
  302. format(inspect.getsource(self.get_input_dataset)))
  303. log.info("check self gen_input_builder_fun method: {}".
  304. format(inspect.getsource(self.gen_input_builder_fun)))
  305. log.info("check self gen_forward_pass_builder_fun method: {}".
  306. format(inspect.getsource(self.gen_forward_pass_builder_fun)))
  307. if self.gen_param_update_builder_fun is not None:
  308. log.info("check self gen_param_update_builder_fun method: {}".
  309. format(inspect.getsource(self.gen_param_update_builder_fun)))
  310. else:
  311. log.info("check self gen_optimizer_fun method: {}".
  312. format(inspect.getsource(self.gen_optimizer_fun)))
  313. log.info("check self assembleAllOutputs method: {}".
  314. format(inspect.getsource(self.assembleAllOutputs)))
  315. log.info("check self prep_data_parallel_models method: {}".
  316. format(inspect.getsource(self.prep_data_parallel_models)))
  317. self.get_model_input_fun()
  318. self.init_model()
  319. self.planning_output()
  320. self.prep_data_parallel_models()
  321. self.loadCheckpoint()
  322. for epoch in self.list_of_epochs():
  323. log.info("start training epoch {}".format(epoch))
  324. self.fun_per_epoch_b4RunNet(epoch)
  325. for epoch_iter in self.list_of_epoch_iters():
  326. self.iter_start_time = time.time()
  327. self.fun_per_iter_b4RunNet(epoch, epoch_iter)
  328. if self.train_model is not None:
  329. self.run_training_net()
  330. self.fun_per_iter_aftRunNetB4Test(epoch, epoch_iter)
  331. self.iter_end_time = time.time()
  332. if (epoch_iter %
  333. opts['epoch_iter']['num_train_iteration_per_test'] == 0):
  334. secs_per_train = (self.iter_end_time - self.iter_start_time)
  335. self.secs_per_train.append(secs_per_train)
  336. sample_trained = self.total_batch_size
  337. samples_per_sec = sample_trained / secs_per_train
  338. self.samples_per_sec.append(samples_per_sec)
  339. self.fract_epoch = (epoch +
  340. float(epoch_iter) / self.epoch_iterations)
  341. self.record_epochs.append(self.fract_epoch)
  342. for key in self.metrics:
  343. metric = self.metrics[key]
  344. if not metric['is_train']:
  345. continue
  346. metric['calculator'].Add()
  347. metric['output'].append(metric['calculator'].Compute())
  348. self.test_loop_start_time = time.time()
  349. for _test_iter in range(0, opts['epoch_iter']['num_test_iter']):
  350. self.run_testing_net()
  351. for key in self.metrics:
  352. metric = self.metrics[key]
  353. if metric['is_train']:
  354. continue
  355. metric['calculator'].Add()
  356. self.test_loop_end_time = time.time()
  357. self.sec_per_test_loop = \
  358. self.test_loop_end_time - self.test_loop_start_time
  359. for metric in self.metrics.values():
  360. if metric['is_train']:
  361. continue
  362. metric['output'].append(metric['calculator'].Compute())
  363. logStr = 'epoch:{}/{} iter:{}/{} secs_per_train:{} '.format(
  364. self.fract_epoch, self.opts['epoch_iter']['num_epochs'],
  365. epoch_iter, self.epoch_iterations, secs_per_train)
  366. logStr += 'samples_per_sec:{} loop {} tests takes {} sec'.format(
  367. samples_per_sec, opts['epoch_iter']['num_test_iter'],
  368. self.sec_per_test_loop)
  369. for metric, value in self.metrics.items():
  370. logStr += ' {}:{} '.format(metric, value['output'][-1])
  371. log.info('Iter Stats: {}'.format(logStr))
  372. self.fun_per_iter_aftRunNetAftTest(epoch, epoch_iter)
  373. self.checkpoint(epoch)
  374. self.fun_per_epoch_aftRunNet(epoch)
  375. self.fun_conclude_operator()
  376. self.createMetricsPlotsModelsOutputs()
  377. return self.assembleAllOutputs()