mirror of https://github.com/llvm/torch-mlir
Default Device Ordinal API (#1079)
* Add default device ordinal API * Fix reference backendpull/1125/head
parent
de6c135dc3
commit
e37891b997
|
@ -192,5 +192,13 @@ BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const {
|
||||||
return BackendDevice(GetDefaultDeviceType(), device.index());
|
return BackendDevice(GetDefaultDeviceType(), device.index());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t TorchMlirBackendImpl::GetDefaultDeviceOrdinal() const {
|
||||||
|
return default_device_ordinal;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TorchMlirBackendImpl::SetDefaultDeviceOrdinal(int64_t ordinal) {
|
||||||
|
default_device_ordinal = ordinal;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -153,6 +153,10 @@ public:
|
||||||
// identity mappings.
|
// identity mappings.
|
||||||
virtual BackendDevice GetBackendDevice(c10::Device device) const override;
|
virtual BackendDevice GetBackendDevice(c10::Device device) const override;
|
||||||
|
|
||||||
|
virtual int64_t GetDefaultDeviceOrdinal() const override;
|
||||||
|
|
||||||
|
virtual void SetDefaultDeviceOrdinal(int64_t ordinal) override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Debug/Metrics
|
* Debug/Metrics
|
||||||
* */
|
* */
|
||||||
|
@ -164,6 +168,9 @@ public:
|
||||||
// virtual std::string GetComputationBackendText(
|
// virtual std::string GetComputationBackendText(
|
||||||
// const ComputationPtr computation
|
// const ComputationPtr computation
|
||||||
// ) const = 0;
|
// ) const = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int64_t default_device_ordinal = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <c10/core/DeviceType.h>
|
||||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||||
#include <torch/csrc/lazy/backend/backend_device.h>
|
#include <torch/csrc/lazy/backend/backend_device.h>
|
||||||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||||
|
@ -26,17 +27,19 @@ namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
|
struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
|
||||||
ReferenceLazyBackendDeviceType(std::string device_type)
|
ReferenceLazyBackendDeviceType(c10::DeviceType device_type)
|
||||||
: device_type_(device_type) {}
|
: device_type_(device_type) {}
|
||||||
|
ReferenceLazyBackendDeviceType(int8_t device_type)
|
||||||
|
: device_type_(static_cast<c10::DeviceType>(device_type)) {}
|
||||||
|
|
||||||
std::string toString() const override { return device_type_; }
|
std::string toString() const override { return c10::DeviceTypeName(device_type_); }
|
||||||
|
|
||||||
std::string device_type_;
|
c10::DeviceType device_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
|
class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
|
||||||
public:
|
public:
|
||||||
ReferenceLazyBackendImpl() : default_device_type_("Magic") {}
|
ReferenceLazyBackendImpl() : default_device_type_(c10::DeviceType::Lazy) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Configuration
|
* Configuration
|
||||||
|
@ -128,7 +131,8 @@ public:
|
||||||
return std::make_shared<BackendDeviceType>(default_device_type_);
|
return std::make_shared<BackendDeviceType>(default_device_type_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetDefaultDeviceType(std::string device_type) {
|
|
||||||
|
void SetDefaultDeviceType(int8_t device_type) override {
|
||||||
default_device_type_ = ReferenceLazyBackendDeviceType(device_type);
|
default_device_type_ = ReferenceLazyBackendDeviceType(device_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue