| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- from caffe2.python.schema import Struct, ConstRecord
- from caffe2.python import core, workspace, model_helper
- from caffe2.python.session import LocalSession
- from caffe2.python.dataset import Dataset
- from caffe2.python.pipeline import pipe
- from caffe2.python.checkpoint import (
- CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner, epoch_limiter,
- UploadTaskGroupBuilder, db_name)
- from caffe2.python.net_builder import ops
- from caffe2.python.task import Node, Task, TaskGroup, WorkspaceType, Cluster
- from caffe2.python.test_util import TestCase
- from caffe2.python.dataio import ReaderWithLimit
- import numpy as np
- import os
- import shutil
- import tempfile
- def build_pipeline(node_id):
- with Node('trainer_%d' % node_id):
- with Job.current().init_group, Task():
- data_arr = Struct(('val', np.array(list(range(10)))))
- data = ConstRecord(ops, data_arr)
- ds = Dataset(data, name='dataset:%d' % node_id)
- full_reader = ds.reader(ops)
- total = ops.Const([100])
- def inc_total(rec):
- ops.Add([total, rec.val()], [total])
- epoch_reader = ReaderWithLimit(full_reader, num_iter=3)
- pipe(epoch_reader, processor=inc_total)
- Job.current().add_stop_condition(epoch_reader.data_finished())
- return [total]
- EXPECTED_TOTALS = [103, 115, 136, 145]
- def local_copy_op(src, dest):
- def copy_op(inputs, outputs):
- shutil.copyfile(src, dest)
- return copy_op
- class UploadToLocalFile(UploadTaskGroupBuilder):
- def __init__(self, dest_dir):
- self.dest_dir = dest_dir
- def build(self, epoch, checkpoint_manager):
- with TaskGroup(WorkspaceType.GLOBAL) as upload_task_group:
- for node, manager in checkpoint_manager._node_managers:
- with Node(str(node)), Task():
- src_path = db_name(epoch, manager._node_name, manager._db_prefix)
- dest_path = os.path.join(self.dest_dir, str(node))
- ops.Python((local_copy_op,
- [src_path, dest_path], {}))([], [])
- return upload_task_group
- class TestCheckpoint(TestCase):
- def run_with(self, builder):
- with Cluster():
- with Job() as job:
- outputs = build_pipeline(node_id=0)
- output_fetcher = Task(step=core.Net('empty'), outputs=outputs)
- def fetch_total(session):
- session.run(output_fetcher)
- return output_fetcher.outputs()[0].fetch()
- session, checkpoint = builder()
- job.compile(LocalSession)
- num_epochs = JobRunner(job, checkpoint).train(session)
- self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
- self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
- for initial_epoch in range(1, num_epochs + 1):
- session, checkpoint = builder()
- JobRunner(
- job,
- checkpoint, resume_from_epoch=initial_epoch
- ).train(session)
- self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
- for epoch in range(1, num_epochs + 1):
- session.run(checkpoint.load(epoch))
- self.assertEquals(fetch_total(session),
- EXPECTED_TOTALS[epoch - 1])
- def test_single_checkpoint(self):
- # test single node
- try:
- tmpdir = tempfile.mkdtemp()
- def builder():
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- checkpoint = CheckpointManager(tmpdir, 'temp_node', 'minidb')
- return session, checkpoint
- self.run_with(builder)
- finally:
- shutil.rmtree(tmpdir)
- # test multi-node
- try:
- tmpdir = tempfile.mkdtemp()
- def builder():
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
- return session, checkpoint
- self.run_with(builder)
- finally:
- shutil.rmtree(tmpdir)
- def test_ckpt_name_and_load_model_from_ckpts(self):
- try:
- num_nodes = 3
- tmpdir = tempfile.mkdtemp()
- # First, check if the checkpoint name generation mechanism is
- # correct.
- checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
- with Cluster():
- with Job() as job:
- for node_id in range(num_nodes):
- build_pipeline(node_id)
- job.compile(LocalSession)
- checkpoint.init(job.nodes_to_checkpoint())
- for node_id in range(num_nodes):
- epoch = 5
- node_name = 'trainer_%d' % node_id
- expected_db_name = tmpdir + '/' + node_name + '.5'
- self.assertEquals(
- checkpoint.get_ckpt_db_name(node_name, epoch),
- expected_db_name)
- shutil.rmtree(tmpdir)
- # Next, check mechanism to load model from checkpoints.
- tmpdir = tempfile.mkdtemp()
- workspace.ResetWorkspace()
- for node_id in range(num_nodes):
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
- with Cluster():
- with Job() as job:
- build_pipeline(node_id)
- job.compile(LocalSession)
- job_runner = JobRunner(job, checkpoint)
- num_epochs = job_runner.train(session)
- self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
- # There are 17 global blobs after finishing up the job runner.
- # (only blobs on init_group are checkpointed)
- self.assertEquals(len(ws.blobs), 17)
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- self.assertEquals(len(ws.blobs), 0)
- model_blob_names = ['trainer_1/task_2/GivenTensorInt64Fill:0',
- 'trainer_2/task_2/GivenTensorInt64Fill:0']
- checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
- with Cluster():
- with Job() as job:
- for node_id in range(num_nodes):
- build_pipeline(node_id)
- job.compile(LocalSession)
- job_runner = JobRunner(job, checkpoint)
- job_runner.load_blobs_from_checkpoints(
- blob_names=model_blob_names, epoch=1, session=session)
- # Check that we can successfully load from checkpoints of epochs
- # 1 to 4, but not epoch 5.
- for epoch in range(1, 5):
- self.assertTrue(
- job_runner.load_blobs_from_checkpoints(
- blob_names=model_blob_names, epoch=epoch,
- session=session))
- # Check that all the model blobs are loaded.
- for blob_name in model_blob_names:
- self.assertTrue(ws.has_blob(blob_name))
- self.assertEquals(
- ws.fetch_blob(blob_name),
- np.array([EXPECTED_TOTALS[epoch - 1]]))
- self.assertFalse(
- job_runner.load_blobs_from_checkpoints(
- blob_names=model_blob_names, epoch=5, session=session))
- finally:
- shutil.rmtree(tmpdir)
- def test_upload_checkpoint(self):
- try:
- tmpdir = tempfile.mkdtemp()
- upload_dir = os.path.join(tmpdir, "upload")
- os.mkdir(upload_dir)
- num_nodes = 3
- # The uploaded files do not exist yet.
- for node_id in range(num_nodes):
- node_name = 'trainer_%d' % node_id
- upload_path = os.path.join(upload_dir, node_name)
- self.assertFalse(os.path.exists(upload_path))
- # Create and run the job runner.
- for node_id in range(3):
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
- with Cluster():
- with Job() as job:
- build_pipeline(node_id)
- job.compile(LocalSession)
- local_upload_builder = UploadToLocalFile(upload_dir)
- job_runner = JobRunner(
- job, checkpoint,
- upload_task_group_builder=local_upload_builder)
- num_epochs = job_runner.train(session)
- self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
- # The uploaded files should exist now.
- for node_id in range(num_nodes):
- node_name = 'trainer_%d' % node_id
- upload_path = os.path.join(upload_dir, node_name)
- self.assertTrue(os.path.exists(upload_path))
- finally:
- shutil.rmtree(tmpdir)
- def test_ckpt_save_failure(self):
- num_nodes = 3
- # The goal of this test is to ensure that the job runs
- # successfully even if saving a checkpoint fails.
- # Hence tmpdir is a non existent directory to emulate a failure
- # while saving checkpoints
- tmpdir = "/tmp/path_does_not_exist/"
- # Check the saving checkpoint failure does not cause job failure
- workspace.ResetWorkspace()
- for node_id in range(num_nodes):
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
- with Cluster():
- with Job() as job:
- build_pipeline(node_id)
- job.compile(LocalSession)
- job_runner = JobRunner(job, checkpoint)
- num_epochs = job_runner.train(session)
- # make sure all epochs are executed even though saving the checkpoint failed
- # Saving checkpoint failure should not cause job failure
- self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
- def test_download_group_simple(self):
- """
- A simple test that ensures we have download task group
- executed between epoch_group and exit_group.
- """
- model = model_helper.ModelHelper(name="test_model")
- download_net = core.Net("download_net")
- for name in ["input1", "input2", "output", "download_result"]:
- model.param_init_net.ConstantFill([],
- [name],
- shape=[8, ],
- value=1.0,
- run_once=0)
- model.net.Add(["input1", "input2"], ["output"])
- download_net.Copy(["output"], ["download_result"])
- # All blob values are initialized as 1.0, after download_net executed
- # we expect to see download result is the same as training result.
- with Job() as job:
- with Node("trainer:0"):
- with job.init_group:
- Task(step=model.param_init_net)
- with job.epoch_group:
- with Task():
- with ops.loop(1):
- ops.net(model.net)
- with job.download_group:
- Task(step=download_net)
- epoch_limiter(job, 1)
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- job_runner = JobRunner(job)
- job_runner.train(session)
- expected_result = np.full(8, 2.0).astype(np.float32)
- self.assertTrue(np.array_equal(expected_result,
- ws.fetch_blob("output")))
- self.assertTrue(np.array_equal(expected_result,
- ws.fetch_blob("download_result")))
- def test_reuse_checkpoint_manager(self):
- """
- A simple test that ensures we can reuse a MultiNodeCheckpointManager
- object.
- """
- try:
- tmpdir = tempfile.mkdtemp()
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
- with Job() as job:
- outputs = build_pipeline(node_id=0)
- output_fetcher = Task(step=core.Net('empty'), outputs=outputs)
- job.compile(LocalSession)
- def fetch_total(session):
- session.run(output_fetcher)
- return output_fetcher.outputs()[0].fetch()
- num_epochs = JobRunner(job, checkpoint).train(session)
- for initial_epoch in range(1, num_epochs + 1):
- JobRunner(
- job,
- checkpoint,
- resume_from_epoch=initial_epoch
- ).train(session)
- self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
- finally:
- shutil.rmtree(tmpdir)
|