| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445 |
- from caffe2.python.dataio import (
- CompositeReader,
- CompositeReaderBuilder,
- ReaderBuilder,
- ReaderWithDelay,
- ReaderWithLimit,
- ReaderWithTimeLimit,
- )
- from caffe2.python.dataset import Dataset
- from caffe2.python.db_file_reader import DBFileReader
- from caffe2.python.pipeline import pipe
- from caffe2.python.schema import Struct, NewRecord, FeedRecord
- from caffe2.python.session import LocalSession
- from caffe2.python.task import TaskGroup, final_output, WorkspaceType
- from caffe2.python.test_util import TestCase
- from caffe2.python.cached_reader import CachedReader
- from caffe2.python import core, workspace, schema
- from caffe2.python.net_builder import ops
- import numpy as np
- import numpy.testing as npt
- import os
- import shutil
- import unittest
- import tempfile
- def make_source_dataset(ws, size=100, offset=0, name=None):
- name = name or "src"
- src_init = core.Net("{}_init".format(name))
- with core.NameScope(name):
- src_values = Struct(('label', np.array(range(offset, offset + size))))
- src_blobs = NewRecord(src_init, src_values)
- src_ds = Dataset(src_blobs, name=name)
- FeedRecord(src_blobs, src_values, ws)
- ws.run(src_init)
- return src_ds
- def make_destination_dataset(ws, schema, name=None):
- name = name or 'dst'
- dst_init = core.Net('{}_init'.format(name))
- with core.NameScope(name):
- dst_ds = Dataset(schema, name=name)
- dst_ds.init_empty(dst_init)
- ws.run(dst_init)
- return dst_ds
- class TestReaderBuilder(ReaderBuilder):
- def __init__(self, name, size, offset):
- self._schema = schema.Struct(
- ('label', schema.Scalar()),
- )
- self._name = name
- self._size = size
- self._offset = offset
- self._src_ds = None
- def schema(self):
- return self._schema
- def setup(self, ws):
- self._src_ds = make_source_dataset(ws, offset=self._offset, size=self._size,
- name=self._name)
- return {}
- def new_reader(self, **kwargs):
- return self._src_ds
- class TestCompositeReader(TestCase):
- @unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
- def test_composite_reader(self):
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- num_srcs = 3
- names = ["src_{}".format(i) for i in range(num_srcs)]
- size = 100
- offsets = [i * size for i in range(num_srcs)]
- src_dses = [make_source_dataset(ws, offset=offset, size=size, name=name)
- for (name, offset) in zip(names, offsets)]
- data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses]
- # Sanity check we didn't overwrite anything
- for d, offset in zip(data, offsets):
- npt.assert_array_equal(d, range(offset, offset + size))
- # Make an identically-sized empty destination dataset
- dst_ds_schema = schema.Struct(
- *[
- (name, src_ds.content().clone_schema())
- for name, src_ds in zip(names, src_dses)
- ]
- )
- dst_ds = make_destination_dataset(ws, dst_ds_schema)
- with TaskGroup() as tg:
- reader = CompositeReader(names,
- [src_ds.reader() for src_ds in src_dses])
- pipe(reader, dst_ds.writer(), num_runtime_threads=3)
- session.run(tg)
- for i in range(num_srcs):
- written_data = sorted(
- ws.fetch_blob(str(dst_ds.content()[names[i]].label())))
- npt.assert_array_equal(data[i], written_data, "i: {}".format(i))
- @unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
- def test_composite_reader_builder(self):
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- num_srcs = 3
- names = ["src_{}".format(i) for i in range(num_srcs)]
- size = 100
- offsets = [i * size for i in range(num_srcs)]
- src_ds_builders = [
- TestReaderBuilder(offset=offset, size=size, name=name)
- for (name, offset) in zip(names, offsets)
- ]
- # Make an identically-sized empty destination dataset
- dst_ds_schema = schema.Struct(
- *[
- (name, src_ds_builder.schema())
- for name, src_ds_builder in zip(names, src_ds_builders)
- ]
- )
- dst_ds = make_destination_dataset(ws, dst_ds_schema)
- with TaskGroup() as tg:
- reader_builder = CompositeReaderBuilder(
- names, src_ds_builders)
- reader_builder.setup(ws=ws)
- pipe(reader_builder.new_reader(), dst_ds.writer(),
- num_runtime_threads=3)
- session.run(tg)
- for name, offset in zip(names, offsets):
- written_data = sorted(
- ws.fetch_blob(str(dst_ds.content()[name].label())))
- npt.assert_array_equal(range(offset, offset + size), written_data,
- "name: {}".format(name))
- class TestReaderWithLimit(TestCase):
- def test_runtime_threads(self):
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- src_ds = make_source_dataset(ws)
- totals = [None] * 3
- def proc(rec):
- # executed once
- with ops.task_init():
- counter1 = ops.CreateCounter([], ['global_counter'])
- counter2 = ops.CreateCounter([], ['global_counter2'])
- counter3 = ops.CreateCounter([], ['global_counter3'])
- # executed once per thread
- with ops.task_instance_init():
- task_counter = ops.CreateCounter([], ['task_counter'])
- # executed on each iteration
- ops.CountUp(counter1)
- ops.CountUp(task_counter)
- # executed once per thread
- with ops.task_instance_exit():
- with ops.loop(ops.RetrieveCount(task_counter)):
- ops.CountUp(counter2)
- ops.CountUp(counter3)
- # executed once
- with ops.task_exit():
- totals[0] = final_output(ops.RetrieveCount(counter1))
- totals[1] = final_output(ops.RetrieveCount(counter2))
- totals[2] = final_output(ops.RetrieveCount(counter3))
- return rec
- # Read full data set from original reader
- with TaskGroup() as tg:
- pipe(src_ds.reader(), num_runtime_threads=8, processor=proc)
- session.run(tg)
- self.assertEqual(totals[0].fetch(), 100)
- self.assertEqual(totals[1].fetch(), 100)
- self.assertEqual(totals[2].fetch(), 8)
- # Read with a count-limited reader
- with TaskGroup() as tg:
- q1 = pipe(src_ds.reader(), num_runtime_threads=2)
- q2 = pipe(
- ReaderWithLimit(q1.reader(), num_iter=25),
- num_runtime_threads=3)
- pipe(q2, processor=proc, num_runtime_threads=6)
- session.run(tg)
- self.assertEqual(totals[0].fetch(), 25)
- self.assertEqual(totals[1].fetch(), 25)
- self.assertEqual(totals[2].fetch(), 6)
- def _test_limit_reader_init_shared(self, size):
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- # Make source dataset
- src_ds = make_source_dataset(ws, size=size)
- # Make an identically-sized empty destination Dataset
- dst_ds = make_destination_dataset(ws, src_ds.content().clone_schema())
- return ws, session, src_ds, dst_ds
- def _test_limit_reader_shared(self, reader_class, size, expected_read_len,
- expected_read_len_threshold,
- expected_finish, num_threads, read_delay,
- **limiter_args):
- ws, session, src_ds, dst_ds = \
- self._test_limit_reader_init_shared(size)
- # Read without limiter
- # WorkspaceType.GLOBAL is required because we are fetching
- # reader.data_finished() after the TaskGroup finishes.
- with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
- if read_delay > 0:
- reader = reader_class(ReaderWithDelay(src_ds.reader(),
- read_delay),
- **limiter_args)
- else:
- reader = reader_class(src_ds.reader(), **limiter_args)
- pipe(reader, dst_ds.writer(), num_runtime_threads=num_threads)
- session.run(tg)
- read_len = len(sorted(ws.blobs[str(dst_ds.content().label())].fetch()))
- # Do a fuzzy match (expected_read_len +/- expected_read_len_threshold)
- # to eliminate flakiness for time-limited tests
- self.assertGreaterEqual(
- read_len,
- expected_read_len - expected_read_len_threshold)
- self.assertLessEqual(
- read_len,
- expected_read_len + expected_read_len_threshold)
- self.assertEqual(
- sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
- list(range(read_len))
- )
- self.assertEqual(ws.blobs[str(reader.data_finished())].fetch(),
- expected_finish)
- def test_count_limit_reader_without_limit(self):
- # No iter count specified, should read all records.
- self._test_limit_reader_shared(ReaderWithLimit,
- size=100,
- expected_read_len=100,
- expected_read_len_threshold=0,
- expected_finish=True,
- num_threads=8,
- read_delay=0,
- num_iter=None)
- def test_count_limit_reader_with_zero_limit(self):
- # Zero iter count specified, should read 0 records.
- self._test_limit_reader_shared(ReaderWithLimit,
- size=100,
- expected_read_len=0,
- expected_read_len_threshold=0,
- expected_finish=False,
- num_threads=8,
- read_delay=0,
- num_iter=0)
- def test_count_limit_reader_with_low_limit(self):
- # Read with limit smaller than size of dataset
- self._test_limit_reader_shared(ReaderWithLimit,
- size=100,
- expected_read_len=10,
- expected_read_len_threshold=0,
- expected_finish=False,
- num_threads=8,
- read_delay=0,
- num_iter=10)
- def test_count_limit_reader_with_high_limit(self):
- # Read with limit larger than size of dataset
- self._test_limit_reader_shared(ReaderWithLimit,
- size=100,
- expected_read_len=100,
- expected_read_len_threshold=0,
- expected_finish=True,
- num_threads=8,
- read_delay=0,
- num_iter=110)
- def test_time_limit_reader_without_limit(self):
- # No duration specified, should read all records.
- self._test_limit_reader_shared(ReaderWithTimeLimit,
- size=100,
- expected_read_len=100,
- expected_read_len_threshold=0,
- expected_finish=True,
- num_threads=8,
- read_delay=0.1,
- duration=0)
- def test_time_limit_reader_with_short_limit(self):
- # Read with insufficient time limit
- size = 50
- num_threads = 4
- sleep_duration = 0.25
- duration = 1
- expected_read_len = int(round(num_threads * duration / sleep_duration))
- # Because the time limit check happens before the delay + read op,
- # subtract a little bit of time to ensure we don't get in an extra read
- duration = duration - 0.25 * sleep_duration
- # NOTE: `expected_read_len_threshold` was added because this test case
- # has significant execution variation under stress. Under stress, we may
- # read strictly less than the expected # of samples; anywhere from
- # [0,N] where N = expected_read_len.
- # Hence we set expected_read_len to N/2, plus or minus N/2.
- self._test_limit_reader_shared(ReaderWithTimeLimit,
- size=size,
- expected_read_len=expected_read_len / 2,
- expected_read_len_threshold=expected_read_len / 2,
- expected_finish=False,
- num_threads=num_threads,
- read_delay=sleep_duration,
- duration=duration)
- def test_time_limit_reader_with_long_limit(self):
- # Read with ample time limit
- # NOTE: we don't use `expected_read_len_threshold` because the duration,
- # read_delay, and # threads should be more than sufficient
- self._test_limit_reader_shared(ReaderWithTimeLimit,
- size=50,
- expected_read_len=50,
- expected_read_len_threshold=0,
- expected_finish=True,
- num_threads=4,
- read_delay=0.2,
- duration=10)
- class TestDBFileReader(TestCase):
- def setUp(self):
- self.temp_paths = []
- def tearDown(self):
- # In case any test method fails, clean up temp paths.
- for path in self.temp_paths:
- self._delete_path(path)
- @staticmethod
- def _delete_path(path):
- if os.path.isfile(path):
- os.remove(path) # Remove file.
- elif os.path.isdir(path):
- shutil.rmtree(path) # Remove dir recursively.
- def _make_temp_path(self):
- # Make a temp path as db_path.
- with tempfile.NamedTemporaryFile() as f:
- temp_path = f.name
- self.temp_paths.append(temp_path)
- return temp_path
- @staticmethod
- def _build_source_reader(ws, size):
- src_ds = make_source_dataset(ws, size)
- return src_ds.reader()
- @staticmethod
- def _read_all_data(ws, reader, session):
- dst_ds = make_destination_dataset(ws, reader.schema().clone_schema())
- with TaskGroup() as tg:
- pipe(reader, dst_ds.writer(), num_runtime_threads=8)
- session.run(tg)
- return ws.blobs[str(dst_ds.content().label())].fetch()
- @unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
- def test_cached_reader(self):
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- db_path = self._make_temp_path()
- # Read data for the first time.
- cached_reader1 = CachedReader(
- self._build_source_reader(ws, 100), db_path, loop_over=False,
- )
- build_cache_step = cached_reader1.build_cache_step()
- session.run(build_cache_step)
- data = self._read_all_data(ws, cached_reader1, session)
- self.assertEqual(sorted(data), list(range(100)))
- # Read data from cache.
- cached_reader2 = CachedReader(
- self._build_source_reader(ws, 200), db_path,
- )
- build_cache_step = cached_reader2.build_cache_step()
- session.run(build_cache_step)
- data = self._read_all_data(ws, cached_reader2, session)
- self.assertEqual(sorted(data), list(range(100)))
- self._delete_path(db_path)
- # We removed cache so we expect to receive data from original reader.
- cached_reader3 = CachedReader(
- self._build_source_reader(ws, 300), db_path,
- )
- build_cache_step = cached_reader3.build_cache_step()
- session.run(build_cache_step)
- data = self._read_all_data(ws, cached_reader3, session)
- self.assertEqual(sorted(data), list(range(300)))
- self._delete_path(db_path)
- @unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
- def test_db_file_reader(self):
- ws = workspace.C.Workspace()
- session = LocalSession(ws)
- db_path = self._make_temp_path()
- # Build a cache DB file.
- cached_reader = CachedReader(
- self._build_source_reader(ws, 100),
- db_path=db_path,
- db_type='LevelDB',
- )
- build_cache_step = cached_reader.build_cache_step()
- session.run(build_cache_step)
- # Read data from cache DB file.
- db_file_reader = DBFileReader(
- db_path=db_path,
- db_type='LevelDB',
- )
- data = self._read_all_data(ws, db_file_reader, session)
- self.assertEqual(sorted(data), list(range(100)))
- self._delete_path(db_path)
|