parallel_workers.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # @package parallel_workers
  2. # Module caffe2.python.parallel_workers
  3. '''
  4. This module provides a python-land multithreaded mechanism for executing work.
  5. Basic usage is as follows:
  6. coordinator = parallel_workers.init_workers(
  7. my_worker_fun,
  8. worker_name="train"
  9. )
  10. ...
  11. coordinator.start()
  12. First argument is the function to run in a loop on potentially multiple threads.
  13. It has the call signature
  14. worker_fun(worker_id)
  15. Argument 'worker_name' is used to distinguish different workers,
  16. such as workers processing train data or workers processing test data.
  17. Optionally, one can define an "init function" that is called once before
  18. threads start, and has call signature:
  19. my_init_fun(worker_coordinator, global_coordinator)
  20. Note that for data_parallel_models, init_workers will be called
  21. for each GPU. Note that the 'coordinator' returned by the function is same
  22. each time.
  23. '''
  24. import logging
  25. import threading
  26. import atexit
  27. import time
  28. import collections
  29. import traceback
  30. from abc import ABCMeta, abstractmethod
  31. log = logging.getLogger("parallel_workers")
  32. log.setLevel(logging.INFO)
  33. LOG_INT_SECS = 60
  34. def init_workers(
  35. worker_fun,
  36. num_worker_threads=2,
  37. worker_name="train",
  38. init_fun=None,
  39. external_loggers=None,
  40. shutdown_fun=None,
  41. ):
  42. global global_coordinator
  43. metrics = Metrics(external_loggers)
  44. worker_ids = [
  45. global_coordinator.get_new_worker_id()
  46. for i in range(num_worker_threads)
  47. ]
  48. # Create coordinator object
  49. coordinator = WorkerCoordinator(
  50. worker_name, worker_ids, init_fun, shutdown_fun=shutdown_fun)
  51. # Launch fetch worker threads
  52. workers = [
  53. threading.Thread(
  54. target=run_worker,
  55. name="parallel_workers worker id {}".format(worker_id),
  56. args=[coordinator,
  57. Worker(coordinator, worker_id, worker_fun, metrics)],
  58. ) for worker_id in worker_ids
  59. ]
  60. coordinator._workers = workers
  61. global_coordinator.add(coordinator)
  62. return global_coordinator
  63. class Metrics(object):
  64. def __init__(self, external_loggers):
  65. self._metrics = collections.defaultdict(lambda: 0)
  66. self._external_loggers = external_loggers
  67. def reset_metrics(self):
  68. self._metrics = collections.defaultdict(lambda: 0)
  69. def log_metrics(self):
  70. if not self._external_loggers:
  71. return
  72. for logger in self._external_loggers:
  73. try:
  74. logger.log(self._metrics)
  75. except Exception as e:
  76. print("Failed to call ExternalLogger: {}".format(e))
  77. def put_metric(self, key, value, count=True):
  78. self._metrics[key] += value
  79. if count:
  80. count_key = '{}_count'.format(key)
  81. self._metrics[count_key] += 1
  82. class State():
  83. __metaclass__ = ABCMeta
  84. @abstractmethod
  85. def start(self):
  86. pass
  87. @abstractmethod
  88. def stop(self):
  89. pass
  90. @abstractmethod
  91. def cleanup(self):
  92. pass
  93. class WorkerCoordinator(object):
  94. def __init__(
  95. self, worker_name, worker_ids, init_fun,
  96. state=None, shutdown_fun=None
  97. ):
  98. self._active = True
  99. self._started = False
  100. self._workers = []
  101. self._worker_name = worker_name
  102. self._worker_ids = worker_ids
  103. self._init_fun = init_fun
  104. self._state = state
  105. self._shutdown_fun = shutdown_fun
  106. def is_active(self):
  107. return self._active
  108. def init(self, global_coordinator):
  109. if self._init_fun and not self._started:
  110. data_coordinator = self
  111. self._init_fun(data_coordinator, global_coordinator)
  112. def _start(self):
  113. if self._started:
  114. return
  115. self._active = True
  116. self._started = True
  117. if self._state:
  118. self._state.start()
  119. for w in self._workers:
  120. w.daemon = True
  121. w.start()
  122. def _stop(self, reason=None):
  123. self._active = False
  124. if reason is not None:
  125. log.error("Data input failed due to an error: {}".format(reason))
  126. if self._shutdown_fun and self._started:
  127. self._shutdown_fun()
  128. if self._state:
  129. self._state.stop()
  130. self._started = False
  131. def _wait_finish(self, cleanup=None):
  132. print("Wait for workers to die: {}".format(self._worker_name))
  133. for w in self._workers:
  134. if w != threading.current_thread():
  135. w.join(5.0) # don't wait forever, thread may be blocked in i/o
  136. success = True
  137. for w in self._workers:
  138. if w.is_alive():
  139. print("Worker {} failed to close while waiting".format(w))
  140. success = False
  141. # Release memory for the scratch blobs
  142. if success and self._state:
  143. self._state.cleanup()
  144. print("All workers terminated: {}".format(success))
  145. return success
  146. def get_worker_ids(self):
  147. return self._worker_ids
  148. class GlobalWorkerCoordinator(object):
  149. def __init__(self):
  150. self._coordinators = []
  151. self._fetcher_id_seq = 0
  152. self._worker_ids = []
  153. self.register_shutdown_handler()
  154. def add(self, coordinator):
  155. self._coordinators.append(coordinator)
  156. def get_new_worker_id(self):
  157. worker_id = self._fetcher_id_seq
  158. self._worker_ids.append(worker_id)
  159. self._fetcher_id_seq += 1
  160. return worker_id
  161. def get_worker_ids(self):
  162. return self._worker_ids
  163. def start(self):
  164. # run init and start in separate for loop to
  165. # ensure init happens serially before threads are spawn.
  166. for c in self._coordinators:
  167. c.init(self)
  168. for c in self._coordinators:
  169. c._start()
  170. def stop(self):
  171. all_success = True
  172. for c in self._coordinators:
  173. c._stop()
  174. for c in self._coordinators:
  175. success = c._wait_finish()
  176. all_success = all_success and success
  177. self._coordinators = []
  178. return all_success
  179. def stop_coordinator(self, worker_name):
  180. '''
  181. Stop a specific coordinator
  182. '''
  183. for c in self._coordinators:
  184. if c._worker_name == worker_name:
  185. c._stop()
  186. c._wait_finish()
  187. self._coordinators = [
  188. c for c in self._coordinators
  189. if c._worker_name != worker_name
  190. ]
  191. def register_shutdown_handler(self):
  192. def cleanup():
  193. self.stop()
  194. atexit.register(cleanup)
  195. class Worker(object):
  196. def __init__(
  197. self,
  198. coordinator,
  199. worker_id,
  200. worker_fun=None,
  201. metrics=None
  202. ):
  203. self._coordinator = coordinator
  204. self._worker_id = worker_id
  205. self._worker_fun = worker_fun
  206. self._metrics = metrics
  207. def start(self):
  208. self._start_time = time.time()
  209. def run(self):
  210. self._worker_fun(self._worker_id)
  211. def handle_exception(self, e):
  212. traceback.print_exc()
  213. logging.exception("Exception in worker", e)
  214. self._coordinator._stop("Exception in worker {}: {}".format(
  215. self._worker_id, e
  216. ))
  217. def finish(self):
  218. self._metrics.put_metric(
  219. 'worker_time', time.time() - self._start_time)
  220. self._metrics.log_metrics()
  221. global_coordinator = GlobalWorkerCoordinator()
  222. def run_worker(coordinator, worker):
  223. while coordinator.is_active():
  224. worker.start()
  225. try:
  226. worker.run()
  227. except Exception as e:
  228. worker.handle_exception(e)
  229. finally:
  230. worker.finish()