resnet_test.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import numpy as np
  2. import caffe2.python.models.resnet as resnet
  3. import hypothesis.strategies as st
  4. from hypothesis import given, settings
  5. import caffe2.python.hypothesis_test_util as hu
  6. import caffe2.python.models.imagenet_trainer_test_utils as utils
  7. class ResnetMemongerTest(hu.HypothesisTestCase):
  8. @given(with_shapes=st.booleans(), **hu.gcs_cpu_only)
  9. @settings(max_examples=2, deadline=None)
  10. def test_resnet_shared_grads(self, with_shapes, gc, dc):
  11. results = utils.test_shared_grads(
  12. with_shapes,
  13. resnet.create_resnet50,
  14. 'gpu_0/conv1_w',
  15. 'gpu_0/last_out_L1000'
  16. )
  17. self.assertTrue(results[0][0] < results[0][1])
  18. np.testing.assert_almost_equal(results[1][0], results[1][1])
  19. np.testing.assert_almost_equal(results[2][0], results[2][1])
  20. def test_resnet_forward_only(self):
  21. results = utils.test_forward_only(
  22. resnet.create_resnet50,
  23. 'gpu_0/last_out_L1000'
  24. )
  25. self.assertTrue(results[0][0] < results[0][1])
  26. self.assertTrue(results[1] < 7 and results[1] > 0)
  27. np.testing.assert_almost_equal(results[2][0], results[2][1])
  28. def test_resnet_forward_only_fast_simplenet(self):
  29. '''
  30. Test C++ memonger that is only for simple nets
  31. '''
  32. results = utils.test_forward_only_fast_simplenet(
  33. resnet.create_resnet50,
  34. 'gpu_0/last_out_L1000'
  35. )
  36. self.assertTrue(results[0][0] < results[0][1])
  37. self.assertTrue(results[1] < 4 and results[1] > 0)
  38. np.testing.assert_almost_equal(results[2][0], results[2][1])
  39. if __name__ == "__main__":
  40. import unittest
  41. import random
  42. random.seed(2603)
  43. # pyre-fixme[10]: Name `workspace` is used but not defined in the current scope
  44. workspace.GlobalInit([
  45. 'caffe2',
  46. '--caffe2_log_level=0',
  47. '--caffe2_print_blob_sizes_at_exit=0',
  48. '--caffe2_gpu_memory_tracking=1'])
  49. unittest.main()