diff --git a/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp index 7d7f16125..38b804c22 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp @@ -192,5 +192,13 @@ BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const { return BackendDevice(GetDefaultDeviceType(), device.index()); } +int64_t TorchMlirBackendImpl::GetDefaultDeviceOrdinal() const { + return default_device_ordinal; +} + +void TorchMlirBackendImpl::SetDefaultDeviceOrdinal(int64_t ordinal) { + default_device_ordinal = ordinal; +} + } // namespace lazy } // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h index cae7266f2..5ef11f662 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h +++ b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h @@ -153,6 +153,10 @@ public: // identity mappings. virtual BackendDevice GetBackendDevice(c10::Device device) const override; + virtual int64_t GetDefaultDeviceOrdinal() const override; + + virtual void SetDefaultDeviceOrdinal(int64_t ordinal) override; + /** * Debug/Metrics * */ @@ -164,6 +168,9 @@ public: // virtual std::string GetComputationBackendText( // const ComputationPtr computation // ) const = 0; + +protected: + int64_t default_device_ordinal = 0; }; } // namespace lazy diff --git a/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp index 8599d6db4..de6fb858f 100644 --- a/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp +++ b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp @@ -7,6 +7,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include #include @@ -26,17 +27,19 @@ namespace torch { namespace lazy { struct ReferenceLazyBackendDeviceType : public BackendDeviceType { - ReferenceLazyBackendDeviceType(std::string device_type) + ReferenceLazyBackendDeviceType(c10::DeviceType device_type) : device_type_(device_type) {} + ReferenceLazyBackendDeviceType(int8_t device_type) + : device_type_(static_cast(device_type)) {} - std::string toString() const override { return device_type_; } + std::string toString() const override { return c10::DeviceTypeName(device_type_); } - std::string device_type_; + c10::DeviceType device_type_; }; class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { public: - ReferenceLazyBackendImpl() : default_device_type_("Magic") {} + ReferenceLazyBackendImpl() : default_device_type_(c10::DeviceType::Lazy) {} /** * Configuration @@ -128,7 +131,8 @@ public: return std::make_shared(default_device_type_); } - void SetDefaultDeviceType(std::string device_type) { + + void SetDefaultDeviceType(int8_t device_type) override { default_device_type_ = ReferenceLazyBackendDeviceType(device_type); }