gloo_test.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. #!/usr/bin/env python3
  2. from hypothesis import given, settings
  3. import hypothesis.strategies as st
  4. from multiprocessing import Process, Queue
  5. import numpy as np
  6. import os
  7. import pickle
  8. import tempfile
  9. import shutil
  10. from caffe2.python import core, workspace, dyndep
  11. import caffe2.python.hypothesis_test_util as hu
  12. from gloo.python import IoError
  13. dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
  14. dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:redis_store_handler_ops")
  15. dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:store_ops")
  16. dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/gloo:gloo_ops")
  17. dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/gloo:gloo_ops_gpu")
  18. op_engine = 'GLOO'
  19. class TemporaryDirectory:
  20. def __enter__(self):
  21. self.tmpdir = tempfile.mkdtemp()
  22. return self.tmpdir
  23. def __exit__(self, type, value, traceback):
  24. shutil.rmtree(self.tmpdir)
  25. class TestCase(hu.HypothesisTestCase):
  26. test_counter = 0
  27. sync_counter = 0
  28. def run_test_locally(self, fn, device_option=None, **kwargs):
  29. # Queue for assertion errors on subprocesses
  30. queue = Queue()
  31. # Capture any exception thrown by the subprocess
  32. def run_fn(*args, **kwargs):
  33. try:
  34. with core.DeviceScope(device_option):
  35. fn(*args, **kwargs)
  36. workspace.ResetWorkspace()
  37. queue.put(True)
  38. except Exception as ex:
  39. queue.put(ex)
  40. # Start N processes in the background
  41. procs = []
  42. for i in range(kwargs['comm_size']):
  43. kwargs['comm_rank'] = i
  44. proc = Process(
  45. target=run_fn,
  46. kwargs=kwargs)
  47. proc.start()
  48. procs.append(proc)
  49. # Test complete, join background processes
  50. while len(procs) > 0:
  51. proc = procs.pop(0)
  52. while proc.is_alive():
  53. proc.join(10)
  54. # Raise exception if we find any. Otherwise each worker
  55. # should put a True into the queue
  56. # Note that the following is executed ALSO after
  57. # the last process was joined, so if ANY exception
  58. # was raised, it will be re-raised here.
  59. self.assertFalse(queue.empty(), "Job failed without a result")
  60. o = queue.get()
  61. if isinstance(o, Exception):
  62. raise o
  63. else:
  64. self.assertTrue(o)
  65. def run_test_distributed(self, fn, device_option=None, **kwargs):
  66. comm_rank = os.getenv('COMM_RANK')
  67. self.assertIsNotNone(comm_rank)
  68. comm_size = os.getenv('COMM_SIZE')
  69. self.assertIsNotNone(comm_size)
  70. kwargs['comm_rank'] = int(comm_rank)
  71. kwargs['comm_size'] = int(comm_size)
  72. with core.DeviceScope(device_option):
  73. fn(**kwargs)
  74. workspace.ResetWorkspace()
  75. def create_common_world(self, comm_rank, comm_size, tmpdir=None, existing_cw=None):
  76. store_handler = "store_handler"
  77. # If REDIS_HOST is set, use RedisStoreHandler for rendezvous.
  78. if existing_cw is None:
  79. redis_host = os.getenv("REDIS_HOST")
  80. redis_port = int(os.getenv("REDIS_PORT", 6379))
  81. if redis_host is not None:
  82. workspace.RunOperatorOnce(
  83. core.CreateOperator(
  84. "RedisStoreHandlerCreate",
  85. [],
  86. [store_handler],
  87. prefix=str(TestCase.test_counter) + "/",
  88. host=redis_host,
  89. port=redis_port))
  90. else:
  91. workspace.RunOperatorOnce(
  92. core.CreateOperator(
  93. "FileStoreHandlerCreate",
  94. [],
  95. [store_handler],
  96. path=tmpdir))
  97. common_world = "common_world"
  98. else:
  99. common_world = str(existing_cw) + ".forked"
  100. if existing_cw is not None:
  101. workspace.RunOperatorOnce(
  102. core.CreateOperator(
  103. "CloneCommonWorld",
  104. [existing_cw],
  105. [common_world],
  106. sync=True,
  107. engine=op_engine))
  108. else:
  109. workspace.RunOperatorOnce(
  110. core.CreateOperator(
  111. "CreateCommonWorld",
  112. [store_handler],
  113. [common_world],
  114. size=comm_size,
  115. rank=comm_rank,
  116. sync=True,
  117. engine=op_engine))
  118. return (store_handler, common_world)
  119. def synchronize(self, store_handler, value, comm_rank=None):
  120. TestCase.sync_counter += 1
  121. blob = "sync_{}".format(TestCase.sync_counter)
  122. if comm_rank == 0:
  123. workspace.FeedBlob(blob, pickle.dumps(value))
  124. workspace.RunOperatorOnce(
  125. core.CreateOperator(
  126. "StoreSet",
  127. [store_handler, blob],
  128. []))
  129. else:
  130. workspace.RunOperatorOnce(
  131. core.CreateOperator(
  132. "StoreGet",
  133. [store_handler],
  134. [blob]))
  135. return pickle.loads(workspace.FetchBlob(blob))
  136. def _test_broadcast(self,
  137. comm_rank=None,
  138. comm_size=None,
  139. blob_size=None,
  140. num_blobs=None,
  141. tmpdir=None,
  142. use_float16=False,
  143. ):
  144. store_handler, common_world = self.create_common_world(
  145. comm_rank=comm_rank,
  146. comm_size=comm_size,
  147. tmpdir=tmpdir)
  148. blob_size = self.synchronize(
  149. store_handler,
  150. blob_size,
  151. comm_rank=comm_rank)
  152. num_blobs = self.synchronize(
  153. store_handler,
  154. num_blobs,
  155. comm_rank=comm_rank)
  156. for i in range(comm_size):
  157. blobs = []
  158. for j in range(num_blobs):
  159. blob = "blob_{}".format(j)
  160. offset = (comm_rank * num_blobs) + j
  161. value = np.full(blob_size, offset,
  162. np.float16 if use_float16 else np.float32)
  163. workspace.FeedBlob(blob, value)
  164. blobs.append(blob)
  165. net = core.Net("broadcast")
  166. net.Broadcast(
  167. [common_world] + blobs,
  168. blobs,
  169. root=i,
  170. engine=op_engine)
  171. workspace.CreateNet(net)
  172. workspace.RunNet(net.Name())
  173. for j in range(num_blobs):
  174. np.testing.assert_array_equal(
  175. workspace.FetchBlob(blobs[j]),
  176. i * num_blobs)
  177. # Run the net a few more times to check the operator
  178. # works not just the first time it's called
  179. for _tmp in range(4):
  180. workspace.RunNet(net.Name())
  181. @given(comm_size=st.integers(min_value=2, max_value=8),
  182. blob_size=st.integers(min_value=int(1e3), max_value=int(1e6)),
  183. num_blobs=st.integers(min_value=1, max_value=4),
  184. device_option=st.sampled_from([hu.cpu_do]),
  185. use_float16=st.booleans())
  186. @settings(deadline=10000)
  187. def test_broadcast(self, comm_size, blob_size, num_blobs, device_option,
  188. use_float16):
  189. TestCase.test_counter += 1
  190. if os.getenv('COMM_RANK') is not None:
  191. self.run_test_distributed(
  192. self._test_broadcast,
  193. blob_size=blob_size,
  194. num_blobs=num_blobs,
  195. use_float16=use_float16,
  196. device_option=device_option)
  197. else:
  198. with TemporaryDirectory() as tmpdir:
  199. self.run_test_locally(
  200. self._test_broadcast,
  201. comm_size=comm_size,
  202. blob_size=blob_size,
  203. num_blobs=num_blobs,
  204. device_option=device_option,
  205. tmpdir=tmpdir,
  206. use_float16=use_float16)
  207. def _test_allreduce(self,
  208. comm_rank=None,
  209. comm_size=None,
  210. blob_size=None,
  211. num_blobs=None,
  212. tmpdir=None,
  213. use_float16=False
  214. ):
  215. store_handler, common_world = self.create_common_world(
  216. comm_rank=comm_rank,
  217. comm_size=comm_size,
  218. tmpdir=tmpdir)
  219. blob_size = self.synchronize(
  220. store_handler,
  221. blob_size,
  222. comm_rank=comm_rank)
  223. num_blobs = self.synchronize(
  224. store_handler,
  225. num_blobs,
  226. comm_rank=comm_rank)
  227. blobs = []
  228. for i in range(num_blobs):
  229. blob = "blob_{}".format(i)
  230. value = np.full(blob_size, (comm_rank * num_blobs) + i,
  231. np.float16 if use_float16 else np.float32)
  232. workspace.FeedBlob(blob, value)
  233. blobs.append(blob)
  234. net = core.Net("allreduce")
  235. net.Allreduce(
  236. [common_world] + blobs,
  237. blobs,
  238. engine=op_engine)
  239. workspace.CreateNet(net)
  240. workspace.RunNet(net.Name())
  241. for i in range(num_blobs):
  242. np.testing.assert_array_equal(
  243. workspace.FetchBlob(blobs[i]),
  244. (num_blobs * comm_size) * (num_blobs * comm_size - 1) / 2)
  245. # Run the net a few more times to check the operator
  246. # works not just the first time it's called
  247. for _tmp in range(4):
  248. workspace.RunNet(net.Name())
  249. def _test_allreduce_multicw(self,
  250. comm_rank=None,
  251. comm_size=None,
  252. tmpdir=None
  253. ):
  254. _store_handler, common_world = self.create_common_world(
  255. comm_rank=comm_rank,
  256. comm_size=comm_size,
  257. tmpdir=tmpdir)
  258. _, common_world2 = self.create_common_world(
  259. comm_rank=comm_rank,
  260. comm_size=comm_size,
  261. tmpdir=tmpdir,
  262. existing_cw=common_world)
  263. blob_size = int(1e4)
  264. num_blobs = 4
  265. for cw in [common_world, common_world2]:
  266. blobs = []
  267. for i in range(num_blobs):
  268. blob = "blob_{}".format(i)
  269. value = np.full(blob_size, (comm_rank * num_blobs) + i, np.float32)
  270. workspace.FeedBlob(blob, value)
  271. blobs.append(blob)
  272. net = core.Net("allreduce_multicw")
  273. net.Allreduce(
  274. [cw] + blobs,
  275. blobs,
  276. engine=op_engine)
  277. workspace.RunNetOnce(net)
  278. for i in range(num_blobs):
  279. np.testing.assert_array_equal(
  280. workspace.FetchBlob(blobs[i]),
  281. (num_blobs * comm_size) * (num_blobs * comm_size - 1) / 2)
  282. @given(comm_size=st.integers(min_value=2, max_value=8),
  283. blob_size=st.integers(min_value=int(1e3), max_value=int(1e6)),
  284. num_blobs=st.integers(min_value=1, max_value=4),
  285. device_option=st.sampled_from([hu.cpu_do]),
  286. use_float16=st.booleans())
  287. @settings(deadline=10000)
  288. def test_allreduce(self, comm_size, blob_size, num_blobs, device_option,
  289. use_float16):
  290. TestCase.test_counter += 1
  291. if os.getenv('COMM_RANK') is not None:
  292. self.run_test_distributed(
  293. self._test_allreduce,
  294. blob_size=blob_size,
  295. num_blobs=num_blobs,
  296. use_float16=use_float16,
  297. device_option=device_option)
  298. else:
  299. with TemporaryDirectory() as tmpdir:
  300. self.run_test_locally(
  301. self._test_allreduce,
  302. comm_size=comm_size,
  303. blob_size=blob_size,
  304. num_blobs=num_blobs,
  305. device_option=device_option,
  306. tmpdir=tmpdir,
  307. use_float16=use_float16)
  308. def _test_reduce_scatter(self,
  309. comm_rank=None,
  310. comm_size=None,
  311. blob_size=None,
  312. num_blobs=None,
  313. tmpdir=None,
  314. use_float16=False
  315. ):
  316. store_handler, common_world = self.create_common_world(
  317. comm_rank=comm_rank,
  318. comm_size=comm_size,
  319. tmpdir=tmpdir)
  320. blob_size = self.synchronize(
  321. store_handler,
  322. blob_size,
  323. comm_rank=comm_rank)
  324. num_blobs = self.synchronize(
  325. store_handler,
  326. num_blobs,
  327. comm_rank=comm_rank)
  328. blobs = []
  329. for i in range(num_blobs):
  330. blob = "blob_{}".format(i)
  331. value = np.full(blob_size, (comm_rank * num_blobs) + i,
  332. np.float16 if use_float16 else np.float32)
  333. workspace.FeedBlob(blob, value)
  334. blobs.append(blob)
  335. # Specify distribution among ranks i.e. number of elements
  336. # scattered/distributed to each process.
  337. recv_counts = np.zeros(comm_size, dtype=np.int32)
  338. remaining = blob_size
  339. chunk_size = (blob_size + comm_size - 1) / comm_size
  340. for i in range(comm_size):
  341. recv_counts[i] = min(chunk_size, remaining)
  342. remaining = remaining - chunk_size if remaining > chunk_size else 0
  343. recv_counts_blob = "recvCounts"
  344. workspace.FeedBlob(recv_counts_blob, recv_counts)
  345. blobs.append(recv_counts_blob)
  346. net = core.Net("reduce_scatter")
  347. net.ReduceScatter(
  348. [common_world] + blobs,
  349. blobs,
  350. engine=op_engine)
  351. workspace.CreateNet(net)
  352. workspace.RunNet(net.Name())
  353. for i in range(num_blobs):
  354. np.testing.assert_array_equal(
  355. np.resize(workspace.FetchBlob(blobs[i]), recv_counts[comm_rank]),
  356. (num_blobs * comm_size) * (num_blobs * comm_size - 1) / 2)
  357. # Run the net a few more times to check the operator
  358. # works not just the first time it's called
  359. for _tmp in range(4):
  360. workspace.RunNet(net.Name())
  361. @given(comm_size=st.integers(min_value=2, max_value=8),
  362. blob_size=st.integers(min_value=int(1e3), max_value=int(1e6)),
  363. num_blobs=st.integers(min_value=1, max_value=4),
  364. device_option=st.sampled_from([hu.cpu_do]),
  365. use_float16=st.booleans())
  366. @settings(deadline=10000)
  367. def test_reduce_scatter(self, comm_size, blob_size, num_blobs,
  368. device_option, use_float16):
  369. TestCase.test_counter += 1
  370. if os.getenv('COMM_RANK') is not None:
  371. self.run_test_distributed(
  372. self._test_reduce_scatter,
  373. blob_size=blob_size,
  374. num_blobs=num_blobs,
  375. use_float16=use_float16,
  376. device_option=device_option)
  377. else:
  378. with TemporaryDirectory() as tmpdir:
  379. self.run_test_locally(
  380. self._test_reduce_scatter,
  381. comm_size=comm_size,
  382. blob_size=blob_size,
  383. num_blobs=num_blobs,
  384. device_option=device_option,
  385. tmpdir=tmpdir,
  386. use_float16=use_float16)
  387. def _test_allgather(self,
  388. comm_rank=None,
  389. comm_size=None,
  390. blob_size=None,
  391. num_blobs=None,
  392. tmpdir=None,
  393. use_float16=False
  394. ):
  395. store_handler, common_world = self.create_common_world(
  396. comm_rank=comm_rank,
  397. comm_size=comm_size,
  398. tmpdir=tmpdir)
  399. blob_size = self.synchronize(
  400. store_handler,
  401. blob_size,
  402. comm_rank=comm_rank)
  403. num_blobs = self.synchronize(
  404. store_handler,
  405. num_blobs,
  406. comm_rank=comm_rank)
  407. blobs = []
  408. for i in range(num_blobs):
  409. blob = "blob_{}".format(i)
  410. value = np.full(blob_size, (comm_rank * num_blobs) + i,
  411. np.float16 if use_float16 else np.float32)
  412. workspace.FeedBlob(blob, value)
  413. blobs.append(blob)
  414. net = core.Net("allgather")
  415. net.Allgather(
  416. [common_world] + blobs,
  417. ["Gathered"],
  418. engine=op_engine)
  419. workspace.CreateNet(net)
  420. workspace.RunNet(net.Name())
  421. # create expected output
  422. expected_output = np.array([])
  423. for i in range(comm_size):
  424. for j in range(num_blobs):
  425. value = np.full(blob_size, (i * num_blobs) + j,
  426. np.float16 if use_float16 else np.float32)
  427. expected_output = np.concatenate((expected_output, value))
  428. np.testing.assert_array_equal(
  429. workspace.FetchBlob("Gathered"), expected_output)
  430. # Run the net a few more times to check the operator
  431. # works not just the first time it's called
  432. for _tmp in range(4):
  433. workspace.RunNet(net.Name())
  434. @given(comm_size=st.integers(min_value=2, max_value=8),
  435. blob_size=st.integers(min_value=int(1e3), max_value=int(1e6)),
  436. num_blobs=st.integers(min_value=1, max_value=4),
  437. device_option=st.sampled_from([hu.cpu_do]),
  438. use_float16=st.booleans())
  439. @settings(max_examples=10, deadline=None)
  440. def test_allgather(self, comm_size, blob_size, num_blobs, device_option,
  441. use_float16):
  442. TestCase.test_counter += 1
  443. if os.getenv('COMM_RANK') is not None:
  444. self.run_test_distributed(
  445. self._test_allgather,
  446. blob_size=blob_size,
  447. num_blobs=num_blobs,
  448. use_float16=use_float16,
  449. device_option=device_option)
  450. else:
  451. with TemporaryDirectory() as tmpdir:
  452. self.run_test_locally(
  453. self._test_allgather,
  454. comm_size=comm_size,
  455. blob_size=blob_size,
  456. num_blobs=num_blobs,
  457. device_option=device_option,
  458. tmpdir=tmpdir,
  459. use_float16=use_float16)
  460. @given(device_option=st.sampled_from([hu.cpu_do]))
  461. @settings(deadline=10000)
  462. def test_forked_cw(self, device_option):
  463. TestCase.test_counter += 1
  464. if os.getenv('COMM_RANK') is not None:
  465. self.run_test_distributed(
  466. self._test_allreduce_multicw,
  467. device_option=device_option)
  468. else:
  469. # Note: this test exercises the path where we fork a common world.
  470. # We therefore don't need a comm size larger than 2. It used to be
  471. # run with comm_size=8, which causes flaky results in a stress run.
  472. # The flakiness was caused by too many listening sockets being
  473. # created by Gloo context initialization (8 processes times
  474. # 7 sockets times 20-way concurrency, plus TIME_WAIT).
  475. with TemporaryDirectory() as tmpdir:
  476. self.run_test_locally(
  477. self._test_allreduce_multicw,
  478. comm_size=2,
  479. device_option=device_option,
  480. tmpdir=tmpdir)
  481. def _test_barrier(
  482. self,
  483. comm_rank=None,
  484. comm_size=None,
  485. tmpdir=None,
  486. ):
  487. store_handler, common_world = self.create_common_world(
  488. comm_rank=comm_rank, comm_size=comm_size, tmpdir=tmpdir
  489. )
  490. net = core.Net("barrier")
  491. net.Barrier(
  492. [common_world],
  493. [],
  494. engine=op_engine)
  495. workspace.CreateNet(net)
  496. workspace.RunNet(net.Name())
  497. # Run the net a few more times to check the operator
  498. # works not just the first time it's called
  499. for _tmp in range(4):
  500. workspace.RunNet(net.Name())
  501. @given(comm_size=st.integers(min_value=2, max_value=8),
  502. device_option=st.sampled_from([hu.cpu_do]))
  503. @settings(deadline=10000)
  504. def test_barrier(self, comm_size, device_option):
  505. TestCase.test_counter += 1
  506. if os.getenv('COMM_RANK') is not None:
  507. self.run_test_distributed(
  508. self._test_barrier,
  509. device_option=device_option)
  510. else:
  511. with TemporaryDirectory() as tmpdir:
  512. self.run_test_locally(
  513. self._test_barrier,
  514. comm_size=comm_size,
  515. device_option=device_option,
  516. tmpdir=tmpdir)
  517. def _test_close_connection(
  518. self,
  519. comm_rank=None,
  520. comm_size=None,
  521. tmpdir=None,
  522. ):
  523. '''
  524. One node calls close connection, others wait it on barrier.
  525. Test will check that all will exit eventually.
  526. '''
  527. # Caffe's for closers only:
  528. # https://www.youtube.com/watch?v=QMFwFgG9NE8
  529. closer = comm_rank == comm_size // 2,
  530. store_handler, common_world = self.create_common_world(
  531. comm_rank=comm_rank, comm_size=comm_size, tmpdir=tmpdir
  532. )
  533. net = core.Net("barrier_or_close")
  534. if not closer:
  535. net.Barrier(
  536. [common_world],
  537. [],
  538. engine=op_engine)
  539. else:
  540. net.DestroyCommonWorld(
  541. [common_world], [common_world], engine=op_engine)
  542. # Sleep a bit to ensure others start the barrier
  543. import time
  544. time.sleep(0.1)
  545. workspace.CreateNet(net)
  546. workspace.RunNet(net.Name())
  547. @given(comm_size=st.integers(min_value=2, max_value=8),
  548. device_option=st.sampled_from([hu.cpu_do]))
  549. @settings(deadline=10000)
  550. def test_close_connection(self, comm_size, device_option):
  551. import time
  552. start_time = time.time()
  553. TestCase.test_counter += 1
  554. if os.getenv('COMM_RANK') is not None:
  555. self.run_test_distributed(
  556. self._test_close_connection,
  557. device_option=device_option)
  558. else:
  559. with TemporaryDirectory() as tmpdir:
  560. self.run_test_locally(
  561. self._test_close_connection,
  562. comm_size=comm_size,
  563. device_option=device_option,
  564. tmpdir=tmpdir)
  565. # Check that test finishes quickly because connections get closed.
  566. # This assert used to check that the end to end runtime was less
  567. # than 2 seconds, but this may not always be the case if there
  568. # is significant overhead in starting processes. Ideally, this
  569. # assert is replaced by one that doesn't depend on time but rather
  570. # checks the success/failure status of the barrier that is run.
  571. self.assertLess(time.time() - start_time, 20.0)
  572. def _test_io_error(
  573. self,
  574. comm_rank=None,
  575. comm_size=None,
  576. tmpdir=None,
  577. ):
  578. '''
  579. Only one node will participate in allreduce, resulting in an IoError
  580. '''
  581. store_handler, common_world = self.create_common_world(
  582. comm_rank=comm_rank,
  583. comm_size=comm_size,
  584. tmpdir=tmpdir)
  585. if comm_rank == 0:
  586. blob_size = 1000
  587. num_blobs = 1
  588. blobs = []
  589. for i in range(num_blobs):
  590. blob = "blob_{}".format(i)
  591. value = np.full(
  592. blob_size, (comm_rank * num_blobs) + i, np.float32
  593. )
  594. workspace.FeedBlob(blob, value)
  595. blobs.append(blob)
  596. net = core.Net("allreduce")
  597. net.Allreduce(
  598. [common_world] + blobs,
  599. blobs,
  600. engine=op_engine)
  601. workspace.CreateNet(net)
  602. workspace.RunNet(net.Name())
  603. @given(comm_size=st.integers(min_value=2, max_value=8),
  604. device_option=st.sampled_from([hu.cpu_do]))
  605. @settings(deadline=10000)
  606. def test_io_error(self, comm_size, device_option):
  607. TestCase.test_counter += 1
  608. with self.assertRaises(IoError):
  609. if os.getenv('COMM_RANK') is not None:
  610. self.run_test_distributed(
  611. self._test_io_error,
  612. device_option=device_option)
  613. else:
  614. with TemporaryDirectory() as tmpdir:
  615. self.run_test_locally(
  616. self._test_io_error,
  617. comm_size=comm_size,
  618. device_option=device_option,
  619. tmpdir=tmpdir)
  620. if __name__ == "__main__":
  621. import unittest
  622. unittest.main()