| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- # @package parallel_workers
- # Module caffe2.python.parallel_workers
- '''
- This module provides a python-land multithreaded mechanism for executing work.
- Basic usage is as follows:
- coordinator = parallel_workers.init_workers(
- my_worker_fun,
- worker_name="train"
- )
- ...
- coordinator.start()
- First argument is the function to run in a loop on potentially multiple threads.
- It has the call signature
- worker_fun(worker_id)
- Argument 'worker_name' is used to distinguish different workers,
- such as workers processing train data or workers processing test data.
- Optionally, one can define an "init function" that is called once before
- threads start, and has call signature:
- my_init_fun(worker_coordinator, global_coordinator)
- Note that for data_parallel_models, init_workers will be called
- for each GPU. Note that the 'coordinator' returned by the function is same
- each time.
- '''
- import logging
- import threading
- import atexit
- import time
- import collections
- import traceback
- from abc import ABCMeta, abstractmethod
- log = logging.getLogger("parallel_workers")
- log.setLevel(logging.INFO)
- LOG_INT_SECS = 60
- def init_workers(
- worker_fun,
- num_worker_threads=2,
- worker_name="train",
- init_fun=None,
- external_loggers=None,
- shutdown_fun=None,
- ):
- global global_coordinator
- metrics = Metrics(external_loggers)
- worker_ids = [
- global_coordinator.get_new_worker_id()
- for i in range(num_worker_threads)
- ]
- # Create coordinator object
- coordinator = WorkerCoordinator(
- worker_name, worker_ids, init_fun, shutdown_fun=shutdown_fun)
- # Launch fetch worker threads
- workers = [
- threading.Thread(
- target=run_worker,
- name="parallel_workers worker id {}".format(worker_id),
- args=[coordinator,
- Worker(coordinator, worker_id, worker_fun, metrics)],
- ) for worker_id in worker_ids
- ]
- coordinator._workers = workers
- global_coordinator.add(coordinator)
- return global_coordinator
- class Metrics(object):
- def __init__(self, external_loggers):
- self._metrics = collections.defaultdict(lambda: 0)
- self._external_loggers = external_loggers
- def reset_metrics(self):
- self._metrics = collections.defaultdict(lambda: 0)
- def log_metrics(self):
- if not self._external_loggers:
- return
- for logger in self._external_loggers:
- try:
- logger.log(self._metrics)
- except Exception as e:
- print("Failed to call ExternalLogger: {}".format(e))
- def put_metric(self, key, value, count=True):
- self._metrics[key] += value
- if count:
- count_key = '{}_count'.format(key)
- self._metrics[count_key] += 1
- class State():
- __metaclass__ = ABCMeta
- @abstractmethod
- def start(self):
- pass
- @abstractmethod
- def stop(self):
- pass
- @abstractmethod
- def cleanup(self):
- pass
- class WorkerCoordinator(object):
- def __init__(
- self, worker_name, worker_ids, init_fun,
- state=None, shutdown_fun=None
- ):
- self._active = True
- self._started = False
- self._workers = []
- self._worker_name = worker_name
- self._worker_ids = worker_ids
- self._init_fun = init_fun
- self._state = state
- self._shutdown_fun = shutdown_fun
- def is_active(self):
- return self._active
- def init(self, global_coordinator):
- if self._init_fun and not self._started:
- data_coordinator = self
- self._init_fun(data_coordinator, global_coordinator)
- def _start(self):
- if self._started:
- return
- self._active = True
- self._started = True
- if self._state:
- self._state.start()
- for w in self._workers:
- w.daemon = True
- w.start()
- def _stop(self, reason=None):
- self._active = False
- if reason is not None:
- log.error("Data input failed due to an error: {}".format(reason))
- if self._shutdown_fun and self._started:
- self._shutdown_fun()
- if self._state:
- self._state.stop()
- self._started = False
- def _wait_finish(self, cleanup=None):
- print("Wait for workers to die: {}".format(self._worker_name))
- for w in self._workers:
- if w != threading.current_thread():
- w.join(5.0) # don't wait forever, thread may be blocked in i/o
- success = True
- for w in self._workers:
- if w.is_alive():
- print("Worker {} failed to close while waiting".format(w))
- success = False
- # Release memory for the scratch blobs
- if success and self._state:
- self._state.cleanup()
- print("All workers terminated: {}".format(success))
- return success
- def get_worker_ids(self):
- return self._worker_ids
- class GlobalWorkerCoordinator(object):
- def __init__(self):
- self._coordinators = []
- self._fetcher_id_seq = 0
- self._worker_ids = []
- self.register_shutdown_handler()
- def add(self, coordinator):
- self._coordinators.append(coordinator)
- def get_new_worker_id(self):
- worker_id = self._fetcher_id_seq
- self._worker_ids.append(worker_id)
- self._fetcher_id_seq += 1
- return worker_id
- def get_worker_ids(self):
- return self._worker_ids
- def start(self):
- # run init and start in separate for loop to
- # ensure init happens serially before threads are spawn.
- for c in self._coordinators:
- c.init(self)
- for c in self._coordinators:
- c._start()
- def stop(self):
- all_success = True
- for c in self._coordinators:
- c._stop()
- for c in self._coordinators:
- success = c._wait_finish()
- all_success = all_success and success
- self._coordinators = []
- return all_success
- def stop_coordinator(self, worker_name):
- '''
- Stop a specific coordinator
- '''
- for c in self._coordinators:
- if c._worker_name == worker_name:
- c._stop()
- c._wait_finish()
- self._coordinators = [
- c for c in self._coordinators
- if c._worker_name != worker_name
- ]
- def register_shutdown_handler(self):
- def cleanup():
- self.stop()
- atexit.register(cleanup)
- class Worker(object):
- def __init__(
- self,
- coordinator,
- worker_id,
- worker_fun=None,
- metrics=None
- ):
- self._coordinator = coordinator
- self._worker_id = worker_id
- self._worker_fun = worker_fun
- self._metrics = metrics
- def start(self):
- self._start_time = time.time()
- def run(self):
- self._worker_fun(self._worker_id)
- def handle_exception(self, e):
- traceback.print_exc()
- logging.exception("Exception in worker", e)
- self._coordinator._stop("Exception in worker {}: {}".format(
- self._worker_id, e
- ))
- def finish(self):
- self._metrics.put_metric(
- 'worker_time', time.time() - self._start_time)
- self._metrics.log_metrics()
- global_coordinator = GlobalWorkerCoordinator()
- def run_worker(coordinator, worker):
- while coordinator.is_active():
- worker.start()
- try:
- worker.run()
- except Exception as e:
- worker.handle_exception(e)
- finally:
- worker.finish()
|