test_util.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. ## @package test_util
  2. # Module caffe2.python.test_util
  3. import numpy as np
  4. from caffe2.python import core, workspace
  5. import os
  6. import pathlib
  7. import shutil
  8. import tempfile
  9. import unittest
  10. from typing import Any, Callable, Tuple, Type
  11. from types import TracebackType
  12. def rand_array(*dims):
  13. # np.random.rand() returns float instead of 0-dim array, that's why need to
  14. # do some tricks
  15. return np.array(np.random.rand(*dims) - 0.5).astype(np.float32)
  16. def randBlob(name, type, *dims, **kwargs):
  17. offset = kwargs['offset'] if 'offset' in kwargs else 0.0
  18. workspace.FeedBlob(name, np.random.rand(*dims).astype(type) + offset)
  19. def randBlobFloat32(name, *dims, **kwargs):
  20. randBlob(name, np.float32, *dims, **kwargs)
  21. def randBlobsFloat32(names, *dims, **kwargs):
  22. for name in names:
  23. randBlobFloat32(name, *dims, **kwargs)
  24. def numOps(net):
  25. return len(net.Proto().op)
  26. def str_compare(a, b, encoding="utf8"):
  27. if isinstance(a, bytes):
  28. a = a.decode(encoding)
  29. if isinstance(b, bytes):
  30. b = b.decode(encoding)
  31. return a == b
  32. def get_default_test_flags():
  33. return [
  34. 'caffe2',
  35. '--caffe2_log_level=0',
  36. '--caffe2_cpu_allocator_do_zero_fill=0',
  37. '--caffe2_cpu_allocator_do_junk_fill=1',
  38. ]
  39. def caffe2_flaky(test_method):
  40. # This decorator is used to mark a test method as flaky.
  41. # This is used in conjunction with the environment variable
  42. # CAFFE2_RUN_FLAKY_TESTS that specifies "flaky tests" mode
  43. # If flaky tests mode are on, only flaky tests are run
  44. # If flaky tests mode are off, only non-flaky tests are run
  45. # NOTE: the decorator should be applied as the top-level decorator
  46. # in a test method.
  47. test_method.__caffe2_flaky__ = True
  48. return test_method
  49. def is_flaky_test_mode():
  50. return os.getenv('CAFFE2_RUN_FLAKY_TESTS', '0') == '1'
  51. class TestCase(unittest.TestCase):
  52. @classmethod
  53. def setUpClass(cls):
  54. workspace.GlobalInit(get_default_test_flags())
  55. # clear the default engines settings to separate out its
  56. # affect from the ops tests
  57. core.SetEnginePref({}, {})
  58. def setUp(self):
  59. # Skip tests based on whether we're in flaky test mode and
  60. # the test is decorated as a flaky test.
  61. test_method = getattr(self, self._testMethodName)
  62. is_flaky_test = getattr(test_method, "__caffe2_flaky__", False)
  63. if (is_flaky_test_mode() and not is_flaky_test):
  64. raise unittest.SkipTest("Non-flaky tests are skipped in flaky test mode")
  65. elif (not is_flaky_test_mode() and is_flaky_test):
  66. raise unittest.SkipTest("Flaky tests are skipped in regular test mode")
  67. self.ws = workspace.C.Workspace()
  68. workspace.ResetWorkspace()
  69. def tearDown(self):
  70. workspace.ResetWorkspace()
  71. def make_tempdir(self) -> pathlib.Path:
  72. tmp_folder = pathlib.Path(tempfile.mkdtemp(prefix="caffe2_test."))
  73. self.addCleanup(self._remove_tempdir, tmp_folder)
  74. return tmp_folder
  75. def _remove_tempdir(self, path: pathlib.Path) -> None:
  76. def _onerror(
  77. fn: Callable[..., Any],
  78. path: str,
  79. exc_info: Tuple[Type[BaseException], BaseException, TracebackType],
  80. ) -> None:
  81. # Ignore FileNotFoundError, but re-raise anything else
  82. if not isinstance(exc_info[1], FileNotFoundError):
  83. raise exc_info[1].with_traceback(exc_info[2])
  84. shutil.rmtree(str(path), onerror=_onerror)