dataio_test.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. from caffe2.python.dataio import (
  2. CompositeReader,
  3. CompositeReaderBuilder,
  4. ReaderBuilder,
  5. ReaderWithDelay,
  6. ReaderWithLimit,
  7. ReaderWithTimeLimit,
  8. )
  9. from caffe2.python.dataset import Dataset
  10. from caffe2.python.db_file_reader import DBFileReader
  11. from caffe2.python.pipeline import pipe
  12. from caffe2.python.schema import Struct, NewRecord, FeedRecord
  13. from caffe2.python.session import LocalSession
  14. from caffe2.python.task import TaskGroup, final_output, WorkspaceType
  15. from caffe2.python.test_util import TestCase
  16. from caffe2.python.cached_reader import CachedReader
  17. from caffe2.python import core, workspace, schema
  18. from caffe2.python.net_builder import ops
  19. import numpy as np
  20. import numpy.testing as npt
  21. import os
  22. import shutil
  23. import unittest
  24. import tempfile
  25. def make_source_dataset(ws, size=100, offset=0, name=None):
  26. name = name or "src"
  27. src_init = core.Net("{}_init".format(name))
  28. with core.NameScope(name):
  29. src_values = Struct(('label', np.array(range(offset, offset + size))))
  30. src_blobs = NewRecord(src_init, src_values)
  31. src_ds = Dataset(src_blobs, name=name)
  32. FeedRecord(src_blobs, src_values, ws)
  33. ws.run(src_init)
  34. return src_ds
  35. def make_destination_dataset(ws, schema, name=None):
  36. name = name or 'dst'
  37. dst_init = core.Net('{}_init'.format(name))
  38. with core.NameScope(name):
  39. dst_ds = Dataset(schema, name=name)
  40. dst_ds.init_empty(dst_init)
  41. ws.run(dst_init)
  42. return dst_ds
  43. class TestReaderBuilder(ReaderBuilder):
  44. def __init__(self, name, size, offset):
  45. self._schema = schema.Struct(
  46. ('label', schema.Scalar()),
  47. )
  48. self._name = name
  49. self._size = size
  50. self._offset = offset
  51. self._src_ds = None
  52. def schema(self):
  53. return self._schema
  54. def setup(self, ws):
  55. self._src_ds = make_source_dataset(ws, offset=self._offset, size=self._size,
  56. name=self._name)
  57. return {}
  58. def new_reader(self, **kwargs):
  59. return self._src_ds
  60. class TestCompositeReader(TestCase):
  61. @unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
  62. def test_composite_reader(self):
  63. ws = workspace.C.Workspace()
  64. session = LocalSession(ws)
  65. num_srcs = 3
  66. names = ["src_{}".format(i) for i in range(num_srcs)]
  67. size = 100
  68. offsets = [i * size for i in range(num_srcs)]
  69. src_dses = [make_source_dataset(ws, offset=offset, size=size, name=name)
  70. for (name, offset) in zip(names, offsets)]
  71. data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses]
  72. # Sanity check we didn't overwrite anything
  73. for d, offset in zip(data, offsets):
  74. npt.assert_array_equal(d, range(offset, offset + size))
  75. # Make an identically-sized empty destination dataset
  76. dst_ds_schema = schema.Struct(
  77. *[
  78. (name, src_ds.content().clone_schema())
  79. for name, src_ds in zip(names, src_dses)
  80. ]
  81. )
  82. dst_ds = make_destination_dataset(ws, dst_ds_schema)
  83. with TaskGroup() as tg:
  84. reader = CompositeReader(names,
  85. [src_ds.reader() for src_ds in src_dses])
  86. pipe(reader, dst_ds.writer(), num_runtime_threads=3)
  87. session.run(tg)
  88. for i in range(num_srcs):
  89. written_data = sorted(
  90. ws.fetch_blob(str(dst_ds.content()[names[i]].label())))
  91. npt.assert_array_equal(data[i], written_data, "i: {}".format(i))
  92. @unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
  93. def test_composite_reader_builder(self):
  94. ws = workspace.C.Workspace()
  95. session = LocalSession(ws)
  96. num_srcs = 3
  97. names = ["src_{}".format(i) for i in range(num_srcs)]
  98. size = 100
  99. offsets = [i * size for i in range(num_srcs)]
  100. src_ds_builders = [
  101. TestReaderBuilder(offset=offset, size=size, name=name)
  102. for (name, offset) in zip(names, offsets)
  103. ]
  104. # Make an identically-sized empty destination dataset
  105. dst_ds_schema = schema.Struct(
  106. *[
  107. (name, src_ds_builder.schema())
  108. for name, src_ds_builder in zip(names, src_ds_builders)
  109. ]
  110. )
  111. dst_ds = make_destination_dataset(ws, dst_ds_schema)
  112. with TaskGroup() as tg:
  113. reader_builder = CompositeReaderBuilder(
  114. names, src_ds_builders)
  115. reader_builder.setup(ws=ws)
  116. pipe(reader_builder.new_reader(), dst_ds.writer(),
  117. num_runtime_threads=3)
  118. session.run(tg)
  119. for name, offset in zip(names, offsets):
  120. written_data = sorted(
  121. ws.fetch_blob(str(dst_ds.content()[name].label())))
  122. npt.assert_array_equal(range(offset, offset + size), written_data,
  123. "name: {}".format(name))
  124. class TestReaderWithLimit(TestCase):
  125. def test_runtime_threads(self):
  126. ws = workspace.C.Workspace()
  127. session = LocalSession(ws)
  128. src_ds = make_source_dataset(ws)
  129. totals = [None] * 3
  130. def proc(rec):
  131. # executed once
  132. with ops.task_init():
  133. counter1 = ops.CreateCounter([], ['global_counter'])
  134. counter2 = ops.CreateCounter([], ['global_counter2'])
  135. counter3 = ops.CreateCounter([], ['global_counter3'])
  136. # executed once per thread
  137. with ops.task_instance_init():
  138. task_counter = ops.CreateCounter([], ['task_counter'])
  139. # executed on each iteration
  140. ops.CountUp(counter1)
  141. ops.CountUp(task_counter)
  142. # executed once per thread
  143. with ops.task_instance_exit():
  144. with ops.loop(ops.RetrieveCount(task_counter)):
  145. ops.CountUp(counter2)
  146. ops.CountUp(counter3)
  147. # executed once
  148. with ops.task_exit():
  149. totals[0] = final_output(ops.RetrieveCount(counter1))
  150. totals[1] = final_output(ops.RetrieveCount(counter2))
  151. totals[2] = final_output(ops.RetrieveCount(counter3))
  152. return rec
  153. # Read full data set from original reader
  154. with TaskGroup() as tg:
  155. pipe(src_ds.reader(), num_runtime_threads=8, processor=proc)
  156. session.run(tg)
  157. self.assertEqual(totals[0].fetch(), 100)
  158. self.assertEqual(totals[1].fetch(), 100)
  159. self.assertEqual(totals[2].fetch(), 8)
  160. # Read with a count-limited reader
  161. with TaskGroup() as tg:
  162. q1 = pipe(src_ds.reader(), num_runtime_threads=2)
  163. q2 = pipe(
  164. ReaderWithLimit(q1.reader(), num_iter=25),
  165. num_runtime_threads=3)
  166. pipe(q2, processor=proc, num_runtime_threads=6)
  167. session.run(tg)
  168. self.assertEqual(totals[0].fetch(), 25)
  169. self.assertEqual(totals[1].fetch(), 25)
  170. self.assertEqual(totals[2].fetch(), 6)
  171. def _test_limit_reader_init_shared(self, size):
  172. ws = workspace.C.Workspace()
  173. session = LocalSession(ws)
  174. # Make source dataset
  175. src_ds = make_source_dataset(ws, size=size)
  176. # Make an identically-sized empty destination Dataset
  177. dst_ds = make_destination_dataset(ws, src_ds.content().clone_schema())
  178. return ws, session, src_ds, dst_ds
  179. def _test_limit_reader_shared(self, reader_class, size, expected_read_len,
  180. expected_read_len_threshold,
  181. expected_finish, num_threads, read_delay,
  182. **limiter_args):
  183. ws, session, src_ds, dst_ds = \
  184. self._test_limit_reader_init_shared(size)
  185. # Read without limiter
  186. # WorkspaceType.GLOBAL is required because we are fetching
  187. # reader.data_finished() after the TaskGroup finishes.
  188. with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
  189. if read_delay > 0:
  190. reader = reader_class(ReaderWithDelay(src_ds.reader(),
  191. read_delay),
  192. **limiter_args)
  193. else:
  194. reader = reader_class(src_ds.reader(), **limiter_args)
  195. pipe(reader, dst_ds.writer(), num_runtime_threads=num_threads)
  196. session.run(tg)
  197. read_len = len(sorted(ws.blobs[str(dst_ds.content().label())].fetch()))
  198. # Do a fuzzy match (expected_read_len +/- expected_read_len_threshold)
  199. # to eliminate flakiness for time-limited tests
  200. self.assertGreaterEqual(
  201. read_len,
  202. expected_read_len - expected_read_len_threshold)
  203. self.assertLessEqual(
  204. read_len,
  205. expected_read_len + expected_read_len_threshold)
  206. self.assertEqual(
  207. sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
  208. list(range(read_len))
  209. )
  210. self.assertEqual(ws.blobs[str(reader.data_finished())].fetch(),
  211. expected_finish)
  212. def test_count_limit_reader_without_limit(self):
  213. # No iter count specified, should read all records.
  214. self._test_limit_reader_shared(ReaderWithLimit,
  215. size=100,
  216. expected_read_len=100,
  217. expected_read_len_threshold=0,
  218. expected_finish=True,
  219. num_threads=8,
  220. read_delay=0,
  221. num_iter=None)
  222. def test_count_limit_reader_with_zero_limit(self):
  223. # Zero iter count specified, should read 0 records.
  224. self._test_limit_reader_shared(ReaderWithLimit,
  225. size=100,
  226. expected_read_len=0,
  227. expected_read_len_threshold=0,
  228. expected_finish=False,
  229. num_threads=8,
  230. read_delay=0,
  231. num_iter=0)
  232. def test_count_limit_reader_with_low_limit(self):
  233. # Read with limit smaller than size of dataset
  234. self._test_limit_reader_shared(ReaderWithLimit,
  235. size=100,
  236. expected_read_len=10,
  237. expected_read_len_threshold=0,
  238. expected_finish=False,
  239. num_threads=8,
  240. read_delay=0,
  241. num_iter=10)
  242. def test_count_limit_reader_with_high_limit(self):
  243. # Read with limit larger than size of dataset
  244. self._test_limit_reader_shared(ReaderWithLimit,
  245. size=100,
  246. expected_read_len=100,
  247. expected_read_len_threshold=0,
  248. expected_finish=True,
  249. num_threads=8,
  250. read_delay=0,
  251. num_iter=110)
  252. def test_time_limit_reader_without_limit(self):
  253. # No duration specified, should read all records.
  254. self._test_limit_reader_shared(ReaderWithTimeLimit,
  255. size=100,
  256. expected_read_len=100,
  257. expected_read_len_threshold=0,
  258. expected_finish=True,
  259. num_threads=8,
  260. read_delay=0.1,
  261. duration=0)
  262. def test_time_limit_reader_with_short_limit(self):
  263. # Read with insufficient time limit
  264. size = 50
  265. num_threads = 4
  266. sleep_duration = 0.25
  267. duration = 1
  268. expected_read_len = int(round(num_threads * duration / sleep_duration))
  269. # Because the time limit check happens before the delay + read op,
  270. # subtract a little bit of time to ensure we don't get in an extra read
  271. duration = duration - 0.25 * sleep_duration
  272. # NOTE: `expected_read_len_threshold` was added because this test case
  273. # has significant execution variation under stress. Under stress, we may
  274. # read strictly less than the expected # of samples; anywhere from
  275. # [0,N] where N = expected_read_len.
  276. # Hence we set expected_read_len to N/2, plus or minus N/2.
  277. self._test_limit_reader_shared(ReaderWithTimeLimit,
  278. size=size,
  279. expected_read_len=expected_read_len / 2,
  280. expected_read_len_threshold=expected_read_len / 2,
  281. expected_finish=False,
  282. num_threads=num_threads,
  283. read_delay=sleep_duration,
  284. duration=duration)
  285. def test_time_limit_reader_with_long_limit(self):
  286. # Read with ample time limit
  287. # NOTE: we don't use `expected_read_len_threshold` because the duration,
  288. # read_delay, and # threads should be more than sufficient
  289. self._test_limit_reader_shared(ReaderWithTimeLimit,
  290. size=50,
  291. expected_read_len=50,
  292. expected_read_len_threshold=0,
  293. expected_finish=True,
  294. num_threads=4,
  295. read_delay=0.2,
  296. duration=10)
  297. class TestDBFileReader(TestCase):
  298. def setUp(self):
  299. self.temp_paths = []
  300. def tearDown(self):
  301. # In case any test method fails, clean up temp paths.
  302. for path in self.temp_paths:
  303. self._delete_path(path)
  304. @staticmethod
  305. def _delete_path(path):
  306. if os.path.isfile(path):
  307. os.remove(path) # Remove file.
  308. elif os.path.isdir(path):
  309. shutil.rmtree(path) # Remove dir recursively.
  310. def _make_temp_path(self):
  311. # Make a temp path as db_path.
  312. with tempfile.NamedTemporaryFile() as f:
  313. temp_path = f.name
  314. self.temp_paths.append(temp_path)
  315. return temp_path
  316. @staticmethod
  317. def _build_source_reader(ws, size):
  318. src_ds = make_source_dataset(ws, size)
  319. return src_ds.reader()
  320. @staticmethod
  321. def _read_all_data(ws, reader, session):
  322. dst_ds = make_destination_dataset(ws, reader.schema().clone_schema())
  323. with TaskGroup() as tg:
  324. pipe(reader, dst_ds.writer(), num_runtime_threads=8)
  325. session.run(tg)
  326. return ws.blobs[str(dst_ds.content().label())].fetch()
  327. @unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
  328. def test_cached_reader(self):
  329. ws = workspace.C.Workspace()
  330. session = LocalSession(ws)
  331. db_path = self._make_temp_path()
  332. # Read data for the first time.
  333. cached_reader1 = CachedReader(
  334. self._build_source_reader(ws, 100), db_path, loop_over=False,
  335. )
  336. build_cache_step = cached_reader1.build_cache_step()
  337. session.run(build_cache_step)
  338. data = self._read_all_data(ws, cached_reader1, session)
  339. self.assertEqual(sorted(data), list(range(100)))
  340. # Read data from cache.
  341. cached_reader2 = CachedReader(
  342. self._build_source_reader(ws, 200), db_path,
  343. )
  344. build_cache_step = cached_reader2.build_cache_step()
  345. session.run(build_cache_step)
  346. data = self._read_all_data(ws, cached_reader2, session)
  347. self.assertEqual(sorted(data), list(range(100)))
  348. self._delete_path(db_path)
  349. # We removed cache so we expect to receive data from original reader.
  350. cached_reader3 = CachedReader(
  351. self._build_source_reader(ws, 300), db_path,
  352. )
  353. build_cache_step = cached_reader3.build_cache_step()
  354. session.run(build_cache_step)
  355. data = self._read_all_data(ws, cached_reader3, session)
  356. self.assertEqual(sorted(data), list(range(300)))
  357. self._delete_path(db_path)
  358. @unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
  359. def test_db_file_reader(self):
  360. ws = workspace.C.Workspace()
  361. session = LocalSession(ws)
  362. db_path = self._make_temp_path()
  363. # Build a cache DB file.
  364. cached_reader = CachedReader(
  365. self._build_source_reader(ws, 100),
  366. db_path=db_path,
  367. db_type='LevelDB',
  368. )
  369. build_cache_step = cached_reader.build_cache_step()
  370. session.run(build_cache_step)
  371. # Read data from cache DB file.
  372. db_file_reader = DBFileReader(
  373. db_path=db_path,
  374. db_type='LevelDB',
  375. )
  376. data = self._read_all_data(ws, db_file_reader, session)
  377. self.assertEqual(sorted(data), list(range(100)))
  378. self._delete_path(db_path)