transformations_test.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # Copyright (c) 2016-present, Facebook, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. ##############################################################################
  15. from hypothesis import given
  16. import hypothesis.strategies as st
  17. import numpy as np
  18. from caffe2.python.transformations import Transformer
  19. from caffe2.python import core, workspace
  20. from caffe2.python import test_util as tu
  21. transformer = Transformer()
  22. class TestTransformations(tu.TestCase):
  23. def _base_test_net(self):
  24. net = core.Net("net")
  25. net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
  26. return net
  27. def _add_nnpack(self, net):
  28. transformer.AddNNPACK(net)
  29. assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
  30. def _fuse_nnpack_convrelu(self, net, expected_result_num_ops,
  31. expected_activation_arg=True):
  32. self._add_nnpack(net)
  33. transformer.FuseNNPACKConvRelu(net)
  34. self.assertEquals(tu.numOps(net), expected_result_num_ops)
  35. has_activation_arg = False
  36. for arg in net.Proto().op[0].arg:
  37. if tu.str_compare(arg.name, "activation"):
  38. assert tu.str_compare(arg.s, "Relu")
  39. has_activation_arg = True
  40. if expected_activation_arg:
  41. assert has_activation_arg
  42. else:
  43. assert not has_activation_arg
  44. def test_transformer_AddNNPACK(self):
  45. net = self._base_test_net()
  46. net.Relu(["Y"], ["Y2"])
  47. self._add_nnpack(net)
  48. def test_transformer_FuseNNPACKConvRelu(self):
  49. net = self._base_test_net()
  50. net.Relu(["Y"], ["Y2"])
  51. self._fuse_nnpack_convrelu(net, 1)
  52. def test_noFuseNNPACKConvRelu(self):
  53. net = self._base_test_net()
  54. net.Relu(["Y"], ["Y2"])
  55. net.Relu(["Y"], ["Y3"])
  56. self._fuse_nnpack_convrelu(net, 3, expected_activation_arg=False)
  57. def test_transformer_FuseNNPACKConvReluNoInplace(self):
  58. net = self._base_test_net()
  59. net.Relu(["Y"], ["X"])
  60. self._fuse_nnpack_convrelu(net, 1)
  61. assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
  62. def test_transformer_FuseNNPACKConvReluInplaceRelu(self):
  63. net = self._base_test_net()
  64. net.Relu(["Y"], ["Y"])
  65. self._fuse_nnpack_convrelu(net, 1)
  66. assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
  67. def test_transformer_FuseNNPACKConvReluPingPongNaming(self):
  68. net = self._base_test_net()
  69. net.Relu(["Y"], ["X"])
  70. net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
  71. self._fuse_nnpack_convrelu(net, 2)
  72. assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
  73. assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
  74. def test_transformer_FuseNNPACKConvReluFollowedByMultipleInputOp(self):
  75. net = self._base_test_net()
  76. net.Relu(["Y"], ["Y2"])
  77. net.Conv(["Y2", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
  78. net.Relu(["Y"], ["Y2"])
  79. self._fuse_nnpack_convrelu(net, 2)
  80. assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
  81. assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
  82. def test_transformer_FuseNNPACKConvReluInplaceFollowedByMultipleInputOp(self):
  83. net = self._base_test_net()
  84. net.Relu(["Y"], ["Y"])
  85. net.Conv(["Y", "w", "b"], ["Y2"], stride=1, pad=0, kernel=3, order="NCHW")
  86. net.Relu(["Y2"], ["Y2"])
  87. self._fuse_nnpack_convrelu(net, 2)
  88. assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
  89. assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
  90. @given(
  91. size=st.integers(7, 10),
  92. input_channels=st.integers(1, 10),
  93. seed=st.integers(0, 65535),
  94. order=st.sampled_from(["NCHW", "NHWC"]),
  95. epsilon=st.floats(min_value=1e-5, max_value=1e-2),
  96. )
  97. def test_transformer_FuseConvBN(self, size, input_channels, seed, order, epsilon):
  98. workspace.ResetWorkspace()
  99. net = core.Net("net")
  100. c = input_channels
  101. h = size
  102. w = size
  103. k = 3
  104. net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=k, order=order)
  105. net.SpatialBN(
  106. ["Y", "scale", "bias", "mean", "var"],
  107. ["Y2"],
  108. is_test=True,
  109. order=order,
  110. epsilon=epsilon,
  111. )
  112. np.random.seed(seed)
  113. if order == "NCHW":
  114. tu.randBlobFloat32("X", 1, c, h, w)
  115. tu.randBlobFloat32("w", c, c, k, k)
  116. else:
  117. tu.randBlobFloat32("X", 1, h, w, c)
  118. tu.randBlobFloat32("w", c, k, k, c)
  119. tu.randBlobsFloat32(["b", "scale", "bias", "mean"], c)
  120. # This is necessary because 1/sqrt(var) is used and if var is too small
  121. # we get floating point artifacts that cause test failures
  122. tu.randBlobFloat32("var", c, offset=0.5)
  123. workspace.RunNetOnce(net)
  124. preTransformOutput = workspace.FetchBlob("Y2").flatten()
  125. workspace.FeedBlob("Y2", np.zeros((1, 1)))
  126. transformer.FuseConvBN(net)
  127. # Ensure fusion
  128. assert tu.numOps(net) == 1
  129. workspace.RunNetOnce(net)
  130. postTransformOutput = workspace.FetchBlob("Y2").flatten()
  131. # Check that there is no numerical difference
  132. assert np.allclose(
  133. preTransformOutput,
  134. postTransformOutput,
  135. rtol=5e-02,
  136. atol=1e-03
  137. )
  138. @given(
  139. size=st.integers(7, 10),
  140. input_channels=st.integers(1, 10),
  141. seed=st.integers(0, 65535),
  142. order=st.sampled_from(["NCHW", "NHWC"]),
  143. epsilon=st.floats(min_value=1e-5, max_value=1e-2),
  144. )
  145. def test_transformer_FuseConvBNNoConvBias(self, size, input_channels, seed, order, epsilon):
  146. workspace.ResetWorkspace()
  147. net = core.Net("net")
  148. c = input_channels
  149. h = size
  150. w = size
  151. k = 3
  152. net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order)
  153. net.SpatialBN(
  154. ["Y", "scale", "bias", "mean", "var"],
  155. ["Y2"],
  156. is_test=True,
  157. order=order,
  158. epsilon=epsilon,
  159. )
  160. np.random.seed(seed)
  161. if order == "NCHW":
  162. tu.randBlobFloat32("X", 1, c, h, w)
  163. tu.randBlobFloat32("w", c, c, k, k)
  164. else:
  165. tu.randBlobFloat32("X", 1, h, w, c)
  166. tu.randBlobFloat32("w", c, k, k, c)
  167. tu.randBlobsFloat32(["scale", "bias", "mean"], c)
  168. # This is necessary because 1/sqrt(var) is used and if var is too small
  169. # we get floating point artifacts that cause test failures
  170. tu.randBlobFloat32("var", c, offset=0.5)
  171. workspace.RunNetOnce(net)
  172. preTransformOutput = workspace.FetchBlob("Y2").flatten()
  173. workspace.FeedBlob("Y2", np.zeros((1, 1)))
  174. transformer.FuseConvBN(net)
  175. # Ensure fusion
  176. assert tu.numOps(net) == 1
  177. workspace.RunNetOnce(net)
  178. postTransformOutput = workspace.FetchBlob("Y2").flatten()
  179. # Check that there is no numerical difference
  180. assert np.allclose(
  181. preTransformOutput,
  182. postTransformOutput,
  183. rtol=5e-02,
  184. atol=1e-03
  185. )
  186. @given(
  187. size=st.integers(7, 10),
  188. input_channels=st.integers(1, 10),
  189. seed=st.integers(0, 65535),
  190. order=st.sampled_from(["NCHW", "NHWC"]),
  191. epsilon=st.floats(min_value=1e-5, max_value=1e-2),
  192. )
  193. def test_transformer_FuseConvBNNoConvBiasDuplicatedName(self, size, input_channels, seed, order, epsilon):
  194. workspace.ResetWorkspace()
  195. net = core.Net("net")
  196. c = input_channels
  197. h = size
  198. w = size
  199. k = 3
  200. net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order)
  201. net.SpatialBN(
  202. ["Y", "scale", "_bias0", "mean", "var"],
  203. ["Y2"],
  204. is_test=True,
  205. order=order,
  206. epsilon=epsilon,
  207. )
  208. np.random.seed(seed)
  209. if order == "NCHW":
  210. tu.randBlobFloat32("X", 1, c, h, w)
  211. tu.randBlobFloat32("w", c, c, k, k)
  212. else:
  213. tu.randBlobFloat32("X", 1, h, w, c)
  214. tu.randBlobFloat32("w", c, k, k, c)
  215. tu.randBlobsFloat32(["scale", "_bias0", "mean"], c)
  216. # This is necessary because 1/sqrt(var) is used and if var is too small
  217. # we get floating point artifacts that cause test failures
  218. tu.randBlobFloat32("var", c, offset=0.5)
  219. workspace.RunNetOnce(net)
  220. preTransformOutput = workspace.FetchBlob("Y2").flatten()
  221. workspace.FeedBlob("Y2", np.zeros((1, 1)))
  222. transformer.FuseConvBN(net)
  223. # Ensure fusion
  224. assert tu.numOps(net) == 1
  225. workspace.RunNetOnce(net)
  226. postTransformOutput = workspace.FetchBlob("Y2").flatten()
  227. print("pre")
  228. print(preTransformOutput)
  229. print("after")
  230. print(postTransformOutput)
  231. # Check that there is no numerical difference
  232. assert np.allclose(
  233. preTransformOutput,
  234. postTransformOutput,
  235. rtol=5e-02,
  236. atol=1e-03
  237. )
  238. @given(
  239. size=st.integers(7, 10),
  240. input_channels=st.integers(1, 10),
  241. kt=st.integers(3, 5),
  242. kh=st.integers(3, 5),
  243. kw=st.integers(3, 5),
  244. seed=st.integers(0, 65535),
  245. epsilon=st.floats(min_value=1e-5, max_value=1e-2),
  246. )
  247. def test_transformer_FuseConv3DBN(
  248. self, size, input_channels, kt, kh, kw, seed, epsilon
  249. ):
  250. workspace.ResetWorkspace()
  251. net = core.Net("net")
  252. c = input_channels
  253. t = size
  254. h = size
  255. w = size
  256. net.Conv(
  257. ["X", "w", "b"],
  258. ["Y"],
  259. kernels=[kt, kh, kw],
  260. )
  261. net.SpatialBN(
  262. ["Y", "scale", "bias", "mean", "var"],
  263. ["Y2"],
  264. is_test=True,
  265. epsilon=epsilon,
  266. )
  267. np.random.seed(seed)
  268. tu.randBlobFloat32("X", 1, c, t, h, w)
  269. tu.randBlobFloat32("w", c, c, kt, kh, kw)
  270. tu.randBlobsFloat32(["b", "scale", "bias", "mean"], c)
  271. # This is necessary because 1/sqrt(var) is used and if var is too small
  272. # we get floating point artifacts that cause test failures
  273. tu.randBlobFloat32("var", c, offset=0.5)
  274. workspace.RunNetOnce(net)
  275. preTransformOutput = workspace.FetchBlob("Y2").flatten()
  276. workspace.FeedBlob("Y2", np.zeros((1, 1)))
  277. transformer.FuseConvBN(net)
  278. # Ensure fusion
  279. assert tu.numOps(net) == 1
  280. workspace.RunNetOnce(net)
  281. postTransformOutput = workspace.FetchBlob("Y2").flatten()
  282. # Check that there is no numerical difference
  283. assert np.allclose(
  284. preTransformOutput,
  285. postTransformOutput,
  286. rtol=1e-02,
  287. atol=1e-04
  288. )
  289. def test_converterDontEnforceUnusedInputs(self):
  290. net = core.Net("net")
  291. net.Relu(["X"], ["Y"])
  292. net.Proto().external_input.extend(["fake"])
  293. # This should now work
  294. transformer.AddNNPACK(net) # just testing the converter
  295. def test_converterDontEnforceUnusedOutputs(self):
  296. net = core.Net("net")
  297. net.Relu(["X"], ["Y"])
  298. net.Proto().external_output.extend(["fake"])
  299. transformer.AddNNPACK(net) # just testing the converter