hypothesis_test_util.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. ## @package hypothesis_test_util
  2. # Module caffe2.python.hypothesis_test_util
  3. """
  4. The Hypothesis library uses *property-based testing* to check
  5. invariants about the code under test under a variety of random inputs.
  6. The key idea here is to express properties of the code under test
  7. (e.g. that it passes a gradient check, that it implements a reference
  8. function, etc), and then generate random instances and verify they
  9. satisfy these properties.
  10. The main functions of interest are exposed on `HypothesisTestCase`.
  11. You can usually just add a short function in this to generate an
  12. arbitrary number of test cases for your operator.
  13. The key functions are:
  14. - `assertDeviceChecks(devices, op, inputs, outputs)`. This asserts that the
  15. operator computes the same outputs, regardless of which device it is executed
  16. on.
  17. - `assertGradientChecks(device, op, inputs, output_,
  18. outputs_with_grads)`. This implements a standard numerical gradient checker
  19. for the operator in question.
  20. - `assertReferenceChecks(device, op, inputs, reference)`. This runs the
  21. reference function (effectively calling `reference(*inputs)`, and comparing
  22. that to the output of output.
  23. `hypothesis_test_util.py` exposes some useful pre-built samplers.
  24. - `hu.gcs` - a gradient checker device (`gc`) and device checker devices (`dc`)
  25. - `hu.gcs_cpu_only` - a CPU-only gradient checker device (`gc`) and
  26. device checker devices (`dc`). Used for when your operator is only
  27. implemented on the CPU.
  28. """
  29. from caffe2.proto import caffe2_pb2
  30. from caffe2.python import (
  31. workspace, device_checker, gradient_checker, test_util, core)
  32. import contextlib
  33. import copy
  34. import functools
  35. import hypothesis
  36. import hypothesis.extra.numpy
  37. import hypothesis.strategies as st
  38. import logging
  39. import numpy as np
  40. import os
  41. import struct
  42. def is_sandcastle():
  43. return os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 'sandcastle'
  44. def is_travis():
  45. return 'TRAVIS' in os.environ
  46. def to_float32(x):
  47. return struct.unpack("f", struct.pack("f", float(x)))[0]
  48. # "min_satisfying_examples" setting has been deprecated in hypothesis
  49. # 3.56.0 and removed in hypothesis 4.x
  50. def settings(*args, **kwargs):
  51. if 'min_satisfying_examples' in kwargs and hypothesis.version.__version_info__ >= (3, 56, 0):
  52. kwargs.pop('min_satisfying_examples')
  53. if 'deadline' in kwargs and hypothesis.version.__version_info__ < (4, 44, 0):
  54. kwargs.pop('deadline')
  55. if 'timeout' in kwargs and hypothesis.version.__version_info__ >= (4, 44, 0):
  56. if 'deadline' not in kwargs:
  57. kwargs['deadline'] = kwargs['timeout'] * 1e3
  58. kwargs.pop('timeout')
  59. return hypothesis.settings(*args, **kwargs)
  60. # This wrapper wraps around `st.floats` and
  61. # sets width parameters to 32 if version is newer than 3.67.0
  62. def floats(*args, **kwargs):
  63. width_supported = hypothesis.version.__version_info__ >= (3, 67, 0)
  64. if 'width' in kwargs and not width_supported:
  65. kwargs.pop('width')
  66. if 'width' not in kwargs and width_supported:
  67. kwargs['width'] = 32
  68. if kwargs.get('min_value', None) is not None:
  69. kwargs['min_value'] = to_float32(kwargs['min_value'])
  70. if kwargs.get('max_value', None) is not None:
  71. kwargs['max_value'] = to_float32(kwargs['max_value'])
  72. return st.floats(*args, **kwargs)
  73. hypothesis.settings.register_profile(
  74. "sandcastle",
  75. settings(
  76. derandomize=True,
  77. suppress_health_check=[hypothesis.HealthCheck.too_slow],
  78. database=None,
  79. max_examples=50,
  80. min_satisfying_examples=1,
  81. verbosity=hypothesis.Verbosity.verbose,
  82. deadline=10000))
  83. hypothesis.settings.register_profile(
  84. "dev",
  85. settings(
  86. suppress_health_check=[hypothesis.HealthCheck.too_slow],
  87. database=None,
  88. max_examples=10,
  89. min_satisfying_examples=1,
  90. verbosity=hypothesis.Verbosity.verbose,
  91. deadline=10000))
  92. hypothesis.settings.register_profile(
  93. "debug",
  94. settings(
  95. suppress_health_check=[hypothesis.HealthCheck.too_slow],
  96. database=None,
  97. max_examples=1000,
  98. min_satisfying_examples=1,
  99. verbosity=hypothesis.Verbosity.verbose,
  100. deadline=50000))
  101. hypothesis.settings.load_profile(
  102. 'sandcastle' if is_sandcastle() else os.getenv('CAFFE2_HYPOTHESIS_PROFILE',
  103. 'dev')
  104. )
  105. def dims(min_value=1, max_value=5):
  106. return st.integers(min_value=min_value, max_value=max_value)
  107. def elements_of_type(dtype=np.float32, filter_=None):
  108. elems = None
  109. if dtype is np.float16:
  110. elems = floats(min_value=-1.0, max_value=1.0, width=16)
  111. elif dtype is np.float32:
  112. elems = floats(min_value=-1.0, max_value=1.0, width=32)
  113. elif dtype is np.float64:
  114. elems = floats(min_value=-1.0, max_value=1.0, width=64)
  115. elif dtype is np.int32:
  116. elems = st.integers(min_value=0, max_value=2 ** 31 - 1)
  117. elif dtype is np.int64:
  118. elems = st.integers(min_value=0, max_value=2 ** 63 - 1)
  119. elif dtype is np.bool:
  120. elems = st.booleans()
  121. else:
  122. raise ValueError("Unexpected dtype without elements provided")
  123. return elems if filter_ is None else elems.filter(filter_)
  124. def arrays(dims, dtype=np.float32, elements=None, unique=False):
  125. if elements is None:
  126. elements = elements_of_type(dtype)
  127. return hypothesis.extra.numpy.arrays(
  128. dtype,
  129. dims,
  130. elements=elements,
  131. unique=unique,
  132. )
  133. def tensor(min_dim=1,
  134. max_dim=4,
  135. dtype=np.float32,
  136. elements=None,
  137. unique=False,
  138. **kwargs):
  139. dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
  140. return dims_.flatmap(
  141. lambda dims: arrays(dims, dtype, elements, unique=unique))
  142. def tensor1d(min_len=1, max_len=64, dtype=np.float32, elements=None):
  143. return tensor(1, 1, dtype, elements, min_value=min_len, max_value=max_len)
  144. def segment_ids(size, is_sorted):
  145. if size == 0:
  146. return st.just(np.empty(shape=[0], dtype=np.int32))
  147. if is_sorted:
  148. return arrays(
  149. [size],
  150. dtype=np.int32,
  151. elements=st.booleans()).map(
  152. lambda x: np.cumsum(x, dtype=np.int32) - x[0])
  153. else:
  154. return arrays(
  155. [size],
  156. dtype=np.int32,
  157. elements=st.integers(min_value=0, max_value=2 * size))
  158. def lengths(size, min_segments=None, max_segments=None, **kwargs):
  159. # First generate number of boarders between segments
  160. # Then create boarder values and add 0 and size
  161. # By sorting and computing diff we convert them to lengths of
  162. # possible 0 value
  163. if min_segments is None:
  164. min_segments = 0
  165. if max_segments is None:
  166. max_segments = size
  167. assert min_segments >= 0
  168. assert min_segments <= max_segments
  169. if size == 0 and max_segments == 0:
  170. return st.just(np.empty(shape=[0], dtype=np.int32))
  171. assert max_segments > 0, "size is not 0, need at least one segment"
  172. return st.integers(
  173. min_value=max(min_segments - 1, 0), max_value=max_segments - 1
  174. ).flatmap(
  175. lambda num_borders:
  176. hypothesis.extra.numpy.arrays(
  177. np.int32, num_borders, elements=st.integers(
  178. min_value=0, max_value=size
  179. )
  180. )
  181. ).map(
  182. lambda x: np.append(x, np.array([0, size], dtype=np.int32))
  183. ).map(sorted).map(np.diff)
  184. def segmented_tensor(
  185. min_dim=1,
  186. max_dim=4,
  187. dtype=np.float32,
  188. is_sorted=True,
  189. elements=None,
  190. segment_generator=segment_ids,
  191. allow_empty=False,
  192. **kwargs
  193. ):
  194. gen_empty = st.booleans() if allow_empty else st.just(False)
  195. data_dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
  196. data_dims_ = st.tuples(
  197. gen_empty, data_dims_
  198. ).map(lambda pair: ([0] if pair[0] else []) + pair[1])
  199. return data_dims_.flatmap(lambda data_dims: st.tuples(
  200. arrays(data_dims, dtype, elements),
  201. segment_generator(data_dims[0], is_sorted=is_sorted),
  202. ))
  203. def lengths_tensor(min_segments=None, max_segments=None, *args, **kwargs):
  204. gen = functools.partial(
  205. lengths, min_segments=min_segments, max_segments=max_segments)
  206. return segmented_tensor(*args, segment_generator=gen, **kwargs)
  207. def sparse_segmented_tensor(min_dim=1, max_dim=4, dtype=np.float32,
  208. is_sorted=True, elements=None, allow_empty=False,
  209. segment_generator=segment_ids, itype=np.int64,
  210. **kwargs):
  211. gen_empty = st.booleans() if allow_empty else st.just(False)
  212. data_dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
  213. all_dims_ = st.tuples(gen_empty, data_dims_).flatmap(
  214. lambda pair: st.tuples(
  215. st.just(pair[1]),
  216. (st.integers(min_value=1, max_value=pair[1][0]) if not pair[0]
  217. else st.just(0)),
  218. ))
  219. return all_dims_.flatmap(lambda dims: st.tuples(
  220. arrays(dims[0], dtype, elements),
  221. arrays(dims[1], dtype=itype, elements=st.integers(
  222. min_value=0, max_value=dims[0][0] - 1)),
  223. segment_generator(dims[1], is_sorted=is_sorted),
  224. ))
  225. def sparse_lengths_tensor(**kwargs):
  226. return sparse_segmented_tensor(segment_generator=lengths, **kwargs)
  227. def tensors(n, min_dim=1, max_dim=4, dtype=np.float32, elements=None, **kwargs):
  228. dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
  229. return dims_.flatmap(
  230. lambda dims: st.lists(
  231. arrays(dims, dtype, elements),
  232. min_size=n,
  233. max_size=n))
  234. def tensors1d(n, min_len=1, max_len=64, dtype=np.float32, elements=None):
  235. return tensors(
  236. n, 1, 1, dtype, elements, min_value=min_len, max_value=max_len
  237. )
  238. cpu_do = caffe2_pb2.DeviceOption()
  239. cuda_do = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA)
  240. hip_do = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.HIP)
  241. gpu_do = caffe2_pb2.DeviceOption(device_type=workspace.GpuDeviceType) # CUDA or ROCm
  242. _cuda_do_list = ([cuda_do] if workspace.has_cuda_support else [])
  243. _hip_do_list = ([hip_do] if workspace.has_hip_support else [])
  244. _gpu_do_list = ([gpu_do] if workspace.has_gpu_support else [])
  245. # (bddppq) Do not rely on this no_hip option! It's just used to
  246. # temporarily skip some flaky tests on ROCM before it's getting more mature.
  247. _device_options_no_hip = [cpu_do] + _cuda_do_list
  248. device_options = _device_options_no_hip + _hip_do_list
  249. # Include device option for each GPU
  250. expanded_device_options = [cpu_do] + [
  251. caffe2_pb2.DeviceOption(device_type=workspace.GpuDeviceType, device_id=i)
  252. for i in range(workspace.NumGpuDevices())]
  253. def device_checker_device_options():
  254. return st.just(device_options)
  255. def gradient_checker_device_option():
  256. return st.sampled_from(device_options)
  257. gcs = dict(
  258. gc=gradient_checker_device_option(),
  259. dc=device_checker_device_options()
  260. )
  261. gcs_cpu_only = dict(gc=st.sampled_from([cpu_do]), dc=st.just([cpu_do]))
  262. gcs_cuda_only = dict(gc=st.sampled_from(_cuda_do_list), dc=st.just(_cuda_do_list))
  263. gcs_gpu_only = dict(gc=st.sampled_from(_gpu_do_list), dc=st.just(_gpu_do_list)) # CUDA or ROCm
  264. gcs_no_hip = dict(gc=st.sampled_from(_device_options_no_hip), dc=st.just(_device_options_no_hip))
  265. @contextlib.contextmanager
  266. def temp_workspace(name=b"temp_ws"):
  267. old_ws_name = workspace.CurrentWorkspace()
  268. workspace.SwitchWorkspace(name, True)
  269. yield
  270. workspace.ResetWorkspace()
  271. workspace.SwitchWorkspace(old_ws_name)
  272. def runOpBenchmark(
  273. device_option,
  274. op,
  275. inputs,
  276. input_device_options=None,
  277. iterations=10,
  278. ):
  279. op = copy.deepcopy(op)
  280. op.device_option.CopyFrom(device_option)
  281. net = caffe2_pb2.NetDef()
  282. net.op.extend([op])
  283. net.name = op.name if op.name else "test"
  284. with temp_workspace():
  285. _input_device_options = input_device_options or \
  286. core.InferOpBlobDevicesAsDict(op)[0]
  287. for (n, b) in zip(op.input, inputs):
  288. workspace.FeedBlob(
  289. n,
  290. b,
  291. device_option=_input_device_options.get(n, device_option)
  292. )
  293. workspace.CreateNet(net)
  294. ret = workspace.BenchmarkNet(net.name, 1, iterations, True)
  295. return ret
  296. def runOpOnInput(
  297. device_option,
  298. op,
  299. inputs,
  300. input_device_options=None,
  301. ):
  302. op = copy.deepcopy(op)
  303. op.device_option.CopyFrom(device_option)
  304. with temp_workspace():
  305. if (len(op.input) > len(inputs)):
  306. raise ValueError(
  307. 'must supply an input for each input on the op: %s vs %s' %
  308. (op.input, inputs))
  309. _input_device_options = input_device_options or \
  310. core.InferOpBlobDevicesAsDict(op)[0]
  311. for (n, b) in zip(op.input, inputs):
  312. workspace.FeedBlob(
  313. n,
  314. b,
  315. device_option=_input_device_options.get(n, device_option)
  316. )
  317. workspace.RunOperatorOnce(op)
  318. outputs_to_check = list(range(len(op.output)))
  319. outs = []
  320. for output_index in outputs_to_check:
  321. output_blob_name = op.output[output_index]
  322. output = workspace.FetchBlob(output_blob_name)
  323. outs.append(output)
  324. return outs
  325. class HypothesisTestCase(test_util.TestCase):
  326. """
  327. A unittest.TestCase subclass with some helper functions for
  328. utilizing the `hypothesis` (hypothesis.readthedocs.io) library.
  329. """
  330. def assertDeviceChecks(
  331. self,
  332. device_options,
  333. op,
  334. inputs,
  335. outputs_to_check,
  336. input_device_options=None,
  337. threshold=0.01
  338. ):
  339. """
  340. Asserts that the operator computes the same outputs, regardless of
  341. which device it is executed on.
  342. Useful for checking the consistency of GPU and CPU
  343. implementations of operators.
  344. Usage example:
  345. @given(inputs=hu.tensors(n=2), in_place=st.booleans(), **hu.gcs)
  346. def test_sum(self, inputs, in_place, gc, dc):
  347. op = core.CreateOperator("Sum", ["X1", "X2"],
  348. ["Y" if not in_place else "X1"])
  349. X1, X2 = inputs
  350. self.assertDeviceChecks(dc, op, [X1, X2], [0])
  351. """
  352. dc = device_checker.DeviceChecker(
  353. threshold,
  354. device_options=device_options
  355. )
  356. self.assertTrue(
  357. dc.CheckSimple(op, inputs, outputs_to_check, input_device_options)
  358. )
  359. def assertGradientChecks(
  360. self,
  361. device_option,
  362. op,
  363. inputs,
  364. outputs_to_check,
  365. outputs_with_grads,
  366. grad_ops=None,
  367. threshold=0.005,
  368. stepsize=0.05,
  369. input_device_options=None,
  370. ensure_outputs_are_inferred=False,
  371. ):
  372. """
  373. Implements a standard numerical gradient checker for the operator
  374. in question.
  375. Useful for checking the consistency of the forward and
  376. backward implementations of operators.
  377. Usage example:
  378. @given(inputs=hu.tensors(n=2), in_place=st.booleans(), **hu.gcs)
  379. def test_sum(self, inputs, in_place, gc, dc):
  380. op = core.CreateOperator("Sum", ["X1", "X2"],
  381. ["Y" if not in_place else "X1"])
  382. X1, X2 = inputs
  383. self.assertGradientChecks(gc, op, [X1, X2], 0, [0])
  384. """
  385. gc = gradient_checker.GradientChecker(
  386. stepsize=stepsize,
  387. threshold=threshold,
  388. device_option=device_option,
  389. workspace_name=str(device_option),
  390. input_device_options=input_device_options,
  391. )
  392. res, grad, grad_estimated = gc.CheckSimple(
  393. op, inputs, outputs_to_check, outputs_with_grads,
  394. grad_ops=grad_ops,
  395. input_device_options=input_device_options,
  396. ensure_outputs_are_inferred=ensure_outputs_are_inferred,
  397. )
  398. self.assertEqual(grad.shape, grad_estimated.shape)
  399. self.assertTrue(
  400. res,
  401. "Gradient check failed for input " + str(op.input[outputs_to_check])
  402. )
  403. def _assertGradReferenceChecks(
  404. self,
  405. op,
  406. inputs,
  407. ref_outputs,
  408. output_to_grad,
  409. grad_reference,
  410. threshold=1e-4,
  411. ):
  412. grad_blob_name = output_to_grad + '_grad'
  413. grad_ops, grad_map = core.GradientRegistry.GetBackwardPass(
  414. [op], {output_to_grad: grad_blob_name})
  415. output_grad = workspace.FetchBlob(output_to_grad)
  416. grad_ref_outputs = grad_reference(output_grad, ref_outputs, inputs)
  417. workspace.FeedBlob(grad_blob_name, workspace.FetchBlob(output_to_grad))
  418. workspace.RunOperatorsOnce(grad_ops)
  419. self.assertEqual(len(grad_ref_outputs), len(inputs))
  420. for (n, ref) in zip(op.input, grad_ref_outputs):
  421. grad_names = grad_map.get(n)
  422. if not grad_names:
  423. # no grad for this input
  424. self.assertIsNone(ref)
  425. else:
  426. if isinstance(grad_names, core.BlobReference):
  427. # dense gradient
  428. ref_vals = ref
  429. ref_indices = None
  430. val_name = grad_names
  431. else:
  432. # sparse gradient
  433. ref_vals, ref_indices = ref
  434. val_name = grad_names.values
  435. vals = workspace.FetchBlob(str(val_name))
  436. np.testing.assert_allclose(
  437. vals,
  438. ref_vals,
  439. atol=threshold,
  440. rtol=threshold,
  441. err_msg='Gradient {0} (x) is not matching the reference (y)'
  442. .format(val_name),
  443. )
  444. if ref_indices is not None:
  445. indices = workspace.FetchBlob(str(grad_names.indices))
  446. np.testing.assert_allclose(indices, ref_indices,
  447. atol=1e-4, rtol=1e-4)
  448. def _assertInferTensorChecks(self, name, shapes, types, output,
  449. ensure_output_is_inferred=False):
  450. self.assertTrue(
  451. not ensure_output_is_inferred or (name in shapes),
  452. 'Shape for {0} was not inferred'.format(name))
  453. if name not in shapes:
  454. # No inferred shape or type available
  455. return
  456. output = workspace.FetchBlob(name)
  457. if type(output) is np.ndarray:
  458. if output.dtype == np.dtype('float64'):
  459. correct_type = caffe2_pb2.TensorProto.DOUBLE
  460. elif output.dtype == np.dtype('float32'):
  461. correct_type = caffe2_pb2.TensorProto.FLOAT
  462. elif output.dtype == np.dtype('int32'):
  463. correct_type = caffe2_pb2.TensorProto.INT32
  464. elif output.dtype == np.dtype('int64'):
  465. correct_type = caffe2_pb2.TensorProto.INT64
  466. else:
  467. correct_type = "unknown {}".format(np.dtype)
  468. else:
  469. correct_type = str(type(output))
  470. try:
  471. np.testing.assert_array_equal(
  472. np.array(shapes[name]).astype(np.int32),
  473. np.array(output.shape).astype(np.int32),
  474. err_msg='Shape {} mismatch: {} vs. {}'.format(
  475. name,
  476. shapes[name],
  477. output.shape))
  478. # BUG: Workspace blob type not being set correctly T16121392
  479. if correct_type != caffe2_pb2.TensorProto.INT32:
  480. return
  481. np.testing.assert_equal(
  482. types[name],
  483. correct_type,
  484. err_msg='Type {} mismatch: {} vs. {}'.format(
  485. name, types[name], correct_type,
  486. )
  487. )
  488. except AssertionError as e:
  489. # Temporarily catch these assertion errors when validating
  490. # inferred shape and type info
  491. logging.warning(str(e))
  492. if os.getenv('CAFFE2_ASSERT_SHAPEINFERENCE') == '1' or ensure_output_is_inferred:
  493. raise e
  494. def assertReferenceChecks(
  495. self,
  496. device_option,
  497. op,
  498. inputs,
  499. reference,
  500. input_device_options=None,
  501. threshold=1e-4,
  502. output_to_grad=None,
  503. grad_reference=None,
  504. atol=None,
  505. outputs_to_check=None,
  506. ensure_outputs_are_inferred=False,
  507. ):
  508. """
  509. This runs the reference Python function implementation
  510. (effectively calling `reference(*inputs)`, and compares that
  511. to the output of output, with an absolute/relative tolerance
  512. given by the `threshold` parameter.
  513. Useful for checking the implementation matches the Python
  514. (typically NumPy) implementation of the same functionality.
  515. Usage example:
  516. @given(X=hu.tensor(), inplace=st.booleans(), **hu.gcs)
  517. def test_softsign(self, X, inplace, gc, dc):
  518. op = core.CreateOperator(
  519. "Softsign", ["X"], ["X" if inplace else "Y"])
  520. def softsign(X):
  521. return (X / (1 + np.abs(X)),)
  522. self.assertReferenceChecks(gc, op, [X], softsign)
  523. """
  524. op = copy.deepcopy(op)
  525. op.device_option.CopyFrom(device_option)
  526. with temp_workspace():
  527. if (len(op.input) > len(inputs)):
  528. raise ValueError(
  529. 'must supply an input for each input on the op: %s vs %s' %
  530. (op.input, inputs))
  531. _input_device_options = input_device_options or \
  532. core.InferOpBlobDevicesAsDict(op)[0]
  533. for (n, b) in zip(op.input, inputs):
  534. workspace.FeedBlob(
  535. n,
  536. b,
  537. device_option=_input_device_options.get(n, device_option)
  538. )
  539. net = core.Net("opnet")
  540. net.Proto().op.extend([op])
  541. test_shape_inference = False
  542. try:
  543. (shapes, types) = workspace.InferShapesAndTypes([net])
  544. test_shape_inference = True
  545. except RuntimeError as e:
  546. # Temporarily catch runtime errors when inferring shape
  547. # and type info
  548. logging.warning(str(e))
  549. if os.getenv('CAFFE2_ASSERT_SHAPEINFERENCE') == '1' or ensure_outputs_are_inferred:
  550. raise e
  551. workspace.RunNetOnce(net)
  552. reference_outputs = reference(*inputs)
  553. if not (isinstance(reference_outputs, tuple) or
  554. isinstance(reference_outputs, list)):
  555. raise RuntimeError(
  556. "You are providing a wrong reference implementation. A "
  557. "proper one should return a tuple/list of numpy arrays.")
  558. if not outputs_to_check:
  559. self.assertEqual(len(reference_outputs), len(op.output))
  560. outputs_to_check = list(range(len(op.output)))
  561. outs = []
  562. for (output_index, ref) in zip(outputs_to_check, reference_outputs):
  563. output_blob_name = op.output[output_index]
  564. output = workspace.FetchBlob(output_blob_name)
  565. if output.dtype.kind in ('S', 'O'):
  566. np.testing.assert_array_equal(output, ref)
  567. else:
  568. if atol is None:
  569. atol = threshold
  570. np.testing.assert_allclose(
  571. output, ref, atol=atol, rtol=threshold,
  572. err_msg=(
  573. 'Output {0} is not matching the reference'.format(
  574. output_blob_name,
  575. )),
  576. )
  577. if test_shape_inference:
  578. self._assertInferTensorChecks(
  579. output_blob_name, shapes, types, output,
  580. ensure_output_is_inferred=ensure_outputs_are_inferred)
  581. outs.append(output)
  582. if grad_reference is not None:
  583. assert output_to_grad is not None, \
  584. "If grad_reference is set," \
  585. "output_to_grad has to be set as well"
  586. with core.DeviceScope(device_option):
  587. self._assertGradReferenceChecks(
  588. op, inputs, reference_outputs,
  589. output_to_grad, grad_reference,
  590. threshold=threshold)
  591. return outs
  592. def assertValidationChecks(
  593. self,
  594. device_option,
  595. op,
  596. inputs,
  597. validator,
  598. input_device_options=None,
  599. as_kwargs=True,
  600. init_net=None,
  601. ):
  602. if as_kwargs:
  603. assert len(set(list(op.input) + list(op.output))) == \
  604. len(op.input) + len(op.output), \
  605. "in-place ops are not supported in as_kwargs mode"
  606. op = copy.deepcopy(op)
  607. op.device_option.CopyFrom(device_option)
  608. with temp_workspace():
  609. _input_device_options = input_device_options or \
  610. core.InferOpBlobDevicesAsDict(op)[0]
  611. for (n, b) in zip(op.input, inputs):
  612. workspace.FeedBlob(
  613. n,
  614. b,
  615. device_option=_input_device_options.get(n, device_option)
  616. )
  617. if init_net:
  618. workspace.RunNetOnce(init_net)
  619. workspace.RunOperatorOnce(op)
  620. outputs = [workspace.FetchBlob(n) for n in op.output]
  621. if as_kwargs:
  622. validator(**dict(zip(
  623. list(op.input) + list(op.output), inputs + outputs)))
  624. else:
  625. validator(inputs=inputs, outputs=outputs)
  626. def assertRunOpRaises(
  627. self,
  628. device_option,
  629. op,
  630. inputs,
  631. input_device_options=None,
  632. exception=(Exception,),
  633. regexp=None,
  634. ):
  635. op = copy.deepcopy(op)
  636. op.device_option.CopyFrom(device_option)
  637. with temp_workspace():
  638. _input_device_options = input_device_options or \
  639. core.InferOpBlobDevicesAsDict(op)[0]
  640. for (n, b) in zip(op.input, inputs):
  641. workspace.FeedBlob(
  642. n,
  643. b,
  644. device_option=_input_device_options.get(n, device_option)
  645. )
  646. if regexp is None:
  647. self.assertRaises(exception, workspace.RunOperatorOnce, op)
  648. else:
  649. self.assertRaisesRegex(
  650. exception, regexp, workspace.RunOperatorOnce, op)