lazy_dyndep_test.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #!/usr/bin/env python3
  2. from hypothesis import given, settings
  3. import hypothesis.strategies as st
  4. from multiprocessing import Process
  5. import numpy as np
  6. import tempfile
  7. import shutil
  8. import caffe2.python.hypothesis_test_util as hu
  9. import unittest
  10. op_engine = 'GLOO'
  11. class TemporaryDirectory:
  12. def __enter__(self):
  13. self.tmpdir = tempfile.mkdtemp()
  14. return self.tmpdir
  15. def __exit__(self, type, value, traceback):
  16. shutil.rmtree(self.tmpdir)
  17. def allcompare_process(filestore_dir, process_id, data, num_procs):
  18. from caffe2.python import core, data_parallel_model, workspace, lazy_dyndep
  19. from caffe2.python.model_helper import ModelHelper
  20. from caffe2.proto import caffe2_pb2
  21. lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
  22. workspace.RunOperatorOnce(
  23. core.CreateOperator(
  24. "FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
  25. )
  26. )
  27. rendezvous = dict(
  28. kv_handler="store_handler",
  29. shard_id=process_id,
  30. num_shards=num_procs,
  31. engine=op_engine,
  32. exit_nets=None
  33. )
  34. model = ModelHelper()
  35. model._rendezvous = rendezvous
  36. workspace.FeedBlob("test_data", data)
  37. data_parallel_model._RunComparison(
  38. model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
  39. )
  40. class TestLazyDynDepAllCompare(hu.HypothesisTestCase):
  41. @given(
  42. d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
  43. )
  44. @settings(deadline=None)
  45. def test_allcompare(self, d, n, num_procs):
  46. dims = []
  47. for _ in range(d):
  48. dims.append(np.random.randint(1, high=n))
  49. test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)
  50. with TemporaryDirectory() as tempdir:
  51. processes = []
  52. for idx in range(num_procs):
  53. process = Process(
  54. target=allcompare_process,
  55. args=(tempdir, idx, test_data, num_procs)
  56. )
  57. processes.append(process)
  58. process.start()
  59. while len(processes) > 0:
  60. process = processes.pop()
  61. process.join()
  62. class TestLazyDynDepError(unittest.TestCase):
  63. def test_errorhandler(self):
  64. from caffe2.python import core, lazy_dyndep
  65. import tempfile
  66. with tempfile.NamedTemporaryFile() as f:
  67. lazy_dyndep.RegisterOpsLibrary(f.name)
  68. def handler(e):
  69. raise ValueError("test")
  70. lazy_dyndep.SetErrorHandler(handler)
  71. with self.assertRaises(ValueError, msg="test"):
  72. core.RefreshRegisteredOperators()
  73. def test_importaftererror(self):
  74. from caffe2.python import core, lazy_dyndep
  75. import tempfile
  76. with tempfile.NamedTemporaryFile() as f:
  77. lazy_dyndep.RegisterOpsLibrary(f.name)
  78. def handler(e):
  79. raise ValueError("test")
  80. lazy_dyndep.SetErrorHandler(handler)
  81. with self.assertRaises(ValueError):
  82. core.RefreshRegisteredOperators()
  83. def handlernoop(e):
  84. raise
  85. lazy_dyndep.SetErrorHandler(handlernoop)
  86. lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
  87. core.RefreshRegisteredOperators()
  88. def test_workspacecreatenet(self):
  89. from caffe2.python import workspace, lazy_dyndep
  90. import tempfile
  91. with tempfile.NamedTemporaryFile() as f:
  92. lazy_dyndep.RegisterOpsLibrary(f.name)
  93. called = False
  94. def handler(e):
  95. raise ValueError("test")
  96. lazy_dyndep.SetErrorHandler(handler)
  97. with self.assertRaises(ValueError, msg="test"):
  98. workspace.CreateNet("fake")
  99. if __name__ == "__main__":
  100. unittest.main()