| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- import numpy as np
- import pickle
- from collections import OrderedDict
- from caffe2.proto import caffe2_pb2
- from caffe2.python import workspace, core, scope
- import logging
- logging.basicConfig()
- log = logging.getLogger("AnyExpOnTerm")
- log.setLevel(logging.DEBUG)
- def initialize_params_from_file(
- model, weights_file, num_xpus, opts,
- broadcast_computed_param=False, reset_epoch=False):
- start_epoch, lr, best_metric = initialize_master_xpu_model_params(
- model, weights_file, opts, reset_epoch)
- broadcast_parameters(opts, model, num_xpus, broadcast_computed_param)
- return start_epoch, lr, best_metric
- def initialize_master_xpu_model_params(model, weights_file, opts, reset_epoch):
- log.info("Initializing model params from file: {}".format(weights_file))
- with open(weights_file, 'r') as fopen:
- blobs = pickle.load(fopen)
- if 'blobs' in blobs:
- blobs = blobs['blobs']
- start_epoch = 0
- best_metric = float('-inf')
- if 'epoch' in blobs:
- log.info('epoch {} is found in model file'.format(blobs['epoch']))
- if not reset_epoch:
- start_epoch = blobs['epoch']
- else:
- log.info('Reset epoch')
- else:
- log.info('no epoch is found in model file')
- lr = opts['model_param']['base_learning_rate']
- if 'lr' in blobs:
- lr = blobs['lr']
- if 'best_metric' in blobs and not reset_epoch:
- best_metric = blobs['best_metric']
- if model is not None:
- log.info('initialize model parameters using weights file: {}'.format(
- weights_file
- ))
- ws_blobs = workspace.Blobs()
- unscoped_blob_names = OrderedDict()
- for blob in model.GetAllParams():
- unscoped_blob_names[unscope_name(str(blob))] = True
- root_xpu_id = opts['distributed']['first_xpu_id']
- device = opts['distributed']['device']
- caffe2_pb2_DEVICE =\
- caffe2_pb2.CUDA if opts['distributed']['device'] == 'gpu'\
- else caffe2_pb2.CPU
- with core.NameScope('{}_{}'.format(device, root_xpu_id)):
- with core.DeviceScope(core.DeviceOption(caffe2_pb2_DEVICE, 0)):
- for unscoped_blob_name in unscoped_blob_names.keys():
- scoped_blob_name = scoped_name(unscoped_blob_name)
- if unscoped_blob_name not in blobs:
- log.info('{:s} not found'.format(unscoped_blob_name))
- continue
- log.info(
- '{:s} loaded from weights file into: {:s}'.format(
- unscoped_blob_name, scoped_blob_name
- )
- )
- if scoped_blob_name in ws_blobs:
- ws_blob = workspace.FetchBlob(scoped_blob_name)
- if not ws_blob.shape == blobs[unscoped_blob_name].shape:
- log.info(
- ('Workspace blob {} with shape {} does '
- 'not match weights file shape {}').format(
- unscoped_blob_name, ws_blob.shape,
- blobs[unscoped_blob_name].shape)
- )
- else:
- workspace.FeedBlob(
- scoped_blob_name,
- blobs[unscoped_blob_name].astype(
- np.float32, copy=False))
- else:
- log.info('Skip initializing model parameters from file: {}'.format(
- weights_file
- ))
- log.info('Complete initialize_master_xpu_model_params')
- return start_epoch, lr, best_metric
- def broadcast_parameters(opts, model, num_xpus, broadcast_computed_param=False):
- if num_xpus == 1:
- log.info("only 1 device. Skip parameter broadcast")
- return
- all_params = [model.GetParams()]
- if broadcast_computed_param:
- all_params.append(model.GetComputedParams())
- caffe2_pb2_DEVICE =\
- caffe2_pb2.CUDA if opts['distributed']['device'] == 'gpu'\
- else caffe2_pb2.CPU
- for params in all_params:
- assert len(params) % num_xpus == 0, \
- "Current model doesn't match device number when loading checkpoint"
- params_per_xpu = int(len(params) / num_xpus)
- for idx in range(params_per_xpu):
- blobs = [param for param in params[idx::params_per_xpu]]
- data = workspace.FetchBlob(blobs[0])
- log.info('Broadcasting {} to'.format(str(blobs[0])))
- for i, p in enumerate(blobs[1:]):
- log.info(' |-> {}'.format(str(p)))
- with core.DeviceScope(core.DeviceOption(caffe2_pb2_DEVICE, i+1)):
- workspace.FeedBlob(p, data)
- log.info("Complete parameter broadcast")
- def save_model_params(is_checkpoint, model, checkpoint_path, epoch, opts, best_metric):
- # best_metric=float('-inf')
- if checkpoint_path is None:
- return None
- try:
- save_model_params_blob(
- model, checkpoint_path, epoch, opts, best_metric
- )
- except Exception as e:
- log.warning('Exception from save_model_params {}'.format(str(e)))
- return checkpoint_path
- def save_model_params_blob(model, params_file, epoch, opts, best_metric):
- # best_metric=float('-inf')
- log.info("Saving model params...")
- root_xpu_id = opts['distributed']['first_xpu_id']
- device = opts['distributed']['device']
- save_params = [str(param) for param in
- model.GetParams('{}_{}'.format(device, root_xpu_id))]
- save_computed_params = [str(param) for param in
- model.GetComputedParams('{}_{}'
- .format(device, root_xpu_id))]
- save_blobs = {}
- save_blobs['epoch'] = epoch
- save_blobs['best_metric'] = best_metric
- save_blobs['lr'] = \
- workspace.FetchBlob('{}_{}/lr'.format(device, root_xpu_id))
- for param in save_params + save_computed_params:
- scoped_blob_name = str(param)
- unscoped_blob_name = unscope_name(scoped_blob_name)
- if unscoped_blob_name not in save_blobs:
- save_blobs[unscoped_blob_name] = workspace.FetchBlob(
- scoped_blob_name)
- log.debug(
- '{:s} -> {:s}'.format(scoped_blob_name, unscoped_blob_name))
- log.info('to weights file {}'.format(params_file))
- try:
- with open(params_file, 'w') as fwrite:
- pickle.dump(dict(blobs=save_blobs), fwrite, pickle.HIGHEST_PROTOCOL)
- except IOError as e:
- log.error('I/O error({0}): {1}'.format(e.errno, e.strerror))
- def unscope_name(blob_name):
- return blob_name[blob_name.rfind(scope._NAMESCOPE_SEPARATOR) + 1:]
- def scoped_name(blob_name):
- return scope.CurrentNameScope() + blob_name
|