| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331 |
- from caffe2.python import control, core, test_util, workspace
- import logging
- logger = logging.getLogger(__name__)
- class TestControl(test_util.TestCase):
- def setUp(self):
- super(TestControl, self).setUp()
- self.N_ = 10
- self.init_net_ = core.Net("init-net")
- cnt = self.init_net_.CreateCounter([], init_count=0)
- const_n = self.init_net_.ConstantFill(
- [], shape=[], value=self.N_, dtype=core.DataType.INT64)
- const_0 = self.init_net_.ConstantFill(
- [], shape=[], value=0, dtype=core.DataType.INT64)
- self.cnt_net_ = core.Net("cnt-net")
- self.cnt_net_.CountUp([cnt])
- curr_cnt = self.cnt_net_.RetrieveCount([cnt])
- self.init_net_.ConstantFill(
- [], [curr_cnt], shape=[], value=0, dtype=core.DataType.INT64)
- self.cnt_net_.AddExternalOutput(curr_cnt)
- self.cnt_2_net_ = core.Net("cnt-2-net")
- self.cnt_2_net_.CountUp([cnt])
- self.cnt_2_net_.CountUp([cnt])
- curr_cnt_2 = self.cnt_2_net_.RetrieveCount([cnt])
- self.init_net_.ConstantFill(
- [], [curr_cnt_2], shape=[], value=0, dtype=core.DataType.INT64)
- self.cnt_2_net_.AddExternalOutput(curr_cnt_2)
- self.cond_net_ = core.Net("cond-net")
- cond_blob = self.cond_net_.LT([curr_cnt, const_n])
- self.cond_net_.AddExternalOutput(cond_blob)
- self.not_cond_net_ = core.Net("not-cond-net")
- cond_blob = self.not_cond_net_.GE([curr_cnt, const_n])
- self.not_cond_net_.AddExternalOutput(cond_blob)
- self.true_cond_net_ = core.Net("true-cond-net")
- true_blob = self.true_cond_net_.LT([const_0, const_n])
- self.true_cond_net_.AddExternalOutput(true_blob)
- self.false_cond_net_ = core.Net("false-cond-net")
- false_blob = self.false_cond_net_.GT([const_0, const_n])
- self.false_cond_net_.AddExternalOutput(false_blob)
- self.idle_net_ = core.Net("idle-net")
- self.idle_net_.ConstantFill(
- [], shape=[], value=0, dtype=core.DataType.INT64)
- def CheckNetOutput(self, nets_and_expects):
- """
- Check the net output is expected
- nets_and_expects is a list of tuples (net, expect)
- """
- for net, expect in nets_and_expects:
- output = workspace.FetchBlob(
- net.Proto().external_output[-1])
- self.assertEqual(output, expect)
- def CheckNetAllOutput(self, net, expects):
- """
- Check the net output is expected
- expects is a list of bools.
- """
- self.assertEqual(len(net.Proto().external_output), len(expects))
- for i in range(len(expects)):
- output = workspace.FetchBlob(
- net.Proto().external_output[i])
- self.assertEqual(output, expects[i])
- def BuildAndRunPlan(self, step):
- plan = core.Plan("test")
- plan.AddStep(control.Do('init', self.init_net_))
- plan.AddStep(step)
- self.assertEqual(workspace.RunPlan(plan), True)
- def ForLoopTest(self, nets_or_steps):
- step = control.For('myFor', nets_or_steps, self.N_)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(self.cnt_net_, self.N_)])
- def testForLoopWithNets(self):
- self.ForLoopTest(self.cnt_net_)
- self.ForLoopTest([self.cnt_net_, self.idle_net_])
- def testForLoopWithStep(self):
- step = control.Do('count', self.cnt_net_)
- self.ForLoopTest(step)
- self.ForLoopTest([step, self.idle_net_])
- def WhileLoopTest(self, nets_or_steps):
- step = control.While('myWhile', self.cond_net_, nets_or_steps)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(self.cnt_net_, self.N_)])
- def testWhileLoopWithNet(self):
- self.WhileLoopTest(self.cnt_net_)
- self.WhileLoopTest([self.cnt_net_, self.idle_net_])
- def testWhileLoopWithStep(self):
- step = control.Do('count', self.cnt_net_)
- self.WhileLoopTest(step)
- self.WhileLoopTest([step, self.idle_net_])
- def UntilLoopTest(self, nets_or_steps):
- step = control.Until('myUntil', self.not_cond_net_, nets_or_steps)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(self.cnt_net_, self.N_)])
- def testUntilLoopWithNet(self):
- self.UntilLoopTest(self.cnt_net_)
- self.UntilLoopTest([self.cnt_net_, self.idle_net_])
- def testUntilLoopWithStep(self):
- step = control.Do('count', self.cnt_net_)
- self.UntilLoopTest(step)
- self.UntilLoopTest([step, self.idle_net_])
- def DoWhileLoopTest(self, nets_or_steps):
- step = control.DoWhile('myDoWhile', self.cond_net_, nets_or_steps)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(self.cnt_net_, self.N_)])
- def testDoWhileLoopWithNet(self):
- self.DoWhileLoopTest(self.cnt_net_)
- self.DoWhileLoopTest([self.idle_net_, self.cnt_net_])
- def testDoWhileLoopWithStep(self):
- step = control.Do('count', self.cnt_net_)
- self.DoWhileLoopTest(step)
- self.DoWhileLoopTest([self.idle_net_, step])
- def DoUntilLoopTest(self, nets_or_steps):
- step = control.DoUntil('myDoUntil', self.not_cond_net_, nets_or_steps)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(self.cnt_net_, self.N_)])
- def testDoUntilLoopWithNet(self):
- self.DoUntilLoopTest(self.cnt_net_)
- self.DoUntilLoopTest([self.cnt_net_, self.idle_net_])
- def testDoUntilLoopWithStep(self):
- step = control.Do('count', self.cnt_net_)
- self.DoUntilLoopTest(step)
- self.DoUntilLoopTest([self.idle_net_, step])
- def IfCondTest(self, cond_net, expect, cond_on_blob):
- if cond_on_blob:
- step = control.Do(
- 'if-all',
- control.Do('count', cond_net),
- control.If('myIf', cond_net.Proto().external_output[-1],
- self.cnt_net_))
- else:
- step = control.If('myIf', cond_net, self.cnt_net_)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(self.cnt_net_, expect)])
- def testIfCondTrueOnNet(self):
- self.IfCondTest(self.true_cond_net_, 1, False)
- def testIfCondTrueOnBlob(self):
- self.IfCondTest(self.true_cond_net_, 1, True)
- def testIfCondFalseOnNet(self):
- self.IfCondTest(self.false_cond_net_, 0, False)
- def testIfCondFalseOnBlob(self):
- self.IfCondTest(self.false_cond_net_, 0, True)
- def IfElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
- if cond_value:
- run_net = self.cnt_net_
- else:
- run_net = self.cnt_2_net_
- if cond_on_blob:
- step = control.Do(
- 'if-else-all',
- control.Do('count', cond_net),
- control.If('myIfElse', cond_net.Proto().external_output[-1],
- self.cnt_net_, self.cnt_2_net_))
- else:
- step = control.If('myIfElse', cond_net,
- self.cnt_net_, self.cnt_2_net_)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(run_net, expect)])
- def testIfElseCondTrueOnNet(self):
- self.IfElseCondTest(self.true_cond_net_, True, 1, False)
- def testIfElseCondTrueOnBlob(self):
- self.IfElseCondTest(self.true_cond_net_, True, 1, True)
- def testIfElseCondFalseOnNet(self):
- self.IfElseCondTest(self.false_cond_net_, False, 2, False)
- def testIfElseCondFalseOnBlob(self):
- self.IfElseCondTest(self.false_cond_net_, False, 2, True)
- def IfNotCondTest(self, cond_net, expect, cond_on_blob):
- if cond_on_blob:
- step = control.Do(
- 'if-not',
- control.Do('count', cond_net),
- control.IfNot('myIfNot', cond_net.Proto().external_output[-1],
- self.cnt_net_))
- else:
- step = control.IfNot('myIfNot', cond_net, self.cnt_net_)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(self.cnt_net_, expect)])
- def testIfNotCondTrueOnNet(self):
- self.IfNotCondTest(self.true_cond_net_, 0, False)
- def testIfNotCondTrueOnBlob(self):
- self.IfNotCondTest(self.true_cond_net_, 0, True)
- def testIfNotCondFalseOnNet(self):
- self.IfNotCondTest(self.false_cond_net_, 1, False)
- def testIfNotCondFalseOnBlob(self):
- self.IfNotCondTest(self.false_cond_net_, 1, True)
- def IfNotElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
- if cond_value:
- run_net = self.cnt_2_net_
- else:
- run_net = self.cnt_net_
- if cond_on_blob:
- step = control.Do(
- 'if-not-else',
- control.Do('count', cond_net),
- control.IfNot('myIfNotElse',
- cond_net.Proto().external_output[-1],
- self.cnt_net_, self.cnt_2_net_))
- else:
- step = control.IfNot('myIfNotElse', cond_net,
- self.cnt_net_, self.cnt_2_net_)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(run_net, expect)])
- def testIfNotElseCondTrueOnNet(self):
- self.IfNotElseCondTest(self.true_cond_net_, True, 2, False)
- def testIfNotElseCondTrueOnBlob(self):
- self.IfNotElseCondTest(self.true_cond_net_, True, 2, True)
- def testIfNotElseCondFalseOnNet(self):
- self.IfNotElseCondTest(self.false_cond_net_, False, 1, False)
- def testIfNotElseCondFalseOnBlob(self):
- self.IfNotElseCondTest(self.false_cond_net_, False, 1, True)
- def testSwitch(self):
- step = control.Switch(
- 'mySwitch',
- (self.false_cond_net_, self.cnt_net_),
- (self.true_cond_net_, self.cnt_2_net_)
- )
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(self.cnt_net_, 0), (self.cnt_2_net_, 2)])
- def testSwitchNot(self):
- step = control.SwitchNot(
- 'mySwitchNot',
- (self.false_cond_net_, self.cnt_net_),
- (self.true_cond_net_, self.cnt_2_net_)
- )
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(self.cnt_net_, 1), (self.cnt_2_net_, 0)])
- def testBoolNet(self):
- bool_net = control.BoolNet(('a', True))
- step = control.Do('bool', bool_net)
- self.BuildAndRunPlan(step)
- self.CheckNetAllOutput(bool_net, [True])
- bool_net = control.BoolNet(('a', True), ('b', False))
- step = control.Do('bool', bool_net)
- self.BuildAndRunPlan(step)
- self.CheckNetAllOutput(bool_net, [True, False])
- bool_net = control.BoolNet([('a', True), ('b', False)])
- step = control.Do('bool', bool_net)
- self.BuildAndRunPlan(step)
- self.CheckNetAllOutput(bool_net, [True, False])
- def testCombineConditions(self):
- # combined by 'Or'
- combine_net = control.CombineConditions(
- 'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
- step = control.Do('combine',
- self.true_cond_net_,
- self.false_cond_net_,
- combine_net)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(combine_net, True)])
- # combined by 'And'
- combine_net = control.CombineConditions(
- 'test', [self.true_cond_net_, self.false_cond_net_], 'And')
- step = control.Do('combine',
- self.true_cond_net_,
- self.false_cond_net_,
- combine_net)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(combine_net, False)])
- def testMergeConditionNets(self):
- # merged by 'Or'
- merge_net = control.MergeConditionNets(
- 'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
- step = control.Do('merge', merge_net)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(merge_net, True)])
- # merged by 'And'
- merge_net = control.MergeConditionNets(
- 'test', [self.true_cond_net_, self.false_cond_net_], 'And')
- step = control.Do('merge', merge_net)
- self.BuildAndRunPlan(step)
- self.CheckNetOutput([(merge_net, False)])
|