mirror of https://github.com/llvm/torch-mlir
Default Device Ordinal API (#1079)
* Add default device ordinal API * Fix reference backendpull/1125/head
parent
de6c135dc3
commit
e37891b997
|
@ -192,5 +192,13 @@ BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const {
|
|||
return BackendDevice(GetDefaultDeviceType(), device.index());
|
||||
}
|
||||
|
||||
int64_t TorchMlirBackendImpl::GetDefaultDeviceOrdinal() const {
|
||||
return default_device_ordinal;
|
||||
}
|
||||
|
||||
void TorchMlirBackendImpl::SetDefaultDeviceOrdinal(int64_t ordinal) {
|
||||
default_device_ordinal = ordinal;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -153,6 +153,10 @@ public:
|
|||
// identity mappings.
|
||||
virtual BackendDevice GetBackendDevice(c10::Device device) const override;
|
||||
|
||||
virtual int64_t GetDefaultDeviceOrdinal() const override;
|
||||
|
||||
virtual void SetDefaultDeviceOrdinal(int64_t ordinal) override;
|
||||
|
||||
/**
|
||||
* Debug/Metrics
|
||||
* */
|
||||
|
@ -164,6 +168,9 @@ public:
|
|||
// virtual std::string GetComputationBackendText(
|
||||
// const ComputationPtr computation
|
||||
// ) const = 0;
|
||||
|
||||
protected:
|
||||
int64_t default_device_ordinal = 0;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||
#include <torch/csrc/lazy/backend/backend_device.h>
|
||||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||
|
@ -26,17 +27,19 @@ namespace torch {
|
|||
namespace lazy {
|
||||
|
||||
struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
|
||||
ReferenceLazyBackendDeviceType(std::string device_type)
|
||||
ReferenceLazyBackendDeviceType(c10::DeviceType device_type)
|
||||
: device_type_(device_type) {}
|
||||
ReferenceLazyBackendDeviceType(int8_t device_type)
|
||||
: device_type_(static_cast<c10::DeviceType>(device_type)) {}
|
||||
|
||||
std::string toString() const override { return device_type_; }
|
||||
std::string toString() const override { return c10::DeviceTypeName(device_type_); }
|
||||
|
||||
std::string device_type_;
|
||||
c10::DeviceType device_type_;
|
||||
};
|
||||
|
||||
class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
|
||||
public:
|
||||
ReferenceLazyBackendImpl() : default_device_type_("Magic") {}
|
||||
ReferenceLazyBackendImpl() : default_device_type_(c10::DeviceType::Lazy) {}
|
||||
|
||||
/**
|
||||
* Configuration
|
||||
|
@ -128,7 +131,8 @@ public:
|
|||
return std::make_shared<BackendDeviceType>(default_device_type_);
|
||||
}
|
||||
|
||||
void SetDefaultDeviceType(std::string device_type) {
|
||||
|
||||
void SetDefaultDeviceType(int8_t device_type) override {
|
||||
default_device_type_ = ReferenceLazyBackendDeviceType(device_type);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue