aten_test.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from caffe2.python import core
  2. from hypothesis import given
  3. import caffe2.python.hypothesis_test_util as hu
  4. import hypothesis.strategies as st
  5. import numpy as np
  6. class TestATen(hu.HypothesisTestCase):
  7. @given(inputs=hu.tensors(n=2), **hu.gcs)
  8. def test_add(self, inputs, gc, dc):
  9. op = core.CreateOperator(
  10. "ATen",
  11. ["X", "Y"],
  12. ["Z"],
  13. operator="add")
  14. def ref(X, Y):
  15. return [X + Y]
  16. self.assertReferenceChecks(gc, op, inputs, ref)
  17. @given(inputs=hu.tensors(n=2, dtype=np.float16), **hu.gcs_gpu_only)
  18. def test_add_half(self, inputs, gc, dc):
  19. op = core.CreateOperator(
  20. "ATen",
  21. ["X", "Y"],
  22. ["Z"],
  23. operator="add")
  24. def ref(X, Y):
  25. return [X + Y]
  26. self.assertReferenceChecks(gc, op, inputs, ref)
  27. @given(inputs=hu.tensors(n=1), **hu.gcs)
  28. def test_pow(self, inputs, gc, dc):
  29. op = core.CreateOperator(
  30. "ATen",
  31. ["S"],
  32. ["Z"],
  33. operator="pow", exponent=2.0)
  34. def ref(X):
  35. return [np.square(X)]
  36. self.assertReferenceChecks(gc, op, inputs, ref)
  37. @given(x=st.integers(min_value=2, max_value=8), **hu.gcs)
  38. def test_sort(self, x, gc, dc):
  39. inputs = [np.random.permutation(x)]
  40. op = core.CreateOperator(
  41. "ATen",
  42. ["S"],
  43. ["Z", "I"],
  44. operator="sort")
  45. def ref(X):
  46. return [np.sort(X), np.argsort(X)]
  47. self.assertReferenceChecks(gc, op, inputs, ref)
  48. @given(inputs=hu.tensors(n=1), **hu.gcs)
  49. def test_sum(self, inputs, gc, dc):
  50. op = core.CreateOperator(
  51. "ATen",
  52. ["S"],
  53. ["Z"],
  54. operator="sum")
  55. def ref(X):
  56. return [np.sum(X)]
  57. self.assertReferenceChecks(gc, op, inputs, ref)
  58. @given(**hu.gcs)
  59. def test_index_uint8(self, gc, dc):
  60. # Indexing with uint8 is deprecated, but we need to provide backward compatibility for some old models exported through ONNX
  61. op = core.CreateOperator(
  62. "ATen",
  63. ['self', 'mask'],
  64. ["Z"],
  65. operator="index")
  66. def ref(self, mask):
  67. return (self[mask.astype(np.bool_)],)
  68. tensor = np.random.randn(2, 3, 4).astype(np.float32)
  69. mask = np.array([[1, 0, 0], [1, 1, 0]]).astype(np.uint8)
  70. self.assertReferenceChecks(gc, op, [tensor, mask], ref)
  71. @given(**hu.gcs)
  72. def test_index_put(self, gc, dc):
  73. op = core.CreateOperator(
  74. "ATen",
  75. ['self', 'indices', 'values'],
  76. ["Z"],
  77. operator="index_put")
  78. def ref(self, indices, values):
  79. self[indices] = values
  80. return (self,)
  81. tensor = np.random.randn(3, 3).astype(np.float32)
  82. mask = np.array([[True, True, True], [True, False, False], [True, True, False]])
  83. values = np.random.randn(6).astype(np.float32)
  84. self.assertReferenceChecks(gc, op, [tensor, mask, values], ref)
  85. @given(**hu.gcs)
  86. def test_unique(self, gc, dc):
  87. op = core.CreateOperator(
  88. "ATen",
  89. ['self'],
  90. ["output"],
  91. sorted=True,
  92. return_inverse=True,
  93. # return_counts=False,
  94. operator="_unique")
  95. def ref(self):
  96. index, _ = np.unique(self, return_index=False, return_inverse=True, return_counts=False)
  97. return (index,)
  98. tensor = np.array([1, 2, 6, 4, 2, 3, 2])
  99. print(ref(tensor))
  100. self.assertReferenceChecks(gc, op, [tensor], ref)
  101. if __name__ == "__main__":
  102. import unittest
  103. unittest.main()