forkserver.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import errno
  2. import os
  3. import selectors
  4. import signal
  5. import socket
  6. import struct
  7. import sys
  8. import threading
  9. import warnings
  10. from . import connection
  11. from . import process
  12. from .context import reduction
  13. from . import semaphore_tracker
  14. from . import spawn
  15. from . import util
  16. __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
  17. 'set_forkserver_preload']
  18. #
  19. #
  20. #
  21. MAXFDS_TO_SEND = 256
  22. SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
  23. #
  24. # Forkserver class
  25. #
  26. class ForkServer(object):
  27. def __init__(self):
  28. self._forkserver_address = None
  29. self._forkserver_alive_fd = None
  30. self._forkserver_pid = None
  31. self._inherited_fds = None
  32. self._lock = threading.Lock()
  33. self._preload_modules = ['__main__']
  34. def _stop(self):
  35. # Method used by unit tests to stop the server
  36. with self._lock:
  37. self._stop_unlocked()
  38. def _stop_unlocked(self):
  39. if self._forkserver_pid is None:
  40. return
  41. # close the "alive" file descriptor asks the server to stop
  42. os.close(self._forkserver_alive_fd)
  43. self._forkserver_alive_fd = None
  44. os.waitpid(self._forkserver_pid, 0)
  45. self._forkserver_pid = None
  46. os.unlink(self._forkserver_address)
  47. self._forkserver_address = None
  48. def set_forkserver_preload(self, modules_names):
  49. '''Set list of module names to try to load in forkserver process.'''
  50. if not all(type(mod) is str for mod in self._preload_modules):
  51. raise TypeError('module_names must be a list of strings')
  52. self._preload_modules = modules_names
  53. def get_inherited_fds(self):
  54. '''Return list of fds inherited from parent process.
  55. This returns None if the current process was not started by fork
  56. server.
  57. '''
  58. return self._inherited_fds
  59. def connect_to_new_process(self, fds):
  60. '''Request forkserver to create a child process.
  61. Returns a pair of fds (status_r, data_w). The calling process can read
  62. the child process's pid and (eventually) its returncode from status_r.
  63. The calling process should write to data_w the pickled preparation and
  64. process data.
  65. '''
  66. self.ensure_running()
  67. if len(fds) + 4 >= MAXFDS_TO_SEND:
  68. raise ValueError('too many fds')
  69. with socket.socket(socket.AF_UNIX) as client:
  70. client.connect(self._forkserver_address)
  71. parent_r, child_w = os.pipe()
  72. child_r, parent_w = os.pipe()
  73. allfds = [child_r, child_w, self._forkserver_alive_fd,
  74. semaphore_tracker.getfd()]
  75. allfds += fds
  76. try:
  77. reduction.sendfds(client, allfds)
  78. return parent_r, parent_w
  79. except:
  80. os.close(parent_r)
  81. os.close(parent_w)
  82. raise
  83. finally:
  84. os.close(child_r)
  85. os.close(child_w)
  86. def ensure_running(self):
  87. '''Make sure that a fork server is running.
  88. This can be called from any process. Note that usually a child
  89. process will just reuse the forkserver started by its parent, so
  90. ensure_running() will do nothing.
  91. '''
  92. with self._lock:
  93. semaphore_tracker.ensure_running()
  94. if self._forkserver_pid is not None:
  95. # forkserver was launched before, is it still running?
  96. pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
  97. if not pid:
  98. # still alive
  99. return
  100. # dead, launch it again
  101. os.close(self._forkserver_alive_fd)
  102. self._forkserver_address = None
  103. self._forkserver_alive_fd = None
  104. self._forkserver_pid = None
  105. cmd = ('from multiprocessing.forkserver import main; ' +
  106. 'main(%d, %d, %r, **%r)')
  107. if self._preload_modules:
  108. desired_keys = {'main_path', 'sys_path'}
  109. data = spawn.get_preparation_data('ignore')
  110. data = {x: y for x, y in data.items() if x in desired_keys}
  111. else:
  112. data = {}
  113. with socket.socket(socket.AF_UNIX) as listener:
  114. address = connection.arbitrary_address('AF_UNIX')
  115. listener.bind(address)
  116. if not util.is_abstract_socket_namespace(address):
  117. os.chmod(address, 0o600)
  118. listener.listen()
  119. # all client processes own the write end of the "alive" pipe;
  120. # when they all terminate the read end becomes ready.
  121. alive_r, alive_w = os.pipe()
  122. try:
  123. fds_to_pass = [listener.fileno(), alive_r]
  124. cmd %= (listener.fileno(), alive_r, self._preload_modules,
  125. data)
  126. exe = spawn.get_executable()
  127. args = [exe] + util._args_from_interpreter_flags()
  128. args += ['-c', cmd]
  129. pid = util.spawnv_passfds(exe, args, fds_to_pass)
  130. except:
  131. os.close(alive_w)
  132. raise
  133. finally:
  134. os.close(alive_r)
  135. self._forkserver_address = address
  136. self._forkserver_alive_fd = alive_w
  137. self._forkserver_pid = pid
  138. #
  139. #
  140. #
  141. def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
  142. '''Run forkserver.'''
  143. if preload:
  144. if '__main__' in preload and main_path is not None:
  145. process.current_process()._inheriting = True
  146. try:
  147. spawn.import_main_path(main_path)
  148. finally:
  149. del process.current_process()._inheriting
  150. for modname in preload:
  151. try:
  152. __import__(modname)
  153. except ImportError:
  154. pass
  155. util._close_stdin()
  156. sig_r, sig_w = os.pipe()
  157. os.set_blocking(sig_r, False)
  158. os.set_blocking(sig_w, False)
  159. def sigchld_handler(*_unused):
  160. # Dummy signal handler, doesn't do anything
  161. pass
  162. handlers = {
  163. # unblocking SIGCHLD allows the wakeup fd to notify our event loop
  164. signal.SIGCHLD: sigchld_handler,
  165. # protect the process from ^C
  166. signal.SIGINT: signal.SIG_IGN,
  167. }
  168. old_handlers = {sig: signal.signal(sig, val)
  169. for (sig, val) in handlers.items()}
  170. # calling os.write() in the Python signal handler is racy
  171. signal.set_wakeup_fd(sig_w)
  172. # map child pids to client fds
  173. pid_to_fd = {}
  174. with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
  175. selectors.DefaultSelector() as selector:
  176. _forkserver._forkserver_address = listener.getsockname()
  177. selector.register(listener, selectors.EVENT_READ)
  178. selector.register(alive_r, selectors.EVENT_READ)
  179. selector.register(sig_r, selectors.EVENT_READ)
  180. while True:
  181. try:
  182. while True:
  183. rfds = [key.fileobj for (key, events) in selector.select()]
  184. if rfds:
  185. break
  186. if alive_r in rfds:
  187. # EOF because no more client processes left
  188. assert os.read(alive_r, 1) == b'', "Not at EOF?"
  189. raise SystemExit
  190. if sig_r in rfds:
  191. # Got SIGCHLD
  192. os.read(sig_r, 65536) # exhaust
  193. while True:
  194. # Scan for child processes
  195. try:
  196. pid, sts = os.waitpid(-1, os.WNOHANG)
  197. except ChildProcessError:
  198. break
  199. if pid == 0:
  200. break
  201. child_w = pid_to_fd.pop(pid, None)
  202. if child_w is not None:
  203. if os.WIFSIGNALED(sts):
  204. returncode = -os.WTERMSIG(sts)
  205. else:
  206. if not os.WIFEXITED(sts):
  207. raise AssertionError(
  208. "Child {0:n} status is {1:n}".format(
  209. pid,sts))
  210. returncode = os.WEXITSTATUS(sts)
  211. # Send exit code to client process
  212. try:
  213. write_signed(child_w, returncode)
  214. except BrokenPipeError:
  215. # client vanished
  216. pass
  217. os.close(child_w)
  218. else:
  219. # This shouldn't happen really
  220. warnings.warn('forkserver: waitpid returned '
  221. 'unexpected pid %d' % pid)
  222. if listener in rfds:
  223. # Incoming fork request
  224. with listener.accept()[0] as s:
  225. # Receive fds from client
  226. fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
  227. if len(fds) > MAXFDS_TO_SEND:
  228. raise RuntimeError(
  229. "Too many ({0:n}) fds to send".format(
  230. len(fds)))
  231. child_r, child_w, *fds = fds
  232. s.close()
  233. pid = os.fork()
  234. if pid == 0:
  235. # Child
  236. code = 1
  237. try:
  238. listener.close()
  239. selector.close()
  240. unused_fds = [alive_r, child_w, sig_r, sig_w]
  241. unused_fds.extend(pid_to_fd.values())
  242. code = _serve_one(child_r, fds,
  243. unused_fds,
  244. old_handlers)
  245. except Exception:
  246. sys.excepthook(*sys.exc_info())
  247. sys.stderr.flush()
  248. finally:
  249. os._exit(code)
  250. else:
  251. # Send pid to client process
  252. try:
  253. write_signed(child_w, pid)
  254. except BrokenPipeError:
  255. # client vanished
  256. pass
  257. pid_to_fd[pid] = child_w
  258. os.close(child_r)
  259. for fd in fds:
  260. os.close(fd)
  261. except OSError as e:
  262. if e.errno != errno.ECONNABORTED:
  263. raise
  264. def _serve_one(child_r, fds, unused_fds, handlers):
  265. # close unnecessary stuff and reset signal handlers
  266. signal.set_wakeup_fd(-1)
  267. for sig, val in handlers.items():
  268. signal.signal(sig, val)
  269. for fd in unused_fds:
  270. os.close(fd)
  271. (_forkserver._forkserver_alive_fd,
  272. semaphore_tracker._semaphore_tracker._fd,
  273. *_forkserver._inherited_fds) = fds
  274. # Run process object received over pipe
  275. code = spawn._main(child_r)
  276. return code
  277. #
  278. # Read and write signed numbers
  279. #
  280. def read_signed(fd):
  281. data = b''
  282. length = SIGNED_STRUCT.size
  283. while len(data) < length:
  284. s = os.read(fd, length - len(data))
  285. if not s:
  286. raise EOFError('unexpected EOF')
  287. data += s
  288. return SIGNED_STRUCT.unpack(data)[0]
  289. def write_signed(fd, n):
  290. msg = SIGNED_STRUCT.pack(n)
  291. while msg:
  292. nbytes = os.write(fd, msg)
  293. if nbytes == 0:
  294. raise RuntimeError('should not get here')
  295. msg = msg[nbytes:]
  296. #
  297. #
  298. #
  299. _forkserver = ForkServer()
  300. ensure_running = _forkserver.ensure_running
  301. get_inherited_fds = _forkserver.get_inherited_fds
  302. connect_to_new_process = _forkserver.connect_to_new_process
  303. set_forkserver_preload = _forkserver.set_forkserver_preload