dataio.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. ## @package dataio
  2. # Module caffe2.python.dataio
  3. """
  4. Defines the base interface for reading and writing operations.
  5. Readers/Writers are objects that produce operations that read/write sequences
  6. of data. Each operation reads or writes a list of BlobReferences.
  7. Readers and Writers must be implemented such that read and write operations
  8. are atomic and thread safe.
  9. Examples of possible Readers and Writers:
  10. QueueReader, QueueWriter,
  11. DatasetReader, DatasetWriter,
  12. See `dataset.py` for an example of implementation.
  13. """
  14. from caffe2.python import core
  15. from caffe2.python.schema import Field, Struct, from_blob_list
  16. import numpy as np
  17. import time
  18. class Reader(object):
  19. """
  20. Reader is an abstract class to be implemented in order to provide
  21. operations capable of iterating through a dataset or stream of data.
  22. A Reader must implement at least one operation, `read`, which
  23. adds operations to a net that read the next batch of data. Readers can
  24. optionally support the `reset` operation, which is useful when multiple
  25. passes over the data are required.
  26. """
  27. def __init__(self, schema=None):
  28. if schema is not None:
  29. assert isinstance(schema, Field)
  30. self._schema = schema
  31. def schema(self):
  32. assert self._schema is not None, 'Schema not provided for this reader.'
  33. return self._schema
  34. def _set_schema(self, schema):
  35. self._schema = schema
  36. def setup_ex(self, init_net, finish_net):
  37. """Setup nets to run at task initialization and cleanup time.
  38. Args:
  39. global_init_net: A net invoked at task init time.
  40. global_finish_net: A net invoked at task cleanup time.
  41. """
  42. pass
  43. def read_ex(self, local_init_net, local_finish_net):
  44. read_net = core.Net('reader_body')
  45. return ([read_net], ) + self.read(read_net)
  46. def read_record_ex(self, local_init_net, local_finish_net):
  47. nets, should_stop, fields = self.read_ex(
  48. local_init_net, local_finish_net)
  49. if self._schema:
  50. fields = from_blob_list(self._schema, fields)
  51. return nets, should_stop, fields
  52. def read(self, read_net):
  53. """Append operations to read_net that will read a batch from the
  54. underlying data soruce.
  55. Operations added to `read_net` must be thread safe and atomic, that is,
  56. it should be possible to clone `read_net` and run multiple instances of
  57. it in parallel.
  58. Args:
  59. read_net: the net that will be appended with read operations
  60. Returns:
  61. A tuple (should_stop, fields), with:
  62. should_stop: BlobReference pointing to a boolean scalar
  63. blob that indicates whether the read operation
  64. was succesfull or whether the end of data has
  65. been reached.
  66. fields: A tuple of BlobReference containing the latest batch
  67. of data that was read.
  68. """
  69. raise NotImplementedError('Readers must implement `read`.')
  70. def reset(self, net):
  71. """Append operations to `net` that will reset the reader.
  72. This can be used to read the data multiple times.
  73. Not all readers support this operation.
  74. """
  75. raise NotImplementedError('This reader cannot be resetted.')
  76. def read_record(self, read_net):
  77. should_stop, fields = self.read(read_net)
  78. if self._schema:
  79. fields = from_blob_list(self._schema, fields)
  80. return should_stop, fields
  81. def execution_step(self, reader_net_name=None, external_should_stop=None):
  82. """Create an execution step with a net containing read operators.
  83. The execution step will contain a `stop_blob` that knows how to stop
  84. the execution loop when end of data was reached.
  85. E.g.:
  86. read_step, fields = reader.execution_step()
  87. consume_net = core.Net('consume')
  88. consume_net.Print(fields[0], [])
  89. p = core.Plan('reader')
  90. p.AddStep(read_step.AddNet(consume_net))
  91. core.RunPlan(p)
  92. Args:
  93. reader_net_name: (optional) the name of the reader_net to be
  94. created. The execution step will
  95. be named accordingly.
  96. Returns:
  97. A tuple (read_step, fields), with:
  98. read_step: A newly created execution step containing a net with
  99. read operations. The step will have `stop_blob` set,
  100. in order to stop the loop on end of data.
  101. fields: A tuple of BlobReference containing the latest batch
  102. of data that was read.
  103. """
  104. reader_net = core.Net(reader_net_name or 'reader')
  105. should_stop, fields = self.read_record(reader_net)
  106. if external_should_stop is not None:
  107. should_stop = reader_net.Or([external_should_stop, should_stop])
  108. read_step = core.execution_step(
  109. '{}_step'.format(reader_net_name),
  110. reader_net,
  111. should_stop_blob=should_stop)
  112. return (read_step, fields)
  113. class Writer(object):
  114. """
  115. Writer is an abstract class to be implemented in order to provide
  116. operations capable of feeding a data stream or a dataset.
  117. A Writer must implement 2 operations:
  118. `write`, which adds operations to a net that write the write batch of
  119. data, and `commit`, which adds operations to a net in order to indicate
  120. that no more data will be written.
  121. """
  122. _schema = None
  123. def schema(self):
  124. return self._schema
  125. def write(self, writer_net, fields):
  126. """Add operations to `writer_net` that write the next batch of data.
  127. Operations added to the net must be thread-safe and unique, that is:
  128. multiple writers must be able to write to the dataset in parallel.
  129. Args:
  130. fields: a tuple of BlobReference containing the batch of data to
  131. write.
  132. """
  133. raise NotImplementedError('Writers must implement write.')
  134. def write_record(self, writer_net, fields):
  135. if isinstance(fields, Field):
  136. self._schema = fields
  137. fields = fields.field_blobs()
  138. self.write(writer_net, fields)
  139. def setup_ex(self, init_net, finish_net):
  140. """Experimental, don't use yet"""
  141. self.commit(finish_net)
  142. def write_ex(self, fields, local_init_net, local_finish_net, stop_blob):
  143. """Experimental extension to the interface. Don't use yet"""
  144. write_net = core.Net('write_net')
  145. self.write(write_net, fields)
  146. return [write_net]
  147. def write_record_ex(
  148. self, fields, local_init_net, local_finish_net, stop_blob=None):
  149. """Experimental extension to the interface. Don't use yet."""
  150. if isinstance(fields, Field):
  151. self._schema = fields
  152. fields = fields.field_blobs()
  153. if stop_blob is None:
  154. stop_blob = local_init_net.NextName("dequeue_status")
  155. write_nets = self.write_ex(
  156. fields, local_init_net, local_finish_net, stop_blob)
  157. return (write_nets, stop_blob)
  158. def commit(self, finish_net):
  159. """Add operations to `finish_net` that signal end of data.
  160. This must be implemented by all Writers, but may be no-op for some
  161. of them.
  162. """
  163. pass
  164. class ReaderBuilder(object):
  165. """ Allow usage of a reader in distributed fashion. """
  166. def schema(self):
  167. raise NotImplementedError()
  168. def setup(self, **kwargs):
  169. """
  170. Optionally, perform one-time setup before calling new_reader().
  171. Subclass should make sure this function is only called once.
  172. """
  173. raise NotImplementedError()
  174. def new_reader(self, **kwargs):
  175. raise NotImplementedError()
  176. class PipedReaderBuilder(ReaderBuilder):
  177. """ReaderBuilder that modifies underlying builder by calling `piper`
  178. function on each new reader produced, and return the result of
  179. the function. This way, it is possible to append data processing
  180. pipelines that will be replicated for each reader that gets created.
  181. E.g.:
  182. PipedReaderBuilder(
  183. ReaderBuilder(...),
  184. lambda reader: pipe(reader, processor=my_proc))
  185. """
  186. def __init__(self, builder, piper):
  187. self._builder = builder
  188. self._piper = piper
  189. def schema(self):
  190. return self._builder.schema()
  191. def setup(self, **kwargs):
  192. return self._builder.setup(**kwargs)
  193. def new_reader(self, **kwargs):
  194. # Passing everything down since you could wrap a PipedReaderBuilder in
  195. # another PipedReaderBuilder
  196. output = self._piper(
  197. reader=self._builder.new_reader(**kwargs),
  198. **kwargs
  199. )
  200. return output if isinstance(output, Reader) else output.reader()
  201. class Pipe(object):
  202. def __init__(self, schema=None, obj_key=None):
  203. self._num_writers = 0
  204. self._num_readers = 0
  205. self._schema = schema
  206. self._obj_key = obj_key
  207. def schema(self):
  208. return self._schema
  209. def setup(self, global_init_net):
  210. pass
  211. def reader(self):
  212. raise NotImplementedError()
  213. def writer(self):
  214. raise NotImplementedError()
  215. def num_readers(self):
  216. return self._num_readers
  217. def num_writers(self):
  218. return self._num_writers
  219. def _new_writer(self, writer_schema, writer_init_net):
  220. if writer_schema is not None and self._schema is None:
  221. self._schema = writer_schema
  222. self._num_writers += 1
  223. if self._obj_key is not None:
  224. writer_init_net.add_attribute(self._obj_key, self)
  225. def _new_reader(self, reader_init_net):
  226. self._num_readers += 1
  227. if self._obj_key is not None:
  228. reader_init_net.add_attribute(self._obj_key, self)
  229. class CounterReader(Reader):
  230. """ Reader that produces increasing integers. """
  231. def __init__(self):
  232. Reader.__init__(self, schema=Struct(('iter', np.int64)))
  233. self.counter = None
  234. self.should_stop = None
  235. def setup_ex(self, global_init_net, global_finish_net):
  236. if self.counter is None:
  237. self.counter = global_init_net.CreateCounter([], init_count=0)
  238. self.should_stop = global_init_net.ConstantFill(
  239. [], shape=[], dtype=core.DataType.BOOL, value=False)
  240. def read_ex(self, local_init_net, local_finish_net):
  241. count_net = core.Net('limited_reader_counter')
  242. value = count_net.CountUp([self.counter], 1)
  243. return [count_net], self.should_stop, [value]
  244. class ReaderWithLimitBase(Reader):
  245. """Abstract Reader constrained by certain conditions.
  246. Base class for Reader classes which check for certain conditions to stop
  247. further processing (e.g. max number of iterations or time limit).
  248. Also produces a boolean blob (data_finished) that can be used to see if
  249. the reader exausted all input data (true) or stopped for another reason
  250. (false).
  251. """
  252. def __init__(self, reader):
  253. Reader.__init__(self, schema=reader._schema)
  254. self.reader = reader
  255. self.net = core.Net('reader_with_limit')
  256. self._data_finished = self.net.AddExternalInput(
  257. self.net.NextName('data_finished'))
  258. self.should_stop = None
  259. def setup_ex(self, global_init_net, global_finish_net):
  260. global_init_net.ConstantFill(
  261. [], [self._data_finished],
  262. shape=[], value=False, dtype=core.DataType.BOOL)
  263. self.reader.setup_ex(global_init_net, global_finish_net)
  264. self.setup_limiter(global_init_net, global_finish_net)
  265. def read_ex(self, local_init_net, local_finish_net):
  266. """Reads from an underlying Reader class, but may stop due to additional
  267. constraints.
  268. Build and return network(s) to read data from a Reader with
  269. additional constraints, depending on which derived class is used.
  270. Derived classes implement setup_limited and check_limiter_condition
  271. which determine the nature of the constraint imposed on the reader,
  272. e.g. iteration limits or time limit.
  273. Args:
  274. local_init_net: A net invoked at task instance init time (Once per
  275. parallel thread).
  276. local_finish_net: A net invoked at task instance cleanup time (Once
  277. per parallel thread).
  278. """
  279. # Check if limiting constraint is met.
  280. stop_condition_net = core.Net('limited_reader_condition')
  281. should_stop = self.check_limiter_condition(stop_condition_net)
  282. # Call original reader.
  283. nets, local_data_finished, fields = self.reader.read_ex(
  284. local_init_net, local_finish_net)
  285. self._set_schema(self.reader._schema)
  286. # Check if original reader is done.
  287. check_done_net = core.Net('limited_reader_post')
  288. # Copy to the same blob as the counter output to trigger reader
  289. # stopping - this is ok because execution will check should_stop_blob
  290. # after every single operation, so it has already been checked on this
  291. # iteration by this point.
  292. check_done_net.Copy(local_data_finished, should_stop)
  293. # Update externally-accessible flag indicating if reader is done
  294. check_done_net.Or([self._data_finished, local_data_finished],
  295. [self._data_finished])
  296. return [stop_condition_net] + nets + [check_done_net], should_stop, fields
  297. def setup_limiter(self, global_init_net, global_finish_net):
  298. """Configure task level init/cleanup nets required to implement limit
  299. condition. Must be implemented by subclass.
  300. Args:
  301. global_init_net: A net invoked at task init time.
  302. global_finish_net: A net invoked at task cleanup time.
  303. """
  304. raise NotImplementedError("Subclass must implement `setup_limiter`")
  305. def check_limiter_condition(self, stop_condition_net):
  306. """Configure a net that is invoked between reading batches to see if
  307. limit condition is met. Must be implemented by subclass.
  308. Args:
  309. stop_condition_net: A net invoked to evaluate an early termination
  310. condition.
  311. """
  312. raise NotImplementedError("Subclass must implement `check_limiter_condition")
  313. def data_finished(self):
  314. """
  315. Return a blob that can be checked after the end of the reading task,
  316. which will contain a scalar float indicating whether the underlying
  317. reader has been exhausted (True) or whether we stopped because reached
  318. the limit of iterations (False).
  319. """
  320. return self._data_finished
  321. class ReaderWithLimit(ReaderWithLimitBase):
  322. """Reader that stops after `num_iter` batches.
  323. If `num_iter` <= 0 or is None, reverts to an unconstrained reader that
  324. exports a boolean blob indicating that the reader has exhausted
  325. the data steam.
  326. """
  327. def __init__(self, reader, num_iter=1):
  328. """Class initializer.
  329. Args:
  330. reader: The underlying reader object doing the actual read.
  331. num_iter: Number of batches to read. If `None`,
  332. the class reverts to a normal reader except that it also
  333. produces a data_finished blob as a side effect to indicate
  334. whether the input stream is exhausted.
  335. """
  336. super(ReaderWithLimit, self).__init__(reader)
  337. self.counter = None
  338. self.num_iter = num_iter
  339. if self.num_iter is not None:
  340. self.counter = self.net.AddExternalInput(
  341. self.net.NextName('counter'))
  342. def setup_limiter(self, global_init_net, global_finish_net):
  343. if self.counter:
  344. global_init_net.CreateCounter(
  345. [], [self.counter], init_count=int(self.num_iter))
  346. def check_limiter_condition(self, stop_condition_net):
  347. if self.counter:
  348. return stop_condition_net.CountDown([self.counter], 1)
  349. else:
  350. return stop_condition_net.ConstantFill(
  351. [], 1,
  352. shape=[], value=False, dtype=core.DataType.BOOL)
  353. def CountUntil(num_iter):
  354. return ReaderWithLimit(CounterReader(), num_iter)
  355. class ReaderWithTimeLimit(ReaderWithLimitBase):
  356. """Reader that stops after `duration` seconds.
  357. If `duration` <= 0 or is None, reverts to an unconstrained reader that
  358. exports a boolean blob indicating that the reader has exhausted
  359. the data steam.
  360. """
  361. def __init__(self, reader, duration=0):
  362. """Class initializer.
  363. Args:
  364. reader: The underlying reader object doing the actual read.
  365. duration: Number of seconds to read. If un-specified, None, or <= 0,
  366. the class reverts to a normal reader except that it also
  367. produces a data_finished blob as a side effect to indicate
  368. whether the input stream is exhausted.
  369. """
  370. super(ReaderWithTimeLimit, self).__init__(reader)
  371. self.timer = None
  372. self.duration = duration
  373. self.duration_ns_blob = None
  374. def setup_limiter(self, global_init_net, global_finish_net):
  375. if self.duration is not None and self.duration > 0:
  376. duration_ns = int(self.duration * (10**9))
  377. self.timer = global_init_net.TimerBegin(
  378. [], counter_name='epoch_timer')
  379. start_time = global_init_net.TimerGet(self.timer)
  380. self.duration_ns_blob = global_init_net.ConstantFill(
  381. [start_time], value=duration_ns)
  382. global_finish_net.TimerEnd([self.timer], [])
  383. def check_limiter_condition(self, stop_condition_net):
  384. if self.duration:
  385. time_elapsed = stop_condition_net.TimerGet(self.timer)
  386. return stop_condition_net.GE(
  387. [time_elapsed, self.duration_ns_blob], str(self.should_stop))
  388. else:
  389. return stop_condition_net.ConstantFill(
  390. [], 1, shape=[], value=False, dtype=core.DataType.BOOL
  391. )
  392. class ReaderWithDelay(Reader):
  393. """Test reader class that inserts a delay between reading batches."""
  394. def __init__(self, reader, delay):
  395. Reader.__init__(self, schema=reader._schema)
  396. self.reader = reader
  397. self.delay = delay
  398. def setup_ex(self, global_init_net, global_finish_net):
  399. self.reader.setup_ex(global_init_net, global_finish_net)
  400. def read_ex(self, local_init_net, local_finish_net):
  401. read_net = core.Net("reader_body")
  402. def sleep_op(*args, **argd):
  403. time.sleep(self.delay)
  404. read_net.Python(sleep_op)([], [])
  405. return ([read_net],) + self.reader.read(read_net)
  406. class CompositeReader(Reader):
  407. """
  408. Base class for a reader that wrap multiple readers, e.g., reading from
  409. multiple sources simultaneously.
  410. """
  411. def __init__(self, names, readers):
  412. """
  413. Args:
  414. names: list[str] names of readers; used as schema keys
  415. readers: list[Reader] Reader instances, must have schema
  416. """
  417. assert len(names) == len(readers)
  418. super(CompositeReader, self).__init__(schema=Struct(*[
  419. (name, reader.schema()) for name, reader in zip(names, readers)
  420. ]))
  421. self._names = names
  422. self._readers = readers
  423. def setup_ex(self, init_net, finish_net):
  424. for reader in self._readers:
  425. reader.setup_ex(init_net, finish_net)
  426. def read_ex(self, local_init_net, local_finish_net):
  427. """
  428. Stops when one of the reader finished
  429. """
  430. # First, instantiate all the reader nets
  431. fields = []
  432. stop_blobs = []
  433. all_sub_read_nets = []
  434. for name, reader in zip(self._names, self._readers):
  435. sub_read_nets, should_stop, record = reader.read_record_ex(
  436. local_init_net, local_finish_net)
  437. stop_blobs.append(should_stop)
  438. all_sub_read_nets.append(sub_read_nets)
  439. fields.extend(record.field_blobs())
  440. read_nets = []
  441. # Use the stop blob of the last reader as stop blob of composite reader.
  442. local_should_stop = stop_blobs[-1]
  443. for name, sub_read_nets, stop_blob in zip(self._names, all_sub_read_nets, stop_blobs):
  444. read_nets.extend(sub_read_nets)
  445. if stop_blob == local_should_stop:
  446. # Skip adding stop net because Or([A, A], A) doesn't pass operator
  447. # schema check
  448. continue
  449. stop_net = core.Net("{}_stop".format(name))
  450. stop_net.Or([local_should_stop, stop_blob], local_should_stop)
  451. read_nets.append(stop_net)
  452. return read_nets, local_should_stop, fields
  453. def reset(self, net):
  454. for reader in self._readers:
  455. reader.reset(net)
  456. class CompositeReaderBuilder(ReaderBuilder):
  457. """
  458. A reader builder for CompositeReader
  459. """
  460. def __init__(self, names, reader_builders):
  461. """
  462. Args:
  463. names: list[str] names of readers; used as schema keys
  464. reader_builders: list[ReaderBuilder] ReaderBuilder instances;
  465. must have schema
  466. """
  467. super(CompositeReaderBuilder, self).__init__()
  468. self._names = names
  469. self._reader_builders = reader_builders
  470. self._schema = Struct(*[
  471. (name, reader_builder.schema())
  472. for name, reader_builder in zip(names, reader_builders)
  473. ])
  474. def schema(self):
  475. return self._schema
  476. def setup(self, **kwargs):
  477. data_finished_blobs = {}
  478. # limiter is stateful; it can only be used once. Since
  479. # CompositeReader stops when one of the reader stops,
  480. # this is fine.
  481. if "limiter" in kwargs:
  482. limiter = kwargs.pop("limiter")
  483. else:
  484. limiter = None
  485. for i, reader_builder in enumerate(self._reader_builders):
  486. if i == len(self._reader_builders) - 1 and limiter is not None:
  487. # The limiter must be applied to the last reader so that the
  488. # batch counter is incremented only if every reader has data
  489. kwargs["limiter"] = limiter
  490. sub_reader_data_finished_blobs = reader_builder.setup(**kwargs)
  491. overlapping_keys = set(data_finished_blobs.keys()) & set(sub_reader_data_finished_blobs.keys())
  492. overlapping_values = set(data_finished_blobs.values()) & set(sub_reader_data_finished_blobs.values())
  493. assert overlapping_keys == set(), "Overlapping keys: {}".format(overlapping_keys)
  494. assert overlapping_values == set(), "Overlapping values: {}".format(overlapping_values)
  495. data_finished_blobs.update(sub_reader_data_finished_blobs)
  496. return data_finished_blobs
  497. def new_reader(self, **kwargs):
  498. readers = []
  499. for reader_builder in self._reader_builders:
  500. reader = reader_builder.new_reader(**kwargs)
  501. if isinstance(reader, Reader):
  502. pass
  503. elif hasattr(reader, 'reader'):
  504. reader = reader.reader()
  505. else:
  506. raise ValueError('reader must be an instance of Reader or Pipe')
  507. readers.append(reader)
  508. multi_reader = CompositeReader(self._names, readers)
  509. assert multi_reader.schema() == self._schema
  510. return multi_reader