[ONNX] Remove kernel shape and weight shape equivalence check from Onnx.Conv lowering (#3869)

This commit removes the equivalence check for kernel shape and weight
shape from the Onnx.conv lowering since those checks seem to be of no
use (not sure why were they part of the lowering in the first place).

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3871/merge
Vivek Khandelwal 2024-11-15 10:36:41 +05:30 committed by GitHub
parent 06d17897f0
commit fe2f64919d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 29 additions and 35 deletions

View File

@ -7,12 +7,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/IR/DialectResourceBlobManager.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/FormatVariadic.h"
#include <numeric> #include <numeric>
using namespace mlir; using namespace mlir;
@ -1292,6 +1290,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
}); });
patterns.onOp( patterns.onOp(
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
Value input, weight; Value input, weight;
int64_t group; int64_t group;
@ -1316,14 +1315,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, binder.op,
"unsupported conversion: kernel_shape list size should have " "unsupported conversion: kernel_shape list size should have "
"number of values equal to weight_rank - 2"); "number of values equal to weight_rank - 2");
} else {
for (unsigned i = 0; i < kernelShape.size(); i++) {
if (weightShape[i + 2] != kernelShape[i]) {
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: kernel_shape value "
"should be equal to the weight tensor shape");
}
}
} }
} }
@ -1380,6 +1371,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
ArrayRef<int64_t> inputShape = inputTensorType.getSizes(); ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
padding.resize_for_overwrite(2 * spatialRank); padding.resize_for_overwrite(2 * spatialRank);
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) { for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
if (weightShape[dimIdx + 2] == Torch::kUnknownSize ||
inputShape[dimIdx + 2] == Torch::kUnknownSize)
return rewriter.notifyMatchFailure(
binder.op,
"expected weight and input tensor to have static shape");
const int64_t dilatedKernelSize = const int64_t dilatedKernelSize =
dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1; dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1;
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
@ -1405,10 +1401,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (padding.size() != 2 * (rank - 2)) { if (padding.size() != 2 * (rank - 2)) {
for (int64_t i : padding) { for (int64_t i : padding) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>( cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i))); loc, rewriter.getI64IntegerAttr(i)));
} }
paddingList = rewriter.create<Torch::PrimListConstructOp>( paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), loc,
Torch::ListType::get( Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())), Torch::IntType::get(binder.op->getContext())),
cstPadding); cstPadding);
@ -1431,10 +1427,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (matchedPads) { if (matchedPads) {
for (unsigned i = 0; i < padding.size() / 2; i++) { for (unsigned i = 0; i < padding.size() / 2; i++) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>( cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); loc, rewriter.getI64IntegerAttr(padding[i])));
} }
paddingList = rewriter.create<Torch::PrimListConstructOp>( paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), loc,
Torch::ListType::get( Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())), Torch::IntType::get(binder.op->getContext())),
cstPadding); cstPadding);
@ -1443,40 +1439,40 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
SmallVector<Value> inputPaddingList; SmallVector<Value> inputPaddingList;
for (uint32_t i = 0; i < padding.size() / 2; i++) { for (uint32_t i = 0; i < padding.size() / 2; i++) {
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>( padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr( loc, rewriter.getI64IntegerAttr(
padding[padding.size() / 2 - i - 1]))); padding[padding.size() / 2 - i - 1])));
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>( padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), loc,
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1]))); rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
inputPaddingList.emplace_back( inputPaddingList.emplace_back(
rewriter.create<Torch::ConstantIntOp>( rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0))); loc, rewriter.getI64IntegerAttr(0)));
} }
// The conv op itself will have no padding since the actual padding // The conv op itself will have no padding since the actual padding
// is performed using the torch.pad preceding it. // is performed using the torch.pad preceding it.
paddingList = rewriter.create<Torch::PrimListConstructOp>( paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), loc,
Torch::ListType::get( Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())), Torch::IntType::get(binder.op->getContext())),
inputPaddingList); inputPaddingList);
Value padsSizeList = Value padsSizeList =
rewriter rewriter
.create<Torch::PrimListConstructOp>( .create<Torch::PrimListConstructOp>(
binder.getLoc(), loc,
Torch::ListType::get( Torch::ListType::get(
rewriter.getType<Torch::IntType>()), rewriter.getType<Torch::IntType>()),
padsRearrange) padsRearrange)
.getResult(); .getResult();
Value modeVal = rewriter.create<Torch::ConstantStrOp>( Value modeVal = rewriter.create<Torch::ConstantStrOp>(
binder.getLoc(), rewriter.getStringAttr("constant")); loc, rewriter.getStringAttr("constant"));
Value constantValue; Value constantValue;
if (isa<IntegerType>(inputTensorType.getDtype())) if (isa<IntegerType>(inputTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantIntOp>( constantValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)); loc, rewriter.getI64IntegerAttr(0));
if (isa<FloatType>(inputTensorType.getDtype())) if (isa<FloatType>(inputTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantFloatOp>( constantValue = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); loc, rewriter.getF64FloatAttr(0.0f));
// Pad output shape must be computed explicitly from the pad values // Pad output shape must be computed explicitly from the pad values
SmallVector<int64_t> newInputShape(inputTensorType.getSizes()); SmallVector<int64_t> newInputShape(inputTensorType.getSizes());
for (uint32_t i = 0; i < padding.size() / 2; i++) { for (uint32_t i = 0; i < padding.size() / 2; i++) {
@ -1486,46 +1482,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
auto padTy = rewriter.getType<Torch::ValueTensorType>( auto padTy = rewriter.getType<Torch::ValueTensorType>(
newInputShape, inputTensorType.getDtype()); newInputShape, inputTensorType.getDtype());
paddedInput = rewriter.create<Torch::AtenPadOp>( paddedInput = rewriter.create<Torch::AtenPadOp>(
binder.getLoc(), padTy, input, padsSizeList, modeVal, loc, padTy, input, padsSizeList, modeVal, constantValue);
constantValue);
} }
} }
for (int64_t i : dilations) { for (int64_t i : dilations) {
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>( cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i))); loc, rewriter.getI64IntegerAttr(i)));
} }
for (int64_t i : strides) { for (int64_t i : strides) {
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>( cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i))); loc, rewriter.getI64IntegerAttr(i)));
} }
Value cstZero = rewriter.create<Torch::ConstantIntOp>( Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)); loc, rewriter.getI64IntegerAttr(0));
cstOutputPadding = {cstZero, cstZero}; cstOutputPadding = {cstZero, cstZero};
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>( Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), loc,
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstDilations); cstDilations);
Value stridesList = rewriter.create<Torch::PrimListConstructOp>( Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), loc,
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstStrides); cstStrides);
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>( Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), loc,
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstOutputPadding); cstOutputPadding);
Value transposed = Value transposed = rewriter.create<Torch::ConstantBoolOp>(loc, false);
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value bias; Value bias;
if (binder.op->getNumOperands() == 3) { if (binder.op->getNumOperands() == 3) {
if (binder.tensorOperandAtIndex(bias, 2)) { if (binder.tensorOperandAtIndex(bias, 2)) {
return failure(); return failure();
} }
} else { } else {
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc()); bias = rewriter.create<Torch::ConstantNoneOp>(loc);
} }
Value cstGroup = rewriter.create<Torch::ConstantIntOp>( Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(group)); loc, rewriter.getI64IntegerAttr(group));
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>( rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
binder.op, resultType, paddedInput, weight, bias, stridesList, binder.op, resultType, paddedInput, weight, bias, stridesList,