data_workers.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. ## @package data_workers
  2. # Module caffe2.python.data_workers
  3. '''
  4. This module provides a python-land multithreaded data input mechanism
  5. for Caffe2 nets.
  6. Basic usage is as follows:
  7. coordinator = data_workers.init_data_input_workers(
  8. net,
  9. ["data", "label"],
  10. my_fetch_fun,
  11. batch_size=32,
  12. input_source_name="train",
  13. dont_rebatch=False
  14. )
  15. ...
  16. coordinator.start()
  17. First argument is the Caffe2 net (or model helper), and second argument
  18. is list of input blobs that are to be fed.
  19. Argument 'input_source_name' is used to distinguish different sources of data,
  20. such as train or test data. This is to ensure the data does not get mixed up,
  21. although two nets would share blobs.
  22. To do the actual data loading, one defines a "fetcher function"
  23. that has call signature
  24. my_fetch_fun(worker_id, batch_size)
  25. Optionally, one can define a "init function" that is called once before
  26. threads start, and has call signature:
  27. my_init_fun(data_coordinator, global_coordinator)
  28. If dont_rebatch is set to True, the data input is not batched into equal sized
  29. chunks but data directly provided by fetchers is used.
  30. 'batch_columns' can be used to specify which dimension is the batch dimension,
  31. for each of the inputs. Default is 0 for all iputs.
  32. 'timeout' is the timeout in seconds after which if no data is available, the
  33. net will fail (default 600s = 10 mins).
  34. This function returns a list of numpy arrays corresponding to the different
  35. input blobs. In the example above, it would return two arrays, one for the
  36. data blob and another for the labels. These arrays can have arbitrary number
  37. of elements (i.e they do not need to match the batch size). The batch size
  38. is provided for the function as a hint only.
  39. For example, fetcher function could download images from a remote service or
  40. load random images from a directory on a file system.
  41. For a dummy example, see the data_workers_test unit test.
  42. Note that for data_parallel_models, init_data_input_workers will be called
  43. for each GPU. Note that the 'coordinator' returned by the function is same
  44. each time.
  45. '''
  46. import queue as Queue
  47. from itertools import chain
  48. import logging
  49. import threading
  50. import numpy as np
  51. import time
  52. from caffe2.python import workspace, core, scope, utils
  53. from caffe2.proto import caffe2_pb2
  54. from caffe2.python.parallel_workers import Metrics, State, \
  55. WorkerCoordinator, GlobalWorkerCoordinator, Worker, run_worker
  56. log = logging.getLogger("data_workers")
  57. log.setLevel(logging.INFO)
  58. LOG_INT_SECS = 60
  59. def get_worker_ids(num_workers):
  60. return list(range(0, num_workers))
  61. def init_data_input_workers(
  62. net,
  63. input_blob_names,
  64. fetch_fun,
  65. batch_size,
  66. num_worker_threads=2,
  67. input_source_name="train",
  68. max_buffered_batches=800,
  69. init_fun=None,
  70. external_loggers=None,
  71. dont_rebatch=False,
  72. batch_columns=None,
  73. timeout=600
  74. ):
  75. global global_coordinator
  76. device_option = scope.CurrentDeviceScope()
  77. if (device_option is None):
  78. device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
  79. metrics = Metrics(external_loggers)
  80. batch_feeder = BatchFeeder(
  81. net,
  82. input_blob_names,
  83. batch_size,
  84. device_option,
  85. scope.CurrentNameScope(),
  86. input_source_name,
  87. global_coordinator.get_queue(input_source_name, max_buffered_batches),
  88. metrics,
  89. dont_rebatch,
  90. batch_columns,
  91. timeout=timeout
  92. )
  93. # Launch fetch worker threads
  94. worker_ids = [
  95. global_coordinator.get_new_worker_id()
  96. for i in range(num_worker_threads)
  97. ]
  98. # Create coordinator object
  99. coordinator = WorkerCoordinator(
  100. input_source_name, worker_ids, init_fun, batch_feeder)
  101. workers = [
  102. threading.Thread(
  103. target=run_worker,
  104. name="data_workers fetcher id {}".format(worker_id),
  105. args=[coordinator,
  106. DataWorker(coordinator, worker_id, fetch_fun, metrics,
  107. batch_size, batch_feeder)],
  108. ) for worker_id in worker_ids
  109. ]
  110. workers.append(threading.Thread(
  111. target=enqueuer,
  112. name="Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
  113. args=[coordinator, batch_feeder]))
  114. coordinator._workers = workers
  115. global_coordinator.add(coordinator)
  116. return global_coordinator
  117. class BatchFeeder(State):
  118. def __init__(self, net, input_blob_names, batch_size,
  119. device_option, namescope, input_source_name, queue,
  120. metrics, dont_rebatch, batch_columns, timeout=600):
  121. self._counter = 0
  122. self._input_blob_names = input_blob_names
  123. self._batch_size = batch_size
  124. self._internal_queue = queue
  125. self._queues = []
  126. self._device_option = device_option
  127. self._namescope = namescope
  128. self._timeout = timeout
  129. self._input_source_name = input_source_name
  130. self._c2_queue_capacity = 4
  131. self._create_caffe2_queues(net)
  132. self._create_caffe2_ops(net)
  133. self._inputs = 0
  134. self._prev_seconds = 0
  135. self._last_warning = time.time()
  136. self._dont_rebatch = dont_rebatch
  137. self._init_scratch()
  138. self._metrics = metrics
  139. if batch_columns is None:
  140. batch_columns = [0 for _ in input_blob_names]
  141. self._batch_columns = batch_columns
  142. def start(self):
  143. self._inputs = 0
  144. self._prev_seconds = time.time()
  145. def stop(self):
  146. try:
  147. for q in self._queues:
  148. workspace.RunOperatorOnce(
  149. core.CreateOperator("CloseBlobsQueue", [q], [])
  150. )
  151. finally:
  152. self._log_inputs_per_interval(0, force=True)
  153. def cleanup(self):
  154. utils.ResetBlobs(self._scratch_blob.values())
  155. utils.ResetBlobs(self._scratch_status.values())
  156. def _get(self, data_input_coordinator):
  157. start_time = time.time()
  158. last_warning = time.time()
  159. while data_input_coordinator.is_active():
  160. try:
  161. return self._internal_queue.get(block=True, timeout=0.5)
  162. except Queue.Empty:
  163. if time.time() - last_warning > 10.0:
  164. log.warning("** Data input is slow: (still) no data in {} secs.".format(
  165. time.time() - start_time))
  166. last_warning = time.time()
  167. continue
  168. return None
  169. def _validate_chunk(self, chunk):
  170. if chunk is None:
  171. log.warning("Fetcher function returned None")
  172. return False
  173. assert len(chunk) == len(self._input_blob_names), \
  174. "Expecting data blob for each input"
  175. for d in chunk:
  176. assert isinstance(d, np.ndarray), \
  177. "Fetcher function must return a numpy array"
  178. if not self._dont_rebatch:
  179. j = 1
  180. for d in chunk[1:]:
  181. assert d.shape[self._batch_columns[j]] == \
  182. chunk[0].shape[self._batch_columns[0]], \
  183. "Each returned input must have equal number of samples"
  184. j += 1
  185. if len(chunk) == 0:
  186. log.warning("Worker provided zero length input")
  187. return False
  188. return True
  189. def put(self, chunk, data_input_coordinator):
  190. if not self._validate_chunk(chunk):
  191. return
  192. while data_input_coordinator.is_active():
  193. try:
  194. qsize = self._internal_queue.qsize()
  195. if qsize < 2 and (time.time() - self._last_warning) > LOG_INT_SECS:
  196. log.warning("Warning, data loading lagging behind: " +
  197. "queue size={}, name={}".format(qsize, self._input_source_name))
  198. self._last_warning = time.time()
  199. self._counter += 1
  200. self._internal_queue.put(chunk, block=True, timeout=0.5)
  201. self._log_inputs_per_interval(chunk[0].shape[0])
  202. return
  203. except Queue.Full:
  204. log.debug("Queue full: stalling fetchers...")
  205. continue
  206. def _enqueue_batch_direct(self, data_input_coordinator):
  207. data = self._get(data_input_coordinator)
  208. if data is None:
  209. return
  210. if data_input_coordinator.is_active():
  211. for b, q, c in zip(self._input_blob_names, self._queues, data):
  212. self._enqueue(b, q, c)
  213. def _enqueue_batch(self, data_input_coordinator):
  214. '''
  215. This pulls data from the python-side queue and collects them
  216. into batch-sized pieces, unless dont_rebatch is set to true.
  217. '''
  218. if self._dont_rebatch:
  219. self._enqueue_batch_direct(data_input_coordinator)
  220. return
  221. cur_batch = [np.array([]) for d in self._input_blob_names]
  222. first_batch_col = self._batch_columns[0]
  223. # Collect data until we have a full batch size
  224. while (
  225. cur_batch[0].shape[0] == 0 or
  226. cur_batch[0].shape[first_batch_col] < self._batch_size
  227. ) and data_input_coordinator.is_active():
  228. chunk = self._get(data_input_coordinator)
  229. if chunk is None:
  230. continue
  231. for j, chunk_elem in enumerate(chunk):
  232. if cur_batch[j].shape[0] == 0:
  233. cur_batch[j] = chunk_elem.copy()
  234. else:
  235. cur_batch[j] = np.append(
  236. cur_batch[j], chunk_elem, axis=self._batch_columns[j]
  237. )
  238. start_time = time.time()
  239. try:
  240. # Return data over the batch size back to queue
  241. if cur_batch[0].shape[0] > 0 and cur_batch[0].shape[
  242. first_batch_col
  243. ] > self._batch_size:
  244. leftover = []
  245. trimmed_batch = []
  246. for j, b in enumerate(cur_batch):
  247. [c, l] = np.split(
  248. b, [self._batch_size], axis=self._batch_columns[j]
  249. )
  250. leftover.append(l)
  251. trimmed_batch.append(c)
  252. cur_batch = trimmed_batch
  253. try:
  254. self._internal_queue.put(leftover, block=False)
  255. except Queue.Full:
  256. pass
  257. assert cur_batch[0].shape[first_batch_col] == self._batch_size
  258. if data_input_coordinator.is_active():
  259. for b, q, c in zip(
  260. self._input_blob_names, self._queues, cur_batch
  261. ):
  262. self._enqueue(b, q, c)
  263. finally:
  264. self._metrics.put_metric('enqueue_time', time.time() - start_time)
  265. def _init_scratch(self):
  266. self._scratch_blob = {}
  267. self._scratch_status = {}
  268. for blob_name in self._input_blob_names:
  269. scratch_name = self._namescope + blob_name + \
  270. "_scratch_" + self._input_source_name
  271. self._scratch_blob[blob_name] = core.BlobReference(scratch_name)
  272. self._scratch_status[blob_name] = core.BlobReference(
  273. scratch_name + "_status"
  274. )
  275. # Feed empty arrays to the scratch blobs here, so that there won't be
  276. # race conditions when calling FeedBlob (which calls wworkspace
  277. # CreateBlob()) from enqueue threads
  278. for b in chain(
  279. self._scratch_blob.values(), self._scratch_status.values()
  280. ):
  281. workspace.FeedBlob(
  282. b,
  283. np.array([]).astype(np.float32),
  284. device_option=self._device_option,
  285. )
  286. def _enqueue(self, blob_name, queue, data_arr):
  287. '''
  288. Enqueue the correctly sized batch arrays to Caffe2's queue.
  289. '''
  290. workspace.FeedBlob(
  291. self._scratch_blob[blob_name],
  292. data_arr,
  293. device_option=self._device_option
  294. )
  295. op = core.CreateOperator(
  296. "SafeEnqueueBlobs",
  297. [queue, self._scratch_blob[blob_name]],
  298. [self._scratch_blob[blob_name], self._scratch_status[blob_name]],
  299. device_option=self._device_option
  300. )
  301. workspace.RunOperatorOnce(op)
  302. def _create_caffe2_queues(self, net):
  303. '''
  304. Creates queues on caffe2 side
  305. '''
  306. def create_queue(queue_name, num_blobs, capacity):
  307. workspace.RunOperatorOnce(
  308. core.CreateOperator(
  309. "CreateBlobsQueue",
  310. [], [queue_name],
  311. num_blobs=1,
  312. capacity=capacity))
  313. return core.ScopedBlobReference(queue_name)
  314. for blob_name in self._input_blob_names:
  315. qname = blob_name + "_c2queue" + "_" + self._input_source_name
  316. q = create_queue(
  317. qname, num_blobs=1, capacity=self._c2_queue_capacity
  318. )
  319. self._queues.append(q)
  320. def _create_caffe2_ops(self, net):
  321. '''
  322. Creates dequeue-ops on caffe2 side
  323. '''
  324. for q, blob_name in zip(self._queues, self._input_blob_names):
  325. # Add operator to the Caffe2 network to dequeue
  326. net.DequeueBlobs(q, blob_name, timeout_secs=float(self._timeout))
  327. def _log_inputs_per_interval(self, inputs, force=False):
  328. self._inputs += inputs
  329. current_seconds = time.time()
  330. delta_seconds = current_seconds - self._prev_seconds
  331. if delta_seconds >= LOG_INT_SECS or force:
  332. inputs_per_sec = int(self._inputs / delta_seconds)
  333. qsize = self._internal_queue.qsize()
  334. log.info("{}/{}: {} inputs/sec".format(
  335. self._input_source_name,
  336. self._namescope,
  337. inputs_per_sec,
  338. ))
  339. log.info("-- queue: {} batches".format(qsize))
  340. # log and reset perf metrics
  341. self._metrics.put_metric(
  342. 'inputs_per_sec', inputs_per_sec, False)
  343. self._metrics.put_metric('queue_size', qsize, False)
  344. self._metrics.put_metric(
  345. 'time_elapsed', delta_seconds, False)
  346. self._metrics.log_metrics()
  347. self._metrics.reset_metrics()
  348. self._inputs = 0
  349. self._prev_seconds = current_seconds
  350. class GlobalCoordinator(GlobalWorkerCoordinator):
  351. def __init__(self):
  352. GlobalWorkerCoordinator.__init__(self)
  353. self._queues = {}
  354. def get_queue(self, queue_name, max_buffered_batches):
  355. assert isinstance(max_buffered_batches, int)
  356. if queue_name not in self._queues:
  357. self._queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
  358. return self._queues[queue_name]
  359. def reset_data_input(self, namescope, name, net, batch_size):
  360. log.info("Reset data input {}, batch size {}: ".format(name, batch_size))
  361. for c in self._coordinators:
  362. if c._worker_name == name and c._state._namescope == namescope:
  363. c._state._batch_size = batch_size
  364. c._state._create_caffe2_ops(net)
  365. class DataWorker(Worker):
  366. def __init__(
  367. self,
  368. coordinator,
  369. worker_id,
  370. worker_fun,
  371. metrics,
  372. batch_size,
  373. batch_feeder
  374. ):
  375. Worker.__init__(self, coordinator, worker_id, worker_fun=worker_fun,
  376. metrics=metrics)
  377. self._batch_size = batch_size
  378. self._batch_feeder = batch_feeder
  379. def run(self):
  380. input_data = self._worker_fun(self._worker_id, self._batch_size)
  381. self._batch_feeder.put(input_data, self._coordinator)
  382. def finish(self):
  383. self._metrics.put_metric(
  384. 'fetcher_time', time.time() - self._start_time)
  385. global_coordinator = GlobalCoordinator()
  386. def enqueuer(coordinator, batch_feeder):
  387. while coordinator.is_active():
  388. batch_feeder._enqueue_batch(coordinator)