| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635 |
- ## @package dataio
- # Module caffe2.python.dataio
- """
- Defines the base interface for reading and writing operations.
- Readers/Writers are objects that produce operations that read/write sequences
- of data. Each operation reads or writes a list of BlobReferences.
- Readers and Writers must be implemented such that read and write operations
- are atomic and thread safe.
- Examples of possible Readers and Writers:
- QueueReader, QueueWriter,
- DatasetReader, DatasetWriter,
- See `dataset.py` for an example of implementation.
- """
- from caffe2.python import core
- from caffe2.python.schema import Field, Struct, from_blob_list
- import numpy as np
- import time
- class Reader(object):
- """
- Reader is an abstract class to be implemented in order to provide
- operations capable of iterating through a dataset or stream of data.
- A Reader must implement at least one operation, `read`, which
- adds operations to a net that read the next batch of data. Readers can
- optionally support the `reset` operation, which is useful when multiple
- passes over the data are required.
- """
- def __init__(self, schema=None):
- if schema is not None:
- assert isinstance(schema, Field)
- self._schema = schema
- def schema(self):
- assert self._schema is not None, 'Schema not provided for this reader.'
- return self._schema
- def _set_schema(self, schema):
- self._schema = schema
- def setup_ex(self, init_net, finish_net):
- """Setup nets to run at task initialization and cleanup time.
- Args:
- global_init_net: A net invoked at task init time.
- global_finish_net: A net invoked at task cleanup time.
- """
- pass
- def read_ex(self, local_init_net, local_finish_net):
- read_net = core.Net('reader_body')
- return ([read_net], ) + self.read(read_net)
- def read_record_ex(self, local_init_net, local_finish_net):
- nets, should_stop, fields = self.read_ex(
- local_init_net, local_finish_net)
- if self._schema:
- fields = from_blob_list(self._schema, fields)
- return nets, should_stop, fields
- def read(self, read_net):
- """Append operations to read_net that will read a batch from the
- underlying data soruce.
- Operations added to `read_net` must be thread safe and atomic, that is,
- it should be possible to clone `read_net` and run multiple instances of
- it in parallel.
- Args:
- read_net: the net that will be appended with read operations
- Returns:
- A tuple (should_stop, fields), with:
- should_stop: BlobReference pointing to a boolean scalar
- blob that indicates whether the read operation
- was succesfull or whether the end of data has
- been reached.
- fields: A tuple of BlobReference containing the latest batch
- of data that was read.
- """
- raise NotImplementedError('Readers must implement `read`.')
- def reset(self, net):
- """Append operations to `net` that will reset the reader.
- This can be used to read the data multiple times.
- Not all readers support this operation.
- """
- raise NotImplementedError('This reader cannot be resetted.')
- def read_record(self, read_net):
- should_stop, fields = self.read(read_net)
- if self._schema:
- fields = from_blob_list(self._schema, fields)
- return should_stop, fields
- def execution_step(self, reader_net_name=None, external_should_stop=None):
- """Create an execution step with a net containing read operators.
- The execution step will contain a `stop_blob` that knows how to stop
- the execution loop when end of data was reached.
- E.g.:
- read_step, fields = reader.execution_step()
- consume_net = core.Net('consume')
- consume_net.Print(fields[0], [])
- p = core.Plan('reader')
- p.AddStep(read_step.AddNet(consume_net))
- core.RunPlan(p)
- Args:
- reader_net_name: (optional) the name of the reader_net to be
- created. The execution step will
- be named accordingly.
- Returns:
- A tuple (read_step, fields), with:
- read_step: A newly created execution step containing a net with
- read operations. The step will have `stop_blob` set,
- in order to stop the loop on end of data.
- fields: A tuple of BlobReference containing the latest batch
- of data that was read.
- """
- reader_net = core.Net(reader_net_name or 'reader')
- should_stop, fields = self.read_record(reader_net)
- if external_should_stop is not None:
- should_stop = reader_net.Or([external_should_stop, should_stop])
- read_step = core.execution_step(
- '{}_step'.format(reader_net_name),
- reader_net,
- should_stop_blob=should_stop)
- return (read_step, fields)
- class Writer(object):
- """
- Writer is an abstract class to be implemented in order to provide
- operations capable of feeding a data stream or a dataset.
- A Writer must implement 2 operations:
- `write`, which adds operations to a net that write the write batch of
- data, and `commit`, which adds operations to a net in order to indicate
- that no more data will be written.
- """
- _schema = None
- def schema(self):
- return self._schema
- def write(self, writer_net, fields):
- """Add operations to `writer_net` that write the next batch of data.
- Operations added to the net must be thread-safe and unique, that is:
- multiple writers must be able to write to the dataset in parallel.
- Args:
- fields: a tuple of BlobReference containing the batch of data to
- write.
- """
- raise NotImplementedError('Writers must implement write.')
- def write_record(self, writer_net, fields):
- if isinstance(fields, Field):
- self._schema = fields
- fields = fields.field_blobs()
- self.write(writer_net, fields)
- def setup_ex(self, init_net, finish_net):
- """Experimental, don't use yet"""
- self.commit(finish_net)
- def write_ex(self, fields, local_init_net, local_finish_net, stop_blob):
- """Experimental extension to the interface. Don't use yet"""
- write_net = core.Net('write_net')
- self.write(write_net, fields)
- return [write_net]
- def write_record_ex(
- self, fields, local_init_net, local_finish_net, stop_blob=None):
- """Experimental extension to the interface. Don't use yet."""
- if isinstance(fields, Field):
- self._schema = fields
- fields = fields.field_blobs()
- if stop_blob is None:
- stop_blob = local_init_net.NextName("dequeue_status")
- write_nets = self.write_ex(
- fields, local_init_net, local_finish_net, stop_blob)
- return (write_nets, stop_blob)
- def commit(self, finish_net):
- """Add operations to `finish_net` that signal end of data.
- This must be implemented by all Writers, but may be no-op for some
- of them.
- """
- pass
- class ReaderBuilder(object):
- """ Allow usage of a reader in distributed fashion. """
- def schema(self):
- raise NotImplementedError()
- def setup(self, **kwargs):
- """
- Optionally, perform one-time setup before calling new_reader().
- Subclass should make sure this function is only called once.
- """
- raise NotImplementedError()
- def new_reader(self, **kwargs):
- raise NotImplementedError()
- class PipedReaderBuilder(ReaderBuilder):
- """ReaderBuilder that modifies underlying builder by calling `piper`
- function on each new reader produced, and return the result of
- the function. This way, it is possible to append data processing
- pipelines that will be replicated for each reader that gets created.
- E.g.:
- PipedReaderBuilder(
- ReaderBuilder(...),
- lambda reader: pipe(reader, processor=my_proc))
- """
- def __init__(self, builder, piper):
- self._builder = builder
- self._piper = piper
- def schema(self):
- return self._builder.schema()
- def setup(self, **kwargs):
- return self._builder.setup(**kwargs)
- def new_reader(self, **kwargs):
- # Passing everything down since you could wrap a PipedReaderBuilder in
- # another PipedReaderBuilder
- output = self._piper(
- reader=self._builder.new_reader(**kwargs),
- **kwargs
- )
- return output if isinstance(output, Reader) else output.reader()
- class Pipe(object):
- def __init__(self, schema=None, obj_key=None):
- self._num_writers = 0
- self._num_readers = 0
- self._schema = schema
- self._obj_key = obj_key
- def schema(self):
- return self._schema
- def setup(self, global_init_net):
- pass
- def reader(self):
- raise NotImplementedError()
- def writer(self):
- raise NotImplementedError()
- def num_readers(self):
- return self._num_readers
- def num_writers(self):
- return self._num_writers
- def _new_writer(self, writer_schema, writer_init_net):
- if writer_schema is not None and self._schema is None:
- self._schema = writer_schema
- self._num_writers += 1
- if self._obj_key is not None:
- writer_init_net.add_attribute(self._obj_key, self)
- def _new_reader(self, reader_init_net):
- self._num_readers += 1
- if self._obj_key is not None:
- reader_init_net.add_attribute(self._obj_key, self)
- class CounterReader(Reader):
- """ Reader that produces increasing integers. """
- def __init__(self):
- Reader.__init__(self, schema=Struct(('iter', np.int64)))
- self.counter = None
- self.should_stop = None
- def setup_ex(self, global_init_net, global_finish_net):
- if self.counter is None:
- self.counter = global_init_net.CreateCounter([], init_count=0)
- self.should_stop = global_init_net.ConstantFill(
- [], shape=[], dtype=core.DataType.BOOL, value=False)
- def read_ex(self, local_init_net, local_finish_net):
- count_net = core.Net('limited_reader_counter')
- value = count_net.CountUp([self.counter], 1)
- return [count_net], self.should_stop, [value]
- class ReaderWithLimitBase(Reader):
- """Abstract Reader constrained by certain conditions.
- Base class for Reader classes which check for certain conditions to stop
- further processing (e.g. max number of iterations or time limit).
- Also produces a boolean blob (data_finished) that can be used to see if
- the reader exausted all input data (true) or stopped for another reason
- (false).
- """
- def __init__(self, reader):
- Reader.__init__(self, schema=reader._schema)
- self.reader = reader
- self.net = core.Net('reader_with_limit')
- self._data_finished = self.net.AddExternalInput(
- self.net.NextName('data_finished'))
- self.should_stop = None
- def setup_ex(self, global_init_net, global_finish_net):
- global_init_net.ConstantFill(
- [], [self._data_finished],
- shape=[], value=False, dtype=core.DataType.BOOL)
- self.reader.setup_ex(global_init_net, global_finish_net)
- self.setup_limiter(global_init_net, global_finish_net)
- def read_ex(self, local_init_net, local_finish_net):
- """Reads from an underlying Reader class, but may stop due to additional
- constraints.
- Build and return network(s) to read data from a Reader with
- additional constraints, depending on which derived class is used.
- Derived classes implement setup_limited and check_limiter_condition
- which determine the nature of the constraint imposed on the reader,
- e.g. iteration limits or time limit.
- Args:
- local_init_net: A net invoked at task instance init time (Once per
- parallel thread).
- local_finish_net: A net invoked at task instance cleanup time (Once
- per parallel thread).
- """
- # Check if limiting constraint is met.
- stop_condition_net = core.Net('limited_reader_condition')
- should_stop = self.check_limiter_condition(stop_condition_net)
- # Call original reader.
- nets, local_data_finished, fields = self.reader.read_ex(
- local_init_net, local_finish_net)
- self._set_schema(self.reader._schema)
- # Check if original reader is done.
- check_done_net = core.Net('limited_reader_post')
- # Copy to the same blob as the counter output to trigger reader
- # stopping - this is ok because execution will check should_stop_blob
- # after every single operation, so it has already been checked on this
- # iteration by this point.
- check_done_net.Copy(local_data_finished, should_stop)
- # Update externally-accessible flag indicating if reader is done
- check_done_net.Or([self._data_finished, local_data_finished],
- [self._data_finished])
- return [stop_condition_net] + nets + [check_done_net], should_stop, fields
- def setup_limiter(self, global_init_net, global_finish_net):
- """Configure task level init/cleanup nets required to implement limit
- condition. Must be implemented by subclass.
- Args:
- global_init_net: A net invoked at task init time.
- global_finish_net: A net invoked at task cleanup time.
- """
- raise NotImplementedError("Subclass must implement `setup_limiter`")
- def check_limiter_condition(self, stop_condition_net):
- """Configure a net that is invoked between reading batches to see if
- limit condition is met. Must be implemented by subclass.
- Args:
- stop_condition_net: A net invoked to evaluate an early termination
- condition.
- """
- raise NotImplementedError("Subclass must implement `check_limiter_condition")
- def data_finished(self):
- """
- Return a blob that can be checked after the end of the reading task,
- which will contain a scalar float indicating whether the underlying
- reader has been exhausted (True) or whether we stopped because reached
- the limit of iterations (False).
- """
- return self._data_finished
- class ReaderWithLimit(ReaderWithLimitBase):
- """Reader that stops after `num_iter` batches.
- If `num_iter` <= 0 or is None, reverts to an unconstrained reader that
- exports a boolean blob indicating that the reader has exhausted
- the data steam.
- """
- def __init__(self, reader, num_iter=1):
- """Class initializer.
- Args:
- reader: The underlying reader object doing the actual read.
- num_iter: Number of batches to read. If `None`,
- the class reverts to a normal reader except that it also
- produces a data_finished blob as a side effect to indicate
- whether the input stream is exhausted.
- """
- super(ReaderWithLimit, self).__init__(reader)
- self.counter = None
- self.num_iter = num_iter
- if self.num_iter is not None:
- self.counter = self.net.AddExternalInput(
- self.net.NextName('counter'))
- def setup_limiter(self, global_init_net, global_finish_net):
- if self.counter:
- global_init_net.CreateCounter(
- [], [self.counter], init_count=int(self.num_iter))
- def check_limiter_condition(self, stop_condition_net):
- if self.counter:
- return stop_condition_net.CountDown([self.counter], 1)
- else:
- return stop_condition_net.ConstantFill(
- [], 1,
- shape=[], value=False, dtype=core.DataType.BOOL)
- def CountUntil(num_iter):
- return ReaderWithLimit(CounterReader(), num_iter)
- class ReaderWithTimeLimit(ReaderWithLimitBase):
- """Reader that stops after `duration` seconds.
- If `duration` <= 0 or is None, reverts to an unconstrained reader that
- exports a boolean blob indicating that the reader has exhausted
- the data steam.
- """
- def __init__(self, reader, duration=0):
- """Class initializer.
- Args:
- reader: The underlying reader object doing the actual read.
- duration: Number of seconds to read. If un-specified, None, or <= 0,
- the class reverts to a normal reader except that it also
- produces a data_finished blob as a side effect to indicate
- whether the input stream is exhausted.
- """
- super(ReaderWithTimeLimit, self).__init__(reader)
- self.timer = None
- self.duration = duration
- self.duration_ns_blob = None
- def setup_limiter(self, global_init_net, global_finish_net):
- if self.duration is not None and self.duration > 0:
- duration_ns = int(self.duration * (10**9))
- self.timer = global_init_net.TimerBegin(
- [], counter_name='epoch_timer')
- start_time = global_init_net.TimerGet(self.timer)
- self.duration_ns_blob = global_init_net.ConstantFill(
- [start_time], value=duration_ns)
- global_finish_net.TimerEnd([self.timer], [])
- def check_limiter_condition(self, stop_condition_net):
- if self.duration:
- time_elapsed = stop_condition_net.TimerGet(self.timer)
- return stop_condition_net.GE(
- [time_elapsed, self.duration_ns_blob], str(self.should_stop))
- else:
- return stop_condition_net.ConstantFill(
- [], 1, shape=[], value=False, dtype=core.DataType.BOOL
- )
- class ReaderWithDelay(Reader):
- """Test reader class that inserts a delay between reading batches."""
- def __init__(self, reader, delay):
- Reader.__init__(self, schema=reader._schema)
- self.reader = reader
- self.delay = delay
- def setup_ex(self, global_init_net, global_finish_net):
- self.reader.setup_ex(global_init_net, global_finish_net)
- def read_ex(self, local_init_net, local_finish_net):
- read_net = core.Net("reader_body")
- def sleep_op(*args, **argd):
- time.sleep(self.delay)
- read_net.Python(sleep_op)([], [])
- return ([read_net],) + self.reader.read(read_net)
- class CompositeReader(Reader):
- """
- Base class for a reader that wrap multiple readers, e.g., reading from
- multiple sources simultaneously.
- """
- def __init__(self, names, readers):
- """
- Args:
- names: list[str] names of readers; used as schema keys
- readers: list[Reader] Reader instances, must have schema
- """
- assert len(names) == len(readers)
- super(CompositeReader, self).__init__(schema=Struct(*[
- (name, reader.schema()) for name, reader in zip(names, readers)
- ]))
- self._names = names
- self._readers = readers
- def setup_ex(self, init_net, finish_net):
- for reader in self._readers:
- reader.setup_ex(init_net, finish_net)
- def read_ex(self, local_init_net, local_finish_net):
- """
- Stops when one of the reader finished
- """
- # First, instantiate all the reader nets
- fields = []
- stop_blobs = []
- all_sub_read_nets = []
- for name, reader in zip(self._names, self._readers):
- sub_read_nets, should_stop, record = reader.read_record_ex(
- local_init_net, local_finish_net)
- stop_blobs.append(should_stop)
- all_sub_read_nets.append(sub_read_nets)
- fields.extend(record.field_blobs())
- read_nets = []
- # Use the stop blob of the last reader as stop blob of composite reader.
- local_should_stop = stop_blobs[-1]
- for name, sub_read_nets, stop_blob in zip(self._names, all_sub_read_nets, stop_blobs):
- read_nets.extend(sub_read_nets)
- if stop_blob == local_should_stop:
- # Skip adding stop net because Or([A, A], A) doesn't pass operator
- # schema check
- continue
- stop_net = core.Net("{}_stop".format(name))
- stop_net.Or([local_should_stop, stop_blob], local_should_stop)
- read_nets.append(stop_net)
- return read_nets, local_should_stop, fields
- def reset(self, net):
- for reader in self._readers:
- reader.reset(net)
- class CompositeReaderBuilder(ReaderBuilder):
- """
- A reader builder for CompositeReader
- """
- def __init__(self, names, reader_builders):
- """
- Args:
- names: list[str] names of readers; used as schema keys
- reader_builders: list[ReaderBuilder] ReaderBuilder instances;
- must have schema
- """
- super(CompositeReaderBuilder, self).__init__()
- self._names = names
- self._reader_builders = reader_builders
- self._schema = Struct(*[
- (name, reader_builder.schema())
- for name, reader_builder in zip(names, reader_builders)
- ])
- def schema(self):
- return self._schema
- def setup(self, **kwargs):
- data_finished_blobs = {}
- # limiter is stateful; it can only be used once. Since
- # CompositeReader stops when one of the reader stops,
- # this is fine.
- if "limiter" in kwargs:
- limiter = kwargs.pop("limiter")
- else:
- limiter = None
- for i, reader_builder in enumerate(self._reader_builders):
- if i == len(self._reader_builders) - 1 and limiter is not None:
- # The limiter must be applied to the last reader so that the
- # batch counter is incremented only if every reader has data
- kwargs["limiter"] = limiter
- sub_reader_data_finished_blobs = reader_builder.setup(**kwargs)
- overlapping_keys = set(data_finished_blobs.keys()) & set(sub_reader_data_finished_blobs.keys())
- overlapping_values = set(data_finished_blobs.values()) & set(sub_reader_data_finished_blobs.values())
- assert overlapping_keys == set(), "Overlapping keys: {}".format(overlapping_keys)
- assert overlapping_values == set(), "Overlapping values: {}".format(overlapping_values)
- data_finished_blobs.update(sub_reader_data_finished_blobs)
- return data_finished_blobs
- def new_reader(self, **kwargs):
- readers = []
- for reader_builder in self._reader_builders:
- reader = reader_builder.new_reader(**kwargs)
- if isinstance(reader, Reader):
- pass
- elif hasattr(reader, 'reader'):
- reader = reader.reader()
- else:
- raise ValueError('reader must be an instance of Reader or Pipe')
- readers.append(reader)
- multi_reader = CompositeReader(self._names, readers)
- assert multi_reader.schema() == self._schema
- return multi_reader
|