brew_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. from caffe2.python import brew, core, scope, workspace
  2. from caffe2.python.modeling.parameter_info import ParameterTags
  3. from caffe2.python.model_helper import ModelHelper
  4. from caffe2.python.cnn import CNNModelHelper
  5. import unittest
  6. import numpy as np
  7. class BrewTest(unittest.TestCase):
  8. def setUp(self):
  9. def myhelper(model, val=-1):
  10. return val
  11. if not brew.has_helper(myhelper):
  12. brew.Register(myhelper)
  13. self.myhelper = myhelper
  14. def myhelper2(model, val=-1):
  15. return val
  16. if not brew.has_helper(myhelper2):
  17. brew.Register(myhelper2)
  18. self.myhelper2 = myhelper2
  19. self.model = ModelHelper(name="test_model")
  20. def test_dropout(self):
  21. p = 0.2
  22. X = np.ones((100, 100)).astype(np.float32) - p
  23. workspace.FeedBlob("x", X)
  24. model = ModelHelper(name="test_model")
  25. brew.dropout(model, "x", "out", is_test=False)
  26. workspace.RunNetOnce(model.param_init_net)
  27. workspace.RunNetOnce(model.net)
  28. out = workspace.FetchBlob("out")
  29. self.assertLess(abs(out.mean() - (1 - p)), 0.05)
  30. def test_fc(self):
  31. m, n, k = (15, 15, 15)
  32. X = np.random.rand(m, k).astype(np.float32) - 0.5
  33. workspace.FeedBlob("x", X)
  34. model = ModelHelper(name="test_model")
  35. brew.fc(model, "x", "out_1", k, n)
  36. model.Validate()
  37. workspace.RunNetOnce(model.param_init_net)
  38. workspace.RunNetOnce(model.net)
  39. def test_relu(self):
  40. Xpos = np.ones((5, 5)).astype(np.float32) - 0.5
  41. Xneg = np.ones((5, 5)).astype(np.float32) - 1.5
  42. workspace.FeedBlob("xpos", Xpos)
  43. workspace.FeedBlob("xneg", Xneg)
  44. model = ModelHelper(name="test_model")
  45. brew.relu(model, "xpos", "out_xpos")
  46. brew.relu(model, "xneg", "out_xneg")
  47. model.Validate()
  48. workspace.RunNetOnce(model.param_init_net)
  49. workspace.RunNetOnce(model.net)
  50. pos = workspace.FetchBlob("out_xpos")
  51. self.assertAlmostEqual(pos.mean(), 0.5)
  52. neg = workspace.FetchBlob("out_xneg")
  53. self.assertAlmostEqual(neg.mean(), 0)
  54. def test_tanh(self):
  55. X = np.ones((5, 5)).astype(np.float32) - 0.5
  56. workspace.FeedBlob("x", X)
  57. model = ModelHelper(name="test_model")
  58. brew.tanh(model, "x", "out_tanh")
  59. model.Validate()
  60. workspace.RunNetOnce(model.param_init_net)
  61. workspace.RunNetOnce(model.net)
  62. out = workspace.FetchBlob("out_tanh")
  63. self.assertAlmostEqual(out.mean(), np.tanh(0.5), places=5)
  64. def test_validate(self):
  65. model = ModelHelper(name="test_model")
  66. model.params.append("aaa")
  67. model.params.append("bbb")
  68. self.assertEqual(model._Validate(), [])
  69. model.params.append("xxx")
  70. model.params.append("bbb")
  71. self.assertEqual(model._Validate(), ["bbb"])
  72. def test_arg_scope(self):
  73. myhelper = self.myhelper
  74. myhelper2 = self.myhelper2
  75. n = 15
  76. with brew.arg_scope([myhelper], val=n):
  77. res = brew.myhelper(self.model)
  78. self.assertEqual(n, res)
  79. with brew.arg_scope([myhelper, myhelper2], val=n):
  80. res1 = brew.myhelper(self.model)
  81. res2 = brew.myhelper2(self.model)
  82. self.assertEqual([n, n], [res1, res2])
  83. def test_arg_scope_single(self):
  84. X = np.random.rand(64, 3, 32, 32).astype(np.float32) - 0.5
  85. workspace.FeedBlob("x", X)
  86. model = ModelHelper(name="test_model")
  87. with brew.arg_scope(
  88. brew.conv,
  89. stride=2,
  90. pad=2,
  91. weight_init=('XavierFill', {}),
  92. bias_init=('ConstantFill', {})
  93. ):
  94. brew.conv(
  95. model=model,
  96. blob_in="x",
  97. blob_out="out",
  98. dim_in=3,
  99. dim_out=64,
  100. kernel=3,
  101. )
  102. model.Validate()
  103. workspace.RunNetOnce(model.param_init_net)
  104. workspace.RunNetOnce(model.net)
  105. out = workspace.FetchBlob("out")
  106. self.assertEqual(out.shape, (64, 64, 17, 17))
  107. def test_arg_scope_nested(self):
  108. myhelper = self.myhelper
  109. n = 16
  110. with brew.arg_scope([myhelper], val=-3), \
  111. brew.arg_scope([myhelper], val=-2):
  112. with brew.arg_scope([myhelper], val=n):
  113. res = brew.myhelper(self.model)
  114. self.assertEqual(n, res)
  115. res = brew.myhelper(self.model)
  116. self.assertEqual(res, -2)
  117. res = brew.myhelper(self.model, val=15)
  118. self.model.Validate()
  119. self.assertEqual(res, 15)
  120. def test_double_register(self):
  121. myhelper = self.myhelper
  122. with self.assertRaises(AttributeError):
  123. brew.Register(myhelper)
  124. def test_has_helper(self):
  125. self.assertTrue(brew.has_helper(brew.conv))
  126. self.assertTrue(brew.has_helper("conv"))
  127. def myhelper3():
  128. pass
  129. self.assertFalse(brew.has_helper(myhelper3))
  130. def test_model_helper(self):
  131. X = np.random.rand(64, 32, 32, 3).astype(np.float32) - 0.5
  132. workspace.FeedBlob("x", X)
  133. my_arg_scope = {'order': 'NHWC'}
  134. model = ModelHelper(name="test_model", arg_scope=my_arg_scope)
  135. with brew.arg_scope(
  136. brew.conv,
  137. stride=2,
  138. pad=2,
  139. weight_init=('XavierFill', {}),
  140. bias_init=('ConstantFill', {})
  141. ):
  142. brew.conv(
  143. model=model,
  144. blob_in="x",
  145. blob_out="out",
  146. dim_in=3,
  147. dim_out=64,
  148. kernel=[8, 3]
  149. )
  150. model.Validate()
  151. workspace.RunNetOnce(model.param_init_net)
  152. workspace.RunNetOnce(model.net)
  153. out = workspace.FetchBlob("out")
  154. self.assertEqual(out.shape, (64, 15, 17, 64))
  155. def test_cnn_model_helper_deprecated(self):
  156. X = np.random.rand(64, 32, 32, 3).astype(np.float32) - 0.5
  157. workspace.FeedBlob("x", X)
  158. # CNNModelHelper is going to be deprecated soon. This test is only
  159. # covering some CNNModelHelper logic
  160. model = CNNModelHelper(name="test_model", order='NHWC')
  161. self.assertEqual(model.arg_scope['order'], 'NHWC')
  162. def test_get_params(self):
  163. def param(x):
  164. return core.ScopedBlobReference(x)
  165. def to_str_list(x):
  166. return sorted([str(p) for p in x])
  167. model = ModelHelper(name="test_model")
  168. model.AddParameter(param("a"))
  169. model.AddParameter(param("b"), tags=ParameterTags.COMPUTED_PARAM)
  170. with scope.NameScope("c"):
  171. model.AddParameter(param("a"))
  172. model.AddParameter(param("d"), tags=ParameterTags.COMPUTED_PARAM)
  173. self.assertEqual(to_str_list(model.GetParams()), ['c/a'])
  174. self.assertEqual(to_str_list(model.GetComputedParams()), ['c/d'])
  175. self.assertEqual(to_str_list(model.GetAllParams()), ['c/a', 'c/d'])
  176. # Get AllParams from the global Scope
  177. self.assertEqual(to_str_list(model.GetAllParams('')), [
  178. 'a', 'b', 'c/a', 'c/d'])
  179. self.assertEqual(to_str_list(model.GetParams()), ['a', 'c/a'])
  180. self.assertEqual(to_str_list(model.GetComputedParams()), ['b', 'c/d'])
  181. self.assertEqual(to_str_list(model.GetAllParams()),
  182. ['a', 'b', 'c/a', 'c/d'])
  183. self.assertEqual(to_str_list(model.GetAllParams('')),
  184. ['a', 'b', 'c/a', 'c/d'])
  185. # Get AllParams from the scope 'c'
  186. self.assertEqual(to_str_list(model.GetAllParams('c')), ['c/a', 'c/d'])
  187. self.assertEqual(to_str_list(model.GetAllParams('c/')), ['c/a', 'c/d'])
  188. def test_param_consistence(self):
  189. model = ModelHelper(name='test_mode')
  190. cnv = brew.conv(model, 'data', 'cnv', 32, 32, 4)
  191. step_model = ModelHelper(name='step_model', param_model=model)
  192. a = brew.fc(step_model, cnv, 'a', 100, 200)
  193. brew.fc(model, a, 'b', 200, 5)
  194. # test the _parameters_info is shared between model and step_model
  195. self.assertEqual(model._parameters_info, step_model._parameters_info)
  196. def test_cond(self):
  197. workspace.FeedBlob("cond", np.array(True))
  198. workspace.FeedBlob("then_value", np.array(1))
  199. workspace.FeedBlob("else_value", np.array(2))
  200. then_model = ModelHelper(name="then_test_model")
  201. then_model.net.Copy("then_value", "output_blob")
  202. else_model = ModelHelper(name="else_test_model")
  203. else_model.net.Copy("else_value", "output_blob")
  204. model = ModelHelper(name="test_model")
  205. brew.cond(
  206. model=model,
  207. cond_blob="cond",
  208. external_blobs=["then_value", "else_value", "output_blob"],
  209. then_model=then_model,
  210. else_model=else_model)
  211. workspace.RunNetOnce(model.param_init_net)
  212. workspace.RunNetOnce(model.net)
  213. output_value = workspace.FetchBlob("output_blob")
  214. self.assertEqual(output_value, 1)
  215. workspace.FeedBlob("cond", np.array(False))
  216. workspace.RunNetOnce(model.param_init_net)
  217. workspace.RunNetOnce(model.net)
  218. output_value = workspace.FetchBlob("output_blob")
  219. self.assertEqual(output_value, 2)
  220. def test_loop(self):
  221. workspace.FeedBlob("cond", np.array(True))
  222. workspace.FeedBlob("ONE", np.array(1))
  223. workspace.FeedBlob("TWO", np.array(2))
  224. workspace.FeedBlob("TEN", np.array(10))
  225. workspace.FeedBlob("counter", np.array(0))
  226. workspace.FeedBlob("output_blob", np.array(0))
  227. loop_model = ModelHelper(name="loop_test_model")
  228. loop_model.net.Add(["output_blob", "TWO"], "output_blob")
  229. cond_model = ModelHelper(name="cond_test_model")
  230. cond_model.net.Add(["counter", "ONE"], "counter")
  231. comp_res = cond_model.net.LT(["counter", "TEN"])
  232. cond_model.net.Copy(comp_res, "cond")
  233. model = ModelHelper(name="test_model")
  234. brew.loop(
  235. model=model,
  236. cond_blob="cond",
  237. external_blobs=["cond", "ONE", "TWO", "TEN", "counter", "output_blob"],
  238. loop_model=loop_model,
  239. cond_model=cond_model)
  240. workspace.RunNetOnce(model.param_init_net)
  241. workspace.RunNetOnce(model.net)
  242. output_value = workspace.FetchBlob("output_blob")
  243. self.assertEqual(output_value, 18)
  244. @unittest.skipIf(not workspace.has_gpu_support, "No gpu support.")
  245. class BrewGPUTest(unittest.TestCase):
  246. def test_relu(self):
  247. Xpos = np.ones((5, 5)).astype(np.float32) - 0.5
  248. Xneg = np.ones((5, 5)).astype(np.float32) - 1.5
  249. workspace.FeedBlob("xpos", Xpos)
  250. workspace.FeedBlob("xneg", Xneg)
  251. model = ModelHelper(name="test_model")
  252. brew.relu(model, "xpos", "out_xpos", use_cudnn=True)
  253. brew.relu(model, "xneg", "out_xneg", use_cudnn=True)
  254. model.Validate()
  255. workspace.RunNetOnce(model.param_init_net)
  256. workspace.RunNetOnce(model.net)
  257. pos = workspace.FetchBlob("out_xpos")
  258. self.assertAlmostEqual(pos.mean(), 0.5)
  259. neg = workspace.FetchBlob("out_xneg")
  260. self.assertAlmostEqual(neg.mean(), 0)
  261. def test_tanh(self):
  262. X = np.ones((5, 5)).astype(np.float32) - 0.5
  263. workspace.FeedBlob("x", X)
  264. model = ModelHelper(name="test_model")
  265. brew.tanh(model, "x", "out_tanh", use_cudnn=True)
  266. model.Validate()
  267. workspace.RunNetOnce(model.param_init_net)
  268. workspace.RunNetOnce(model.net)
  269. out = workspace.FetchBlob("out_tanh")
  270. self.assertAlmostEqual(out.mean(), np.tanh(0.5), places=5)