[ONNX] Remove kernel shape and weight shape equivalence check from Onnx.Conv lowering (#3869)

This commit removes the equivalence check for kernel shape and weight
shape from the Onnx.conv lowering since those checks seem to be of no
use (not sure why were they part of the lowering in the first place).

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3732/merge
Vivek Khandelwal 2024-11-15 10:36:41 +05:30 committed by GitHub
parent 06d17897f0
commit fe2f64919d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 29 additions and 35 deletions

View File

@ -7,12 +7,10 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/DialectResourceBlobManager.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/FormatVariadic.h"
#include <numeric>
using namespace mlir;
@ -1292,6 +1290,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
});
patterns.onOp(
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
Value input, weight;
int64_t group;
@ -1316,14 +1315,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op,
"unsupported conversion: kernel_shape list size should have "
"number of values equal to weight_rank - 2");
} else {
for (unsigned i = 0; i < kernelShape.size(); i++) {
if (weightShape[i + 2] != kernelShape[i]) {
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: kernel_shape value "
"should be equal to the weight tensor shape");
}
}
}
}
@ -1380,6 +1371,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
padding.resize_for_overwrite(2 * spatialRank);
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
if (weightShape[dimIdx + 2] == Torch::kUnknownSize ||
inputShape[dimIdx + 2] == Torch::kUnknownSize)
return rewriter.notifyMatchFailure(
binder.op,
"expected weight and input tensor to have static shape");
const int64_t dilatedKernelSize =
dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1;
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
@ -1405,10 +1401,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (padding.size() != 2 * (rank - 2)) {
for (int64_t i : padding) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
loc, rewriter.getI64IntegerAttr(i)));
}
paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
cstPadding);
@ -1431,10 +1427,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (matchedPads) {
for (unsigned i = 0; i < padding.size() / 2; i++) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
loc, rewriter.getI64IntegerAttr(padding[i])));
}
paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
cstPadding);
@ -1443,40 +1439,40 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
SmallVector<Value> inputPaddingList;
for (uint32_t i = 0; i < padding.size() / 2; i++) {
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
padding[padding.size() / 2 - i - 1])));
loc, rewriter.getI64IntegerAttr(
padding[padding.size() / 2 - i - 1])));
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
loc,
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
inputPaddingList.emplace_back(
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
loc, rewriter.getI64IntegerAttr(0)));
}
// The conv op itself will have no padding since the actual padding
// is performed using the torch.pad preceding it.
paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
inputPaddingList);
Value padsSizeList =
rewriter
.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(
rewriter.getType<Torch::IntType>()),
padsRearrange)
.getResult();
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
binder.getLoc(), rewriter.getStringAttr("constant"));
loc, rewriter.getStringAttr("constant"));
Value constantValue;
if (isa<IntegerType>(inputTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
loc, rewriter.getI64IntegerAttr(0));
if (isa<FloatType>(inputTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
loc, rewriter.getF64FloatAttr(0.0f));
// Pad output shape must be computed explicitly from the pad values
SmallVector<int64_t> newInputShape(inputTensorType.getSizes());
for (uint32_t i = 0; i < padding.size() / 2; i++) {
@ -1486,46 +1482,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
auto padTy = rewriter.getType<Torch::ValueTensorType>(
newInputShape, inputTensorType.getDtype());
paddedInput = rewriter.create<Torch::AtenPadOp>(
binder.getLoc(), padTy, input, padsSizeList, modeVal,
constantValue);
loc, padTy, input, padsSizeList, modeVal, constantValue);
}
}
for (int64_t i : dilations) {
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
loc, rewriter.getI64IntegerAttr(i)));
}
for (int64_t i : strides) {
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
loc, rewriter.getI64IntegerAttr(i)));
}
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
loc, rewriter.getI64IntegerAttr(0));
cstOutputPadding = {cstZero, cstZero};
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstDilations);
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstStrides);
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstOutputPadding);
Value transposed =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value transposed = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value bias;
if (binder.op->getNumOperands() == 3) {
if (binder.tensorOperandAtIndex(bias, 2)) {
return failure();
}
} else {
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
bias = rewriter.create<Torch::ConstantNoneOp>(loc);
}
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(group));
loc, rewriter.getI64IntegerAttr(group));
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
binder.op, resultType, paddedInput, weight, bias, stridesList,