checkpoint_test.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. from caffe2.python.schema import Struct, ConstRecord
  2. from caffe2.python import core, workspace, model_helper
  3. from caffe2.python.session import LocalSession
  4. from caffe2.python.dataset import Dataset
  5. from caffe2.python.pipeline import pipe
  6. from caffe2.python.checkpoint import (
  7. CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner, epoch_limiter,
  8. UploadTaskGroupBuilder, db_name)
  9. from caffe2.python.net_builder import ops
  10. from caffe2.python.task import Node, Task, TaskGroup, WorkspaceType, Cluster
  11. from caffe2.python.test_util import TestCase
  12. from caffe2.python.dataio import ReaderWithLimit
  13. import numpy as np
  14. import os
  15. import shutil
  16. import tempfile
  17. def build_pipeline(node_id):
  18. with Node('trainer_%d' % node_id):
  19. with Job.current().init_group, Task():
  20. data_arr = Struct(('val', np.array(list(range(10)))))
  21. data = ConstRecord(ops, data_arr)
  22. ds = Dataset(data, name='dataset:%d' % node_id)
  23. full_reader = ds.reader(ops)
  24. total = ops.Const([100])
  25. def inc_total(rec):
  26. ops.Add([total, rec.val()], [total])
  27. epoch_reader = ReaderWithLimit(full_reader, num_iter=3)
  28. pipe(epoch_reader, processor=inc_total)
  29. Job.current().add_stop_condition(epoch_reader.data_finished())
  30. return [total]
  31. EXPECTED_TOTALS = [103, 115, 136, 145]
  32. def local_copy_op(src, dest):
  33. def copy_op(inputs, outputs):
  34. shutil.copyfile(src, dest)
  35. return copy_op
  36. class UploadToLocalFile(UploadTaskGroupBuilder):
  37. def __init__(self, dest_dir):
  38. self.dest_dir = dest_dir
  39. def build(self, epoch, checkpoint_manager):
  40. with TaskGroup(WorkspaceType.GLOBAL) as upload_task_group:
  41. for node, manager in checkpoint_manager._node_managers:
  42. with Node(str(node)), Task():
  43. src_path = db_name(epoch, manager._node_name, manager._db_prefix)
  44. dest_path = os.path.join(self.dest_dir, str(node))
  45. ops.Python((local_copy_op,
  46. [src_path, dest_path], {}))([], [])
  47. return upload_task_group
  48. class TestCheckpoint(TestCase):
  49. def run_with(self, builder):
  50. with Cluster():
  51. with Job() as job:
  52. outputs = build_pipeline(node_id=0)
  53. output_fetcher = Task(step=core.Net('empty'), outputs=outputs)
  54. def fetch_total(session):
  55. session.run(output_fetcher)
  56. return output_fetcher.outputs()[0].fetch()
  57. session, checkpoint = builder()
  58. job.compile(LocalSession)
  59. num_epochs = JobRunner(job, checkpoint).train(session)
  60. self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
  61. self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
  62. for initial_epoch in range(1, num_epochs + 1):
  63. session, checkpoint = builder()
  64. JobRunner(
  65. job,
  66. checkpoint, resume_from_epoch=initial_epoch
  67. ).train(session)
  68. self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
  69. for epoch in range(1, num_epochs + 1):
  70. session.run(checkpoint.load(epoch))
  71. self.assertEquals(fetch_total(session),
  72. EXPECTED_TOTALS[epoch - 1])
  73. def test_single_checkpoint(self):
  74. # test single node
  75. try:
  76. tmpdir = tempfile.mkdtemp()
  77. def builder():
  78. ws = workspace.C.Workspace()
  79. session = LocalSession(ws)
  80. checkpoint = CheckpointManager(tmpdir, 'temp_node', 'minidb')
  81. return session, checkpoint
  82. self.run_with(builder)
  83. finally:
  84. shutil.rmtree(tmpdir)
  85. # test multi-node
  86. try:
  87. tmpdir = tempfile.mkdtemp()
  88. def builder():
  89. ws = workspace.C.Workspace()
  90. session = LocalSession(ws)
  91. checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
  92. return session, checkpoint
  93. self.run_with(builder)
  94. finally:
  95. shutil.rmtree(tmpdir)
  96. def test_ckpt_name_and_load_model_from_ckpts(self):
  97. try:
  98. num_nodes = 3
  99. tmpdir = tempfile.mkdtemp()
  100. # First, check if the checkpoint name generation mechanism is
  101. # correct.
  102. checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
  103. with Cluster():
  104. with Job() as job:
  105. for node_id in range(num_nodes):
  106. build_pipeline(node_id)
  107. job.compile(LocalSession)
  108. checkpoint.init(job.nodes_to_checkpoint())
  109. for node_id in range(num_nodes):
  110. epoch = 5
  111. node_name = 'trainer_%d' % node_id
  112. expected_db_name = tmpdir + '/' + node_name + '.5'
  113. self.assertEquals(
  114. checkpoint.get_ckpt_db_name(node_name, epoch),
  115. expected_db_name)
  116. shutil.rmtree(tmpdir)
  117. # Next, check mechanism to load model from checkpoints.
  118. tmpdir = tempfile.mkdtemp()
  119. workspace.ResetWorkspace()
  120. for node_id in range(num_nodes):
  121. ws = workspace.C.Workspace()
  122. session = LocalSession(ws)
  123. checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
  124. with Cluster():
  125. with Job() as job:
  126. build_pipeline(node_id)
  127. job.compile(LocalSession)
  128. job_runner = JobRunner(job, checkpoint)
  129. num_epochs = job_runner.train(session)
  130. self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
  131. # There are 17 global blobs after finishing up the job runner.
  132. # (only blobs on init_group are checkpointed)
  133. self.assertEquals(len(ws.blobs), 17)
  134. ws = workspace.C.Workspace()
  135. session = LocalSession(ws)
  136. self.assertEquals(len(ws.blobs), 0)
  137. model_blob_names = ['trainer_1/task_2/GivenTensorInt64Fill:0',
  138. 'trainer_2/task_2/GivenTensorInt64Fill:0']
  139. checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
  140. with Cluster():
  141. with Job() as job:
  142. for node_id in range(num_nodes):
  143. build_pipeline(node_id)
  144. job.compile(LocalSession)
  145. job_runner = JobRunner(job, checkpoint)
  146. job_runner.load_blobs_from_checkpoints(
  147. blob_names=model_blob_names, epoch=1, session=session)
  148. # Check that we can successfully load from checkpoints of epochs
  149. # 1 to 4, but not epoch 5.
  150. for epoch in range(1, 5):
  151. self.assertTrue(
  152. job_runner.load_blobs_from_checkpoints(
  153. blob_names=model_blob_names, epoch=epoch,
  154. session=session))
  155. # Check that all the model blobs are loaded.
  156. for blob_name in model_blob_names:
  157. self.assertTrue(ws.has_blob(blob_name))
  158. self.assertEquals(
  159. ws.fetch_blob(blob_name),
  160. np.array([EXPECTED_TOTALS[epoch - 1]]))
  161. self.assertFalse(
  162. job_runner.load_blobs_from_checkpoints(
  163. blob_names=model_blob_names, epoch=5, session=session))
  164. finally:
  165. shutil.rmtree(tmpdir)
  166. def test_upload_checkpoint(self):
  167. try:
  168. tmpdir = tempfile.mkdtemp()
  169. upload_dir = os.path.join(tmpdir, "upload")
  170. os.mkdir(upload_dir)
  171. num_nodes = 3
  172. # The uploaded files do not exist yet.
  173. for node_id in range(num_nodes):
  174. node_name = 'trainer_%d' % node_id
  175. upload_path = os.path.join(upload_dir, node_name)
  176. self.assertFalse(os.path.exists(upload_path))
  177. # Create and run the job runner.
  178. for node_id in range(3):
  179. ws = workspace.C.Workspace()
  180. session = LocalSession(ws)
  181. checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
  182. with Cluster():
  183. with Job() as job:
  184. build_pipeline(node_id)
  185. job.compile(LocalSession)
  186. local_upload_builder = UploadToLocalFile(upload_dir)
  187. job_runner = JobRunner(
  188. job, checkpoint,
  189. upload_task_group_builder=local_upload_builder)
  190. num_epochs = job_runner.train(session)
  191. self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
  192. # The uploaded files should exist now.
  193. for node_id in range(num_nodes):
  194. node_name = 'trainer_%d' % node_id
  195. upload_path = os.path.join(upload_dir, node_name)
  196. self.assertTrue(os.path.exists(upload_path))
  197. finally:
  198. shutil.rmtree(tmpdir)
  199. def test_ckpt_save_failure(self):
  200. num_nodes = 3
  201. # The goal of this test is to ensure that the job runs
  202. # successfully even if saving a checkpoint fails.
  203. # Hence tmpdir is a non existent directory to emulate a failure
  204. # while saving checkpoints
  205. tmpdir = "/tmp/path_does_not_exist/"
  206. # Check the saving checkpoint failure does not cause job failure
  207. workspace.ResetWorkspace()
  208. for node_id in range(num_nodes):
  209. ws = workspace.C.Workspace()
  210. session = LocalSession(ws)
  211. checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
  212. with Cluster():
  213. with Job() as job:
  214. build_pipeline(node_id)
  215. job.compile(LocalSession)
  216. job_runner = JobRunner(job, checkpoint)
  217. num_epochs = job_runner.train(session)
  218. # make sure all epochs are executed even though saving the checkpoint failed
  219. # Saving checkpoint failure should not cause job failure
  220. self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
  221. def test_download_group_simple(self):
  222. """
  223. A simple test that ensures we have download task group
  224. executed between epoch_group and exit_group.
  225. """
  226. model = model_helper.ModelHelper(name="test_model")
  227. download_net = core.Net("download_net")
  228. for name in ["input1", "input2", "output", "download_result"]:
  229. model.param_init_net.ConstantFill([],
  230. [name],
  231. shape=[8, ],
  232. value=1.0,
  233. run_once=0)
  234. model.net.Add(["input1", "input2"], ["output"])
  235. download_net.Copy(["output"], ["download_result"])
  236. # All blob values are initialized as 1.0, after download_net executed
  237. # we expect to see download result is the same as training result.
  238. with Job() as job:
  239. with Node("trainer:0"):
  240. with job.init_group:
  241. Task(step=model.param_init_net)
  242. with job.epoch_group:
  243. with Task():
  244. with ops.loop(1):
  245. ops.net(model.net)
  246. with job.download_group:
  247. Task(step=download_net)
  248. epoch_limiter(job, 1)
  249. ws = workspace.C.Workspace()
  250. session = LocalSession(ws)
  251. job_runner = JobRunner(job)
  252. job_runner.train(session)
  253. expected_result = np.full(8, 2.0).astype(np.float32)
  254. self.assertTrue(np.array_equal(expected_result,
  255. ws.fetch_blob("output")))
  256. self.assertTrue(np.array_equal(expected_result,
  257. ws.fetch_blob("download_result")))
  258. def test_reuse_checkpoint_manager(self):
  259. """
  260. A simple test that ensures we can reuse a MultiNodeCheckpointManager
  261. object.
  262. """
  263. try:
  264. tmpdir = tempfile.mkdtemp()
  265. ws = workspace.C.Workspace()
  266. session = LocalSession(ws)
  267. checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
  268. with Job() as job:
  269. outputs = build_pipeline(node_id=0)
  270. output_fetcher = Task(step=core.Net('empty'), outputs=outputs)
  271. job.compile(LocalSession)
  272. def fetch_total(session):
  273. session.run(output_fetcher)
  274. return output_fetcher.outputs()[0].fetch()
  275. num_epochs = JobRunner(job, checkpoint).train(session)
  276. for initial_epoch in range(1, num_epochs + 1):
  277. JobRunner(
  278. job,
  279. checkpoint,
  280. resume_from_epoch=initial_epoch
  281. ).train(session)
  282. self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
  283. finally:
  284. shutil.rmtree(tmpdir)