WrapDimUtils.h 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #pragma once
  2. #include <c10/core/WrapDimMinimal.h>
  3. #include <c10/core/TensorImpl.h>
  4. #include <c10/util/irange.h>
  5. #include <ATen/core/Tensor.h>
  6. #include <ATen/core/IListRef.h>
  7. namespace at {
  8. static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar=true) {
  9. // if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the range [-1, 0].
  10. // This is a special case for scalar tensors and manifests in e.g. torch.sum(scalar_tensor, 0)
  11. // Otherwise, dim should be in the range [-dim_post_expr, dim_post_expr-1].
  12. return c10::maybe_wrap_dim(dim, dim_post_expr, wrap_scalar);
  13. }
  14. static inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl *tensor) {
  15. return maybe_wrap_dim(dim, tensor->dim());
  16. }
  17. static inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
  18. if (tensors.size() == 0) {
  19. // can't wrap empty TensorList; rely on underlying implementation to throw error if necessary.
  20. return dim;
  21. }
  22. return maybe_wrap_dim(dim, tensors[0].dim());
  23. }
  24. static inline int64_t maybe_wrap_dim(int64_t dim, const std::vector<std::vector<int64_t>> & tensor_sizes) {
  25. if (tensor_sizes.size() == 0) {
  26. // can't wrap empty list; rely on underlying implementation to throw error if necessary
  27. return dim;
  28. }
  29. return maybe_wrap_dim(dim, tensor_sizes[0].size());
  30. }
  31. // wrap each dim in the dims array, taking dim_post_expr as the true number of dimensions
  32. static inline void maybe_wrap_dims_n(int64_t* dims, int64_t ndims, int64_t dim_post_expr) {
  33. if (dim_post_expr <= 0) {
  34. dim_post_expr = 1; // this will make range [-1, 0]
  35. }
  36. int64_t min = -dim_post_expr;
  37. int64_t max = dim_post_expr - 1;
  38. for (const auto i : c10::irange(ndims)) {
  39. auto &dim = dims[i];
  40. if (dim < min || dim > max) {
  41. TORCH_CHECK_INDEX(false,
  42. "Dimension out of range (expected to be in range of [",
  43. min, ", ", max, "], but got ", dim, ")");
  44. }
  45. if (dim < 0) dim += dim_post_expr;
  46. }
  47. }
  48. // Wrap each dim in a contiguous container, taking dim_post_expr as the true number of dimensions
  49. // E.g. could also be std::array or c10::SmallVector
  50. template <typename Container>
  51. inline void maybe_wrap_dims(Container& dims, int64_t dim_post_expr) {
  52. return maybe_wrap_dims_n(dims.data(), dims.size(), dim_post_expr);
  53. }
  54. // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
  55. // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
  56. // to be "skipped" (both for wrap dimension behavior and dimension size checking).
  57. // We maintain this behavior for backwards compatibility, but only for this specific size
  58. // (i.e. other empty sizes are not skipped).
  59. static inline int64_t legacy_cat_wrap_dim(int64_t dim, const std::vector<std::vector<int64_t>>& tensor_sizes) {
  60. for (auto& sizes : tensor_sizes) {
  61. if (sizes == std::vector<int64_t>({0})) {
  62. continue;
  63. }
  64. return maybe_wrap_dim(dim, sizes.size());
  65. }
  66. return dim;
  67. }
  68. static inline int64_t legacy_cat_wrap_dim(int64_t dim, ITensorListRef tensors) {
  69. for (auto& tensor : tensors) {
  70. if (tensor.dim() == 1 && tensor.sizes()[0] == 0) {
  71. continue;
  72. }
  73. return maybe_wrap_dim(dim, tensor.dim());
  74. }
  75. return dim;
  76. }
  77. // wrap negative dims in a vector
  78. static inline void wrap_all_dims(std::vector<int64_t>& dims_to_wrap, int64_t tensor_total_dims) {
  79. for (const auto i : c10::irange(dims_to_wrap.size())) {
  80. dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims);
  81. }
  82. }
  83. }