mirror of https://github.com/llvm/torch-mlir
Update TorchMlirBackendImpl Methods (#1580)
* Fix LTC build * Remove passing test from xfail setpull/1576/head snapshot-20221114.657
parent
eec9a7e022
commit
6909eaf7fc
|
@ -22,7 +22,6 @@ EAGER_MODE_XFAIL_SET = {
|
||||||
"QuantizedMLP_basic",
|
"QuantizedMLP_basic",
|
||||||
"Matmul_vecmat",
|
"Matmul_vecmat",
|
||||||
"BatchMlpLayerModule_basic",
|
"BatchMlpLayerModule_basic",
|
||||||
"UpSampleNearest2dDynamicFactor_basic",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MHLO_PASS_SET = {
|
MHLO_PASS_SET = {
|
||||||
|
|
|
@ -106,9 +106,9 @@ BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
|
||||||
}
|
}
|
||||||
|
|
||||||
BackendDataPtr
|
BackendDataPtr
|
||||||
TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const {
|
TorchMlirBackendImpl::GetComputationDataFromNode(const Node* node) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
auto* device_data_node = dynamic_cast<DeviceData*>(node);
|
const auto* device_data_node = dynamic_cast<const DeviceData*>(node);
|
||||||
if (!device_data_node) {
|
if (!device_data_node) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -141,11 +141,11 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
|
||||||
|
|
||||||
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
|
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
|
||||||
const std::string& name, BackendDevice device,
|
const std::string& name, BackendDevice device,
|
||||||
c10::ArrayRef<Node*> post_order, Util::EmissionMap emit_status) const {
|
c10::ArrayRef<const Node*> post_order, Util::EmissionMap emit_status) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_unique<TorchMlirLoweringContext>(
|
return std::make_unique<TorchMlirLoweringContext>(
|
||||||
name, std::forward<BackendDevice>(device),
|
name, std::forward<BackendDevice>(device),
|
||||||
std::forward<c10::ArrayRef<Node*>>(post_order),
|
std::forward<c10::ArrayRef<const Node*>>(post_order),
|
||||||
std::forward<Util::EmissionMap>(emit_status));
|
std::forward<Util::EmissionMap>(emit_status));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -103,7 +103,7 @@ public:
|
||||||
|
|
||||||
// Gets backend data if the node is a device data node. Otherwise returns
|
// Gets backend data if the node is a device data node. Otherwise returns
|
||||||
// nullptr.
|
// nullptr.
|
||||||
virtual BackendDataPtr GetComputationDataFromNode(Node*) const override;
|
virtual BackendDataPtr GetComputationDataFromNode(const Node*) const override;
|
||||||
|
|
||||||
virtual at::Tensor MakeTensorFromComputationData(
|
virtual at::Tensor MakeTensorFromComputationData(
|
||||||
const BackendDataPtr data,
|
const BackendDataPtr data,
|
||||||
|
@ -115,7 +115,7 @@ public:
|
||||||
|
|
||||||
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
|
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
|
||||||
const std::string& name, BackendDevice device,
|
const std::string& name, BackendDevice device,
|
||||||
c10::ArrayRef<Node*> post_order,
|
c10::ArrayRef<const Node*> post_order,
|
||||||
Util::EmissionMap emit_status) const override;
|
Util::EmissionMap emit_status) const override;
|
||||||
|
|
||||||
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
|
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
|
||||||
|
|
|
@ -45,10 +45,10 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||||
|
|
||||||
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||||
const std::string& name, BackendDevice device,
|
const std::string& name, BackendDevice device,
|
||||||
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
|
c10::ArrayRef<const torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
|
||||||
: LoweringContext(
|
: LoweringContext(
|
||||||
name, std::forward<BackendDevice>(device),
|
name, std::forward<BackendDevice>(device),
|
||||||
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
std::forward<c10::ArrayRef<const torch::lazy::Node*>>(post_order),
|
||||||
std::forward<Util::EmissionMap>(emit_status)),
|
std::forward<Util::EmissionMap>(emit_status)),
|
||||||
graph_(std::make_shared<torch::jit::Graph>()),
|
graph_(std::make_shared<torch::jit::Graph>()),
|
||||||
function_(
|
function_(
|
||||||
|
|
|
@ -43,7 +43,7 @@ public:
|
||||||
const std::string& name, torch::lazy::BackendDevice device);
|
const std::string& name, torch::lazy::BackendDevice device);
|
||||||
TorchMlirLoweringContext(
|
TorchMlirLoweringContext(
|
||||||
const std::string& name, torch::lazy::BackendDevice device,
|
const std::string& name, torch::lazy::BackendDevice device,
|
||||||
c10::ArrayRef<torch::lazy::Node*> post_order,
|
c10::ArrayRef<const torch::lazy::Node*> post_order,
|
||||||
torch::lazy::Util::EmissionMap emit_status);
|
torch::lazy::Util::EmissionMap emit_status);
|
||||||
|
|
||||||
void Lower(const Node* node);
|
void Lower(const Node* node);
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
3b29687f0f43615cf2f959731f1a25c9aad1eeec
|
637228bcc4d2566fb617bbf1c4abeff69b3bdae7
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||||
--pre
|
--pre
|
||||||
torch==1.14.0.dev20221109
|
torch==1.14.0.dev20221113
|
||||||
|
|
Loading…
Reference in New Issue