functional_test.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import unittest
  2. from caffe2.python import core
  3. from hypothesis import given
  4. import hypothesis.strategies as st
  5. import caffe2.python.hypothesis_test_util as hu
  6. from caffe2.python import workspace
  7. from caffe2.python.functional import Functional
  8. import numpy as np
  9. @st.composite
  10. def _tensor_splits(draw, add_axis=False):
  11. """Generates (axis, split_info, tensor_splits) tuples."""
  12. tensor = draw(hu.tensor(min_value=4)) # Each dim has at least 4 elements.
  13. axis = draw(st.integers(0, len(tensor.shape) - 1))
  14. if add_axis:
  15. # Simple case: get individual slices along one axis, where each of them
  16. # is (N-1)-dimensional. The axis will be added back upon concatenation.
  17. return (
  18. axis, np.ones(tensor.shape[axis], dtype=np.int32), [
  19. np.array(tensor.take(i, axis=axis))
  20. for i in range(tensor.shape[axis])
  21. ]
  22. )
  23. else:
  24. # General case: pick some (possibly consecutive, even non-unique)
  25. # indices at which we will split the tensor, along the given axis.
  26. splits = sorted(
  27. draw(
  28. st.
  29. lists(elements=st.integers(0, tensor.shape[axis]), max_size=4)
  30. ) + [0, tensor.shape[axis]]
  31. )
  32. return (
  33. axis, np.array(np.diff(splits), dtype=np.int32), [
  34. tensor.take(range(splits[i], splits[i + 1]), axis=axis)
  35. for i in range(len(splits) - 1)
  36. ],
  37. )
  38. class TestFunctional(hu.HypothesisTestCase):
  39. @given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs)
  40. def test_relu(self, X, engine, gc, dc):
  41. X += 0.02 * np.sign(X)
  42. X[X == 0.0] += 0.02
  43. output = Functional.Relu(X, device_option=gc)
  44. Y_l = output[0]
  45. Y_d = output["output_0"]
  46. with workspace.WorkspaceGuard("tmp_workspace"):
  47. op = core.CreateOperator("Relu", ["X"], ["Y"], engine=engine)
  48. workspace.FeedBlob("X", X)
  49. workspace.RunOperatorOnce(op)
  50. Y_ref = workspace.FetchBlob("Y")
  51. np.testing.assert_array_equal(
  52. Y_l, Y_ref, err_msg='Functional Relu result mismatch'
  53. )
  54. np.testing.assert_array_equal(
  55. Y_d, Y_ref, err_msg='Functional Relu result mismatch'
  56. )
  57. @given(tensor_splits=_tensor_splits(), **hu.gcs)
  58. def test_concat(self, tensor_splits, gc, dc):
  59. # Input Size: 1 -> inf
  60. axis, _, splits = tensor_splits
  61. concat_result, split_info = Functional.Concat(*splits, axis=axis, device_option=gc)
  62. concat_result_ref = np.concatenate(splits, axis=axis)
  63. split_info_ref = np.array([a.shape[axis] for a in splits])
  64. np.testing.assert_array_equal(
  65. concat_result,
  66. concat_result_ref,
  67. err_msg='Functional Concat result mismatch'
  68. )
  69. np.testing.assert_array_equal(
  70. split_info,
  71. split_info_ref,
  72. err_msg='Functional Concat split info mismatch'
  73. )
  74. @given(tensor_splits=_tensor_splits(), split_as_arg=st.booleans(), **hu.gcs)
  75. def test_split(self, tensor_splits, split_as_arg, gc, dc):
  76. # Output Size: 1 - inf
  77. axis, split_info, splits = tensor_splits
  78. split_as_arg = True
  79. if split_as_arg:
  80. input_tensors = [np.concatenate(splits, axis=axis)]
  81. kwargs = dict(axis=axis, split=split_info, num_output=len(splits))
  82. else:
  83. input_tensors = [np.concatenate(splits, axis=axis), split_info]
  84. kwargs = dict(axis=axis, num_output=len(splits))
  85. result = Functional.Split(*input_tensors, device_option=gc, **kwargs)
  86. def split_ref(input, split=split_info):
  87. s = np.cumsum([0] + list(split))
  88. return [
  89. np.array(input.take(np.arange(s[i], s[i + 1]), axis=axis))
  90. for i in range(len(split))
  91. ]
  92. result_ref = split_ref(*input_tensors)
  93. for i, ref in enumerate(result_ref):
  94. np.testing.assert_array_equal(
  95. result[i], ref, err_msg='Functional Relu result mismatch'
  96. )
  97. if __name__ == '__main__':
  98. unittest.main()