allcompare_test.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #!/usr/bin/env python3
  2. from hypothesis import given, settings
  3. import hypothesis.strategies as st
  4. from multiprocessing import Process
  5. import numpy as np
  6. import tempfile
  7. import shutil
  8. import caffe2.python.hypothesis_test_util as hu
  9. op_engine = 'GLOO'
  10. class TemporaryDirectory:
  11. def __enter__(self):
  12. self.tmpdir = tempfile.mkdtemp()
  13. return self.tmpdir
  14. def __exit__(self, type, value, traceback):
  15. shutil.rmtree(self.tmpdir)
  16. def allcompare_process(filestore_dir, process_id, data, num_procs):
  17. from caffe2.python import core, data_parallel_model, workspace, dyndep
  18. from caffe2.python.model_helper import ModelHelper
  19. from caffe2.proto import caffe2_pb2
  20. dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
  21. workspace.RunOperatorOnce(
  22. core.CreateOperator(
  23. "FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
  24. )
  25. )
  26. rendezvous = dict(
  27. kv_handler="store_handler",
  28. shard_id=process_id,
  29. num_shards=num_procs,
  30. engine=op_engine,
  31. exit_nets=None
  32. )
  33. model = ModelHelper()
  34. model._rendezvous = rendezvous
  35. workspace.FeedBlob("test_data", data)
  36. data_parallel_model._RunComparison(
  37. model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
  38. )
  39. class TestAllCompare(hu.HypothesisTestCase):
  40. @given(
  41. d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
  42. )
  43. @settings(deadline=10000)
  44. def test_allcompare(self, d, n, num_procs):
  45. dims = []
  46. for _ in range(d):
  47. dims.append(np.random.randint(1, high=n))
  48. test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)
  49. with TemporaryDirectory() as tempdir:
  50. processes = []
  51. for idx in range(num_procs):
  52. process = Process(
  53. target=allcompare_process,
  54. args=(tempdir, idx, test_data, num_procs)
  55. )
  56. processes.append(process)
  57. process.start()
  58. while len(processes) > 0:
  59. process = processes.pop()
  60. process.join()
  61. if __name__ == "__main__":
  62. import unittest
  63. unittest.main()