CPUApplyUtils.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. #pragma once
  2. #include <ATen/Parallel.h>
  3. #include <ATen/TensorUtils.h>
  4. #include <ATen/CollapseDims.h>
  5. #include <c10/util/irange.h>
  6. #include <limits>
  7. #include <utility>
  8. #include <cstring>
  9. namespace at {
  10. /*
  11. * The basic strategy for apply is as follows:
  12. *
  13. * 1. Starting with the outermost index, loop until we reach a dimension where
  14. * the data is no longer contiguous, i.e. the stride at that dimension is not
  15. * equal to the size of the tensor defined by the outer dimensions. Let's call
  16. * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
  17. * A is equal to the entire Tensor. Let's call the inner tensor B.
  18. *
  19. * 2. We loop through the indices in B, starting at its outermost dimension. For
  20. * example, if B is a 2x2 matrix, then we do:
  21. *
  22. * B[0][0]
  23. * B[0][1]
  24. * B[1][0]
  25. * B[1][1]
  26. *
  27. * We set the offset into the underlying storage as (storageOffset + stride_B *
  28. * index_B), i.e. basically we compute the offset into the storage as we would
  29. * normally for a Tensor. But because we are guaranteed the subsequent data is
  30. * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
  31. * the operation, without having to follow the order described by the strides of
  32. * A.
  33. *
  34. * 3. As an optimization, we merge dimensions of A that are contiguous in
  35. * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
  36. * then the first two dimensions can be merged for the purposes of APPLY,
  37. * reducing the number of nested loops.
  38. */
  39. inline Tensor sort_strides(Tensor& tensor_) {
  40. IntArrayRef strides = tensor_.strides();
  41. std::vector<int64_t> indices;
  42. indices.reserve(tensor_.ndimension());
  43. for (const auto i : c10::irange(tensor_.ndimension())) {
  44. indices.push_back(i);
  45. }
  46. std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
  47. return strides[i1] > strides[i2];
  48. });
  49. Tensor tensor = tensor_.permute(indices);
  50. return tensor;
  51. }
  52. template <typename T, int N>
  53. struct strided_tensor_iter_fixed {
  54. public:
  55. T* data_ = NULL;
  56. int64_t dim_ = 0;
  57. int64_t counter_[N] = {0};
  58. int64_t sizes_[N] = {0};
  59. int64_t strides_[N] = {0};
  60. strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
  61. void operator=(strided_tensor_iter_fixed const& x) = delete;
  62. strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
  63. strided_tensor_iter_fixed(Tensor& tensor, bool sort_strides = false)
  64. : data_(tensor.data_ptr<T>()) {
  65. (void)sort_strides; // Suppress unused variable warning
  66. std::memset(counter_, 0, sizeof(int64_t) * N);
  67. if (tensor.dim() > 0) {
  68. std::memcpy(
  69. sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
  70. std::memcpy(
  71. strides_,
  72. tensor.strides().data(),
  73. tensor.dim() * sizeof(int64_t));
  74. }
  75. dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
  76. }
  77. };
  78. template <typename T>
  79. struct strided_tensor_iter {
  80. private:
  81. public:
  82. T* data_ = NULL;
  83. int64_t dim_;
  84. std::vector<int64_t> counter_;
  85. std::vector<int64_t> sizes_;
  86. std::vector<int64_t> strides_;
  87. strided_tensor_iter(strided_tensor_iter const&) = delete;
  88. void operator=(strided_tensor_iter const& x) = delete;
  89. strided_tensor_iter(strided_tensor_iter&&) = default;
  90. strided_tensor_iter(Tensor& tensor)
  91. : data_(tensor.data_ptr<T>()),
  92. dim_(tensor.ndimension()),
  93. counter_(dim_, 0),
  94. sizes_(tensor.sizes().vec()),
  95. strides_(tensor.strides().vec()) {
  96. dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
  97. }
  98. };
  99. inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
  100. if (tensors.size() == 0)
  101. return true;
  102. int64_t all_numel = tensors[0].numel();
  103. for (const auto i : c10::irange(1, tensors.size())) {
  104. if (tensors[i].numel() != all_numel)
  105. return false;
  106. }
  107. return true;
  108. }
  109. inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
  110. std::ostringstream oss;
  111. oss << "inconsistent tensor size, expected ";
  112. for (size_t i = 0; i < tensors.size() - 1; i++) {
  113. oss << tensors[i].sizes() << ", ";
  114. }
  115. oss << "and " << tensors[tensors.size() - 1].sizes()
  116. << " to have the same number of elements, but got ";
  117. for (size_t i = 0; i < tensors.size() - 1; i++) {
  118. oss << tensors[i].numel() << ", ";
  119. }
  120. oss << "and " << tensors[tensors.size() - 1].numel()
  121. << " elements respectively";
  122. return oss.str();
  123. }
  124. inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
  125. checkDeviceType("CPU_tensor_apply", tensors, kCPU);
  126. checkLayout("CPU_tensor_apply", tensors, kStrided);
  127. if (!_all_equal_numel(tensors))
  128. AT_ERROR(_all_equal_numel_error(tensors));
  129. // An empty tensor has no elements
  130. for (auto& t : tensors)
  131. if (t.numel() == 0)
  132. return false;
  133. return true;
  134. }
  135. inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
  136. int64_t dim = 0;
  137. for (auto& t : tensors)
  138. dim = std::max(dim, t.ndimension());
  139. return dim;
  140. }
  141. inline void iterate(int64_t /*size*/){};
  142. template <typename Arg, typename... Args>
  143. inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
  144. iter.counter_[iter.dim_ - 1] += size;
  145. iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
  146. iterate(size, iter_tail...);
  147. }
  148. inline bool iterate_continue() {
  149. return true;
  150. };
  151. template <typename Arg, typename... Args>
  152. inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
  153. return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
  154. iterate_continue(iter_tail...);
  155. }
  156. inline int64_t max_iterate_size() {
  157. return std::numeric_limits<int64_t>::max();
  158. };
  159. template <typename Arg, typename... Args>
  160. inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
  161. return std::min(
  162. (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
  163. max_iterate_size(iter_tail...));
  164. }
  165. inline void iterate_overflow(){};
  166. template <typename Arg, typename... Args>
  167. inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
  168. if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
  169. for (int64_t i = iter.dim_ - 1; i > 0; i--) {
  170. if (iter.counter_[i] == iter.sizes_[i]) {
  171. iter.counter_[i] = 0;
  172. iter.counter_[i - 1]++;
  173. iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
  174. iter.strides_[i - 1];
  175. }
  176. }
  177. }
  178. iterate_overflow(iter_tail...);
  179. }
  180. inline void forward(int64_t /*offset*/){};
  181. template <typename Arg, typename... Args>
  182. inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
  183. int64_t multi = offset;
  184. for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
  185. int64_t inc = multi % iter.sizes_[i];
  186. multi = multi / iter.sizes_[i];
  187. iter.data_ = iter.data_ + inc * iter.strides_[i];
  188. iter.counter_[i] += inc;
  189. }
  190. forward(offset, iter_tail...);
  191. }
  192. inline int64_t max_dim() {
  193. return 0;
  194. }
  195. template <typename Arg, typename... Args>
  196. inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
  197. return std::max(iter.dim_, max_dim(iter_tail...));
  198. }
  199. inline void apply_op(){};
  200. template <typename Op, typename... Args>
  201. inline void
  202. apply_op(int64_t numel, int64_t offset, const Op& op, Args... iters) {
  203. // For 0-dim tensors
  204. if (numel == 1 && max_dim(iters...) == 0) {
  205. op(*iters.data_...);
  206. return;
  207. }
  208. if (offset > 0)
  209. forward(offset, iters...);
  210. // Splitting this into chunks helps the compiler create faster assembly
  211. for (int64_t i = 0; i < numel;) {
  212. for (; iterate_continue(iters...) && i < numel;) {
  213. op(*iters.data_...);
  214. iterate(1, iters...);
  215. i++;
  216. }
  217. iterate_overflow(iters...);
  218. }
  219. }
  220. /*
  221. Apply a pointwise operator to sequence of tensors
  222. The calling convention for op is a function/functor that takes the same
  223. number of pointers of type scalar as the number of given tensors. For example,
  224. to compute a = b * c, op would be of the form:
  225. [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
  226. b_val[0] * c_val[0]; };
  227. */
  228. template <typename scalar1, typename scalar2, typename Op>
  229. inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
  230. if (!_apply_preamble({tensor1, tensor2}))
  231. return;
  232. if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
  233. apply_op(
  234. tensor1.numel(),
  235. 0,
  236. op,
  237. strided_tensor_iter_fixed<scalar1, 8>(tensor1),
  238. strided_tensor_iter_fixed<scalar2, 8>(tensor2));
  239. } else {
  240. apply_op(
  241. tensor1.numel(),
  242. 0,
  243. op,
  244. strided_tensor_iter<scalar1>(tensor1),
  245. strided_tensor_iter<scalar2>(tensor2));
  246. }
  247. }
  248. template <typename scalar1, typename scalar2, typename scalar3, typename Op>
  249. inline void
  250. CPU_tensor_apply3(Tensor tensor1, Tensor tensor2, Tensor tensor3, const Op op) {
  251. if (!_apply_preamble({tensor1, tensor2, tensor3}))
  252. return;
  253. if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
  254. apply_op(
  255. tensor1.numel(),
  256. 0,
  257. op,
  258. strided_tensor_iter_fixed<scalar1, 8>(tensor1),
  259. strided_tensor_iter_fixed<scalar2, 8>(tensor2),
  260. strided_tensor_iter_fixed<scalar3, 8>(tensor3));
  261. } else {
  262. apply_op(
  263. tensor1.numel(),
  264. 0,
  265. op,
  266. strided_tensor_iter<scalar1>(tensor1),
  267. strided_tensor_iter<scalar2>(tensor2),
  268. strided_tensor_iter<scalar3>(tensor3));
  269. }
  270. }
  271. template <
  272. typename scalar1,
  273. typename scalar2,
  274. typename scalar3,
  275. typename scalar4,
  276. typename Op>
  277. inline void CPU_tensor_apply4(
  278. Tensor tensor1,
  279. Tensor tensor2,
  280. Tensor tensor3,
  281. Tensor tensor4,
  282. const Op op) {
  283. if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
  284. return;
  285. if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
  286. apply_op(
  287. tensor1.numel(),
  288. 0,
  289. op,
  290. strided_tensor_iter_fixed<scalar1, 8>(tensor1),
  291. strided_tensor_iter_fixed<scalar2, 8>(tensor2),
  292. strided_tensor_iter_fixed<scalar3, 8>(tensor3),
  293. strided_tensor_iter_fixed<scalar4, 8>(tensor4));
  294. } else {
  295. apply_op(
  296. tensor1.numel(),
  297. 0,
  298. op,
  299. strided_tensor_iter<scalar1>(tensor1),
  300. strided_tensor_iter<scalar2>(tensor2),
  301. strided_tensor_iter<scalar3>(tensor3),
  302. strided_tensor_iter<scalar4>(tensor4));
  303. }
  304. }
  305. } // namespace at