| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- ## @package test_util
- # Module caffe2.python.test_util
- import numpy as np
- from caffe2.python import core, workspace
- import os
- import pathlib
- import shutil
- import tempfile
- import unittest
- from typing import Any, Callable, Tuple, Type
- from types import TracebackType
- def rand_array(*dims):
- # np.random.rand() returns float instead of 0-dim array, that's why need to
- # do some tricks
- return np.array(np.random.rand(*dims) - 0.5).astype(np.float32)
- def randBlob(name, type, *dims, **kwargs):
- offset = kwargs['offset'] if 'offset' in kwargs else 0.0
- workspace.FeedBlob(name, np.random.rand(*dims).astype(type) + offset)
- def randBlobFloat32(name, *dims, **kwargs):
- randBlob(name, np.float32, *dims, **kwargs)
- def randBlobsFloat32(names, *dims, **kwargs):
- for name in names:
- randBlobFloat32(name, *dims, **kwargs)
- def numOps(net):
- return len(net.Proto().op)
- def str_compare(a, b, encoding="utf8"):
- if isinstance(a, bytes):
- a = a.decode(encoding)
- if isinstance(b, bytes):
- b = b.decode(encoding)
- return a == b
- def get_default_test_flags():
- return [
- 'caffe2',
- '--caffe2_log_level=0',
- '--caffe2_cpu_allocator_do_zero_fill=0',
- '--caffe2_cpu_allocator_do_junk_fill=1',
- ]
- def caffe2_flaky(test_method):
- # This decorator is used to mark a test method as flaky.
- # This is used in conjunction with the environment variable
- # CAFFE2_RUN_FLAKY_TESTS that specifies "flaky tests" mode
- # If flaky tests mode are on, only flaky tests are run
- # If flaky tests mode are off, only non-flaky tests are run
- # NOTE: the decorator should be applied as the top-level decorator
- # in a test method.
- test_method.__caffe2_flaky__ = True
- return test_method
- def is_flaky_test_mode():
- return os.getenv('CAFFE2_RUN_FLAKY_TESTS', '0') == '1'
- class TestCase(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- workspace.GlobalInit(get_default_test_flags())
- # clear the default engines settings to separate out its
- # affect from the ops tests
- core.SetEnginePref({}, {})
- def setUp(self):
- # Skip tests based on whether we're in flaky test mode and
- # the test is decorated as a flaky test.
- test_method = getattr(self, self._testMethodName)
- is_flaky_test = getattr(test_method, "__caffe2_flaky__", False)
- if (is_flaky_test_mode() and not is_flaky_test):
- raise unittest.SkipTest("Non-flaky tests are skipped in flaky test mode")
- elif (not is_flaky_test_mode() and is_flaky_test):
- raise unittest.SkipTest("Flaky tests are skipped in regular test mode")
- self.ws = workspace.C.Workspace()
- workspace.ResetWorkspace()
- def tearDown(self):
- workspace.ResetWorkspace()
- def make_tempdir(self) -> pathlib.Path:
- tmp_folder = pathlib.Path(tempfile.mkdtemp(prefix="caffe2_test."))
- self.addCleanup(self._remove_tempdir, tmp_folder)
- return tmp_folder
- def _remove_tempdir(self, path: pathlib.Path) -> None:
- def _onerror(
- fn: Callable[..., Any],
- path: str,
- exc_info: Tuple[Type[BaseException], BaseException, TracebackType],
- ) -> None:
- # Ignore FileNotFoundError, but re-raise anything else
- if not isinstance(exc_info[1], FileNotFoundError):
- raise exc_info[1].with_traceback(exc_info[2])
- shutil.rmtree(str(path), onerror=_onerror)
|