ExpandBase.h 883 B

1234567891011121314151617181920212223
  1. #include <ATen/core/TensorBase.h>
  2. // Broadcasting utilities for working with TensorBase
  3. namespace at {
  4. namespace internal {
  5. TORCH_API TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size);
  6. } // namespace internal
  7. inline c10::MaybeOwned<TensorBase> expand_size(const TensorBase &self, IntArrayRef size) {
  8. if (size.equals(self.sizes())) {
  9. return c10::MaybeOwned<TensorBase>::borrowed(self);
  10. }
  11. return c10::MaybeOwned<TensorBase>::owned(
  12. at::internal::expand_slow_path(self, size));
  13. }
  14. c10::MaybeOwned<TensorBase> expand_size(TensorBase &&self, IntArrayRef size) = delete;
  15. inline c10::MaybeOwned<TensorBase> expand_inplace(const TensorBase &tensor, const TensorBase &to_expand) {
  16. return expand_size(to_expand, tensor.sizes());
  17. }
  18. c10::MaybeOwned<TensorBase> expand_inplace(const TensorBase &tensor, TensorBase &&to_expand) = delete;
  19. } // namespace at