net_builder.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. ## @package net_builder
  2. # Module caffe2.python.net_builder
  3. from caffe2.python import core, context
  4. from caffe2.python.task import Task, TaskGroup
  5. from caffe2.python.control_ops_util import add_if_op, add_while_op
  6. class NetBuilder(context.Managed):
  7. """
  8. Scope-driven mechanism for building nets, loops and conditional blocks.
  9. Args:
  10. name: NetBuilder's name
  11. initial_scope: list of blobs that are available for reading/writing
  12. Example:
  13. from caffe2.python.net_builder import NetBuilder, ops
  14. with NetBuilder() as nb:
  15. c = ops.Const(5)
  16. d = ops.Const(0)
  17. with ops.loop():
  18. ops.stop_if(ops.LE([c, ops.Const(0)]))
  19. ops.Add([c, ops.Const(-1)], [c])
  20. with ops.If(ops.GE([c, ops.Const(3)])):
  21. ops.Add([d, ops.Const(10)], [d])
  22. ops.Print(c, [])
  23. ops.Print(d, [])
  24. step = core.to_execution_step(nb)
  25. """
  26. def __init__(self, name=None, initial_scope=None, _stop_blob_required=False,
  27. _stop_blob=None, _fullname=None, _use_control_ops=False):
  28. parent = NetBuilder.current(required=False)
  29. assert not _fullname or not name, 'Cannot set both _fullname and name'
  30. assert not _use_control_ops or \
  31. (not _stop_blob_required and not _stop_blob), \
  32. 'Stop blobs are not used with control operators'
  33. self.name = _fullname or '/'.join(
  34. n for n in (parent.name if parent else None, name) if n
  35. )
  36. self._frozen = False
  37. self._current_net = None
  38. self._children = []
  39. if parent:
  40. # make sure parent has an up to date lexical scope computed
  41. parent._update_lexical_scope()
  42. self._init_lexical_scope = set(parent._lexical_scope) if parent else set()
  43. if initial_scope:
  44. self._init_lexical_scope |= set([str(b) for b in initial_scope])
  45. self._lexical_scope = set(self._init_lexical_scope)
  46. self._stop_blob = _stop_blob
  47. self._stop_blob_required = _stop_blob_required
  48. self._use_control_ops = _use_control_ops
  49. def stop_blob(self):
  50. """
  51. Returns the BlobReference to the stop_blob of this NetBuilder.
  52. If one is not yet available, creates one.
  53. This function assumes that the stop_blob() will be used immediatelly
  54. in the current net, so it doesn't initialize it if the current net is
  55. the first of the builder.
  56. """
  57. assert not self._use_control_ops, \
  58. 'Stop blobs are not used with control operators'
  59. if self._stop_blob is None:
  60. net = self.current_net()
  61. self._stop_blob = core.BlobReference(
  62. net.NextName('stop_blob'), net=net)
  63. net.Const(False, blob_out=self._stop_blob)
  64. if self._current_net != self._children[0]:
  65. self._children.insert(0, core.Net('stop_blob_init'))
  66. self._children[0].Const(False, blob_out=self._stop_blob)
  67. return self._stop_blob
  68. def stop_if(self, blob):
  69. assert not self._use_control_ops, \
  70. 'Stop blobs are not used with control operators'
  71. stop_blob = self.stop_blob()
  72. ops.Or([stop_blob, blob], [stop_blob])
  73. self._current_net = None
  74. def _assert_mutable(self):
  75. assert not self._frozen, (
  76. 'This NetBuilder (%s) has been built already.' % self.name)
  77. def _update_lexical_scope(self):
  78. """
  79. Updates lexical scope based on the current list of children.
  80. Lexical scope contains names of blobs that are currently available
  81. and were introduced in the net builder
  82. """
  83. self._lexical_scope = set(self._init_lexical_scope)
  84. for child in self._children:
  85. if isinstance(child, core.Net):
  86. self._lexical_scope |= child.UsedBlobNames()
  87. elif isinstance(child, NetBuilder) and child._use_control_ops:
  88. self._lexical_scope |= child._lexical_scope
  89. def _reset_children(self):
  90. self._current_net = None
  91. self._children = []
  92. self._lexical_scope = set(self._init_lexical_scope)
  93. def add(self, child):
  94. self._assert_mutable()
  95. if self._use_control_ops:
  96. assert isinstance(child, core.Net) or (
  97. isinstance(child, NetBuilder) and child._use_control_ops), \
  98. "Expected Net or NetBuilder with control ops"
  99. self._current_net = None
  100. self._children.append(child)
  101. # to-do : check it's not a dag net
  102. if isinstance(child, core.Net):
  103. self._current_net = child
  104. self._update_lexical_scope()
  105. return child
  106. def current_net(self, name=None):
  107. self._assert_mutable()
  108. if self._current_net is None or name is not None:
  109. self.add(core.Net(name))
  110. return self._current_net
  111. def freeze(self):
  112. for child in self._children:
  113. if hasattr(child, 'freeze'):
  114. child.freeze()
  115. self._current_net = None
  116. self._frozen = True
  117. def get(self):
  118. self.freeze()
  119. return self._children
  120. def __exit__(self, etype, *args):
  121. super(NetBuilder, self).__exit__(etype, *args)
  122. if self._use_control_ops and len(self._children) > 0:
  123. _children = self._children
  124. self._reset_children()
  125. merged_net = NetBuilder.merge_nets(
  126. _children, self._lexical_scope)
  127. assert merged_net, "Expected a non-empty merge of children"
  128. self._children = [merged_net]
  129. self.freeze()
  130. if etype is not None:
  131. return
  132. assert (not self._stop_blob_required) or self._stop_blob is not None, (
  133. 'This NetBuilder (%s) requires a stop condition ' % self.name +
  134. 'to be set with `stop` or `stop_if`')
  135. @staticmethod
  136. def merge_nets(nets_or_builders, outer_blob_names):
  137. # Only nets or builders with control ops are allowed.
  138. # Need to pay attention to external outputs, e.g.
  139. # ...
  140. # IfNet1 (cond_blob):
  141. # (Net1)
  142. # X = 1
  143. # IfNet2 (...):
  144. # X = X + 1
  145. # ...
  146. # In this example there're two children in then branch of IfNet1:
  147. # a subnet Net1 that creates blob X and sets its value to one, and
  148. # a net builder IfNet2 that (conditionally) increments X.
  149. # From IfNet2's point of view X is an external input
  150. # and output blob, it will be put into IfNet2 net's external_output.
  151. # At the same time, from the point of view of IfNet1 X is purely local.
  152. # Net.AppendNet just merges external outputs of the networks, so
  153. # without checking this the result of Net1.AppendNet(IfNet2's net)
  154. # would have blob X in external_output
  155. net = None
  156. for n in nets_or_builders:
  157. cur = None
  158. if isinstance(n, NetBuilder):
  159. assert n._use_control_ops, \
  160. "Merging of NetBuilder supported only for control ops"
  161. nets = n.get()
  162. assert len(nets) == 1 and isinstance(nets[0], core.Net), \
  163. "Invalid control op net builder"
  164. cur = nets[0]
  165. else:
  166. assert isinstance(n, core.Net)
  167. cur = n
  168. if net:
  169. net.AppendNet(cur)
  170. else:
  171. net = cur
  172. if net:
  173. # correct external output
  174. external_outputs = [o for o in net.Proto().external_output
  175. if o in outer_blob_names]
  176. net.Proto().external_output[:] = external_outputs
  177. return net
  178. def __str__(self):
  179. return self.name or 'Un-named NetBuilder'
  180. class Operations(object):
  181. """
  182. Operations to be used in the context of a NetBuilder.
  183. """
  184. def net(self, net=None, name=None):
  185. """
  186. Retrieves the current net, or add a new net to the builder.
  187. Args:
  188. net: If provided, add the given net to the active builder.
  189. Else, returns the current Net or creates a new one as needed.
  190. name: if provided, creates a new Net with given name and makes
  191. it the new current net of the active builder. Cannot
  192. be provided if net is provided.
  193. """
  194. assert name is None or net is None, (
  195. 'Cannot provide both `net` and `name`.')
  196. if net is not None:
  197. NetBuilder.current().add(net)
  198. return net
  199. return NetBuilder.current().current_net(name=name)
  200. def __getattr__(self, op_type):
  201. """
  202. Adds an operator call to the currently active Net.
  203. """
  204. if op_type.startswith('__'):
  205. raise AttributeError()
  206. # We want hasattr to work properly even if no context is active.
  207. if NetBuilder.current(required=False) is None:
  208. raise AttributeError('No active NetBuilder.')
  209. return getattr(self.net(), op_type)
  210. def task_group(self):
  211. """
  212. Creates a local task group which will execute as the next step of
  213. the current NetBuilder.
  214. """
  215. from caffe2.python import task
  216. group = NetBuilder.current()
  217. with task.Cluster():
  218. with task.Node('local'):
  219. tg = task.TaskGroup()
  220. group.add(tg)
  221. return tg
  222. def stop(self):
  223. """
  224. Stop execution of the current execution step.
  225. Example:
  226. ops.Print(a, 0)
  227. ops.stop()
  228. ops.Print(b, 0)
  229. In the example, 'b' will never be printed.
  230. """
  231. return self.stop_if(ops.Const(True))
  232. def stop_if(self, blob):
  233. """
  234. Stop execution of the current execution step if the
  235. condition `blob` is met.
  236. Example:
  237. ops.Print(a, 0)
  238. ops.stop_if(ops.LE([x, ops.Const(0)]))
  239. ops.Print(b, 0)
  240. In the example, 'b' will only be printed if the value of scalar
  241. tensor 'x' is greater than 0.
  242. """
  243. return NetBuilder.current().stop_if(blob)
  244. def loop(self, iters=None, name=None):
  245. """
  246. Creates a NetBuilder that will execute in a loop as the next step of
  247. the current NetBuilder. If `iters` is provided, the loop will execute
  248. for `iters` iterations and then stop. `iters` can be a constant or a
  249. BlobReference. If `iters` is not provided, the loop will execute
  250. until `ops.stop` or `ops.stop_if` is called.
  251. Examples:
  252. a = ops.Const(5)
  253. with ops.loop():
  254. ops.stop_if(ops.LE([a, ops.Const(0)]))
  255. ops.Print(a, 0)
  256. ops.Add([a, ops.Const(-1)], [a])
  257. Above, 'a' will be printed 5 times, with values 5 to 1.
  258. with ops.loop(10) as loop:
  259. ops.LogInfo(loop.iter())
  260. This will print the numbers from 0 to 9.
  261. x = ops.Add([ops.Const(10), ops.Const(10)])
  262. with ops.loop(x) as loop:
  263. ops.LogInfo(loop.iter())
  264. This will print the numbers from 0 to 19.
  265. """
  266. return NetBuilder.current().add(_Loop(iters, name=name))
  267. def stop_guard(self, has_stopped_blob=None, name=None):
  268. """
  269. Creates a NetBuilder that will execute once as the next step of the
  270. current NetBuilder. After execution, a bool tensor will indicate
  271. whether the inner execution was halted with `stop` or `stop_if`.
  272. Example:
  273. a = ops.Const(True)
  274. with ops.stop_guard() as sg1:
  275. ops.stop_if(a)
  276. ops.Print(ops.Const('did not stop'))
  277. b = ops.Const(False)
  278. with ops.stop_guard() as sg2:
  279. ops.stop_if(b)
  280. ops.Print(ops.Const('did not stop'))
  281. ops.Print(sg1.has_stopped(), [])
  282. ops.Print(sg2.has_stopped(), [])
  283. In the example, 'did not stop' will be printed once,
  284. followed by True and False.
  285. """
  286. return NetBuilder.current().add(
  287. _StopGuard(has_stopped_blob=has_stopped_blob, name=name))
  288. def If(self, cond, name=None):
  289. """
  290. Creates a NetBuilder that will execute once as the next step of the
  291. current NetBuilder if the blob `cond` is True.
  292. Example:
  293. with ops.If(ops.Const(True)):
  294. ops.Print(ops.Const('Will print'))
  295. with ops.If(ops.Const(False)):
  296. ops.Print(ops.Const('Wont print'))
  297. The example will print 'Will print' once.
  298. """
  299. return NetBuilder.current().add(_RunIf(cond, name=name))
  300. def IfNet(self, cond, name=None):
  301. """
  302. Same as If, but uses 'If' operator instead of execution step logic
  303. """
  304. return NetBuilder.current().add(_RunIfNet(cond, name=name))
  305. def Else(self, name=None):
  306. """
  307. Else branch of IfNet, has to be specified immediately after IfNet.
  308. Example:
  309. with ops.IfNet(ops.LT([x, y])):
  310. ...
  311. with ops.Else():
  312. ...
  313. """
  314. return _RunElseNet(name=name)
  315. def WhileNet(self, name=None):
  316. """
  317. NetBuilder for 'While' control operator
  318. """
  319. return NetBuilder.current().add(_RunWhileNet(name=name))
  320. def Condition(self, name=None):
  321. """
  322. Loop's condition, executed within WhileNet context
  323. """
  324. assert isinstance(NetBuilder.current(), _RunWhileNet), \
  325. "Use of Condition outside of WhileNet"
  326. return _RunWhileCondition(name=name)
  327. def task_init(self):
  328. """
  329. Defines operations that will be executed once at task startup.
  330. Useful when implementing processors, that don't have access to the Task
  331. top-level structure.
  332. This setup will be run only once, even if multiple instances of the task
  333. will run in parallel. For instance-local initialization, use
  334. `task_instance_init` instead.
  335. Example:
  336. def my_processor(rec):
  337. with ops.task_init():
  338. one = ops.Const(1)
  339. two = ops.Const(1)
  340. return Tuple(
  341. ops.Add(rec[0](), zero), ops.Add(rec[1](), two))
  342. """
  343. setup = _SetupBuilder(_SetupBuilder.INIT)
  344. self.net().add_attribute(Task.TASK_SETUP, setup)
  345. return setup
  346. def task_exit(self):
  347. """
  348. Define operations to be executed once at task shutdown.
  349. Useful when implementing processors, that don't have access to the Task
  350. top-level structure.
  351. This shutdown will be run only once, after all concurrent instances of
  352. the task have already finished. For instance-local shutdown,
  353. use `task_instance_exit` instead.
  354. Example:
  355. def read_queue(queue):
  356. with ops.task_exit():
  357. queue.close(ops.net())
  358. return queue.read(ops.net())
  359. """
  360. setup = _SetupBuilder(_SetupBuilder.EXIT)
  361. self.net().add_attribute(Task.TASK_SETUP, setup)
  362. return setup
  363. def task_instance_init(self):
  364. """
  365. Defines operations that will be executed once at startup of each
  366. instance of a task. This can be seen as "thread_local" initialization.
  367. It is guaranteed to run only after all `task_init` logic finishes.
  368. This setup will be run concurrently for each instance of a task.
  369. For global task initialization, use `task_init` instead.
  370. """
  371. setup = _SetupBuilder(_SetupBuilder.INIT)
  372. self.net().add_attribute(Task.TASK_INSTANCE_SETUP, setup)
  373. return setup
  374. def task_instance_exit(self):
  375. """
  376. Defines operations that will be executed once at shutdown of each
  377. instance of a task. This can be seen as "thread_local" finalization.
  378. This shutdown will be run concurrently for each instance of a task.
  379. For global task shutdown, use `task_exit` instead.
  380. """
  381. setup = _SetupBuilder(_SetupBuilder.EXIT)
  382. self.net().add_attribute(Task.TASK_INSTANCE_SETUP, setup)
  383. return setup
  384. def local_init(self):
  385. """
  386. Similar to `task_init`, but executes at TaskGroup's startup instead,
  387. before any task of the group starts executing. This will run only
  388. once on each node, before initialization of any task, so it can be
  389. used e.g. to initialize blobs shared across tasks.
  390. """
  391. setup = _SetupBuilder(_SetupBuilder.INIT)
  392. self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
  393. return setup
  394. def local_exit(self, name=None):
  395. """
  396. Similar to `task_exit`, but executes at TaskGroup's exit instead,
  397. after all tasks of the group finished execution.
  398. This will run only once on each node.
  399. """
  400. setup = _SetupBuilder(_SetupBuilder.EXIT, name)
  401. self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
  402. return setup
  403. def task_reporter(self, interval_ms=1000, name=None):
  404. """
  405. Define operations to be executed at every time interval from
  406. task start-up to finish. These operations are guaranteed to
  407. execute at least once after all other operations of the task are
  408. finished.
  409. Example:
  410. with ops.task_reporter(interval_ms=10000):
  411. ops.LogInfo('10s elapsed')
  412. """
  413. return _ReporterBuilder(interval_ms, net=self.net(), name=name)
  414. def local_reporter(self, interval_ms=1000, name=None):
  415. """
  416. Similar to task_report, but operations defined within this block
  417. will run repeatedly for as long as any of the tasks in the current
  418. TaskGroup have not finished.
  419. """
  420. return _ReporterBuilder(interval_ms, name=name)
  421. ops = Operations()
  422. class _ReporterBuilder(NetBuilder):
  423. def __init__(self, interval_ms, net=None, name=None):
  424. NetBuilder.__init__(self, name)
  425. self._net = net
  426. self.interval_ms = interval_ms
  427. def __exit__(self, etype, *args):
  428. if etype is None:
  429. step = core.to_execution_step(self)
  430. step.RunEveryMillis(self.interval_ms)
  431. if self._net:
  432. self._net.add_attribute(Task.REPORT_STEP, step)
  433. else:
  434. TaskGroup.current().report_step(
  435. step, interval_ms=self.interval_ms)
  436. NetBuilder.__exit__(self, etype, *args)
  437. class _SetupBuilder(NetBuilder):
  438. INIT = 'init'
  439. EXIT = 'exit'
  440. def __init__(self, type, name=None):
  441. NetBuilder.__init__(self, name)
  442. self.type = type
  443. def setup(self, net):
  444. if self.type == _SetupBuilder.INIT:
  445. return core.to_execution_step(self)
  446. def exit(self, net):
  447. if self.type == _SetupBuilder.EXIT:
  448. return core.to_execution_step(self)
  449. class _RunOnce(NetBuilder):
  450. def __init__(self, name=None):
  451. NetBuilder.__init__(self, name)
  452. def __exit__(self, etype, *args):
  453. if etype is None and self._stop_blob is not None:
  454. ops.stop()
  455. NetBuilder.__exit__(self, etype, *args)
  456. class _StopGuard(_RunOnce):
  457. def __init__(self, has_stopped_blob=None, name=None):
  458. _RunOnce.__init__(self, name)
  459. self._stopped = has_stopped_blob
  460. self._ran = False
  461. def __enter__(self):
  462. r = _RunOnce.__enter__(self)
  463. self._stopped = ops.Const(True, blob_out=self._stopped)
  464. return r
  465. def __exit__(self, etype, *args):
  466. if etype is None:
  467. self._ran = True
  468. ops.Const(False, blob_out=self._stopped)
  469. _RunOnce.__exit__(self, etype, *args)
  470. def has_stopped(self):
  471. """
  472. Return a blob that will be set to scalar bool `True` after
  473. this net builder ran, iff it was halted early.
  474. """
  475. assert self._ran, 'Context not used yet.'
  476. return self._stopped
  477. class _Loop(NetBuilder):
  478. def __init__(self, iters=None, name=None):
  479. NetBuilder.__init__(self, name, _stop_blob_required=True)
  480. if iters is not None:
  481. self._inc = ops.Const(1)
  482. self._iter = ops.Const(0)
  483. self._num_iters = (
  484. iters if isinstance(iters, core.BlobReference)
  485. else ops.Const(iters))
  486. else:
  487. self._num_iters = None
  488. def iter(self):
  489. assert self._num_iters is not None, (
  490. 'This loop does not have a number of iterations.')
  491. assert self._iter is not None, (
  492. 'iter() must be called from inside the loop context')
  493. return self._iter
  494. def __enter__(self):
  495. builder = NetBuilder.__enter__(self)
  496. if self._num_iters is not None:
  497. ops.stop_if(ops.GE([self._iter, self._num_iters]))
  498. return builder
  499. def __exit__(self, type, *args):
  500. if type is None and self._num_iters is not None:
  501. self.current_net().Add([self._iter, self._inc], [self._iter])
  502. NetBuilder.__exit__(self, type, *args)
  503. class _RunIf(_RunOnce):
  504. def __init__(self, cond_blob=None, name=None, _already_ran=None):
  505. _RunOnce.__init__(self, name)
  506. assert cond_blob or _already_ran
  507. self._is_else = cond_blob is None
  508. if _already_ran is None:
  509. self._else_blob = ops.Not(cond_blob)
  510. self._already_ran = ops.Const(False)
  511. else:
  512. self._already_ran = _already_ran
  513. self._else_blob = _already_ran if cond_blob is None else (
  514. ops.Or([_already_ran, ops.Not(cond_blob)]))
  515. def __enter__(self):
  516. r = _RunOnce.__enter__(self)
  517. ops.stop_if(self._else_blob)
  518. ops.Const(True, blob_out=self._already_ran)
  519. return r
  520. def Elif(self, cond, name=None):
  521. assert not self._is_else, 'Else not allowed for an Else.'
  522. return NetBuilder.current().add(_RunIf(
  523. cond, name=name or self.name, _already_ran=self._already_ran))
  524. def Else(self, name=None):
  525. assert not self._is_else, 'Elif not allowed for an Else.'
  526. return NetBuilder.current().add(
  527. _RunIf(name=name or self.name, _already_ran=self._already_ran))
  528. class _RunIfNet(NetBuilder):
  529. """
  530. Generates a single net that uses If operator
  531. """
  532. def __init__(self, cond_blob, name=None):
  533. NetBuilder.__init__(self, name=name, _use_control_ops=True)
  534. assert cond_blob, 'Conditional blob is not specified for an If net'
  535. self._cond_blob = cond_blob
  536. self._then_net = None
  537. self._else_net = None
  538. def add(self, child):
  539. return NetBuilder.add(self, child)
  540. def __exit__(self, type, *args):
  541. if type is None:
  542. _then_nets = self._children
  543. self._reset_children()
  544. self._then_net = NetBuilder.merge_nets(
  545. _then_nets, self._lexical_scope)
  546. if not self._then_net:
  547. self._then_net = core.Net('empty_then_net')
  548. if_net = core.Net(self.name + '/if_net')
  549. add_if_op(if_net, self._cond_blob, self._lexical_scope,
  550. self._then_net, self._else_net)
  551. self._current_net = if_net
  552. self._children = [if_net]
  553. NetBuilder.__exit__(self, type, *args)
  554. class _RunElseNet(NetBuilder):
  555. """
  556. Else branch for _RunIfNet builder
  557. """
  558. def __init__(self, name=None):
  559. NetBuilder.__init__(self, name=name, _use_control_ops=True)
  560. parent = NetBuilder.current(required=False)
  561. assert parent and len(parent._children) > 0 and \
  562. isinstance(parent._children[-1], _RunIfNet), \
  563. 'Invalid use of Else builder'
  564. self._if_builder = parent._children[-1]
  565. def __exit__(self, type, *args):
  566. if type is None:
  567. _else_nets = self._children
  568. self._reset_children()
  569. self._if_builder._else_net = NetBuilder.merge_nets(
  570. _else_nets, self._lexical_scope)
  571. if self._if_builder._else_net:
  572. if_else_net = core.Net(self.name + '/if_else_net')
  573. add_if_op(
  574. if_else_net,
  575. self._if_builder._cond_blob,
  576. self._lexical_scope,
  577. self._if_builder._then_net,
  578. self._if_builder._else_net)
  579. self._if_builder._current_net = if_else_net
  580. self._if_builder._children = [if_else_net]
  581. NetBuilder.__exit__(self, type, *args)
  582. class _RunWhileNet(NetBuilder):
  583. """
  584. Generates a single net that uses While operator
  585. """
  586. def __init__(self, name=None):
  587. NetBuilder.__init__(self, name=name, _use_control_ops=True)
  588. self._cond_builder = None
  589. def __exit__(self, type, *args):
  590. if type is None:
  591. assert self._cond_builder, \
  592. 'Condition builder must be specified in While op'
  593. _cond_blob = self._cond_builder._cond_blob
  594. _cond_net = self._cond_builder._cond_net
  595. loop_body = self._children
  596. self._reset_children()
  597. loop_body_net = NetBuilder.merge_nets(
  598. loop_body, self._lexical_scope)
  599. if not loop_body_net:
  600. loop_body_net = core.Net('empty_loop_body_net')
  601. while_net = core.Net(self.name + '/while_net')
  602. add_while_op(while_net, _cond_blob, self._lexical_scope,
  603. loop_body_net, _cond_net)
  604. self._current_net = while_net
  605. self._children = [while_net]
  606. NetBuilder.__exit__(self, type, *args)
  607. class _RunWhileCondition(NetBuilder):
  608. """
  609. Computes loop's condition, used in the context of WhileNet.
  610. Last operator must have a single scalar boolean output that will be used
  611. as a condition value, no other blobs created in the condition net are
  612. visible outside of it
  613. """
  614. def __init__(self, name=None):
  615. NetBuilder.__init__(self, name=name, _use_control_ops=True)
  616. parent = NetBuilder.current(required=False)
  617. assert parent and isinstance(parent, _RunWhileNet), \
  618. 'Invalid use of loop condition builder'
  619. assert not parent._cond_builder, \
  620. 'Multiple loop condition builders specified'
  621. assert len(parent._children) == 0, \
  622. 'Condition definition must be specified before the loop\'s body'
  623. parent._cond_builder = self
  624. self._cond_blob = None
  625. self._cond_net = None
  626. def __exit__(self, type, *args):
  627. if type is None:
  628. condition_body = self._children
  629. self._reset_children()
  630. self._cond_net = NetBuilder.merge_nets(
  631. condition_body, self._lexical_scope)
  632. assert self._cond_net, 'Invalid loop condition specified'
  633. assert len(self._cond_net.Proto().op) > 0, 'Invalid condition net'
  634. last_op = self._cond_net.Proto().op[-1]
  635. assert len(last_op.output) == 1, 'Invalid condition net'
  636. self._cond_blob = core.BlobReference(name=last_op.output[0], net=None)
  637. self._current_net = self._cond_net
  638. self._children = [self._cond_net]
  639. NetBuilder.__exit__(self, type, *args)