queue_util.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. ## @package queue_util
  2. # Module caffe2.python.queue_util
  3. from caffe2.python import core, dataio
  4. from caffe2.python.task import TaskGroup
  5. import logging
  6. logger = logging.getLogger(__name__)
  7. class _QueueReader(dataio.Reader):
  8. def __init__(self, wrapper, num_dequeue_records=1):
  9. assert wrapper.schema is not None, (
  10. 'Queue needs a schema in order to be read from.')
  11. dataio.Reader.__init__(self, wrapper.schema())
  12. self._wrapper = wrapper
  13. self._num_dequeue_records = num_dequeue_records
  14. def setup_ex(self, init_net, exit_net):
  15. exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
  16. def read_ex(self, local_init_net, local_finish_net):
  17. self._wrapper._new_reader(local_init_net)
  18. dequeue_net = core.Net('dequeue')
  19. fields, status_blob = dequeue(
  20. dequeue_net,
  21. self._wrapper.queue(),
  22. len(self.schema().field_names()),
  23. field_names=self.schema().field_names(),
  24. num_records=self._num_dequeue_records)
  25. return [dequeue_net], status_blob, fields
  26. def read(self, net):
  27. net, _, fields = self.read_ex(net, None)
  28. return net, fields
  29. class _QueueWriter(dataio.Writer):
  30. def __init__(self, wrapper):
  31. self._wrapper = wrapper
  32. def setup_ex(self, init_net, exit_net):
  33. exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
  34. def write_ex(self, fields, local_init_net, local_finish_net, status):
  35. self._wrapper._new_writer(self.schema(), local_init_net)
  36. enqueue_net = core.Net('enqueue')
  37. enqueue(enqueue_net, self._wrapper.queue(), fields, status)
  38. return [enqueue_net]
  39. class QueueWrapper(dataio.Pipe):
  40. def __init__(self, handler, schema=None, num_dequeue_records=1):
  41. dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP)
  42. self._queue = handler
  43. self._num_dequeue_records = num_dequeue_records
  44. def reader(self):
  45. return _QueueReader(
  46. self, num_dequeue_records=self._num_dequeue_records)
  47. def writer(self):
  48. return _QueueWriter(self)
  49. def queue(self):
  50. return self._queue
  51. class Queue(QueueWrapper):
  52. def __init__(self, capacity, schema=None, name='queue',
  53. num_dequeue_records=1):
  54. # find a unique blob name for the queue
  55. net = core.Net(name)
  56. queue_blob = net.AddExternalInput(net.NextName('handler'))
  57. QueueWrapper.__init__(
  58. self, queue_blob, schema, num_dequeue_records=num_dequeue_records)
  59. self.capacity = capacity
  60. self._setup_done = False
  61. def setup(self, global_init_net):
  62. assert self._schema, 'This queue does not have a schema.'
  63. self._setup_done = True
  64. global_init_net.CreateBlobsQueue(
  65. [],
  66. [self._queue],
  67. capacity=self.capacity,
  68. num_blobs=len(self._schema.field_names()),
  69. field_names=self._schema.field_names())
  70. def enqueue(net, queue, data_blobs, status=None):
  71. if status is None:
  72. status = net.NextName('status')
  73. # Enqueueing moved the data into the queue;
  74. # duplication will result in data corruption
  75. queue_blobs = []
  76. for blob in data_blobs:
  77. if blob not in queue_blobs:
  78. queue_blobs.append(blob)
  79. else:
  80. logger.warning("Need to copy blob {} to enqueue".format(blob))
  81. queue_blobs.append(net.Copy(blob))
  82. results = net.SafeEnqueueBlobs([queue] + queue_blobs, queue_blobs + [status])
  83. return results[-1]
  84. def dequeue(net, queue, num_blobs, status=None, field_names=None,
  85. num_records=1):
  86. if field_names is not None:
  87. assert len(field_names) == num_blobs
  88. data_names = [net.NextName(name) for name in field_names]
  89. else:
  90. data_names = [net.NextName('data', i) for i in range(num_blobs)]
  91. if status is None:
  92. status = net.NextName('status')
  93. results = net.SafeDequeueBlobs(
  94. queue, data_names + [status], num_records=num_records)
  95. results = list(results)
  96. status_blob = results.pop(-1)
  97. return results, status_blob
  98. def close_queue(step, *queues):
  99. close_net = core.Net("close_queue_net")
  100. for queue in queues:
  101. close_net.CloseBlobsQueue([queue], 0)
  102. close_step = core.execution_step("%s_step" % str(close_net), close_net)
  103. return core.execution_step(
  104. "%s_wraper_step" % str(close_net),
  105. [step, close_step])