From 061af696ce94c932152bdf64ca7eba3b4034b367 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 15 Dec 2023 11:37:49 -0800 Subject: [PATCH] [onnx] Lowering for `onnx.shape` to `torch` and `tensor` (#2648) Includes the lowering from the `aten` equivalent to `tensor` operations. --- include/torch-mlir/Conversion/Passes.td | 9 ++ .../Conversion/TorchToTensor/TorchToTensor.h | 23 +++++ lib/Conversion/CMakeLists.txt | 2 + lib/Conversion/Passes.cpp | 7 +- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 13 +++ lib/Conversion/TorchToTensor/CMakeLists.txt | 18 ++++ .../TorchToTensor/TorchToTensor.cpp | 93 +++++++++++++++++++ .../TorchToTensor/torch_to_tensor.mlir | 8 ++ 8 files changed, 170 insertions(+), 3 deletions(-) create mode 100644 include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h create mode 100644 lib/Conversion/TorchToTensor/CMakeLists.txt create mode 100644 lib/Conversion/TorchToTensor/TorchToTensor.cpp create mode 100644 test/Conversion/TorchToTensor/torch_to_tensor.mlir diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 3a130f472..ed58c6995 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -105,6 +105,15 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; } +def ConvertTorchToTensor : Pass<"convert-torch-to-tensor", "func::FuncOp"> { + let summary = "Convert Torch ops to the Tensor dialect"; + let description = [{ + Converts any `Torch` operators that were expressible as `Tensor` dialect + operations. + }]; + let constructor = "mlir::torch::createConvertTorchToTensorPass()"; +} + def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { let summary = "Convert Torch ops to TOSA ops"; let description = [{ diff --git a/include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h b/include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h new file mode 100644 index 000000000..9dd5a6542 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h @@ -0,0 +1,23 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H +#define TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { +std::unique_ptr> createConvertTorchToTensorPass(); +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index afbe775d3..dd9e94a50 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(TorchOnnxToTorch) add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToArith) +add_subdirectory(TorchToTensor) add_subdirectory(TorchToTosa) if(TORCH_MLIR_ENABLE_STABLEHLO) add_subdirectory(TorchToStablehlo) @@ -14,6 +15,7 @@ add_subdirectory(Utils) set(linked_libs TorchMLIRTorchToLinalg TorchMLIRTorchToSCF TorchMLIRTorchToArith + TorchMLIRTorchToTensor TorchMLIRTorchToTosa TorchMLIRTorchToTMTensor TorchMLIRTorchConversionToMLProgram diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 0dae24678..b9af2afa3 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -13,12 +13,13 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #endif // TORCH_MLIR_ENABLE_STABLEHLO +#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" -#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" +#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" //===----------------------------------------------------------------------===// // Pass registration diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index a8fa8972f..482e20d6a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -459,4 +459,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); + + patterns.onOp("Shape", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); } diff --git a/lib/Conversion/TorchToTensor/CMakeLists.txt b/lib/Conversion/TorchToTensor/CMakeLists.txt new file mode 100644 index 000000000..21082d1d1 --- /dev/null +++ b/lib/Conversion/TorchToTensor/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(TorchMLIRTorchToTensor + TorchToTensor.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTensor + + DEPENDS + TorchMLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTensorDialect + TorchMLIRTorchDialect + TorchMLIRConversionUtils +) + +torch_mlir_target_includes(TorchMLIRTorchToTensor) diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp new file mode 100644 index 000000000..417fd17fc --- /dev/null +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v3.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-1.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" + +#include "../PassDetail.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +class ConvertAtenShapeToTensorPatternOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Aten_ShapeAsTensorOp::Adaptor; + LogicalResult + matchAndRewrite(Aten_ShapeAsTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto operand = adaptor.getOperands()[0]; + auto operandTy = operand.getType().cast(); + auto resultTy = + getTypeConverter()->convertType(op.getType()).cast(); + + int64_t rank = operandTy.getRank(); + SmallVector dims; + for (int i = 0; i < rank; ++i) { + Value dim = rewriter.createOrFold(loc, operand, i); + dim = rewriter.createOrFold( + loc, resultTy.getElementType(), dim); + dims.push_back(dim); + } + + Value tensor = + rewriter.createOrFold(op.getLoc(), dims); + rewriter.replaceOp(op, tensor); + return success(); + } +}; + +class ConvertTorchToTensor + : public ConvertTorchToTensorBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalOp(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::createConvertTorchToTensorPass() { + return std::make_unique(); +} diff --git a/test/Conversion/TorchToTensor/torch_to_tensor.mlir b/test/Conversion/TorchToTensor/torch_to_tensor.mlir new file mode 100644 index 000000000..277dabc3b --- /dev/null +++ b/test/Conversion/TorchToTensor/torch_to_tensor.mlir @@ -0,0 +1,8 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tensor | FileCheck %s + +// CHECK-LABEL: func.func @test_shape +func.func @test_shape(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3],si64> { + // CHECK: %[[SHAPE:.+]] = arith.constant dense<[3, 4, 5]> : tensor<3xi64> + %0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3],si64> + return %0 : !torch.vtensor<[3],si64> +}