DeviceGuard.h 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. #pragma once
  2. #include <c10/core/DeviceGuard.h>
  3. #include <ATen/core/Tensor.h>
  4. #include <c10/core/ScalarType.h> // TensorList whyyyyy
  5. namespace at {
  6. // Are you here because you're wondering why DeviceGuard(tensor) no
  7. // longer works? For code organization reasons, we have temporarily(?)
  8. // removed this constructor from DeviceGuard. The new way to
  9. // spell it is:
  10. //
  11. // OptionalDeviceGuard guard(device_of(tensor));
  12. /// Return the Device of a Tensor, if the Tensor is defined.
  13. inline c10::optional<Device> device_of(const Tensor& t) {
  14. if (t.defined()) {
  15. return c10::make_optional(t.device());
  16. } else {
  17. return c10::nullopt;
  18. }
  19. }
  20. inline c10::optional<Device> device_of(const c10::optional<Tensor>& t) {
  21. return t.has_value() ? device_of(t.value()) : nullopt;
  22. }
  23. /// Return the Device of a TensorList, if the list is non-empty and
  24. /// the first Tensor is defined. (This function implicitly assumes
  25. /// that all tensors in the list have the same device.)
  26. inline c10::optional<Device> device_of(TensorList t) {
  27. if (!t.empty()) {
  28. return device_of(t.front());
  29. } else {
  30. return c10::nullopt;
  31. }
  32. }
  33. } // namespace at