workspace_test.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939
  1. import errno
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. from collections import namedtuple
  7. from typing import List
  8. import caffe2.python.hypothesis_test_util as htu
  9. import hypothesis.strategies as st
  10. import numpy as np
  11. import torch
  12. from torch import Tensor
  13. from caffe2.proto import caffe2_pb2
  14. from caffe2.python import core, test_util, workspace, model_helper, brew
  15. from hypothesis import given, settings
  16. class TestWorkspace(unittest.TestCase):
  17. def setUp(self):
  18. self.net = core.Net("test-net")
  19. self.testblob_ref = self.net.ConstantFill(
  20. [], "testblob", shape=[1, 2, 3, 4], value=1.0
  21. )
  22. workspace.ResetWorkspace()
  23. def testRootFolder(self):
  24. self.assertEqual(workspace.ResetWorkspace(), True)
  25. self.assertEqual(workspace.RootFolder(), ".")
  26. self.assertEqual(workspace.ResetWorkspace("/tmp/caffe-workspace-test"), True)
  27. self.assertEqual(workspace.RootFolder(), "/tmp/caffe-workspace-test")
  28. def testWorkspaceHasBlobWithNonexistingName(self):
  29. self.assertEqual(workspace.HasBlob("non-existing"), False)
  30. def testRunOperatorOnce(self):
  31. self.assertEqual(
  32. workspace.RunOperatorOnce(self.net.Proto().op[0].SerializeToString()), True
  33. )
  34. self.assertEqual(workspace.HasBlob("testblob"), True)
  35. blobs = workspace.Blobs()
  36. self.assertEqual(len(blobs), 1)
  37. self.assertEqual(blobs[0], "testblob")
  38. def testGetOperatorCost(self):
  39. op = core.CreateOperator(
  40. "Conv2D",
  41. ["X", "W"],
  42. ["Y"],
  43. stride_h=1,
  44. stride_w=1,
  45. pad_t=1,
  46. pad_l=1,
  47. pad_b=1,
  48. pad_r=1,
  49. kernel=3,
  50. )
  51. X = np.zeros((1, 8, 8, 8))
  52. W = np.zeros((1, 1, 3, 3))
  53. workspace.FeedBlob("X", X)
  54. workspace.FeedBlob("W", W)
  55. op_cost = workspace.GetOperatorCost(op.SerializeToString(), ["X", "W"])
  56. self.assertTupleEqual(
  57. op_cost,
  58. namedtuple("Cost", ["flops", "bytes_written", "bytes_read"])(
  59. 1152, 256, 4168
  60. ),
  61. )
  62. def testRunNetOnce(self):
  63. self.assertEqual(
  64. workspace.RunNetOnce(self.net.Proto().SerializeToString()), True
  65. )
  66. self.assertEqual(workspace.HasBlob("testblob"), True)
  67. def testCurrentWorkspaceWrapper(self):
  68. self.assertNotIn("testblob", workspace.C.Workspace.current.blobs)
  69. self.assertEqual(
  70. workspace.RunNetOnce(self.net.Proto().SerializeToString()), True
  71. )
  72. self.assertEqual(workspace.HasBlob("testblob"), True)
  73. self.assertIn("testblob", workspace.C.Workspace.current.blobs)
  74. workspace.ResetWorkspace()
  75. self.assertNotIn("testblob", workspace.C.Workspace.current.blobs)
  76. def testRunPlan(self):
  77. plan = core.Plan("test-plan")
  78. plan.AddStep(core.ExecutionStep("test-step", self.net))
  79. self.assertEqual(workspace.RunPlan(plan.Proto().SerializeToString()), True)
  80. self.assertEqual(workspace.HasBlob("testblob"), True)
  81. def testRunPlanInBackground(self):
  82. plan = core.Plan("test-plan")
  83. plan.AddStep(core.ExecutionStep("test-step", self.net))
  84. background_plan = workspace.RunPlanInBackground(plan)
  85. while not background_plan.is_done():
  86. pass
  87. self.assertEqual(background_plan.is_succeeded(), True)
  88. self.assertEqual(workspace.HasBlob("testblob"), True)
  89. def testConstructPlanFromSteps(self):
  90. step = core.ExecutionStep("test-step-as-plan", self.net)
  91. self.assertEqual(workspace.RunPlan(step), True)
  92. self.assertEqual(workspace.HasBlob("testblob"), True)
  93. def testResetWorkspace(self):
  94. self.assertEqual(
  95. workspace.RunNetOnce(self.net.Proto().SerializeToString()), True
  96. )
  97. self.assertEqual(workspace.HasBlob("testblob"), True)
  98. self.assertEqual(workspace.ResetWorkspace(), True)
  99. self.assertEqual(workspace.HasBlob("testblob"), False)
  100. def testTensorAccess(self):
  101. ws = workspace.C.Workspace()
  102. """ test in-place modification """
  103. ws.create_blob("tensor").feed(np.array([1.1, 1.2, 1.3]))
  104. tensor = ws.blobs["tensor"].tensor()
  105. tensor.data[0] = 3.3
  106. val = np.array([3.3, 1.2, 1.3])
  107. np.testing.assert_array_equal(tensor.data, val)
  108. np.testing.assert_array_equal(ws.blobs["tensor"].fetch(), val)
  109. """ test in-place initialization """
  110. tensor.init([2, 3], core.DataType.INT32)
  111. for x in range(2):
  112. for y in range(3):
  113. tensor.data[x, y] = 0
  114. tensor.data[1, 1] = 100
  115. val = np.zeros([2, 3], dtype=np.int32)
  116. val[1, 1] = 100
  117. np.testing.assert_array_equal(tensor.data, val)
  118. np.testing.assert_array_equal(ws.blobs["tensor"].fetch(), val)
  119. """ strings cannot be initialized from python """
  120. with self.assertRaises(RuntimeError):
  121. tensor.init([3, 4], core.DataType.STRING)
  122. """ feed (copy) data into tensor """
  123. val = np.array([[b"abc", b"def"], [b"ghi", b"jkl"]], dtype=np.object)
  124. tensor.feed(val)
  125. self.assertEquals(tensor.data[0, 0], b"abc")
  126. np.testing.assert_array_equal(ws.blobs["tensor"].fetch(), val)
  127. val = np.array([1.1, 10.2])
  128. tensor.feed(val)
  129. val[0] = 5.2
  130. self.assertEquals(tensor.data[0], 1.1)
  131. """ fetch (copy) data from tensor """
  132. val = np.array([1.1, 1.2])
  133. tensor.feed(val)
  134. val2 = tensor.fetch()
  135. tensor.data[0] = 5.2
  136. val3 = tensor.fetch()
  137. np.testing.assert_array_equal(val, val2)
  138. self.assertEquals(val3[0], 5.2)
  139. def testFetchFeedBlob(self):
  140. self.assertEqual(
  141. workspace.RunNetOnce(self.net.Proto().SerializeToString()), True
  142. )
  143. fetched = workspace.FetchBlob("testblob")
  144. # check if fetched is correct.
  145. self.assertEqual(fetched.shape, (1, 2, 3, 4))
  146. np.testing.assert_array_equal(fetched, 1.0)
  147. fetched[:] = 2.0
  148. self.assertEqual(workspace.FeedBlob("testblob", fetched), True)
  149. fetched_again = workspace.FetchBlob("testblob")
  150. self.assertEqual(fetched_again.shape, (1, 2, 3, 4))
  151. np.testing.assert_array_equal(fetched_again, 2.0)
  152. def testFetchFeedBlobViaBlobReference(self):
  153. self.assertEqual(
  154. workspace.RunNetOnce(self.net.Proto().SerializeToString()), True
  155. )
  156. fetched = workspace.FetchBlob(self.testblob_ref)
  157. # check if fetched is correct.
  158. self.assertEqual(fetched.shape, (1, 2, 3, 4))
  159. np.testing.assert_array_equal(fetched, 1.0)
  160. fetched[:] = 2.0
  161. self.assertEqual(workspace.FeedBlob(self.testblob_ref, fetched), True)
  162. fetched_again = workspace.FetchBlob("testblob") # fetch by name now
  163. self.assertEqual(fetched_again.shape, (1, 2, 3, 4))
  164. np.testing.assert_array_equal(fetched_again, 2.0)
  165. def testFetchFeedBlobTypes(self):
  166. for dtype in [
  167. np.float16,
  168. np.float32,
  169. np.float64,
  170. np.bool,
  171. np.int8,
  172. np.int16,
  173. np.int32,
  174. np.int64,
  175. np.uint8,
  176. np.uint16,
  177. ]:
  178. try:
  179. rng = np.iinfo(dtype).max * 2
  180. except ValueError:
  181. rng = 1000
  182. data = ((np.random.rand(2, 3, 4) - 0.5) * rng).astype(dtype)
  183. self.assertEqual(workspace.FeedBlob("testblob_types", data), True)
  184. fetched_back = workspace.FetchBlob("testblob_types")
  185. self.assertEqual(fetched_back.shape, (2, 3, 4))
  186. self.assertEqual(fetched_back.dtype, dtype)
  187. np.testing.assert_array_equal(fetched_back, data)
  188. def testFetchFeedBlobBool(self):
  189. """Special case for bool to ensure coverage of both true and false."""
  190. data = np.zeros((2, 3, 4)).astype(np.bool)
  191. data.flat[::2] = True
  192. self.assertEqual(workspace.FeedBlob("testblob_types", data), True)
  193. fetched_back = workspace.FetchBlob("testblob_types")
  194. self.assertEqual(fetched_back.shape, (2, 3, 4))
  195. self.assertEqual(fetched_back.dtype, np.bool)
  196. np.testing.assert_array_equal(fetched_back, data)
  197. def testGetBlobSizeBytes(self):
  198. for dtype in [
  199. np.float16,
  200. np.float32,
  201. np.float64,
  202. np.bool,
  203. np.int8,
  204. np.int16,
  205. np.int32,
  206. np.int64,
  207. np.uint8,
  208. np.uint16,
  209. ]:
  210. data = np.random.randn(2, 3).astype(dtype)
  211. self.assertTrue(workspace.FeedBlob("testblob_sizeBytes", data), True)
  212. self.assertEqual(
  213. workspace.GetBlobSizeBytes("testblob_sizeBytes"),
  214. 6 * np.dtype(dtype).itemsize,
  215. )
  216. strs1 = np.array([b"Hello World!", b"abcd"])
  217. strs2 = np.array([b"element1", b"element2"])
  218. strs1_len, strs2_len = 0, 0
  219. for str in strs1:
  220. strs1_len += len(str)
  221. for str in strs2:
  222. strs2_len += len(str)
  223. self.assertTrue(workspace.FeedBlob("testblob_str1", strs1), True)
  224. self.assertTrue(workspace.FeedBlob("testblob_str2", strs2), True)
  225. # size of blob "testblob_str1" = size_str1 * meta_.itemsize() + strs1_len
  226. # size of blob "testblob_str2" = size_str2 * meta_.itemsize() + strs2_len
  227. self.assertEqual(
  228. workspace.GetBlobSizeBytes("testblob_str1")
  229. - workspace.GetBlobSizeBytes("testblob_str2"),
  230. strs1_len - strs2_len,
  231. )
  232. def testFetchFeedBlobZeroDim(self):
  233. data = np.empty(shape=(2, 0, 3), dtype=np.float32)
  234. self.assertEqual(workspace.FeedBlob("testblob_empty", data), True)
  235. fetched_back = workspace.FetchBlob("testblob_empty")
  236. self.assertEqual(fetched_back.shape, (2, 0, 3))
  237. self.assertEqual(fetched_back.dtype, np.float32)
  238. def testFetchFeedLongStringTensor(self):
  239. # long strings trigger array of object creation
  240. strs = np.array(
  241. [
  242. b" ".join(10 * [b"long string"]),
  243. b" ".join(128 * [b"very long string"]),
  244. b"small \0\1\2 string",
  245. b"Hello, world! I have special \0 symbols \1!",
  246. ]
  247. )
  248. workspace.FeedBlob("my_str_tensor", strs)
  249. strs2 = workspace.FetchBlob("my_str_tensor")
  250. self.assertEqual(strs.shape, strs2.shape)
  251. for i in range(0, strs.shape[0]):
  252. self.assertEqual(strs[i], strs2[i])
  253. def testFetchFeedShortStringTensor(self):
  254. # small strings trigger NPY_STRING array
  255. strs = np.array([b"elem1", b"elem 2", b"element 3"])
  256. workspace.FeedBlob("my_str_tensor_2", strs)
  257. strs2 = workspace.FetchBlob("my_str_tensor_2")
  258. self.assertEqual(strs.shape, strs2.shape)
  259. for i in range(0, strs.shape[0]):
  260. self.assertEqual(strs[i], strs2[i])
  261. def testFetchFeedPlainString(self):
  262. # this is actual string, not a tensor of strings
  263. s = b"Hello, world! I have special \0 symbols \1!"
  264. workspace.FeedBlob("my_plain_string", s)
  265. s2 = workspace.FetchBlob("my_plain_string")
  266. self.assertEqual(s, s2)
  267. def testFetchBlobs(self):
  268. s1 = b"test1"
  269. s2 = b"test2"
  270. workspace.FeedBlob("s1", s1)
  271. workspace.FeedBlob("s2", s2)
  272. fetch1, fetch2 = workspace.FetchBlobs(["s1", "s2"])
  273. self.assertEquals(s1, fetch1)
  274. self.assertEquals(s2, fetch2)
  275. def testFetchFeedViaBlobDict(self):
  276. self.assertEqual(
  277. workspace.RunNetOnce(self.net.Proto().SerializeToString()), True
  278. )
  279. fetched = workspace.blobs["testblob"]
  280. # check if fetched is correct.
  281. self.assertEqual(fetched.shape, (1, 2, 3, 4))
  282. np.testing.assert_array_equal(fetched, 1.0)
  283. fetched[:] = 2.0
  284. workspace.blobs["testblob"] = fetched
  285. fetched_again = workspace.blobs["testblob"]
  286. self.assertEqual(fetched_again.shape, (1, 2, 3, 4))
  287. np.testing.assert_array_equal(fetched_again, 2.0)
  288. self.assertTrue("testblob" in workspace.blobs)
  289. self.assertFalse("non_existant" in workspace.blobs)
  290. self.assertEqual(len(workspace.blobs), 1)
  291. for key in workspace.blobs:
  292. self.assertEqual(key, "testblob")
  293. def testTorchInterop(self):
  294. workspace.RunOperatorOnce(
  295. core.CreateOperator(
  296. "ConstantFill", [], "foo", shape=(4,), value=2, dtype=10
  297. )
  298. )
  299. t = workspace.FetchTorch("foo")
  300. t.resize_(5)
  301. t[4] = t[2] = 777
  302. np.testing.assert_array_equal(t.numpy(), np.array([2, 2, 777, 2, 777]))
  303. np.testing.assert_array_equal(
  304. workspace.FetchBlob("foo"), np.array([2, 2, 777, 2, 777])
  305. )
  306. z = torch.ones((4,), dtype=torch.int64)
  307. workspace.FeedBlob("bar", z)
  308. workspace.RunOperatorOnce(
  309. core.CreateOperator("Reshape", ["bar"], ["bar", "_"], shape=(2, 2))
  310. )
  311. z[0, 1] = 123
  312. np.testing.assert_array_equal(
  313. workspace.FetchBlob("bar"), np.array([[1, 123], [1, 1]])
  314. )
  315. np.testing.assert_array_equal(z, np.array([[1, 123], [1, 1]]))
  316. class TestMultiWorkspaces(unittest.TestCase):
  317. def setUp(self):
  318. workspace.SwitchWorkspace("default")
  319. workspace.ResetWorkspace()
  320. def testCreateWorkspace(self):
  321. self.net = core.Net("test-net")
  322. self.net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0)
  323. self.assertEqual(
  324. workspace.RunNetOnce(self.net.Proto().SerializeToString()), True
  325. )
  326. self.assertEqual(workspace.HasBlob("testblob"), True)
  327. self.assertEqual(workspace.SwitchWorkspace("test", True), None)
  328. self.assertEqual(workspace.HasBlob("testblob"), False)
  329. self.assertEqual(workspace.SwitchWorkspace("default"), None)
  330. self.assertEqual(workspace.HasBlob("testblob"), True)
  331. try:
  332. # The following should raise an error.
  333. workspace.SwitchWorkspace("non-existing")
  334. # so this should never happen.
  335. self.assertEqual(True, False)
  336. except RuntimeError:
  337. pass
  338. workspaces = workspace.Workspaces()
  339. self.assertTrue("default" in workspaces)
  340. self.assertTrue("test" in workspaces)
  341. @unittest.skipIf(not workspace.has_gpu_support, "No gpu support.")
  342. class TestWorkspaceGPU(test_util.TestCase):
  343. def setUp(self):
  344. workspace.ResetWorkspace()
  345. self.net = core.Net("test-net")
  346. self.net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0)
  347. self.net.RunAllOnGPU()
  348. def testFetchBlobGPU(self):
  349. self.assertEqual(
  350. workspace.RunNetOnce(self.net.Proto().SerializeToString()), True
  351. )
  352. fetched = workspace.FetchBlob("testblob")
  353. # check if fetched is correct.
  354. self.assertEqual(fetched.shape, (1, 2, 3, 4))
  355. np.testing.assert_array_equal(fetched, 1.0)
  356. fetched[:] = 2.0
  357. self.assertEqual(workspace.FeedBlob("testblob", fetched), True)
  358. fetched_again = workspace.FetchBlob("testblob")
  359. self.assertEqual(fetched_again.shape, (1, 2, 3, 4))
  360. np.testing.assert_array_equal(fetched_again, 2.0)
  361. def testGetGpuPeerAccessPattern(self):
  362. pattern = workspace.GetGpuPeerAccessPattern()
  363. self.assertEqual(type(pattern), np.ndarray)
  364. self.assertEqual(pattern.ndim, 2)
  365. self.assertEqual(pattern.shape[0], pattern.shape[1])
  366. self.assertEqual(pattern.shape[0], workspace.NumGpuDevices())
  367. @unittest.skipIf(
  368. not workspace.has_cuda_support, "Tensor interop doesn't yet work on ROCm"
  369. )
  370. def testTorchInterop(self):
  371. # CUDA has convenient mem stats, let's use them to make sure we didn't
  372. # leak memory
  373. initial_mem = torch.cuda.memory_allocated()
  374. workspace.RunOperatorOnce(
  375. core.CreateOperator(
  376. "ConstantFill",
  377. [],
  378. "foo",
  379. shape=(4,),
  380. value=2,
  381. dtype=10,
  382. device_option=core.DeviceOption(workspace.GpuDeviceType),
  383. )
  384. )
  385. t = workspace.FetchTorch("foo")
  386. t.resize_(5)
  387. self.assertTrue(t.is_cuda)
  388. t[4] = t[2] = 777
  389. np.testing.assert_array_equal(t.cpu().numpy(), np.array([2, 2, 777, 2, 777]))
  390. np.testing.assert_array_equal(
  391. workspace.FetchBlob("foo"), np.array([2, 2, 777, 2, 777])
  392. )
  393. z = torch.ones((4,), dtype=torch.int64, device="cuda")
  394. workspace.FeedBlob("bar", z)
  395. workspace.RunOperatorOnce(
  396. core.CreateOperator(
  397. "Reshape",
  398. ["bar"],
  399. ["bar", "_"],
  400. shape=(2, 2),
  401. device_option=core.DeviceOption(workspace.GpuDeviceType),
  402. )
  403. )
  404. z[0, 1] = 123
  405. np.testing.assert_array_equal(
  406. workspace.FetchBlob("bar"), np.array([[1, 123], [1, 1]])
  407. )
  408. np.testing.assert_array_equal(z.cpu(), np.array([[1, 123], [1, 1]]))
  409. self.assertGreater(torch.cuda.memory_allocated(), initial_mem)
  410. # clean up everything
  411. del t
  412. del z
  413. workspace.ResetWorkspace()
  414. self.assertEqual(torch.cuda.memory_allocated(), initial_mem)
  415. @unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.")
  416. class TestWorkspaceIDEEP(test_util.TestCase):
  417. def testFeedFetchBlobIDEEP(self):
  418. arr = np.random.randn(2, 3).astype(np.float32)
  419. workspace.FeedBlob("testblob_ideep", arr, core.DeviceOption(caffe2_pb2.IDEEP))
  420. fetched = workspace.FetchBlob("testblob_ideep")
  421. np.testing.assert_array_equal(arr, fetched)
  422. class TestImmedibate(test_util.TestCase):
  423. def testImmediateEnterExit(self):
  424. workspace.StartImmediate(i_know=True)
  425. self.assertTrue(workspace.IsImmediate())
  426. workspace.StopImmediate()
  427. self.assertFalse(workspace.IsImmediate())
  428. def testImmediateRunsCorrectly(self):
  429. workspace.StartImmediate(i_know=True)
  430. net = core.Net("test-net")
  431. net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0)
  432. self.assertEqual(workspace.ImmediateBlobs(), ["testblob"])
  433. content = workspace.FetchImmediate("testblob")
  434. # Also, the immediate mode should not invade the original namespace,
  435. # so we check if this is so.
  436. with self.assertRaises(RuntimeError):
  437. workspace.FetchBlob("testblob")
  438. np.testing.assert_array_equal(content, 1.0)
  439. content[:] = 2.0
  440. self.assertTrue(workspace.FeedImmediate("testblob", content))
  441. np.testing.assert_array_equal(workspace.FetchImmediate("testblob"), 2.0)
  442. workspace.StopImmediate()
  443. with self.assertRaises(RuntimeError):
  444. content = workspace.FetchImmediate("testblob")
  445. def testImmediateRootFolder(self):
  446. workspace.StartImmediate(i_know=True)
  447. # for testing we will look into the _immediate_root_folder variable
  448. # but in normal usage you should not access that.
  449. self.assertTrue(len(workspace._immediate_root_folder) > 0)
  450. root_folder = workspace._immediate_root_folder
  451. self.assertTrue(os.path.isdir(root_folder))
  452. workspace.StopImmediate()
  453. self.assertTrue(len(workspace._immediate_root_folder) == 0)
  454. # After termination, immediate mode should have the root folder
  455. # deleted.
  456. self.assertFalse(os.path.exists(root_folder))
  457. class TestCppEnforceAsException(test_util.TestCase):
  458. def testEnforce(self):
  459. op = core.CreateOperator("Relu", ["X"], ["Y"])
  460. with self.assertRaises(RuntimeError):
  461. workspace.RunOperatorOnce(op)
  462. class TestCWorkspace(htu.HypothesisTestCase):
  463. def test_net_execution(self):
  464. ws = workspace.C.Workspace()
  465. self.assertEqual(ws.nets, {})
  466. self.assertEqual(ws.blobs, {})
  467. net = core.Net("test-net")
  468. net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0)
  469. ws.create_net(net)
  470. # If we do not specify overwrite, this should raise an error.
  471. with self.assertRaises(RuntimeError):
  472. ws.create_net(net)
  473. # But, if we specify overwrite, this should pass.
  474. ws.create_net(net, True)
  475. # Overwrite can also be a kwarg.
  476. ws.create_net(net, overwrite=True)
  477. self.assertIn("testblob", ws.blobs)
  478. self.assertEqual(len(ws.nets), 1)
  479. net_name = net.Proto().name
  480. self.assertIn("test-net", net_name)
  481. net = ws.nets[net_name].run()
  482. blob = ws.blobs["testblob"]
  483. np.testing.assert_array_equal(
  484. np.ones((1, 2, 3, 4), dtype=np.float32), blob.fetch()
  485. )
  486. @given(name=st.text(), value=st.floats(min_value=-1, max_value=1.0))
  487. def test_operator_run(self, name, value):
  488. ws = workspace.C.Workspace()
  489. op = core.CreateOperator("ConstantFill", [], [name], shape=[1], value=value)
  490. ws.run(op)
  491. self.assertIn(name, ws.blobs)
  492. np.testing.assert_allclose(
  493. [value], ws.blobs[name].fetch(), atol=1e-4, rtol=1e-4
  494. )
  495. @given(
  496. blob_name=st.text(),
  497. net_name=st.text(),
  498. value=st.floats(min_value=-1, max_value=1.0),
  499. )
  500. def test_net_run(self, blob_name, net_name, value):
  501. ws = workspace.C.Workspace()
  502. net = core.Net(net_name)
  503. net.ConstantFill([], [blob_name], shape=[1], value=value)
  504. ws.run(net)
  505. self.assertIn(blob_name, ws.blobs)
  506. self.assertNotIn(net_name, ws.nets)
  507. np.testing.assert_allclose(
  508. [value], ws.blobs[blob_name].fetch(), atol=1e-4, rtol=1e-4
  509. )
  510. @given(
  511. blob_name=st.text(),
  512. net_name=st.text(),
  513. plan_name=st.text(),
  514. value=st.floats(min_value=-1, max_value=1.0),
  515. )
  516. def test_plan_run(self, blob_name, plan_name, net_name, value):
  517. ws = workspace.C.Workspace()
  518. plan = core.Plan(plan_name)
  519. net = core.Net(net_name)
  520. net.ConstantFill([], [blob_name], shape=[1], value=value)
  521. plan.AddStep(core.ExecutionStep("step", nets=[net], num_iter=1))
  522. ws.run(plan)
  523. self.assertIn(blob_name, ws.blobs)
  524. self.assertIn(net.Name(), ws.nets)
  525. np.testing.assert_allclose(
  526. [value], ws.blobs[blob_name].fetch(), atol=1e-4, rtol=1e-4
  527. )
  528. @given(
  529. blob_name=st.text(),
  530. net_name=st.text(),
  531. value=st.floats(min_value=-1, max_value=1.0),
  532. )
  533. def test_net_create(self, blob_name, net_name, value):
  534. ws = workspace.C.Workspace()
  535. net = core.Net(net_name)
  536. net.ConstantFill([], [blob_name], shape=[1], value=value)
  537. ws.create_net(net).run()
  538. self.assertIn(blob_name, ws.blobs)
  539. self.assertIn(net.Name(), ws.nets)
  540. np.testing.assert_allclose(
  541. [value], ws.blobs[blob_name].fetch(), atol=1e-4, rtol=1e-4
  542. )
  543. @given(
  544. name=st.text(),
  545. value=htu.tensor(),
  546. device_option=st.sampled_from(htu.device_options),
  547. )
  548. def test_array_serde(self, name, value, device_option):
  549. ws = workspace.C.Workspace()
  550. ws.create_blob(name).feed(value, device_option=device_option)
  551. self.assertIn(name, ws.blobs)
  552. blob = ws.blobs[name]
  553. np.testing.assert_equal(value, ws.blobs[name].fetch())
  554. serde_blob = ws.create_blob("{}_serde".format(name))
  555. serde_blob.deserialize(blob.serialize(name))
  556. np.testing.assert_equal(value, serde_blob.fetch())
  557. @given(name=st.text(), value=st.text())
  558. def test_string_serde(self, name, value):
  559. value = value.encode("ascii", "ignore")
  560. ws = workspace.C.Workspace()
  561. ws.create_blob(name).feed(value)
  562. self.assertIn(name, ws.blobs)
  563. blob = ws.blobs[name]
  564. self.assertEqual(value, ws.blobs[name].fetch())
  565. serde_blob = ws.create_blob("{}_serde".format(name))
  566. serde_blob.deserialize(blob.serialize(name))
  567. self.assertEqual(value, serde_blob.fetch())
  568. def test_exception(self):
  569. ws = workspace.C.Workspace()
  570. with self.assertRaises(TypeError):
  571. ws.create_net("...")
  572. class TestPredictor(unittest.TestCase):
  573. def _create_model(self):
  574. m = model_helper.ModelHelper()
  575. y = brew.fc(
  576. m,
  577. "data",
  578. "y",
  579. dim_in=4,
  580. dim_out=2,
  581. weight_init=("ConstantFill", dict(value=1.0)),
  582. bias_init=("ConstantFill", dict(value=0.0)),
  583. axis=0,
  584. )
  585. m.net.AddExternalOutput(y)
  586. return m
  587. # Use this test with a bigger model to see how using Predictor allows to
  588. # avoid issues with low protobuf size limit in Python
  589. #
  590. # def test_predictor_predefined(self):
  591. # workspace.ResetWorkspace()
  592. # path = 'caffe2/caffe2/test/assets/'
  593. # with open(path + 'squeeze_predict_net.pb') as f:
  594. # self.predict_net = f.read()
  595. # with open(path + 'squeeze_init_net.pb') as f:
  596. # self.init_net = f.read()
  597. # self.predictor = workspace.Predictor(self.init_net, self.predict_net)
  598. # inputs = [np.zeros((1, 3, 256, 256), dtype='f')]
  599. # outputs = self.predictor.run(inputs)
  600. # self.assertEqual(len(outputs), 1)
  601. # self.assertEqual(outputs[0].shape, (1, 1000, 1, 1))
  602. # self.assertAlmostEqual(outputs[0][0][0][0][0], 5.19026289e-05)
  603. def test_predictor_memory_model(self):
  604. workspace.ResetWorkspace()
  605. m = self._create_model()
  606. workspace.FeedBlob("data", np.zeros([4], dtype="float32"))
  607. self.predictor = workspace.Predictor(
  608. workspace.StringifyProto(m.param_init_net.Proto()),
  609. workspace.StringifyProto(m.net.Proto()),
  610. )
  611. inputs = np.array([1, 3, 256, 256], dtype="float32")
  612. outputs = self.predictor.run([inputs])
  613. np.testing.assert_array_almost_equal(
  614. np.array([[516, 516]], dtype="float32"), outputs
  615. )
  616. class TestTransform(htu.HypothesisTestCase):
  617. @given(
  618. input_dim=st.integers(min_value=1, max_value=10),
  619. output_dim=st.integers(min_value=1, max_value=10),
  620. batch_size=st.integers(min_value=1, max_value=10),
  621. )
  622. def test_simple_transform(self, input_dim, output_dim, batch_size):
  623. m = model_helper.ModelHelper()
  624. fc1 = brew.fc(m, "data", "fc1", dim_in=input_dim, dim_out=output_dim)
  625. fc2 = brew.fc(m, fc1, "fc2", dim_in=output_dim, dim_out=output_dim)
  626. conv = brew.conv(
  627. m,
  628. fc2,
  629. "conv",
  630. dim_in=output_dim,
  631. dim_out=output_dim,
  632. use_cudnn=True,
  633. engine="CUDNN",
  634. kernel=3,
  635. )
  636. conv.Relu([], conv).Softmax([], "pred").LabelCrossEntropy(
  637. ["label"], ["xent"]
  638. ).AveragedLoss([], "loss")
  639. transformed_net_proto = workspace.ApplyTransform("ConvToNNPack", m.net.Proto())
  640. self.assertEqual(transformed_net_proto.op[2].engine, "NNPACK")
  641. @given(
  642. input_dim=st.integers(min_value=1, max_value=10),
  643. output_dim=st.integers(min_value=1, max_value=10),
  644. batch_size=st.integers(min_value=1, max_value=10),
  645. )
  646. @settings(deadline=10000)
  647. def test_registry_invalid(self, input_dim, output_dim, batch_size):
  648. m = model_helper.ModelHelper()
  649. brew.fc(m, "data", "fc1", dim_in=input_dim, dim_out=output_dim)
  650. with self.assertRaises(RuntimeError):
  651. workspace.ApplyTransform("definitely_not_a_real_transform", m.net.Proto())
  652. @given(value=st.floats(min_value=-1, max_value=1))
  653. @settings(deadline=10000)
  654. def test_apply_transform_if_faster(self, value):
  655. init_net = core.Net("init_net")
  656. init_net.ConstantFill([], ["data"], shape=[5, 5, 5, 5], value=value)
  657. init_net.ConstantFill([], ["conv_w"], shape=[5, 5, 3, 3], value=value)
  658. init_net.ConstantFill([], ["conv_b"], shape=[5], value=value)
  659. self.assertEqual(
  660. workspace.RunNetOnce(init_net.Proto().SerializeToString()), True
  661. )
  662. m = model_helper.ModelHelper()
  663. conv = brew.conv(
  664. m,
  665. "data",
  666. "conv",
  667. dim_in=5,
  668. dim_out=5,
  669. kernel=3,
  670. use_cudnn=True,
  671. engine="CUDNN",
  672. )
  673. conv.Relu([], conv).Softmax([], "pred").AveragedLoss([], "loss")
  674. self.assertEqual(workspace.RunNetOnce(m.net.Proto().SerializeToString()), True)
  675. proto = workspace.ApplyTransformIfFaster(
  676. "ConvToNNPack", m.net.Proto(), init_net.Proto()
  677. )
  678. self.assertEqual(workspace.RunNetOnce(proto.SerializeToString()), True)
  679. proto = workspace.ApplyTransformIfFaster(
  680. "ConvToNNPack",
  681. m.net.Proto(),
  682. init_net.Proto(),
  683. warmup_runs=10,
  684. main_runs=100,
  685. improvement_threshold=2.0,
  686. )
  687. self.assertEqual(workspace.RunNetOnce(proto.SerializeToString()), True)
  688. class MyModule(torch.jit.ScriptModule):
  689. def __init__(self):
  690. super(MyModule, self).__init__()
  691. self.mult = torch.nn.Parameter(torch.tensor([[1, 2, 3, 4, 5.0]]))
  692. @torch.jit.script_method
  693. def forward(self, x):
  694. return self.mult.mm(x)
  695. @torch.jit.script_method
  696. def multi_input(self, x: torch.Tensor, y: torch.Tensor, z: int = 2) -> torch.Tensor:
  697. return x + y + z
  698. @torch.jit.script_method
  699. def multi_input_tensor_list(self, tensor_list: List[Tensor]) -> Tensor:
  700. return tensor_list[0] + tensor_list[1] + tensor_list[2]
  701. @torch.jit.script_method
  702. def multi_output(self, x):
  703. return (x, x + 1)
  704. @unittest.skipIf(
  705. "ScriptModule" not in core._REGISTERED_OPERATORS,
  706. "Script module integration in Caffe2 is not enabled",
  707. )
  708. class TestScriptModule(test_util.TestCase):
  709. def _createFeedModule(self):
  710. workspace.FeedBlob("m", MyModule())
  711. def testCreation(self):
  712. m = MyModule()
  713. workspace.FeedBlob("module", m)
  714. m2 = workspace.FetchBlob("module")
  715. self.assertTrue(m2 is not None)
  716. def testForward(self):
  717. self._createFeedModule()
  718. val = np.random.rand(5, 5).astype(np.float32)
  719. param = np.array([[1, 2, 3, 4, 5]]).astype(np.float32)
  720. workspace.FeedBlob("w", val)
  721. workspace.RunOperatorOnce(
  722. core.CreateOperator("ScriptModule", ["m", "w"], ["y"])
  723. )
  724. np.testing.assert_almost_equal(
  725. workspace.FetchBlob("y"), np.matmul(param, val), decimal=5
  726. )
  727. def testMultiInputOutput(self):
  728. self._createFeedModule()
  729. val = np.random.rand(5, 5).astype(np.float32)
  730. workspace.FeedBlob("w", val)
  731. val2 = np.random.rand(5, 5).astype(np.float32)
  732. workspace.FeedBlob("w2", val2)
  733. workspace.RunOperatorOnce(
  734. core.CreateOperator(
  735. "ScriptModule", ["m", "w", "w2"], ["y"], method="multi_input"
  736. )
  737. )
  738. workspace.RunOperatorOnce(
  739. core.CreateOperator(
  740. "ScriptModule", ["m", "w"], ["y1", "y2"], method="multi_output"
  741. )
  742. )
  743. np.testing.assert_almost_equal(
  744. workspace.FetchBlob("y"), val + val2 + 2, decimal=5
  745. )
  746. np.testing.assert_almost_equal(workspace.FetchBlob("y1"), val, decimal=5)
  747. np.testing.assert_almost_equal(workspace.FetchBlob("y2"), val + 1, decimal=5)
  748. def testMultiTensorListInput(self):
  749. self._createFeedModule()
  750. val = np.random.rand(5, 5).astype(np.float32)
  751. workspace.FeedBlob("w", val)
  752. val2 = np.random.rand(5, 5).astype(np.float32)
  753. workspace.FeedBlob("w2", val2)
  754. val3 = np.random.rand(5, 5).astype(np.float32)
  755. workspace.FeedBlob("w3", val3)
  756. workspace.RunOperatorOnce(
  757. core.CreateOperator(
  758. "ScriptModule",
  759. ["m", "w", "w2", "w3"],
  760. ["y"],
  761. method="multi_input_tensor_list",
  762. pass_inputs_as_tensor_list=True,
  763. )
  764. )
  765. np.testing.assert_almost_equal(
  766. workspace.FetchBlob("y"), val + val2 + val3, decimal=5
  767. )
  768. def testSerialization(self):
  769. tmpdir = tempfile.mkdtemp()
  770. try:
  771. self._createFeedModule()
  772. workspace.RunOperatorOnce(
  773. core.CreateOperator(
  774. "Save",
  775. ["m"],
  776. [],
  777. absolute_path=1,
  778. db=os.path.join(tmpdir, "db"),
  779. db_type="minidb",
  780. )
  781. )
  782. workspace.ResetWorkspace()
  783. self.assertFalse(workspace.HasBlob("m"))
  784. workspace.RunOperatorOnce(
  785. core.CreateOperator(
  786. "Load",
  787. [],
  788. [],
  789. absolute_path=1,
  790. db=os.path.join(tmpdir, "db"),
  791. db_type="minidb",
  792. load_all=1,
  793. )
  794. )
  795. self.assertTrue(workspace.HasBlob("m"))
  796. # TODO: make caffe2 side load return python-sided module
  797. # right now it returns the base class (torch._C.ScriptModule)
  798. # self.assertTrue(isinstance(workspace.FetchBlob('m'), torch.jit.ScriptModule))
  799. # do something with the module
  800. val = np.random.rand(5, 5).astype(np.float32)
  801. param = np.array([[1, 2, 3, 4, 5]]).astype(np.float32)
  802. workspace.FeedBlob("w", val)
  803. workspace.RunOperatorOnce(
  804. core.CreateOperator("ScriptModule", ["m", "w"], ["y"])
  805. )
  806. np.testing.assert_almost_equal(
  807. workspace.FetchBlob("y"), np.matmul(param, val), decimal=5
  808. )
  809. finally:
  810. # clean up temp folder.
  811. try:
  812. shutil.rmtree(tmpdir)
  813. except OSError as e:
  814. if e.errno != errno.ENOENT:
  815. raise
  816. class TestScriptModuleFromString(TestScriptModule):
  817. def _createFeedModule(self):
  818. workspace.RunOperatorOnce(
  819. core.CreateOperator(
  820. "ScriptModuleLoad",
  821. [],
  822. ["m"],
  823. serialized_binary=self._get_modules_bytes(MyModule()),
  824. )
  825. )
  826. def _get_modules_bytes(self, the_module):
  827. import io
  828. buffer = io.BytesIO()
  829. torch.jit.save(the_module, buffer)
  830. return buffer.getvalue()
  831. if __name__ == "__main__":
  832. unittest.main()