Default Device Ordinal API (#1079)

* Add default device ordinal API

* Fix reference backend
pull/1125/head
Jae Hoon (Antonio) Kim 2022-07-19 09:19:12 -04:00 committed by Henry Tu
parent de6c135dc3
commit e37891b997
3 changed files with 24 additions and 5 deletions

View File

@ -192,5 +192,13 @@ BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const {
return BackendDevice(GetDefaultDeviceType(), device.index());
}
int64_t TorchMlirBackendImpl::GetDefaultDeviceOrdinal() const {
return default_device_ordinal;
}
void TorchMlirBackendImpl::SetDefaultDeviceOrdinal(int64_t ordinal) {
default_device_ordinal = ordinal;
}
} // namespace lazy
} // namespace torch

View File

@ -153,6 +153,10 @@ public:
// identity mappings.
virtual BackendDevice GetBackendDevice(c10::Device device) const override;
virtual int64_t GetDefaultDeviceOrdinal() const override;
virtual void SetDefaultDeviceOrdinal(int64_t ordinal) override;
/**
* Debug/Metrics
* */
@ -164,6 +168,9 @@ public:
// virtual std::string GetComputationBackendText(
// const ComputationPtr computation
// ) const = 0;
protected:
int64_t default_device_ordinal = 0;
};
} // namespace lazy

View File

@ -7,6 +7,7 @@
//
//===----------------------------------------------------------------------===//
#include <c10/core/DeviceType.h>
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
@ -26,17 +27,19 @@ namespace torch {
namespace lazy {
struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
ReferenceLazyBackendDeviceType(std::string device_type)
ReferenceLazyBackendDeviceType(c10::DeviceType device_type)
: device_type_(device_type) {}
ReferenceLazyBackendDeviceType(int8_t device_type)
: device_type_(static_cast<c10::DeviceType>(device_type)) {}
std::string toString() const override { return device_type_; }
std::string toString() const override { return c10::DeviceTypeName(device_type_); }
std::string device_type_;
c10::DeviceType device_type_;
};
class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
public:
ReferenceLazyBackendImpl() : default_device_type_("Magic") {}
ReferenceLazyBackendImpl() : default_device_type_(c10::DeviceType::Lazy) {}
/**
* Configuration
@ -128,7 +131,8 @@ public:
return std::make_shared<BackendDeviceType>(default_device_type_);
}
void SetDefaultDeviceType(std::string device_type) {
void SetDefaultDeviceType(int8_t device_type) override {
default_device_type_ = ReferenceLazyBackendDeviceType(device_type);
}