context_test.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from caffe2.python import context, test_util
  2. from threading import Thread
  3. class MyContext(context.Managed):
  4. pass
  5. class DefaultMyContext(context.DefaultManaged):
  6. pass
  7. class ChildMyContext(MyContext):
  8. pass
  9. class TestContext(test_util.TestCase):
  10. def use_my_context(self):
  11. try:
  12. for _ in range(100):
  13. with MyContext() as a:
  14. for _ in range(100):
  15. self.assertTrue(MyContext.current() == a)
  16. except Exception as e:
  17. self._exceptions.append(e)
  18. def testMultiThreaded(self):
  19. threads = []
  20. self._exceptions = []
  21. for _ in range(8):
  22. thread = Thread(target=self.use_my_context)
  23. thread.start()
  24. threads.append(thread)
  25. for t in threads:
  26. t.join()
  27. for e in self._exceptions:
  28. raise e
  29. @MyContext()
  30. def testDecorator(self):
  31. self.assertIsNotNone(MyContext.current())
  32. def testNonDefaultCurrent(self):
  33. with self.assertRaises(AssertionError):
  34. MyContext.current()
  35. ctx = MyContext()
  36. self.assertEqual(MyContext.current(value=ctx), ctx)
  37. self.assertIsNone(MyContext.current(required=False))
  38. def testDefaultCurrent(self):
  39. self.assertIsInstance(DefaultMyContext.current(), DefaultMyContext)
  40. def testNestedContexts(self):
  41. with MyContext() as ctx1:
  42. with DefaultMyContext() as ctx2:
  43. self.assertEqual(DefaultMyContext.current(), ctx2)
  44. self.assertEqual(MyContext.current(), ctx1)
  45. def testChildClasses(self):
  46. with ChildMyContext() as ctx:
  47. self.assertEqual(ChildMyContext.current(), ctx)
  48. self.assertEqual(MyContext.current(), ctx)