pipeline.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. ## @package pipeline
  2. # Module caffe2.python.pipeline
  3. from caffe2.python import core, queue_util
  4. from caffe2.python.dataio import Reader, Writer
  5. from caffe2.python.net_builder import NetBuilder, ops
  6. from caffe2.python.schema import as_record, Field
  7. from caffe2.python.task import Node, Task, TaskGroup
  8. class Output(object):
  9. """
  10. Represents the result of a processor function. A processor can either
  11. return an Output, or it can return a record, in which case an Output will be
  12. created for it afterwards.
  13. """
  14. def __init__(self, nets=None, record=None, should_stop=None):
  15. builder_children = NetBuilder.current().get()
  16. assert nets is None or len(builder_children) == 0, (
  17. 'Cannot both use `ops` syntax and return a list of nets.')
  18. if nets is None:
  19. nets = builder_children
  20. if isinstance(nets, core.Net):
  21. nets = [nets]
  22. self.nets = [] if nets is None else list(nets)
  23. self.record = None if record is None else as_record(record)
  24. self.should_stop = should_stop
  25. DEFAULT_QUEUE_CAPACITY = 100
  26. def _init_output(output, capacity, global_init_net, global_exit_net):
  27. if output is None:
  28. out_queue = queue_util.Queue(
  29. capacity=(
  30. capacity if capacity is not None
  31. else DEFAULT_QUEUE_CAPACITY))
  32. writer = out_queue.writer()
  33. elif isinstance(output, Writer):
  34. assert capacity is None, 'capacity would not be used.'
  35. out_queue = None
  36. writer = output
  37. elif hasattr(output, 'writer'):
  38. assert capacity is None, 'capacity would not be used.'
  39. out_queue = output
  40. writer = output.writer()
  41. else:
  42. raise ValueError('output must be a reader, queue or stream.')
  43. writer.setup_ex(global_init_net, global_exit_net)
  44. return out_queue, writer
  45. def make_processor(processor, reader=None):
  46. if processor is None:
  47. return lambda rec: rec
  48. elif isinstance(processor, core.Net):
  49. return NetProcessor(processor)
  50. else:
  51. if reader is not None and hasattr(processor, "schema_func"):
  52. def processor_schema():
  53. return processor.schema_func(reader)
  54. processor.schema = processor_schema
  55. return processor
  56. def normalize_processor_output(output):
  57. """
  58. Allow for processors to return results in several formats.
  59. TODO(azzolini): simplify once all processors use NetBuilder API.
  60. """
  61. if isinstance(output, Output):
  62. """ Processor returned an Output. """
  63. return output
  64. elif isinstance(output, Field):
  65. """ Processor returned a record. """
  66. return Output(record=output)
  67. elif isinstance(output, tuple):
  68. is_record_and_blob = (
  69. len(output) == 2 and
  70. isinstance(output[0], Field) and
  71. isinstance(output[1], core.BlobReference))
  72. if is_record_and_blob:
  73. """ Processor returned (record, stop_blob) """
  74. return Output(None, *output)
  75. else:
  76. """ Processor returned (nets, record, stop_blob) """
  77. return Output(*output)
  78. else:
  79. """ Processor returned nets, no output """
  80. return Output(output)
  81. def pipe(
  82. input, output=None, num_threads=1, processor=None, name=None,
  83. capacity=None, group=None, num_runtime_threads=1):
  84. """
  85. Given a Reader, Queue or DataStream in `input`, and optionally, a Writer,
  86. Queue or DataStream in `output`, creates a Task that, when run, will
  87. pipe the input into the output, using multiple parallel threads.
  88. Additionally, if a processor is given, it will be called between reading
  89. and writing steps, allowing it to transform the record.
  90. Args:
  91. input: either a Reader, Queue or DataStream that will be read
  92. until a stop is signaled either by the reader or the
  93. writer.
  94. output: either a Writer, a Queue or a DataStream that will be
  95. written to as long as neither reader nor writer signal
  96. a stop condition. If output is not provided or is None,
  97. a Queue is created with given `capacity` and written to.
  98. num_threads: number of concurrent threads used for processing and
  99. piping. If set to 0, no Task is created, and a
  100. reader is returned instead -- the reader returned will
  101. read from the reader passed in and process it.
  102. ** DEPRECATED **. Use `num_runtime_threads` instead.
  103. This option will be removed once all readers/processors
  104. support `num_runtime_threads`.
  105. processor: (optional) function that takes an input record and
  106. optionally returns a record; this will be called
  107. between read and write steps. If the processor does
  108. not return a record, a writer will not be instantiated.
  109. Processor can also be a core.Net with input and output
  110. records properly set. In that case, a NetProcessor is
  111. instantiated, cloning the net for each of the threads.
  112. name: (optional) name of the task to be created.
  113. capacity: when output is not passed, a queue of given `capacity`
  114. is created and written to.
  115. group: (optional) explicitly add the created Task to this
  116. TaskGroup, instead of using the currently active one.
  117. num_runtime_threads: Similar to `num_threads`, but instead of expanding
  118. the tasks with a `for` loop in python, does that at
  119. runtime. This is preferable to `num_threads`, but some
  120. processors/readers still require to be called multiple
  121. times in python.
  122. Returns:
  123. Output Queue, DataStream, Reader, or None, depending on the parameters
  124. passed.
  125. """
  126. result, _ = _pipe_step(
  127. input, output, num_threads, processor, name, capacity, group,
  128. num_runtime_threads)
  129. return result
  130. def pipe_and_output(
  131. input, output=None, num_threads=1, processor=None, name=None,
  132. capacity=None, group=None, num_runtime_threads=1, final_outputs=None):
  133. """
  134. Similar to `pipe`, with the additional ability for the pipe Task to
  135. return output values to the `Session` once done.
  136. Returns:
  137. Tuple (out_queue, *task_outputs)
  138. out_queue: same as return value of `pipe`.
  139. task_outputs: TaskOutput object, fetchable from the client after
  140. session.run() returns.
  141. """
  142. assert num_threads > 0
  143. result, task = _pipe_step(
  144. input, output, num_threads, processor, name, capacity, group,
  145. num_runtime_threads, final_outputs)
  146. output = None
  147. if final_outputs is not None:
  148. output = task.outputs()
  149. if type(final_outputs) not in (list, tuple):
  150. output = output[0]
  151. return result, output
  152. def processor_name(processor):
  153. if hasattr(processor, 'name'):
  154. return processor.name
  155. if hasattr(processor, 'func_name'):
  156. if processor.func_name == '<lambda>':
  157. return processor.__module__
  158. if hasattr(processor, 'im_class'):
  159. return '%s.%s' % (processor.im_class.__name__, processor.func_name)
  160. return processor.func_name
  161. return processor.__class__.__name__
  162. def _runtime_threads_task(name, group, final_outputs, reader, num_threads,
  163. output, capacity):
  164. node_name = str(Node.current())
  165. profiler_name = "{0}/{1}/{2}/{3}/{4}".format(
  166. node_name,
  167. "pipe",
  168. name,
  169. processor_name(input) if input else "NoInput",
  170. processor_name(output) if output else "NoOutput")
  171. with Task(name=name, group=group, outputs=final_outputs,
  172. num_instances=num_threads) as task:
  173. global_exit_net = core.Net('pipe:exit')
  174. global_init_net = core.Net('pipe:init')
  175. reader.setup_ex(global_init_net, global_exit_net)
  176. init_net = core.Net('pipe:instance:init')
  177. exit_net = core.Net('pipe:instance:exit')
  178. read_nets, status, rec = reader.read_record_ex(init_net, exit_net)
  179. init_net.ConstantFill(
  180. [], [status],
  181. shape=[],
  182. value=False,
  183. dtype=core.DataType.BOOL
  184. )
  185. if rec is not None:
  186. out_queue, writer = _init_output(
  187. output, capacity, global_init_net, global_exit_net)
  188. write_nets, _ = writer.write_record_ex(
  189. rec, init_net, exit_net, status)
  190. else:
  191. out_queue = None
  192. write_nets = []
  193. with ops.task_init():
  194. ops.net(global_init_net)
  195. with ops.task_instance_init():
  196. ops.net(init_net)
  197. timer_start_net = core.Net('timer_start')
  198. timer = timer_start_net.TimerBegin([], counter_name=profiler_name)
  199. timer_end_net = core.Net('timer_end')
  200. timer_end_net.TimerEnd(timer, [])
  201. ops.net(core.execution_step(
  202. 'body',
  203. [timer_start_net] + list(read_nets) + list(write_nets) +
  204. [timer_end_net],
  205. should_stop_blob=status))
  206. ops.net(timer_end_net)
  207. with ops.task_instance_exit():
  208. ops.net(exit_net)
  209. with ops.task_exit():
  210. ops.net(global_exit_net)
  211. return out_queue, task
  212. def _static_threads_task(name, group, final_outputs, reader, num_threads,
  213. output, capacity):
  214. node_name = str(Node.current())
  215. profiler_name = "{0}/{1}/{2}/{3}/{4}".format(
  216. node_name,
  217. "pipe",
  218. name,
  219. processor_name(input) if input else "NoInput",
  220. processor_name(output) if output else "NoOutput")
  221. with Task(name=name, group=group, outputs=final_outputs) as task:
  222. global_exit_net = core.Net('exit')
  223. global_init_net = core.Net('init')
  224. reader.setup_ex(global_init_net, global_exit_net)
  225. out_queue = None
  226. writer = None
  227. steps = []
  228. for thread_id in range(num_threads):
  229. with NetBuilder(name='t:%d' % thread_id) as nb:
  230. init_net = core.Net('init')
  231. exit_net = core.Net('exit')
  232. read_nets, status, rec = reader.read_record_ex(
  233. init_net, exit_net)
  234. init_net.ConstantFill(
  235. [], [status],
  236. shape=[],
  237. value=False,
  238. dtype=core.DataType.BOOL
  239. )
  240. if rec is not None:
  241. if writer is None:
  242. # hack so that the out queue gets the right name prefix
  243. # (otherwise they would be prefixed with the thread id)
  244. with NetBuilder(_fullname=task.name):
  245. out_queue, writer = _init_output(
  246. output, capacity, global_init_net,
  247. global_exit_net)
  248. write_nets, _ = writer.write_record_ex(
  249. rec, init_net, exit_net, status)
  250. else:
  251. write_nets = []
  252. timer_start_net = core.Net('timer_start')
  253. timer = timer_start_net.TimerBegin([], counter_name=profiler_name)
  254. timer_end_net = core.Net('timer_end')
  255. timer_end_net.TimerEnd(timer, [])
  256. ops.net(init_net)
  257. ops.net(core.execution_step(
  258. 'body',
  259. [timer_start_net] + list(read_nets) + list(write_nets) +
  260. [timer_end_net],
  261. should_stop_blob=status))
  262. ops.net(timer_end_net)
  263. ops.net(exit_net)
  264. steps.append(core.to_execution_step(nb))
  265. ops.net(global_init_net)
  266. ops.net(core.execution_step('body', steps, concurrent_substeps=True))
  267. ops.net(global_exit_net)
  268. return out_queue, task
  269. def _pipe_step(
  270. input, output=None, num_threads=1, processor=None, name=None,
  271. capacity=None, group=None, num_runtime_threads=None, final_outputs=None):
  272. """
  273. """
  274. assert num_threads <= 1 or num_runtime_threads <= 1, (
  275. 'Only one of num_threads or num_runtime_threads must be set.')
  276. if isinstance(input, Reader):
  277. reader = input
  278. elif hasattr(input, 'reader'):
  279. reader = input.reader()
  280. else:
  281. raise ValueError(
  282. 'Input must be a reader, queue or stream. Got {}'.format(type(input)))
  283. if processor is not None:
  284. reader = ProcessingReader(reader, processor)
  285. if num_threads == 0 or num_runtime_threads == 0:
  286. assert output is None
  287. return reader, None
  288. if name is None and processor is not None:
  289. name = processor_name(processor)
  290. if name is None and output is not None:
  291. name = 'pipe_into:%s' % processor_name(output)
  292. if name is None:
  293. name = 'pipe_from:%s' % processor_name(input)
  294. if num_threads > 1:
  295. return _static_threads_task(
  296. name, group, final_outputs, reader, num_threads, output, capacity)
  297. else:
  298. return _runtime_threads_task(
  299. name, group, final_outputs, reader, num_runtime_threads, output,
  300. capacity)
  301. class ProcessingReader(Reader):
  302. """
  303. Reader that reads from an upstream reader, calls the processor, and returns
  304. the processed record.
  305. """
  306. def __init__(self, reader, processor):
  307. Reader.__init__(self)
  308. self.reader = reader
  309. self.processor = make_processor(processor, reader)
  310. def schema(self):
  311. return self.processor.schema()
  312. def setup_ex(self, init_net, finish_net):
  313. self.reader.setup_ex(init_net, finish_net)
  314. def read_ex(self, init_net, exit_net):
  315. read_nets, status, rec = self.reader.read_record_ex(init_net, exit_net)
  316. # We don't use status as stop_blob of NetBuilder it's not guarantee that
  317. # it would end up being the true stob_blob. For example,
  318. # ReaderWithLimitBase doesn't pass the status through but rather copy
  319. # from it.
  320. with NetBuilder() as nb:
  321. # Current NetBuilder is optionally used inside the processor,
  322. # then its children are retrieved inside of
  323. # normalize_processor_output.
  324. # Once readers and writers also use NetBuilder,
  325. # this logic will be more natural.
  326. result = normalize_processor_output(self.processor(rec))
  327. read_nets += result.nets
  328. if result.should_stop or nb._stop_blob:
  329. stop_net = core.Net('stop_net')
  330. if result.should_stop:
  331. stop_net.Or([status, result.should_stop], [status])
  332. if nb._stop_blob:
  333. stop_net.Or([status, nb._stop_blob], [status])
  334. read_nets.append(stop_net)
  335. if hasattr(self.processor, 'setup'):
  336. init_net.add_attribute(TaskGroup.LOCAL_SETUP, self.processor)
  337. self._set_schema(result.record)
  338. fields = result.record.field_blobs() if result.record else None
  339. return read_nets, status, fields
  340. class NetProcessor(object):
  341. """
  342. Processor that clones a core.Net each time it's called, executing
  343. the cloned net as the processor. It requires the Net to have input
  344. and (optionally) output records set, with net.set_input_record() and
  345. net.set_output_record().
  346. """
  347. def __init__(self, net, stop_signal=None, thread_init_nets=None, name=None):
  348. assert isinstance(net, core.Net)
  349. assert stop_signal is None or isinstance(
  350. stop_signal, core.BlobReference)
  351. self.name = name or str(net)
  352. self.thread_init_nets = thread_init_nets or []
  353. self.net = net
  354. self._stop_signal = stop_signal
  355. self._blob_maps = []
  356. self._frozen = False
  357. self._cloned_init_nets = []
  358. def schema(self):
  359. return self.net.output_record()
  360. def setup(self, init_net):
  361. self._frozen = True
  362. cloned_init_nets = self._cloned_init_nets
  363. self._cloned_init_nets = []
  364. return cloned_init_nets
  365. def __call__(self, rec):
  366. assert not self._frozen
  367. prefix = NetBuilder.current().name + '/'
  368. blob_remap = {}
  369. for net in self.thread_init_nets:
  370. new_net, _ = core.clone_and_bind_net(
  371. net, str(net) + prefix, prefix, blob_remap)
  372. self._cloned_init_nets.append(new_net)
  373. new_net, remappings = core.clone_and_bind_net(
  374. self.net, str(self.net) + prefix, prefix, blob_remap, rec)
  375. if self._stop_signal is None:
  376. stop_signal = None
  377. elif str(self._stop_signal) in remappings:
  378. stop_signal = core.BlobReference(
  379. remappings[str(self._stop_signal)],
  380. net=new_net)
  381. else:
  382. stop_signal = self._stop_signal
  383. self._blob_maps.append(remappings)
  384. return Output([new_net], new_net.output_record(), stop_signal)
  385. def blob_maps(self):
  386. self._frozen = True
  387. return self._blob_maps