| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- ## @package queue_util
- # Module caffe2.python.queue_util
- from caffe2.python import core, dataio
- from caffe2.python.task import TaskGroup
- import logging
- logger = logging.getLogger(__name__)
- class _QueueReader(dataio.Reader):
- def __init__(self, wrapper, num_dequeue_records=1):
- assert wrapper.schema is not None, (
- 'Queue needs a schema in order to be read from.')
- dataio.Reader.__init__(self, wrapper.schema())
- self._wrapper = wrapper
- self._num_dequeue_records = num_dequeue_records
- def setup_ex(self, init_net, exit_net):
- exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
- def read_ex(self, local_init_net, local_finish_net):
- self._wrapper._new_reader(local_init_net)
- dequeue_net = core.Net('dequeue')
- fields, status_blob = dequeue(
- dequeue_net,
- self._wrapper.queue(),
- len(self.schema().field_names()),
- field_names=self.schema().field_names(),
- num_records=self._num_dequeue_records)
- return [dequeue_net], status_blob, fields
- def read(self, net):
- net, _, fields = self.read_ex(net, None)
- return net, fields
- class _QueueWriter(dataio.Writer):
- def __init__(self, wrapper):
- self._wrapper = wrapper
- def setup_ex(self, init_net, exit_net):
- exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
- def write_ex(self, fields, local_init_net, local_finish_net, status):
- self._wrapper._new_writer(self.schema(), local_init_net)
- enqueue_net = core.Net('enqueue')
- enqueue(enqueue_net, self._wrapper.queue(), fields, status)
- return [enqueue_net]
- class QueueWrapper(dataio.Pipe):
- def __init__(self, handler, schema=None, num_dequeue_records=1):
- dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP)
- self._queue = handler
- self._num_dequeue_records = num_dequeue_records
- def reader(self):
- return _QueueReader(
- self, num_dequeue_records=self._num_dequeue_records)
- def writer(self):
- return _QueueWriter(self)
- def queue(self):
- return self._queue
- class Queue(QueueWrapper):
- def __init__(self, capacity, schema=None, name='queue',
- num_dequeue_records=1):
- # find a unique blob name for the queue
- net = core.Net(name)
- queue_blob = net.AddExternalInput(net.NextName('handler'))
- QueueWrapper.__init__(
- self, queue_blob, schema, num_dequeue_records=num_dequeue_records)
- self.capacity = capacity
- self._setup_done = False
- def setup(self, global_init_net):
- assert self._schema, 'This queue does not have a schema.'
- self._setup_done = True
- global_init_net.CreateBlobsQueue(
- [],
- [self._queue],
- capacity=self.capacity,
- num_blobs=len(self._schema.field_names()),
- field_names=self._schema.field_names())
- def enqueue(net, queue, data_blobs, status=None):
- if status is None:
- status = net.NextName('status')
- # Enqueueing moved the data into the queue;
- # duplication will result in data corruption
- queue_blobs = []
- for blob in data_blobs:
- if blob not in queue_blobs:
- queue_blobs.append(blob)
- else:
- logger.warning("Need to copy blob {} to enqueue".format(blob))
- queue_blobs.append(net.Copy(blob))
- results = net.SafeEnqueueBlobs([queue] + queue_blobs, queue_blobs + [status])
- return results[-1]
- def dequeue(net, queue, num_blobs, status=None, field_names=None,
- num_records=1):
- if field_names is not None:
- assert len(field_names) == num_blobs
- data_names = [net.NextName(name) for name in field_names]
- else:
- data_names = [net.NextName('data', i) for i in range(num_blobs)]
- if status is None:
- status = net.NextName('status')
- results = net.SafeDequeueBlobs(
- queue, data_names + [status], num_records=num_records)
- results = list(results)
- status_blob = results.pop(-1)
- return results, status_blob
- def close_queue(step, *queues):
- close_net = core.Net("close_queue_net")
- for queue in queues:
- close_net.CloseBlobsQueue([queue], 0)
- close_step = core.execution_step("%s_step" % str(close_net), close_net)
- return core.execution_step(
- "%s_wraper_step" % str(close_net),
- [step, close_step])
|