mirror of https://github.com/llvm/torch-mlir
[ONNX] Remove kernel shape and weight shape equivalence check from Onnx.Conv lowering (#3869)
This commit removes the equivalence check for kernel shape and weight shape from the Onnx.conv lowering since those checks seem to be of no use (not sure why were they part of the lowering in the first place). Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3871/merge
parent
06d17897f0
commit
fe2f64919d
|
@ -7,12 +7,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/DialectResourceBlobManager.h"
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -1292,6 +1290,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
});
|
||||
patterns.onOp(
|
||||
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Location loc = binder.getLoc();
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input, weight;
|
||||
int64_t group;
|
||||
|
@ -1316,14 +1315,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op,
|
||||
"unsupported conversion: kernel_shape list size should have "
|
||||
"number of values equal to weight_rank - 2");
|
||||
} else {
|
||||
for (unsigned i = 0; i < kernelShape.size(); i++) {
|
||||
if (weightShape[i + 2] != kernelShape[i]) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: kernel_shape value "
|
||||
"should be equal to the weight tensor shape");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1380,6 +1371,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||
padding.resize_for_overwrite(2 * spatialRank);
|
||||
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
|
||||
if (weightShape[dimIdx + 2] == Torch::kUnknownSize ||
|
||||
inputShape[dimIdx + 2] == Torch::kUnknownSize)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"expected weight and input tensor to have static shape");
|
||||
const int64_t dilatedKernelSize =
|
||||
dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1;
|
||||
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
|
||||
|
@ -1405,10 +1401,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
if (padding.size() != 2 * (rank - 2)) {
|
||||
for (int64_t i : padding) {
|
||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
loc,
|
||||
Torch::ListType::get(
|
||||
Torch::IntType::get(binder.op->getContext())),
|
||||
cstPadding);
|
||||
|
@ -1431,10 +1427,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
if (matchedPads) {
|
||||
for (unsigned i = 0; i < padding.size() / 2; i++) {
|
||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
||||
loc, rewriter.getI64IntegerAttr(padding[i])));
|
||||
}
|
||||
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
loc,
|
||||
Torch::ListType::get(
|
||||
Torch::IntType::get(binder.op->getContext())),
|
||||
cstPadding);
|
||||
|
@ -1443,40 +1439,40 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
SmallVector<Value> inputPaddingList;
|
||||
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
||||
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||
loc, rewriter.getI64IntegerAttr(
|
||||
padding[padding.size() / 2 - i - 1])));
|
||||
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
loc,
|
||||
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
|
||||
inputPaddingList.emplace_back(
|
||||
rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
|
||||
loc, rewriter.getI64IntegerAttr(0)));
|
||||
}
|
||||
// The conv op itself will have no padding since the actual padding
|
||||
// is performed using the torch.pad preceding it.
|
||||
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
loc,
|
||||
Torch::ListType::get(
|
||||
Torch::IntType::get(binder.op->getContext())),
|
||||
inputPaddingList);
|
||||
Value padsSizeList =
|
||||
rewriter
|
||||
.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
loc,
|
||||
Torch::ListType::get(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
padsRearrange)
|
||||
.getResult();
|
||||
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
||||
binder.getLoc(), rewriter.getStringAttr("constant"));
|
||||
loc, rewriter.getStringAttr("constant"));
|
||||
Value constantValue;
|
||||
|
||||
if (isa<IntegerType>(inputTensorType.getDtype()))
|
||||
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
if (isa<FloatType>(inputTensorType.getDtype()))
|
||||
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
|
||||
loc, rewriter.getF64FloatAttr(0.0f));
|
||||
// Pad output shape must be computed explicitly from the pad values
|
||||
SmallVector<int64_t> newInputShape(inputTensorType.getSizes());
|
||||
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
||||
|
@ -1486,46 +1482,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
auto padTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
newInputShape, inputTensorType.getDtype());
|
||||
paddedInput = rewriter.create<Torch::AtenPadOp>(
|
||||
binder.getLoc(), padTy, input, padsSizeList, modeVal,
|
||||
constantValue);
|
||||
loc, padTy, input, padsSizeList, modeVal, constantValue);
|
||||
}
|
||||
}
|
||||
for (int64_t i : dilations) {
|
||||
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
for (int64_t i : strides) {
|
||||
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
cstOutputPadding = {cstZero, cstZero};
|
||||
|
||||
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
loc,
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstDilations);
|
||||
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
loc,
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstStrides);
|
||||
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
loc,
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstOutputPadding);
|
||||
Value transposed =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
Value transposed = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value bias;
|
||||
if (binder.op->getNumOperands() == 3) {
|
||||
if (binder.tensorOperandAtIndex(bias, 2)) {
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
bias = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
}
|
||||
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(group));
|
||||
loc, rewriter.getI64IntegerAttr(group));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
|
||||
binder.op, resultType, paddedInput, weight, bias, stridesList,
|
||||
|
|
Loading…
Reference in New Issue