control_test.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. from caffe2.python import control, core, test_util, workspace
  2. import logging
  3. logger = logging.getLogger(__name__)
  4. class TestControl(test_util.TestCase):
  5. def setUp(self):
  6. super(TestControl, self).setUp()
  7. self.N_ = 10
  8. self.init_net_ = core.Net("init-net")
  9. cnt = self.init_net_.CreateCounter([], init_count=0)
  10. const_n = self.init_net_.ConstantFill(
  11. [], shape=[], value=self.N_, dtype=core.DataType.INT64)
  12. const_0 = self.init_net_.ConstantFill(
  13. [], shape=[], value=0, dtype=core.DataType.INT64)
  14. self.cnt_net_ = core.Net("cnt-net")
  15. self.cnt_net_.CountUp([cnt])
  16. curr_cnt = self.cnt_net_.RetrieveCount([cnt])
  17. self.init_net_.ConstantFill(
  18. [], [curr_cnt], shape=[], value=0, dtype=core.DataType.INT64)
  19. self.cnt_net_.AddExternalOutput(curr_cnt)
  20. self.cnt_2_net_ = core.Net("cnt-2-net")
  21. self.cnt_2_net_.CountUp([cnt])
  22. self.cnt_2_net_.CountUp([cnt])
  23. curr_cnt_2 = self.cnt_2_net_.RetrieveCount([cnt])
  24. self.init_net_.ConstantFill(
  25. [], [curr_cnt_2], shape=[], value=0, dtype=core.DataType.INT64)
  26. self.cnt_2_net_.AddExternalOutput(curr_cnt_2)
  27. self.cond_net_ = core.Net("cond-net")
  28. cond_blob = self.cond_net_.LT([curr_cnt, const_n])
  29. self.cond_net_.AddExternalOutput(cond_blob)
  30. self.not_cond_net_ = core.Net("not-cond-net")
  31. cond_blob = self.not_cond_net_.GE([curr_cnt, const_n])
  32. self.not_cond_net_.AddExternalOutput(cond_blob)
  33. self.true_cond_net_ = core.Net("true-cond-net")
  34. true_blob = self.true_cond_net_.LT([const_0, const_n])
  35. self.true_cond_net_.AddExternalOutput(true_blob)
  36. self.false_cond_net_ = core.Net("false-cond-net")
  37. false_blob = self.false_cond_net_.GT([const_0, const_n])
  38. self.false_cond_net_.AddExternalOutput(false_blob)
  39. self.idle_net_ = core.Net("idle-net")
  40. self.idle_net_.ConstantFill(
  41. [], shape=[], value=0, dtype=core.DataType.INT64)
  42. def CheckNetOutput(self, nets_and_expects):
  43. """
  44. Check the net output is expected
  45. nets_and_expects is a list of tuples (net, expect)
  46. """
  47. for net, expect in nets_and_expects:
  48. output = workspace.FetchBlob(
  49. net.Proto().external_output[-1])
  50. self.assertEqual(output, expect)
  51. def CheckNetAllOutput(self, net, expects):
  52. """
  53. Check the net output is expected
  54. expects is a list of bools.
  55. """
  56. self.assertEqual(len(net.Proto().external_output), len(expects))
  57. for i in range(len(expects)):
  58. output = workspace.FetchBlob(
  59. net.Proto().external_output[i])
  60. self.assertEqual(output, expects[i])
  61. def BuildAndRunPlan(self, step):
  62. plan = core.Plan("test")
  63. plan.AddStep(control.Do('init', self.init_net_))
  64. plan.AddStep(step)
  65. self.assertEqual(workspace.RunPlan(plan), True)
  66. def ForLoopTest(self, nets_or_steps):
  67. step = control.For('myFor', nets_or_steps, self.N_)
  68. self.BuildAndRunPlan(step)
  69. self.CheckNetOutput([(self.cnt_net_, self.N_)])
  70. def testForLoopWithNets(self):
  71. self.ForLoopTest(self.cnt_net_)
  72. self.ForLoopTest([self.cnt_net_, self.idle_net_])
  73. def testForLoopWithStep(self):
  74. step = control.Do('count', self.cnt_net_)
  75. self.ForLoopTest(step)
  76. self.ForLoopTest([step, self.idle_net_])
  77. def WhileLoopTest(self, nets_or_steps):
  78. step = control.While('myWhile', self.cond_net_, nets_or_steps)
  79. self.BuildAndRunPlan(step)
  80. self.CheckNetOutput([(self.cnt_net_, self.N_)])
  81. def testWhileLoopWithNet(self):
  82. self.WhileLoopTest(self.cnt_net_)
  83. self.WhileLoopTest([self.cnt_net_, self.idle_net_])
  84. def testWhileLoopWithStep(self):
  85. step = control.Do('count', self.cnt_net_)
  86. self.WhileLoopTest(step)
  87. self.WhileLoopTest([step, self.idle_net_])
  88. def UntilLoopTest(self, nets_or_steps):
  89. step = control.Until('myUntil', self.not_cond_net_, nets_or_steps)
  90. self.BuildAndRunPlan(step)
  91. self.CheckNetOutput([(self.cnt_net_, self.N_)])
  92. def testUntilLoopWithNet(self):
  93. self.UntilLoopTest(self.cnt_net_)
  94. self.UntilLoopTest([self.cnt_net_, self.idle_net_])
  95. def testUntilLoopWithStep(self):
  96. step = control.Do('count', self.cnt_net_)
  97. self.UntilLoopTest(step)
  98. self.UntilLoopTest([step, self.idle_net_])
  99. def DoWhileLoopTest(self, nets_or_steps):
  100. step = control.DoWhile('myDoWhile', self.cond_net_, nets_or_steps)
  101. self.BuildAndRunPlan(step)
  102. self.CheckNetOutput([(self.cnt_net_, self.N_)])
  103. def testDoWhileLoopWithNet(self):
  104. self.DoWhileLoopTest(self.cnt_net_)
  105. self.DoWhileLoopTest([self.idle_net_, self.cnt_net_])
  106. def testDoWhileLoopWithStep(self):
  107. step = control.Do('count', self.cnt_net_)
  108. self.DoWhileLoopTest(step)
  109. self.DoWhileLoopTest([self.idle_net_, step])
  110. def DoUntilLoopTest(self, nets_or_steps):
  111. step = control.DoUntil('myDoUntil', self.not_cond_net_, nets_or_steps)
  112. self.BuildAndRunPlan(step)
  113. self.CheckNetOutput([(self.cnt_net_, self.N_)])
  114. def testDoUntilLoopWithNet(self):
  115. self.DoUntilLoopTest(self.cnt_net_)
  116. self.DoUntilLoopTest([self.cnt_net_, self.idle_net_])
  117. def testDoUntilLoopWithStep(self):
  118. step = control.Do('count', self.cnt_net_)
  119. self.DoUntilLoopTest(step)
  120. self.DoUntilLoopTest([self.idle_net_, step])
  121. def IfCondTest(self, cond_net, expect, cond_on_blob):
  122. if cond_on_blob:
  123. step = control.Do(
  124. 'if-all',
  125. control.Do('count', cond_net),
  126. control.If('myIf', cond_net.Proto().external_output[-1],
  127. self.cnt_net_))
  128. else:
  129. step = control.If('myIf', cond_net, self.cnt_net_)
  130. self.BuildAndRunPlan(step)
  131. self.CheckNetOutput([(self.cnt_net_, expect)])
  132. def testIfCondTrueOnNet(self):
  133. self.IfCondTest(self.true_cond_net_, 1, False)
  134. def testIfCondTrueOnBlob(self):
  135. self.IfCondTest(self.true_cond_net_, 1, True)
  136. def testIfCondFalseOnNet(self):
  137. self.IfCondTest(self.false_cond_net_, 0, False)
  138. def testIfCondFalseOnBlob(self):
  139. self.IfCondTest(self.false_cond_net_, 0, True)
  140. def IfElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
  141. if cond_value:
  142. run_net = self.cnt_net_
  143. else:
  144. run_net = self.cnt_2_net_
  145. if cond_on_blob:
  146. step = control.Do(
  147. 'if-else-all',
  148. control.Do('count', cond_net),
  149. control.If('myIfElse', cond_net.Proto().external_output[-1],
  150. self.cnt_net_, self.cnt_2_net_))
  151. else:
  152. step = control.If('myIfElse', cond_net,
  153. self.cnt_net_, self.cnt_2_net_)
  154. self.BuildAndRunPlan(step)
  155. self.CheckNetOutput([(run_net, expect)])
  156. def testIfElseCondTrueOnNet(self):
  157. self.IfElseCondTest(self.true_cond_net_, True, 1, False)
  158. def testIfElseCondTrueOnBlob(self):
  159. self.IfElseCondTest(self.true_cond_net_, True, 1, True)
  160. def testIfElseCondFalseOnNet(self):
  161. self.IfElseCondTest(self.false_cond_net_, False, 2, False)
  162. def testIfElseCondFalseOnBlob(self):
  163. self.IfElseCondTest(self.false_cond_net_, False, 2, True)
  164. def IfNotCondTest(self, cond_net, expect, cond_on_blob):
  165. if cond_on_blob:
  166. step = control.Do(
  167. 'if-not',
  168. control.Do('count', cond_net),
  169. control.IfNot('myIfNot', cond_net.Proto().external_output[-1],
  170. self.cnt_net_))
  171. else:
  172. step = control.IfNot('myIfNot', cond_net, self.cnt_net_)
  173. self.BuildAndRunPlan(step)
  174. self.CheckNetOutput([(self.cnt_net_, expect)])
  175. def testIfNotCondTrueOnNet(self):
  176. self.IfNotCondTest(self.true_cond_net_, 0, False)
  177. def testIfNotCondTrueOnBlob(self):
  178. self.IfNotCondTest(self.true_cond_net_, 0, True)
  179. def testIfNotCondFalseOnNet(self):
  180. self.IfNotCondTest(self.false_cond_net_, 1, False)
  181. def testIfNotCondFalseOnBlob(self):
  182. self.IfNotCondTest(self.false_cond_net_, 1, True)
  183. def IfNotElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
  184. if cond_value:
  185. run_net = self.cnt_2_net_
  186. else:
  187. run_net = self.cnt_net_
  188. if cond_on_blob:
  189. step = control.Do(
  190. 'if-not-else',
  191. control.Do('count', cond_net),
  192. control.IfNot('myIfNotElse',
  193. cond_net.Proto().external_output[-1],
  194. self.cnt_net_, self.cnt_2_net_))
  195. else:
  196. step = control.IfNot('myIfNotElse', cond_net,
  197. self.cnt_net_, self.cnt_2_net_)
  198. self.BuildAndRunPlan(step)
  199. self.CheckNetOutput([(run_net, expect)])
  200. def testIfNotElseCondTrueOnNet(self):
  201. self.IfNotElseCondTest(self.true_cond_net_, True, 2, False)
  202. def testIfNotElseCondTrueOnBlob(self):
  203. self.IfNotElseCondTest(self.true_cond_net_, True, 2, True)
  204. def testIfNotElseCondFalseOnNet(self):
  205. self.IfNotElseCondTest(self.false_cond_net_, False, 1, False)
  206. def testIfNotElseCondFalseOnBlob(self):
  207. self.IfNotElseCondTest(self.false_cond_net_, False, 1, True)
  208. def testSwitch(self):
  209. step = control.Switch(
  210. 'mySwitch',
  211. (self.false_cond_net_, self.cnt_net_),
  212. (self.true_cond_net_, self.cnt_2_net_)
  213. )
  214. self.BuildAndRunPlan(step)
  215. self.CheckNetOutput([(self.cnt_net_, 0), (self.cnt_2_net_, 2)])
  216. def testSwitchNot(self):
  217. step = control.SwitchNot(
  218. 'mySwitchNot',
  219. (self.false_cond_net_, self.cnt_net_),
  220. (self.true_cond_net_, self.cnt_2_net_)
  221. )
  222. self.BuildAndRunPlan(step)
  223. self.CheckNetOutput([(self.cnt_net_, 1), (self.cnt_2_net_, 0)])
  224. def testBoolNet(self):
  225. bool_net = control.BoolNet(('a', True))
  226. step = control.Do('bool', bool_net)
  227. self.BuildAndRunPlan(step)
  228. self.CheckNetAllOutput(bool_net, [True])
  229. bool_net = control.BoolNet(('a', True), ('b', False))
  230. step = control.Do('bool', bool_net)
  231. self.BuildAndRunPlan(step)
  232. self.CheckNetAllOutput(bool_net, [True, False])
  233. bool_net = control.BoolNet([('a', True), ('b', False)])
  234. step = control.Do('bool', bool_net)
  235. self.BuildAndRunPlan(step)
  236. self.CheckNetAllOutput(bool_net, [True, False])
  237. def testCombineConditions(self):
  238. # combined by 'Or'
  239. combine_net = control.CombineConditions(
  240. 'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
  241. step = control.Do('combine',
  242. self.true_cond_net_,
  243. self.false_cond_net_,
  244. combine_net)
  245. self.BuildAndRunPlan(step)
  246. self.CheckNetOutput([(combine_net, True)])
  247. # combined by 'And'
  248. combine_net = control.CombineConditions(
  249. 'test', [self.true_cond_net_, self.false_cond_net_], 'And')
  250. step = control.Do('combine',
  251. self.true_cond_net_,
  252. self.false_cond_net_,
  253. combine_net)
  254. self.BuildAndRunPlan(step)
  255. self.CheckNetOutput([(combine_net, False)])
  256. def testMergeConditionNets(self):
  257. # merged by 'Or'
  258. merge_net = control.MergeConditionNets(
  259. 'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
  260. step = control.Do('merge', merge_net)
  261. self.BuildAndRunPlan(step)
  262. self.CheckNetOutput([(merge_net, True)])
  263. # merged by 'And'
  264. merge_net = control.MergeConditionNets(
  265. 'test', [self.true_cond_net_, self.false_cond_net_], 'And')
  266. step = control.Do('merge', merge_net)
  267. self.BuildAndRunPlan(step)
  268. self.CheckNetOutput([(merge_net, False)])