[onnx] Lowering for `onnx.shape` to `torch` and `tensor` (#2648)

Includes the lowering from the `aten` equivalent to `tensor` operations.
pull/2658/head
Rob Suderman 2023-12-15 11:37:49 -08:00 committed by GitHub
parent 55e9401c5c
commit 061af696ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 170 additions and 3 deletions

View File

@ -105,6 +105,15 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
} }
def ConvertTorchToTensor : Pass<"convert-torch-to-tensor", "func::FuncOp"> {
let summary = "Convert Torch ops to the Tensor dialect";
let description = [{
Converts any `Torch` operators that were expressible as `Tensor` dialect
operations.
}];
let constructor = "mlir::torch::createConvertTorchToTensorPass()";
}
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA ops"; let summary = "Convert Torch ops to TOSA ops";
let description = [{ let description = [{

View File

@ -0,0 +1,23 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H
#define TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTensorPass();
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H

View File

@ -2,6 +2,7 @@ add_subdirectory(TorchOnnxToTorch)
add_subdirectory(TorchToLinalg) add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF) add_subdirectory(TorchToSCF)
add_subdirectory(TorchToArith) add_subdirectory(TorchToArith)
add_subdirectory(TorchToTensor)
add_subdirectory(TorchToTosa) add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
add_subdirectory(TorchToStablehlo) add_subdirectory(TorchToStablehlo)
@ -14,6 +15,7 @@ add_subdirectory(Utils)
set(linked_libs TorchMLIRTorchToLinalg set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchToSCF TorchMLIRTorchToSCF
TorchMLIRTorchToArith TorchMLIRTorchToArith
TorchMLIRTorchToTensor
TorchMLIRTorchToTosa TorchMLIRTorchToTosa
TorchMLIRTorchToTMTensor TorchMLIRTorchToTMTensor
TorchMLIRTorchConversionToMLProgram TorchMLIRTorchConversionToMLProgram

View File

@ -13,12 +13,13 @@
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#endif // TORCH_MLIR_ENABLE_STABLEHLO #endif // TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" #include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pass registration // Pass registration

View File

@ -459,4 +459,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, operand, vAlpha, vScale, vInputScale); binder.op, resultType, operand, vAlpha, vScale, vInputScale);
return success(); return success();
}); });
patterns.onOp("Shape", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
binder.op, resultType, operand);
return success();
});
} }

View File

@ -0,0 +1,18 @@
add_mlir_conversion_library(TorchMLIRTorchToTensor
TorchToTensor.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTensor
DEPENDS
TorchMLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRTensorDialect
TorchMLIRTorchDialect
TorchMLIRConversionUtils
)
torch_mlir_target_includes(TorchMLIRTorchToTensor)

View File

@ -0,0 +1,93 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v3.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-1.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class ConvertAtenShapeToTensorPatternOp
: public OpConversionPattern<Aten_ShapeAsTensorOp> {
public:
using OpConversionPattern<Aten_ShapeAsTensorOp>::OpConversionPattern;
using OpAdaptor = typename Aten_ShapeAsTensorOp::Adaptor;
LogicalResult
matchAndRewrite(Aten_ShapeAsTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto operand = adaptor.getOperands()[0];
auto operandTy = operand.getType().cast<RankedTensorType>();
auto resultTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t rank = operandTy.getRank();
SmallVector<Value> dims;
for (int i = 0; i < rank; ++i) {
Value dim = rewriter.createOrFold<tensor::DimOp>(loc, operand, i);
dim = rewriter.createOrFold<arith::IndexCastOp>(
loc, resultTy.getElementType(), dim);
dims.push_back(dim);
}
Value tensor =
rewriter.createOrFold<tensor::FromElementsOp>(op.getLoc(), dims);
rewriter.replaceOp(op, tensor);
return success();
}
};
class ConvertTorchToTensor
: public ConvertTorchToTensorBase<ConvertTorchToTensor> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<tensor::TensorDialect>();
target.addIllegalOp<Torch::Aten_ShapeAsTensorOp>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context);
patterns.add<ConvertAtenShapeToTensorPatternOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTensorPass() {
return std::make_unique<ConvertTorchToTensor>();
}

View File

@ -0,0 +1,8 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tensor | FileCheck %s
// CHECK-LABEL: func.func @test_shape
func.func @test_shape(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3],si64> {
// CHECK: %[[SHAPE:.+]] = arith.constant dense<[3, 4, 5]> : tensor<3xi64>
%0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3],si64>
return %0 : !torch.vtensor<[3],si64>
}