Update TorchMlirBackendImpl Methods (#1580)

* Fix LTC build

* Remove passing test from xfail set
pull/1576/head snapshot-20221114.657
Gleb Kazantaev 2022-11-14 00:37:49 -05:00 committed by GitHub
parent eec9a7e022
commit 6909eaf7fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 11 additions and 12 deletions

View File

@ -22,7 +22,6 @@ EAGER_MODE_XFAIL_SET = {
"QuantizedMLP_basic",
"Matmul_vecmat",
"BatchMlpLayerModule_basic",
"UpSampleNearest2dDynamicFactor_basic",
}
MHLO_PASS_SET = {

View File

@ -106,9 +106,9 @@ BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
}
BackendDataPtr
TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const {
TorchMlirBackendImpl::GetComputationDataFromNode(const Node* node) const {
PRINT_FUNCTION();
auto* device_data_node = dynamic_cast<DeviceData*>(node);
const auto* device_data_node = dynamic_cast<const DeviceData*>(node);
if (!device_data_node) {
return nullptr;
}
@ -141,11 +141,11 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
const std::string& name, BackendDevice device,
c10::ArrayRef<Node*> post_order, Util::EmissionMap emit_status) const {
c10::ArrayRef<const Node*> post_order, Util::EmissionMap emit_status) const {
PRINT_FUNCTION();
return std::make_unique<TorchMlirLoweringContext>(
name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<Node*>>(post_order),
std::forward<c10::ArrayRef<const Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status));
}

View File

@ -103,7 +103,7 @@ public:
// Gets backend data if the node is a device data node. Otherwise returns
// nullptr.
virtual BackendDataPtr GetComputationDataFromNode(Node*) const override;
virtual BackendDataPtr GetComputationDataFromNode(const Node*) const override;
virtual at::Tensor MakeTensorFromComputationData(
const BackendDataPtr data,
@ -115,7 +115,7 @@ public:
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
const std::string& name, BackendDevice device,
c10::ArrayRef<Node*> post_order,
c10::ArrayRef<const Node*> post_order,
Util::EmissionMap emit_status) const override;
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(

View File

@ -45,10 +45,10 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
c10::ArrayRef<const torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
: LoweringContext(
name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<c10::ArrayRef<const torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)),
graph_(std::make_shared<torch::jit::Graph>()),
function_(

View File

@ -43,7 +43,7 @@ public:
const std::string& name, torch::lazy::BackendDevice device);
TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
c10::ArrayRef<const torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);
void Lower(const Node* node);

View File

@ -1 +1 @@
3b29687f0f43615cf2f959731f1a25c9aad1eeec
637228bcc4d2566fb617bbf1c4abeff69b3bdae7

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==1.14.0.dev20221109
torch==1.14.0.dev20221113