| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833 |
- ## @package checkpoint
- # Module caffe2.python.checkpoint
- import os
- import logging
- from caffe2.python import core, context
- from caffe2.python.net_builder import ops
- from caffe2.python.task import (
- final_output,
- Node,
- Task,
- TaskGroup,
- TaskOutput,
- WorkspaceType,
- )
- logger = logging.getLogger(__name__)
- class Job(context.Managed):
- """
- A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the
- `exit_group` which will be run by a JobRunner.
- The `init_group` will be run only once at startup. Its role is to
- initialize globally persistent blobs such as model weights, accumulators
- and data file lists.
- The `epoch_group` will be run in a loop after init_group. The loop will
- exit when any of the stop signals added with `add_stop_condition` is True
- at the end of an epoch.
- The download_group will be run only once, after all the executions of
- epoch_group finish. Its role is to collect the distribute scattered
- parameters back after training.
- The `exit_group` will be run only once at the very end of the job, the
- role of this group is to save the results of training in the end of the job.
- Jobs are context-driven, so that Tasks can be added to the active Job
- without having to explicitly pass the job object around.
- Example of usage:
- def build_reader(partitions):
- with Job.current().init_group:
- reader = HiveReader(init_reader, ..., partitions)
- Task(step=init_reader)
- with Job.current().epoch_group:
- limited_reader = ReaderWithLimit(reader, num_iter=10000)
- data_queue = pipe(limited_reader, num_threads=8)
- Job.current().add_stop_condition(limited_reader.data_finished())
- return data_queue
- def build_hogwild_trainer(reader, model):
- with Job.current().init_group:
- Task(step=model.param_init_net)
- with Job.current().epoch_group:
- pipe(reader, processor=model, num_threads=8)
- with Job.current().exit_group:
- Task(step=model.save_model_net)
- with Job() as job:
- reader = build_reader(partitions)
- model = build_model(params)
- build_hogwild_trainer(reader, model)
- """
- def __init__(self,
- init_group=None, epoch_group=None,
- download_group=None, exit_group=None,
- stop_conditions=None, nodes_to_checkpoint=None):
- self.init_group = init_group or TaskGroup(
- workspace_type=WorkspaceType.GLOBAL)
- self.epoch_group = epoch_group or TaskGroup()
- self.download_group = download_group or TaskGroup()
- self.exit_group = exit_group or TaskGroup()
- self.stop_conditions = stop_conditions or []
- self._nodes_to_checkpoint = nodes_to_checkpoint
- def nodes_to_checkpoint(self):
- if self._nodes_to_checkpoint:
- return self._nodes_to_checkpoint
- else:
- return self.init_group.used_nodes()
- def compile(self, session_class):
- self._nodes_to_checkpoint = self.nodes_to_checkpoint()
- self.init_group = session_class.compile(self.init_group)
- self.epoch_group = session_class.compile(self.epoch_group)
- self.download_group = session_class.compile(self.download_group)
- self.exit_group = session_class.compile(self.exit_group)
- def __enter__(self):
- super(Job, self).__enter__()
- self.epoch_group.__enter__()
- return self
- def __exit__(self, *args):
- self.epoch_group.__exit__()
- super(Job, self).__exit__(*args)
- def add_stop_condition(self, output):
- if isinstance(output, core.BlobReference):
- t = Task(outputs=[output], group=self.epoch_group)
- output = t.outputs()[0]
- assert isinstance(output, TaskOutput)
- self.stop_conditions.append(output)
- def get_ckpt_filename(node_name, epoch):
- """Returns the checkpoint filename.
- Args:
- node_name: A string. The name of the node.
- epoch: An integer. The checkpoint epoch.
- Returns:
- ckpt_filename: A string. The filename of the checkpoint.
- """
- return node_name + '.' + str(epoch)
- def db_name(epoch, node_name, db_prefix, path_prefix=None):
- """Returns the full db name where checkpoint files are saved.
- Args:
- epoch: An integer. The checkpoint epoch.
- node_name: A string. The name of the node.
- db_prefix: A string. The prefix used to construct full db name.
- path_prefix: A string. Optional param used to construct db name or path
- where checkpoint files are stored.
- Returns:
- db_name: A string. The absolute path of full_db_name where checkpoint
- files are saved
- """
- if path_prefix:
- db_name = path_prefix + get_ckpt_filename(node_name, epoch)
- else:
- ckpt_filename = get_ckpt_filename(node_name, epoch)
- db_name = os.path.join(db_prefix, ckpt_filename)
- return db_name
- class CheckpointManager(object):
- """
- Controls saving and loading of workspaces on every epoch boundary of a job.
- If a CheckpointManager instance is passed to JobRunner, then JobRunner will
- call `init`, `read` and `save` at different moments in between epoch runs.
- Args:
- db_prefix: The prefix used to construct full db name. Since `absolute_path`
- is set to True, this will be used as db_name in SaveOp.
- node_name: Name of the node where this checkpoint_manager is used.
- db_type: Type of database to use for storing checkpoint.
- metadata_handler: An optional object capable of reading/writing
- checkpoint info in storage of choice.
- """
- BLOB_NAMES = "blob_names"
- def __init__(self, db_prefix, node_name, db_type, metadata_handler=None):
- self._db_prefix = db_prefix
- self._node_name = node_name
- self._db_type = db_type
- self._metadata_handler = metadata_handler
- # make sure these blobs are the first in the checkpoint file.
- self._net = core.Net('!!checkpoint_mngr')
- self._blob_names = self._net.AddExternalInput(self.BLOB_NAMES)
- self._names_output = None
- self._path_prefix = None
- self._path_type = None
- self._current_db_name = None
- self._current_checkpoint_duration = None
- """
- Initialize the checkpoint manager. Determines all blobs that need to be saved
- or loads from a checkpoint.
- Args:
- nodes: An array of nodes where this checkpoint manager is running. Should
- only contain a single node.
- retrieve_from_epoch: Set to a number to load blobs from this epoch.
- path_prefix: Used to construct db name or path where checkpoint files are
- stored.
- path_type: Indicate the type of path where checkpoint files are stored.
- """
- def init(
- self,
- nodes=None,
- retrieve_from_epoch=None,
- path_prefix=None,
- path_type=None
- ):
- """
- Build a Task that will be run once after the job's `init_group` is run.
- This task will determine which blobs need to be checkpointed.
- If retrieve_from_epoch is not None, then the checkpoint metadata is
- retrieved from a previously saved checkpoint.
- """
- assert nodes is None or len(nodes) == 1, (
- 'CheckpointManager only supports single node.')
- with Task(outputs=[self._blob_names]) as task:
- if retrieve_from_epoch is None:
- ops.GetAllBlobNames(
- [],
- self._blob_names,
- include_shared=False)
- else:
- full_db_name = db_name(retrieve_from_epoch,
- self._node_name, self._db_prefix, path_prefix)
- db_type = path_type or self._db_type
- logger.info("Initializing checkpoints from = %s"
- % full_db_name)
- ops.Load(
- [], self._blob_names,
- db=full_db_name,
- db_type=db_type,
- absolute_path=True,
- keep_device=True,
- )
- self._names_output = task.outputs()[0]
- return task
- def blob_list(self):
- assert self._names_output
- return self._names_output.fetch().tolist()
- def _timed_task(self, cp_op_name, add_op):
- """
- Build a Task that will measure the time span of checkpoint operations,
- once operation is done, time can be read from _current_checkpoint_duration.
- Args:
- cp_op_name: A string name of the checkpoint operation.
- add_op: A functor to add the checkpoint operation.
- Returns:
- A task with timer.
- """
- with Task(name=cp_op_name) as task:
- with ops.task_init():
- timer = ops.TimerBegin([], counter_name=self._node_name)
- add_op()
- with ops.task_exit():
- time_span_blob = ops.TimerGetAndEnd(timer)
- self._current_checkpoint_duration = final_output(time_span_blob)
- return task
- def collect_checkpoint_stats(self, stats):
- """
- Add one checkpoint stats into the stats.
- Args:
- stats: A dict of checkpoint stats that will be reported.
- """
- if self._current_db_name and self._current_checkpoint_duration:
- stats[self._current_db_name] = self._current_checkpoint_duration.fetch()[0]
- else:
- logger.info(
- "Failed to collect checkpoint stats: {}".format(
- self._current_db_name
- )
- )
- def load(self, epoch, path_prefix=None, path_type=None):
- """
- Build a Task that will be run by JobRunner when the job is to be
- resumed from a given epoch. This task will run a Load op that will
- load and deserialize all relevant blobs from a persistent storage.
- """
- self._current_db_name = db_name(
- epoch, self._node_name, self._db_prefix, path_prefix
- )
- db_type = path_type or self._db_type
- logger.info("Loading checkpoints from = %s" % self._current_db_name)
- def add_op():
- ops.Load(
- [],
- self.blob_list(),
- db=self._current_db_name,
- db_type=db_type,
- absolute_path=True,
- keep_device=True,
- )
- return self._timed_task('checkpoint_load', add_op)
- def load_blobs_from_checkpoint(self, blob_names, epoch):
- """
- Builds a Task that loads only the necessary blobs from a checkpoint of
- the given epoch. The necessary blobs are given in the blob_names
- argument.
- Args:
- blob_names: A list of strings. Each string is the name of a
- blob.
- epoch: The checkpoint epoch to load from.
- Returns:
- A Task which loads the specified blobs from the checkpoint of the
- given epoch.
- """
- self._current_db_name = db_name(epoch, self._node_name, self._db_prefix)
- logger.info('Load from %s' % self._current_db_name)
- def add_op():
- ops.Load(
- [],
- blob_names,
- db=self._current_db_name,
- db_type=self._db_type,
- absolute_path=True,
- allow_incomplete=True)
- return self._timed_task('checkpoint_partial_load', add_op)
- def check_db_exists(self, epoch):
- logger.info('Check existence of %s' %
- db_name(epoch, self._node_name, self._db_prefix))
- with Task() as task:
- existence = ops.Const(False)
- ops.DBExists(
- [],
- [existence],
- db_name=db_name(epoch, self._node_name, self._db_prefix),
- db_type=self._db_type,
- absolute_path=True)
- task.add_output(existence)
- return task
- def report_checkpoint_stats(self, action_name):
- """
- Report checkpoint operation stats for current node.
- Args:
- action_name: A string of the name of checkpoint operation.
- """
- all_stats = {}
- self.collect_checkpoint_stats(all_stats)
- if self._metadata_handler:
- self._metadata_handler.report(action_name, all_stats)
- def save(self, epoch):
- """
- Build a Task that is run once after `init_group` and after each
- epoch is run. This will execute a Save ops to serialize and persist
- blobs present in the global workspace.
- """
- self._current_db_name = db_name(epoch, self._node_name, self._db_prefix)
- logger.info('Saving to %s' % self._current_db_name)
- def add_op():
- ops.Save(
- self.blob_list(), [],
- db=self._current_db_name,
- db_type=self._db_type,
- absolute_path=True)
- return self._timed_task('checkpoint_save', add_op)
- def write_checkpoint_metadata(self, epoch):
- """
- Write metadata for checkpoint
- Args:
- epoch: An integer. The epoch-id for which checkpoint metadata is
- written
- """
- if self._metadata_handler is not None:
- self._metadata_handler.write(epoch=epoch)
- def get_resume_from_epoch_id(self, user_epoch=None):
- """
- Identify the epoch-id from which Job must resume
- Args:
- user_epoch: An integer. Optional parameter for user to explicitly
- identify the epoch-id to load checkpoint from
- Returns:
- epoch: the epoch-id to load checkpoints from
- or None if no checkpoints were written
- """
- last_epoch = user_epoch
- if self._metadata_handler is not None:
- last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
- return last_epoch
- def set_params(self, nodes, path_prefix=None, path_type=None):
- """Set parameters associated with CP manager
- Args:
- nodes: An array of nodes where this checkpoint manager is running.
- path_prefix: Used to construct db name or path where checkpoint files are
- stored.
- path_type: Indicate the type of path where checkpoint files are stored.
- """
- if path_prefix:
- self._path_prefix = path_prefix
- if path_type:
- self._path_type = path_type
- if self._metadata_handler:
- self._metadata_handler.set_params(
- db_prefix=self._db_prefix,
- db_type=self._db_type,
- node_names=[str(self._node_name)],
- path_prefix=self._path_prefix,
- path_type=self._path_type)
- def cp_accessible(self, epoch=None):
- """Returns True if Checkpoint data is accessible
- Args:
- epoch: An integer. The epoch of the checkpoint. If None,
- it implies we need to check if checkpoint directory is accessible
- Returns:
- is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
- """
- if self._metadata_handler is not None:
- return self._metadata_handler.cp_accessible(epoch)
- else:
- return True
- class MultiNodeCheckpointManager(object):
- """
- Coordinates checkpointing and checkpointing across multiple nodes.
- Each of `init`, `load` and `save` will build TaskGroups which will
- trigger checkpointing on each of the nodes involved in a distributed job.
- Args:
- db_prefix: The prefix used to construct full db name. Since `absolute_path`
- is set to True, this will be used as db_name in SaveOp.
- db_type: Type of database to use for storing checkpoint.
- metadata_handler: An optional object capable of reading/writing
- checkpoint info in storage of choice.
- """
- def __init__(self, db_prefix, db_type, metadata_handler=None):
- self._node_managers = None
- self._db_prefix = db_prefix
- self._db_type = db_type
- self._metadata_handler = metadata_handler
- self._path_prefix = None
- self._path_type = None
- def _task_group(self, func, *args, **kw):
- assert self._node_managers is not None, 'init must be called first.'
- with TaskGroup(WorkspaceType.GLOBAL) as task_group:
- for node, manager in self._node_managers:
- with Node(node):
- func(manager, *args, **kw)
- return task_group
- """
- Args:
- nodes: An array of nodes where this checkpoint manager is running.
- retrieve_from_epoch: Set to a number to load blobs from this epoch.
- path_prefix: Used to construct db name or path where checkpoint files are
- stored.
- path_type: Indicate the type of path where checkpoint files are stored.
- """
- def init(
- self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None
- ):
- if self._node_managers is not None:
- assert [node for node, _ in self._node_managers] == nodes
- return TaskGroup(WorkspaceType.GLOBAL)
- self._node_managers = []
- for node in nodes:
- with Node(node):
- manager = CheckpointManager(
- db_prefix=self._db_prefix,
- node_name=str(node),
- db_type=self._db_type)
- self._node_managers.append((node, manager))
- return self._task_group(
- CheckpointManager.init,
- nodes=[node],
- retrieve_from_epoch=retrieve_from_epoch,
- path_prefix=path_prefix,
- path_type=path_type)
- def load(self, epoch, path_prefix=None, path_type=None):
- return self._task_group(
- CheckpointManager.load,
- epoch,
- path_prefix=path_prefix,
- path_type=path_type)
- def load_blobs_locally(self, nodes, blob_names, epoch, session):
- """Loads the necessary blobs from the checkpoints to the current node.
- Args:
- blob_names: A list of strings. Each string is the name of a
- blob.
- epoch: An integer. The checkpoint epoch to load from.
- session: A Session object to execute the Load ops.
- """
- if self._node_managers is not None:
- assert [node for node, _ in self._node_managers] == nodes
- else:
- self._node_managers = []
- for node in nodes:
- with Node(node):
- manager = CheckpointManager(
- db_prefix=self._db_prefix,
- node_name=str(node),
- db_type=self._db_type)
- self._node_managers.append((node, manager))
- assert self._node_managers is not None, 'must initialize node managers'
- for _, manager in self._node_managers:
- existence_task = manager.check_db_exists(epoch)
- session.run(existence_task)
- existence = existence_task.outputs()[0].fetch()
- if not existence:
- logger.info('DB %s does not exist!' %
- db_name(epoch, manager._node_name, manager._db_prefix))
- return False
- load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
- session.run(load_task)
- logger.info('Successfully loaded from checkpoints.')
- return True
- def get_ckpt_db_name(self, node_name, epoch):
- """Returns the DB name of the given node and the given epoch.
- The DB name is effectively the checkpoint path of the given node and
- the given epoch.
- Args:
- node_name: A string. The node name of interest.
- epoch: An integer. The epoch of the checkpoint.
- Returns:
- checkpoint_db_name: A string. The checkpoint path of the given
- node and the given epoch.
- """
- for node, manager in self._node_managers:
- if str(node) == node_name:
- return db_name(epoch, manager._node_name, manager._db_prefix)
- def report_checkpoint_stats(self, action_name):
- """
- Report the checkpoint stats for all the nodes, we need to aggregate all
- the node's stats together so that we know which node's checkpoint
- operation dominates.
- Args:
- action_name: A string of the name of checkpoint operation.
- """
- all_stats = {}
- for _, manager in self._node_managers:
- manager.collect_checkpoint_stats(all_stats)
- logger.debug("checkpoint stats: {}".format(all_stats))
- if self._metadata_handler:
- self._metadata_handler.report(action_name, all_stats)
- def save(self, epoch):
- """
- Build a Task that will execute a Save ops to serialize and persist
- blobs present in the global workspace.
- """
- return self._task_group(CheckpointManager.save, epoch)
- def write_checkpoint_metadata(self, epoch):
- """
- Write metadata for checkpoint
- Args:
- epoch: An integer. The epoch-id for which checkpoint metadata is
- written
- """
- if self._metadata_handler is not None:
- self._metadata_handler.write(epoch=epoch)
- def get_resume_from_epoch_id(self, user_epoch=None):
- """
- Identify the epoch-id from which Job must resume
- Args:
- user_epoch: An integer. Optional parameter for user to explicitly
- identify the epoch-id to load checkpoint from
- Returns:
- epoch: the epoch-id to load checkpoints from
- or None if no checkpoints were written
- """
- last_epoch = user_epoch
- if self._metadata_handler is not None:
- last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
- return last_epoch
- def set_params(self, nodes, path_prefix=None, path_type=None):
- """Set parameters associated with CP manager
- Args:
- nodes: An array of nodes where this checkpoint manager is running.
- path_prefix: Used to construct db name or path where checkpoint files are
- stored.
- path_type: Indicate the type of path where checkpoint files are stored.
- """
- self._node_names = [str(node) for node in nodes]
- if path_prefix:
- self._path_prefix = path_prefix
- if path_type:
- self._path_type = path_type
- if self._metadata_handler:
- self._metadata_handler.set_params(
- db_prefix=self._db_prefix,
- db_type=self._db_type,
- node_names=self._node_names,
- path_prefix=self._path_prefix,
- path_type=self._path_type)
- def cp_accessible(self, epoch=None):
- """Returns True if Checkpoint data is accessible
- Args:
- epoch: An integer. The epoch of the checkpoint. If None,
- it implies we need to check if checkpoint directory is accessible
- Returns:
- is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
- """
- if self._metadata_handler is not None:
- return self._metadata_handler.cp_accessible(epoch)
- else:
- return True
- class UploadTaskGroupBuilder(object):
- """A simple class to upload checkpoints."""
- def build(self, epoch, checkpoint_manager):
- """Builds the task group to upload checkpoints.
- Args:
- epoch: An integer. The checkpoint epoch to be uploaded.
- checkpoint_manager: Can be a CheckpointManager for single machine
- or a MultiNodeCheckpointManager for multi-machine. The manager
- that initializes/saves/loads checkpoints.
- Raises:
- NotImplementedError: This base class only has the interface,
- the implementation will be in the subclasses.
- """
- raise NotImplementedError()
- class JobRunner(object):
- """
- Implement the runtime logic for jobs with checkpointing at the level of
- epoch. Can be used to run either single-host or distributed jobs. Job
- runner is a callable to be called once from the master, passing a session
- as an argument. This call will block until the Job execution is complete.
- If a checkpoint_manager is passed, checkpoints will be taken after
- initialization and after each epoch execution. If, in addition,
- `resume_from_epoch` is an epoch number, the corresponding checkpoint will
- be loaded and job execution will continue from the given epoch. In
- this case, the job's init_group will not be run.
- Refer to checkpoint_test.py for an example.
- """
- def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None,
- upload_task_group_builder=None):
- """Initializes the JobRunner.
- Args:
- job: A Job object. The job to be executed.
- checkpoint_manager: Can be a CheckpointManager for single machine
- or a MultiNodeCheckpointManager for multi-machine. The manager
- that initializes/saves/loads checkpoints.
- resume_from_epoch: An integer. The epoch to resume from.
- upload_task_group_builder: A subclass of the
- UploadTaskGroupBuilder. Creates a task group to upload
- checkpoints.
- """
- self.resume_from_epoch = resume_from_epoch
- self.checkpoint_manager = checkpoint_manager
- self.job = job
- self.upload_task_group_builder = upload_task_group_builder
- def train(self, session):
- """Runs the training flow.
- Args:
- session: A Session object. Valid choises are: LocalSession,
- LocalHostScheduler, and DistributedSession. It is used to
- execute one TaskGroup a time.
- """
- # identify the epoch we must resume from
- if self.checkpoint_manager:
- self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
- self.resume_from_epoch = self.checkpoint_manager.\
- get_resume_from_epoch_id(self.resume_from_epoch)
- if self.resume_from_epoch is not None:
- logger.info('Resuming from epoch {}'.format(self.resume_from_epoch))
- # Initialize all the nodes.
- from_scratch = self.resume_from_epoch is None
- if from_scratch:
- session.run(self.job.init_group)
- if self.checkpoint_manager:
- logger.info('Preparing checkpoints ...')
- session.run(self.checkpoint_manager.init(
- self.job.nodes_to_checkpoint(),
- retrieve_from_epoch=self.resume_from_epoch))
- # Save the first checkpoint before training starts, or resume from
- # a previously saved checkpoint.
- if from_scratch:
- self.save_checkpoints(0, session)
- else:
- logger.info('Loading checkpoints for epoch {} ...'.format(
- self.resume_from_epoch))
- session.run(
- self.checkpoint_manager.load(self.resume_from_epoch))
- self.checkpoint_manager.report_checkpoint_stats('checkpoint_load')
- logger.info('Checkpoint loaded')
- logger.info("Finished initializing")
- # Start training.
- epoch = 1 if from_scratch else self.resume_from_epoch + 1
- while True:
- logger.info('Starting epoch %d' % epoch)
- session.run(self.job.epoch_group)
- logger.info('Finished epoch %d' % epoch)
- stop_conditions = [o.fetch() for o in self.job.stop_conditions]
- if self.checkpoint_manager:
- self.save_checkpoints(epoch, session)
- if any(stop_conditions):
- logger.info('Stopping')
- break
- epoch += 1
- logger.info('Finished training')
- # Upload the checkpoints.
- if (self.upload_task_group_builder):
- upload_task_group = self.upload_task_group_builder.build(
- epoch, self.checkpoint_manager)
- session.run(upload_task_group)
- logger.info('Finished uploading the checkpoints')
- # Download the parameters to save
- session.run(self.job.download_group)
- logger.info('Finished downloading the parameters')
- # Finally run the exit step to save nets
- session.run(self.job.exit_group)
- logger.info('Finished running the exit group')
- return epoch
- def load_blobs_from_checkpoints(self, blob_names, epoch, session):
- """Loads the necessary blobs from the checkpoints.
- Checkpoints store the snapshots of the workspace in each node.
- Sometimes we only need to load a subset of the blobs from the
- checkpoints. One common scenario is to load only the model blobs from
- the checkpoints for evaluation purpose. Given the names of the
- necessary blobs, this function goes over all the checkpoints of all the
- nodes, but only loads the blobs specified in the blob_names to the
- current workspace.
- Args:
- blob_names: A list of strings. Each string is the name of a
- blob.
- epoch: An integer. The checkpoint epoch to load from.
- session: A Session object to execute the load ops.
- Raises:
- ValueError: When the checkpoint manager is invalid.
- """
- if not self.checkpoint_manager:
- raise ValueError('Checkpoint manager is None')
- logger.info('Loading checkpoint for epoch {} ...'.format(epoch))
- result = self.checkpoint_manager.load_blobs_locally(
- self.job.nodes_to_checkpoint(), blob_names, epoch, session)
- self.checkpoint_manager.report_checkpoint_stats('checkpoint_partial_load')
- return result
- def save_checkpoints(self, epoch, session):
- """Triggers operation to save checkpoints
- This method will trigger the Save ops to serialize and persist the
- blobs present in the global workspaace.
- Args:
- epoch: An integer. The checkpoint epoch-id that we are saving.
- session: A Session object to execute the save ops.
- Raises:
- ValueError: When the checkpoint manager is invalid.
- """
- if not self.checkpoint_manager:
- raise ValueError('Checkpoint manager is None')
- try:
- is_accessible = self.checkpoint_manager.cp_accessible(epoch=None)
- if is_accessible:
- logger.info('Saving checkpoints for epoch {}'.format(epoch))
- session.run(self.checkpoint_manager.save(epoch))
- self.checkpoint_manager.write_checkpoint_metadata(epoch)
- logger.info('Checkpoints saved')
- self.checkpoint_manager.report_checkpoint_stats('checkpoint_save')
- else:
- logger.warning("Checkpoint files cannot be accessed!")
- except Exception as ex:
- logger.warning("Unable to write checkpoint for epoch {}. Error={}".
- format(epoch, ex))
- def epoch_limiter(job, num_epochs):
- """
- Creates a task that will output True when a given
- number of epochs has finished.
- """
- with job.init_group:
- init_net = core.Net('epoch_counter_init')
- counter = init_net.CreateCounter([], init_count=num_epochs - 1)
- Task(step=init_net)
- with job.epoch_group:
- epoch_net = core.Net('epoch_countdown')
- finished = epoch_net.CountDown(counter)
- output = Task(step=epoch_net, outputs=finished).outputs()[0]
- job.add_stop_condition(output)
|