net_builder_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. from caffe2.python import workspace
  2. from caffe2.python.core import Plan, to_execution_step, Net
  3. from caffe2.python.task import Task, TaskGroup, final_output
  4. from caffe2.python.net_builder import ops, NetBuilder
  5. from caffe2.python.session import LocalSession
  6. import unittest
  7. import threading
  8. class PythonOpStats(object):
  9. lock = threading.Lock()
  10. num_instances = 0
  11. num_calls = 0
  12. def python_op_builder():
  13. PythonOpStats.lock.acquire()
  14. PythonOpStats.num_instances += 1
  15. PythonOpStats.lock.release()
  16. def my_op(inputs, outputs):
  17. PythonOpStats.lock.acquire()
  18. PythonOpStats.num_calls += 1
  19. PythonOpStats.lock.release()
  20. return my_op
  21. def _test_loop():
  22. x = ops.Const(5)
  23. y = ops.Const(0)
  24. with ops.loop():
  25. ops.stop_if(ops.EQ([x, ops.Const(0)]))
  26. ops.Add([x, ops.Const(-1)], [x])
  27. ops.Add([y, ops.Const(1)], [y])
  28. return y
  29. def _test_inner_stop(x):
  30. ops.stop_if(ops.LT([x, ops.Const(5)]))
  31. def _test_outer():
  32. x = ops.Const(10)
  33. # test stop_if(False)
  34. with ops.stop_guard() as g1:
  35. _test_inner_stop(x)
  36. # test stop_if(True)
  37. y = ops.Const(3)
  38. with ops.stop_guard() as g2:
  39. _test_inner_stop(y)
  40. # test no stop
  41. with ops.stop_guard() as g4:
  42. ops.Const(0)
  43. # test empty clause
  44. with ops.stop_guard() as g3:
  45. pass
  46. return (
  47. g1.has_stopped(), g2.has_stopped(), g3.has_stopped(), g4.has_stopped())
  48. def _test_if(x):
  49. y = ops.Const(1)
  50. with ops.If(ops.GT([x, ops.Const(50)])):
  51. ops.Const(2, blob_out=y)
  52. with ops.If(ops.LT([x, ops.Const(50)])):
  53. ops.Const(3, blob_out=y)
  54. ops.stop()
  55. ops.Const(4, blob_out=y)
  56. return y
  57. class TestNetBuilder(unittest.TestCase):
  58. def test_ops(self):
  59. with NetBuilder() as nb:
  60. y = _test_loop()
  61. z, w, a, b = _test_outer()
  62. p = _test_if(ops.Const(75))
  63. q = _test_if(ops.Const(25))
  64. plan = Plan('name')
  65. plan.AddStep(to_execution_step(nb))
  66. ws = workspace.C.Workspace()
  67. ws.run(plan)
  68. expected = [
  69. (y, 5),
  70. (z, False),
  71. (w, True),
  72. (a, False),
  73. (b, False),
  74. (p, 2),
  75. (q, 3),
  76. ]
  77. for b, expected in expected:
  78. actual = ws.blobs[str(b)].fetch()
  79. self.assertEquals(actual, expected)
  80. def _expected_loop(self):
  81. total = 0
  82. total_large = 0
  83. total_small = 0
  84. total_tiny = 0
  85. for loop_iter in range(10):
  86. outer = loop_iter * 10
  87. for inner_iter in range(loop_iter):
  88. val = outer + inner_iter
  89. if val >= 80:
  90. total_large += val
  91. elif val >= 50:
  92. total_small += val
  93. else:
  94. total_tiny += val
  95. total += val
  96. return total, total_large, total_small, total_tiny
  97. def _actual_loop(self):
  98. total = ops.Const(0)
  99. total_large = ops.Const(0)
  100. total_small = ops.Const(0)
  101. total_tiny = ops.Const(0)
  102. with ops.loop(10) as loop:
  103. outer = ops.Mul([loop.iter(), ops.Const(10)])
  104. with ops.loop(loop.iter()) as inner:
  105. val = ops.Add([outer, inner.iter()])
  106. with ops.If(ops.GE([val, ops.Const(80)])) as c:
  107. ops.Add([total_large, val], [total_large])
  108. with c.Elif(ops.GE([val, ops.Const(50)])) as c:
  109. ops.Add([total_small, val], [total_small])
  110. with c.Else():
  111. ops.Add([total_tiny, val], [total_tiny])
  112. ops.Add([total, val], total)
  113. return [
  114. final_output(x)
  115. for x in [total, total_large, total_small, total_tiny]
  116. ]
  117. def test_net_multi_use(self):
  118. with Task() as task:
  119. total = ops.Const(0)
  120. net = Net('my_net')
  121. net.Add([total, net.Const(1)], [total])
  122. ops.net(net)
  123. ops.net(net)
  124. result = final_output(total)
  125. with LocalSession() as session:
  126. session.run(task)
  127. self.assertEquals(2, result.fetch())
  128. def test_loops(self):
  129. with Task() as task:
  130. out_actual = self._actual_loop()
  131. with LocalSession() as session:
  132. session.run(task)
  133. expected = self._expected_loop()
  134. actual = [o.fetch() for o in out_actual]
  135. for e, a in zip(expected, actual):
  136. self.assertEquals(e, a)
  137. def test_setup(self):
  138. with Task() as task:
  139. with ops.task_init():
  140. one = ops.Const(1)
  141. two = ops.Add([one, one])
  142. with ops.task_init():
  143. three = ops.Const(3)
  144. accum = ops.Add([two, three])
  145. # here, accum should be 5
  146. with ops.task_exit():
  147. # here, accum should be 6, since this executes after lines below
  148. seven_1 = ops.Add([accum, one])
  149. six = ops.Add([accum, one])
  150. ops.Add([accum, one], [accum])
  151. seven_2 = ops.Add([accum, one])
  152. o6 = final_output(six)
  153. o7_1 = final_output(seven_1)
  154. o7_2 = final_output(seven_2)
  155. with LocalSession() as session:
  156. session.run(task)
  157. self.assertEquals(o6.fetch(), 6)
  158. self.assertEquals(o7_1.fetch(), 7)
  159. self.assertEquals(o7_2.fetch(), 7)
  160. def test_multi_instance_python_op(self):
  161. """
  162. When task instances are created at runtime, C++ concurrently creates
  163. multiple instances of operators in C++, and concurrently destroys them
  164. once the task is finished. This means that the destructor of PythonOp
  165. will be called concurrently, so the GIL must be acquired. This
  166. test exercises this condition.
  167. """
  168. with Task(num_instances=64) as task:
  169. with ops.loop(4):
  170. ops.Python((python_op_builder, [], {}))([], [])
  171. with LocalSession() as session:
  172. PythonOpStats.num_instances = 0
  173. PythonOpStats.num_calls = 0
  174. session.run(task)
  175. self.assertEquals(PythonOpStats.num_instances, 64)
  176. self.assertEquals(PythonOpStats.num_calls, 256)
  177. def test_multi_instance(self):
  178. NUM_INSTANCES = 10
  179. NUM_ITERS = 15
  180. with TaskGroup() as tg:
  181. with Task(num_instances=NUM_INSTANCES):
  182. with ops.task_init():
  183. counter1 = ops.CreateCounter([], ['global_counter'])
  184. counter2 = ops.CreateCounter([], ['global_counter2'])
  185. counter3 = ops.CreateCounter([], ['global_counter3'])
  186. # both task_counter and local_counter should be thread local
  187. with ops.task_instance_init():
  188. task_counter = ops.CreateCounter([], ['task_counter'])
  189. local_counter = ops.CreateCounter([], ['local_counter'])
  190. with ops.loop(NUM_ITERS):
  191. ops.CountUp(counter1)
  192. ops.CountUp(task_counter)
  193. ops.CountUp(local_counter)
  194. # gather sum of squares of local counters to make sure that
  195. # each local counter counted exactly up to NUM_ITERS, and
  196. # that there was no false sharing of counter instances.
  197. with ops.task_instance_exit():
  198. count2 = ops.RetrieveCount(task_counter)
  199. with ops.loop(ops.Mul([count2, count2])):
  200. ops.CountUp(counter2)
  201. # This should have the same effect as the above
  202. count3 = ops.RetrieveCount(local_counter)
  203. with ops.loop(ops.Mul([count3, count3])):
  204. ops.CountUp(counter3)
  205. # The code below will only run once
  206. with ops.task_exit():
  207. total1 = final_output(ops.RetrieveCount(counter1))
  208. total2 = final_output(ops.RetrieveCount(counter2))
  209. total3 = final_output(ops.RetrieveCount(counter3))
  210. with LocalSession() as session:
  211. session.run(tg)
  212. self.assertEquals(total1.fetch(), NUM_INSTANCES * NUM_ITERS)
  213. self.assertEquals(total2.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
  214. self.assertEquals(total3.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
  215. def test_if_net(self):
  216. with NetBuilder() as nb:
  217. x0 = ops.Const(0)
  218. x1 = ops.Const(1)
  219. x2 = ops.Const(2)
  220. y0 = ops.Const(0)
  221. y1 = ops.Const(1)
  222. y2 = ops.Const(2)
  223. # basic logic
  224. first_res = ops.Const(0)
  225. with ops.IfNet(ops.Const(True)):
  226. ops.Const(1, blob_out=first_res)
  227. with ops.Else():
  228. ops.Const(2, blob_out=first_res)
  229. second_res = ops.Const(0)
  230. with ops.IfNet(ops.Const(False)):
  231. ops.Const(1, blob_out=second_res)
  232. with ops.Else():
  233. ops.Const(2, blob_out=second_res)
  234. # nested and sequential ifs,
  235. # empty then/else,
  236. # passing outer blobs into branches,
  237. # writing into outer blobs, incl. into input blob
  238. # using local blobs
  239. with ops.IfNet(ops.LT([x0, x1])):
  240. local_blob = ops.Const(900)
  241. ops.Add([ops.Const(100), local_blob], [y0])
  242. gt = ops.GT([x1, x2])
  243. with ops.IfNet(gt):
  244. # empty then
  245. pass
  246. with ops.Else():
  247. ops.Add([y1, local_blob], [local_blob])
  248. ops.Add([ops.Const(100), y1], [y1])
  249. with ops.IfNet(ops.EQ([local_blob, ops.Const(901)])):
  250. ops.Const(7, blob_out=y2)
  251. ops.Add([y1, y2], [y2])
  252. with ops.Else():
  253. # empty else
  254. pass
  255. plan = Plan('if_net_test')
  256. plan.AddStep(to_execution_step(nb))
  257. ws = workspace.C.Workspace()
  258. ws.run(plan)
  259. first_res_value = ws.blobs[str(first_res)].fetch()
  260. second_res_value = ws.blobs[str(second_res)].fetch()
  261. y0_value = ws.blobs[str(y0)].fetch()
  262. y1_value = ws.blobs[str(y1)].fetch()
  263. y2_value = ws.blobs[str(y2)].fetch()
  264. self.assertEquals(first_res_value, 1)
  265. self.assertEquals(second_res_value, 2)
  266. self.assertEquals(y0_value, 1000)
  267. self.assertEquals(y1_value, 101)
  268. self.assertEquals(y2_value, 108)
  269. self.assertTrue(str(local_blob) not in ws.blobs)
  270. def test_while_net(self):
  271. with NetBuilder() as nb:
  272. x = ops.Const(0)
  273. y = ops.Const(0)
  274. with ops.WhileNet():
  275. with ops.Condition():
  276. ops.Add([x, ops.Const(1)], [x])
  277. ops.LT([x, ops.Const(7)])
  278. ops.Add([x, y], [y])
  279. plan = Plan('while_net_test')
  280. plan.AddStep(to_execution_step(nb))
  281. ws = workspace.C.Workspace()
  282. ws.run(plan)
  283. x_value = ws.blobs[str(x)].fetch()
  284. y_value = ws.blobs[str(y)].fetch()
  285. self.assertEqual(x_value, 7)
  286. self.assertEqual(y_value, 21)