checkpoint.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import numpy as np
  2. import pickle
  3. from collections import OrderedDict
  4. from caffe2.proto import caffe2_pb2
  5. from caffe2.python import workspace, core, scope
  6. import logging
  7. logging.basicConfig()
  8. log = logging.getLogger("AnyExpOnTerm")
  9. log.setLevel(logging.DEBUG)
  10. def initialize_params_from_file(
  11. model, weights_file, num_xpus, opts,
  12. broadcast_computed_param=False, reset_epoch=False):
  13. start_epoch, lr, best_metric = initialize_master_xpu_model_params(
  14. model, weights_file, opts, reset_epoch)
  15. broadcast_parameters(opts, model, num_xpus, broadcast_computed_param)
  16. return start_epoch, lr, best_metric
  17. def initialize_master_xpu_model_params(model, weights_file, opts, reset_epoch):
  18. log.info("Initializing model params from file: {}".format(weights_file))
  19. with open(weights_file, 'r') as fopen:
  20. blobs = pickle.load(fopen)
  21. if 'blobs' in blobs:
  22. blobs = blobs['blobs']
  23. start_epoch = 0
  24. best_metric = float('-inf')
  25. if 'epoch' in blobs:
  26. log.info('epoch {} is found in model file'.format(blobs['epoch']))
  27. if not reset_epoch:
  28. start_epoch = blobs['epoch']
  29. else:
  30. log.info('Reset epoch')
  31. else:
  32. log.info('no epoch is found in model file')
  33. lr = opts['model_param']['base_learning_rate']
  34. if 'lr' in blobs:
  35. lr = blobs['lr']
  36. if 'best_metric' in blobs and not reset_epoch:
  37. best_metric = blobs['best_metric']
  38. if model is not None:
  39. log.info('initialize model parameters using weights file: {}'.format(
  40. weights_file
  41. ))
  42. ws_blobs = workspace.Blobs()
  43. unscoped_blob_names = OrderedDict()
  44. for blob in model.GetAllParams():
  45. unscoped_blob_names[unscope_name(str(blob))] = True
  46. root_xpu_id = opts['distributed']['first_xpu_id']
  47. device = opts['distributed']['device']
  48. caffe2_pb2_DEVICE =\
  49. caffe2_pb2.CUDA if opts['distributed']['device'] == 'gpu'\
  50. else caffe2_pb2.CPU
  51. with core.NameScope('{}_{}'.format(device, root_xpu_id)):
  52. with core.DeviceScope(core.DeviceOption(caffe2_pb2_DEVICE, 0)):
  53. for unscoped_blob_name in unscoped_blob_names.keys():
  54. scoped_blob_name = scoped_name(unscoped_blob_name)
  55. if unscoped_blob_name not in blobs:
  56. log.info('{:s} not found'.format(unscoped_blob_name))
  57. continue
  58. log.info(
  59. '{:s} loaded from weights file into: {:s}'.format(
  60. unscoped_blob_name, scoped_blob_name
  61. )
  62. )
  63. if scoped_blob_name in ws_blobs:
  64. ws_blob = workspace.FetchBlob(scoped_blob_name)
  65. if not ws_blob.shape == blobs[unscoped_blob_name].shape:
  66. log.info(
  67. ('Workspace blob {} with shape {} does '
  68. 'not match weights file shape {}').format(
  69. unscoped_blob_name, ws_blob.shape,
  70. blobs[unscoped_blob_name].shape)
  71. )
  72. else:
  73. workspace.FeedBlob(
  74. scoped_blob_name,
  75. blobs[unscoped_blob_name].astype(
  76. np.float32, copy=False))
  77. else:
  78. log.info('Skip initializing model parameters from file: {}'.format(
  79. weights_file
  80. ))
  81. log.info('Complete initialize_master_xpu_model_params')
  82. return start_epoch, lr, best_metric
  83. def broadcast_parameters(opts, model, num_xpus, broadcast_computed_param=False):
  84. if num_xpus == 1:
  85. log.info("only 1 device. Skip parameter broadcast")
  86. return
  87. all_params = [model.GetParams()]
  88. if broadcast_computed_param:
  89. all_params.append(model.GetComputedParams())
  90. caffe2_pb2_DEVICE =\
  91. caffe2_pb2.CUDA if opts['distributed']['device'] == 'gpu'\
  92. else caffe2_pb2.CPU
  93. for params in all_params:
  94. assert len(params) % num_xpus == 0, \
  95. "Current model doesn't match device number when loading checkpoint"
  96. params_per_xpu = int(len(params) / num_xpus)
  97. for idx in range(params_per_xpu):
  98. blobs = [param for param in params[idx::params_per_xpu]]
  99. data = workspace.FetchBlob(blobs[0])
  100. log.info('Broadcasting {} to'.format(str(blobs[0])))
  101. for i, p in enumerate(blobs[1:]):
  102. log.info(' |-> {}'.format(str(p)))
  103. with core.DeviceScope(core.DeviceOption(caffe2_pb2_DEVICE, i+1)):
  104. workspace.FeedBlob(p, data)
  105. log.info("Complete parameter broadcast")
  106. def save_model_params(is_checkpoint, model, checkpoint_path, epoch, opts, best_metric):
  107. # best_metric=float('-inf')
  108. if checkpoint_path is None:
  109. return None
  110. try:
  111. save_model_params_blob(
  112. model, checkpoint_path, epoch, opts, best_metric
  113. )
  114. except Exception as e:
  115. log.warning('Exception from save_model_params {}'.format(str(e)))
  116. return checkpoint_path
  117. def save_model_params_blob(model, params_file, epoch, opts, best_metric):
  118. # best_metric=float('-inf')
  119. log.info("Saving model params...")
  120. root_xpu_id = opts['distributed']['first_xpu_id']
  121. device = opts['distributed']['device']
  122. save_params = [str(param) for param in
  123. model.GetParams('{}_{}'.format(device, root_xpu_id))]
  124. save_computed_params = [str(param) for param in
  125. model.GetComputedParams('{}_{}'
  126. .format(device, root_xpu_id))]
  127. save_blobs = {}
  128. save_blobs['epoch'] = epoch
  129. save_blobs['best_metric'] = best_metric
  130. save_blobs['lr'] = \
  131. workspace.FetchBlob('{}_{}/lr'.format(device, root_xpu_id))
  132. for param in save_params + save_computed_params:
  133. scoped_blob_name = str(param)
  134. unscoped_blob_name = unscope_name(scoped_blob_name)
  135. if unscoped_blob_name not in save_blobs:
  136. save_blobs[unscoped_blob_name] = workspace.FetchBlob(
  137. scoped_blob_name)
  138. log.debug(
  139. '{:s} -> {:s}'.format(scoped_blob_name, unscoped_blob_name))
  140. log.info('to weights file {}'.format(params_file))
  141. try:
  142. with open(params_file, 'w') as fwrite:
  143. pickle.dump(dict(blobs=save_blobs), fwrite, pickle.HIGHEST_PROTOCOL)
  144. except IOError as e:
  145. log.error('I/O error({0}): {1}'.format(e.errno, e.strerror))
  146. def unscope_name(blob_name):
  147. return blob_name[blob_name.rfind(scope._NAMESCOPE_SEPARATOR) + 1:]
  148. def scoped_name(blob_name):
  149. return scope.CurrentNameScope() + blob_name