Update TorchMlirBackendImpl Methods (#1580)

* Fix LTC build

* Remove passing test from xfail set
pull/1576/head snapshot-20221114.657
Gleb Kazantaev 2022-11-14 00:37:49 -05:00 committed by GitHub
parent eec9a7e022
commit 6909eaf7fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 11 additions and 12 deletions

View File

@ -22,7 +22,6 @@ EAGER_MODE_XFAIL_SET = {
"QuantizedMLP_basic", "QuantizedMLP_basic",
"Matmul_vecmat", "Matmul_vecmat",
"BatchMlpLayerModule_basic", "BatchMlpLayerModule_basic",
"UpSampleNearest2dDynamicFactor_basic",
} }
MHLO_PASS_SET = { MHLO_PASS_SET = {

View File

@ -106,9 +106,9 @@ BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
} }
BackendDataPtr BackendDataPtr
TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const { TorchMlirBackendImpl::GetComputationDataFromNode(const Node* node) const {
PRINT_FUNCTION(); PRINT_FUNCTION();
auto* device_data_node = dynamic_cast<DeviceData*>(node); const auto* device_data_node = dynamic_cast<const DeviceData*>(node);
if (!device_data_node) { if (!device_data_node) {
return nullptr; return nullptr;
} }
@ -141,11 +141,11 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext( std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
const std::string& name, BackendDevice device, const std::string& name, BackendDevice device,
c10::ArrayRef<Node*> post_order, Util::EmissionMap emit_status) const { c10::ArrayRef<const Node*> post_order, Util::EmissionMap emit_status) const {
PRINT_FUNCTION(); PRINT_FUNCTION();
return std::make_unique<TorchMlirLoweringContext>( return std::make_unique<TorchMlirLoweringContext>(
name, std::forward<BackendDevice>(device), name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<Node*>>(post_order), std::forward<c10::ArrayRef<const Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)); std::forward<Util::EmissionMap>(emit_status));
} }

View File

@ -103,7 +103,7 @@ public:
// Gets backend data if the node is a device data node. Otherwise returns // Gets backend data if the node is a device data node. Otherwise returns
// nullptr. // nullptr.
virtual BackendDataPtr GetComputationDataFromNode(Node*) const override; virtual BackendDataPtr GetComputationDataFromNode(const Node*) const override;
virtual at::Tensor MakeTensorFromComputationData( virtual at::Tensor MakeTensorFromComputationData(
const BackendDataPtr data, const BackendDataPtr data,
@ -115,7 +115,7 @@ public:
virtual std::unique_ptr<LoweringContext> CreateLoweringContext( virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
const std::string& name, BackendDevice device, const std::string& name, BackendDevice device,
c10::ArrayRef<Node*> post_order, c10::ArrayRef<const Node*> post_order,
Util::EmissionMap emit_status) const override; Util::EmissionMap emit_status) const override;
virtual std::unique_ptr<LoweringContext> CreateLoweringContext( virtual std::unique_ptr<LoweringContext> CreateLoweringContext(

View File

@ -45,10 +45,10 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
TorchMlirLoweringContext::TorchMlirLoweringContext( TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device, const std::string& name, BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status) c10::ArrayRef<const torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
: LoweringContext( : LoweringContext(
name, std::forward<BackendDevice>(device), name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order), std::forward<c10::ArrayRef<const torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)), std::forward<Util::EmissionMap>(emit_status)),
graph_(std::make_shared<torch::jit::Graph>()), graph_(std::make_shared<torch::jit::Graph>()),
function_( function_(

View File

@ -43,7 +43,7 @@ public:
const std::string& name, torch::lazy::BackendDevice device); const std::string& name, torch::lazy::BackendDevice device);
TorchMlirLoweringContext( TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device, const std::string& name, torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order, c10::ArrayRef<const torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status); torch::lazy::Util::EmissionMap emit_status);
void Lower(const Node* node); void Lower(const Node* node);

View File

@ -1 +1 @@
3b29687f0f43615cf2f959731f1a25c9aad1eeec 637228bcc4d2566fb617bbf1c4abeff69b3bdae7

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre --pre
torch==1.14.0.dev20221109 torch==1.14.0.dev20221113