Default Device Ordinal API (#1079)

* Add default device ordinal API

* Fix reference backend
pull/1125/head
Jae Hoon (Antonio) Kim 2022-07-19 09:19:12 -04:00 committed by Henry Tu
parent de6c135dc3
commit e37891b997
3 changed files with 24 additions and 5 deletions

View File

@ -192,5 +192,13 @@ BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const {
return BackendDevice(GetDefaultDeviceType(), device.index()); return BackendDevice(GetDefaultDeviceType(), device.index());
} }
int64_t TorchMlirBackendImpl::GetDefaultDeviceOrdinal() const {
return default_device_ordinal;
}
void TorchMlirBackendImpl::SetDefaultDeviceOrdinal(int64_t ordinal) {
default_device_ordinal = ordinal;
}
} // namespace lazy } // namespace lazy
} // namespace torch } // namespace torch

View File

@ -153,6 +153,10 @@ public:
// identity mappings. // identity mappings.
virtual BackendDevice GetBackendDevice(c10::Device device) const override; virtual BackendDevice GetBackendDevice(c10::Device device) const override;
virtual int64_t GetDefaultDeviceOrdinal() const override;
virtual void SetDefaultDeviceOrdinal(int64_t ordinal) override;
/** /**
* Debug/Metrics * Debug/Metrics
* */ * */
@ -164,6 +168,9 @@ public:
// virtual std::string GetComputationBackendText( // virtual std::string GetComputationBackendText(
// const ComputationPtr computation // const ComputationPtr computation
// ) const = 0; // ) const = 0;
protected:
int64_t default_device_ordinal = 0;
}; };
} // namespace lazy } // namespace lazy

View File

@ -7,6 +7,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <c10/core/DeviceType.h>
#include <torch/csrc/lazy/backend/backend_data.h> #include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h> #include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h> #include <torch/csrc/lazy/backend/lowering_context.h>
@ -26,17 +27,19 @@ namespace torch {
namespace lazy { namespace lazy {
struct ReferenceLazyBackendDeviceType : public BackendDeviceType { struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
ReferenceLazyBackendDeviceType(std::string device_type) ReferenceLazyBackendDeviceType(c10::DeviceType device_type)
: device_type_(device_type) {} : device_type_(device_type) {}
ReferenceLazyBackendDeviceType(int8_t device_type)
: device_type_(static_cast<c10::DeviceType>(device_type)) {}
std::string toString() const override { return device_type_; } std::string toString() const override { return c10::DeviceTypeName(device_type_); }
std::string device_type_; c10::DeviceType device_type_;
}; };
class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
public: public:
ReferenceLazyBackendImpl() : default_device_type_("Magic") {} ReferenceLazyBackendImpl() : default_device_type_(c10::DeviceType::Lazy) {}
/** /**
* Configuration * Configuration
@ -128,7 +131,8 @@ public:
return std::make_shared<BackendDeviceType>(default_device_type_); return std::make_shared<BackendDeviceType>(default_device_type_);
} }
void SetDefaultDeviceType(std::string device_type) {
void SetDefaultDeviceType(int8_t device_type) override {
default_device_type_ = ReferenceLazyBackendDeviceType(device_type); default_device_type_ = ReferenceLazyBackendDeviceType(device_type);
} }