scope_test.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from caffe2.python import scope, core, workspace
  2. import unittest
  3. import threading
  4. import time
  5. SUCCESS_COUNT = 0
  6. def thread_runner(idx, testobj):
  7. global SUCCESS_COUNT
  8. testobj.assertEquals(scope.CurrentNameScope(), "")
  9. testobj.assertEquals(scope.CurrentDeviceScope(), None)
  10. namescope = "namescope_{}".format(idx)
  11. dsc = core.DeviceOption(workspace.GpuDeviceType, idx)
  12. with scope.DeviceScope(dsc):
  13. with scope.NameScope(namescope):
  14. testobj.assertEquals(scope.CurrentNameScope(), namescope + "/")
  15. testobj.assertEquals(scope.CurrentDeviceScope(), dsc)
  16. time.sleep(0.01 + idx * 0.01)
  17. testobj.assertEquals(scope.CurrentNameScope(), namescope + "/")
  18. testobj.assertEquals(scope.CurrentDeviceScope(), dsc)
  19. testobj.assertEquals(scope.CurrentNameScope(), "")
  20. testobj.assertEquals(scope.CurrentDeviceScope(), None)
  21. SUCCESS_COUNT += 1
  22. class TestScope(unittest.TestCase):
  23. def testNamescopeBasic(self):
  24. self.assertEquals(scope.CurrentNameScope(), "")
  25. with scope.NameScope("test_scope"):
  26. self.assertEquals(scope.CurrentNameScope(), "test_scope/")
  27. self.assertEquals(scope.CurrentNameScope(), "")
  28. def testNamescopeAssertion(self):
  29. self.assertEquals(scope.CurrentNameScope(), "")
  30. try:
  31. with scope.NameScope("test_scope"):
  32. self.assertEquals(scope.CurrentNameScope(), "test_scope/")
  33. raise Exception()
  34. except Exception:
  35. pass
  36. self.assertEquals(scope.CurrentNameScope(), "")
  37. def testEmptyNamescopeBasic(self):
  38. self.assertEquals(scope.CurrentNameScope(), "")
  39. with scope.NameScope("test_scope"):
  40. with scope.EmptyNameScope():
  41. self.assertEquals(scope.CurrentNameScope(), "")
  42. self.assertEquals(scope.CurrentNameScope(), "test_scope/")
  43. def testDevicescopeBasic(self):
  44. self.assertEquals(scope.CurrentDeviceScope(), None)
  45. dsc = core.DeviceOption(workspace.GpuDeviceType, 9)
  46. with scope.DeviceScope(dsc):
  47. self.assertEquals(scope.CurrentDeviceScope(), dsc)
  48. self.assertEquals(scope.CurrentDeviceScope(), None)
  49. def testEmptyDevicescopeBasic(self):
  50. self.assertEquals(scope.CurrentDeviceScope(), None)
  51. dsc = core.DeviceOption(workspace.GpuDeviceType, 9)
  52. with scope.DeviceScope(dsc):
  53. self.assertEquals(scope.CurrentDeviceScope(), dsc)
  54. with scope.EmptyDeviceScope():
  55. self.assertEquals(scope.CurrentDeviceScope(), None)
  56. self.assertEquals(scope.CurrentDeviceScope(), dsc)
  57. self.assertEquals(scope.CurrentDeviceScope(), None)
  58. def testDevicescopeAssertion(self):
  59. self.assertEquals(scope.CurrentDeviceScope(), None)
  60. dsc = core.DeviceOption(workspace.GpuDeviceType, 9)
  61. try:
  62. with scope.DeviceScope(dsc):
  63. self.assertEquals(scope.CurrentDeviceScope(), dsc)
  64. raise Exception()
  65. except Exception:
  66. pass
  67. self.assertEquals(scope.CurrentDeviceScope(), None)
  68. def testTags(self):
  69. self.assertEquals(scope.CurrentDeviceScope(), None)
  70. extra_info1 = ["key1:value1"]
  71. extra_info2 = ["key2:value2"]
  72. extra_info3 = ["key3:value3"]
  73. extra_info_1_2 = ["key1:value1", "key2:value2"]
  74. extra_info_1_2_3 = ["key1:value1", "key2:value2", "key3:value3"]
  75. with scope.DeviceScope(core.DeviceOption(0, extra_info=extra_info1)):
  76. self.assertEquals(scope.CurrentDeviceScope().extra_info, extra_info1)
  77. with scope.DeviceScope(core.DeviceOption(0, extra_info=extra_info2)):
  78. self.assertEquals(scope.CurrentDeviceScope().extra_info, extra_info_1_2)
  79. with scope.DeviceScope(core.DeviceOption(0, extra_info=extra_info3)):
  80. self.assertEquals(
  81. scope.CurrentDeviceScope().extra_info, extra_info_1_2_3
  82. )
  83. self.assertEquals(scope.CurrentDeviceScope().extra_info, extra_info_1_2)
  84. self.assertEquals(scope.CurrentDeviceScope().extra_info, extra_info1)
  85. self.assertEquals(scope.CurrentDeviceScope(), None)
  86. def testMultiThreaded(self):
  87. """
  88. Test that name/device scope are properly local to the thread
  89. and don't interfere
  90. """
  91. global SUCCESS_COUNT
  92. self.assertEquals(scope.CurrentNameScope(), "")
  93. self.assertEquals(scope.CurrentDeviceScope(), None)
  94. threads = []
  95. for i in range(4):
  96. threads.append(threading.Thread(
  97. target=thread_runner,
  98. args=(i, self),
  99. ))
  100. for t in threads:
  101. t.start()
  102. with scope.NameScope("master"):
  103. self.assertEquals(scope.CurrentDeviceScope(), None)
  104. self.assertEquals(scope.CurrentNameScope(), "master/")
  105. for t in threads:
  106. t.join()
  107. self.assertEquals(scope.CurrentNameScope(), "master/")
  108. self.assertEquals(scope.CurrentDeviceScope(), None)
  109. # Ensure all threads succeeded
  110. self.assertEquals(SUCCESS_COUNT, 4)