mirror of https://github.com/llvm/torch-mlir
[ONNX] Remove kernel shape and weight shape equivalence check from Onnx.Conv lowering (#3869)
This commit removes the equivalence check for kernel shape and weight shape from the Onnx.conv lowering since those checks seem to be of no use (not sure why were they part of the lowering in the first place). Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3871/merge
parent
06d17897f0
commit
fe2f64919d
|
@ -7,12 +7,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/IR/DialectResourceBlobManager.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
||||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -1292,6 +1290,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Location loc = binder.getLoc();
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value input, weight;
|
Value input, weight;
|
||||||
int64_t group;
|
int64_t group;
|
||||||
|
@ -1316,14 +1315,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.op,
|
binder.op,
|
||||||
"unsupported conversion: kernel_shape list size should have "
|
"unsupported conversion: kernel_shape list size should have "
|
||||||
"number of values equal to weight_rank - 2");
|
"number of values equal to weight_rank - 2");
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < kernelShape.size(); i++) {
|
|
||||||
if (weightShape[i + 2] != kernelShape[i]) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
binder.op, "unsupported conversion: kernel_shape value "
|
|
||||||
"should be equal to the weight tensor shape");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1380,6 +1371,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||||
padding.resize_for_overwrite(2 * spatialRank);
|
padding.resize_for_overwrite(2 * spatialRank);
|
||||||
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
|
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
|
||||||
|
if (weightShape[dimIdx + 2] == Torch::kUnknownSize ||
|
||||||
|
inputShape[dimIdx + 2] == Torch::kUnknownSize)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op,
|
||||||
|
"expected weight and input tensor to have static shape");
|
||||||
const int64_t dilatedKernelSize =
|
const int64_t dilatedKernelSize =
|
||||||
dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1;
|
dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1;
|
||||||
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
|
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
|
||||||
|
@ -1405,10 +1401,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
if (padding.size() != 2 * (rank - 2)) {
|
if (padding.size() != 2 * (rank - 2)) {
|
||||||
for (int64_t i : padding) {
|
for (int64_t i : padding) {
|
||||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
loc, rewriter.getI64IntegerAttr(i)));
|
||||||
}
|
}
|
||||||
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
loc,
|
||||||
Torch::ListType::get(
|
Torch::ListType::get(
|
||||||
Torch::IntType::get(binder.op->getContext())),
|
Torch::IntType::get(binder.op->getContext())),
|
||||||
cstPadding);
|
cstPadding);
|
||||||
|
@ -1431,10 +1427,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
if (matchedPads) {
|
if (matchedPads) {
|
||||||
for (unsigned i = 0; i < padding.size() / 2; i++) {
|
for (unsigned i = 0; i < padding.size() / 2; i++) {
|
||||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
loc, rewriter.getI64IntegerAttr(padding[i])));
|
||||||
}
|
}
|
||||||
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
loc,
|
||||||
Torch::ListType::get(
|
Torch::ListType::get(
|
||||||
Torch::IntType::get(binder.op->getContext())),
|
Torch::IntType::get(binder.op->getContext())),
|
||||||
cstPadding);
|
cstPadding);
|
||||||
|
@ -1443,40 +1439,40 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
SmallVector<Value> inputPaddingList;
|
SmallVector<Value> inputPaddingList;
|
||||||
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
||||||
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
loc, rewriter.getI64IntegerAttr(
|
||||||
padding[padding.size() / 2 - i - 1])));
|
padding[padding.size() / 2 - i - 1])));
|
||||||
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(),
|
loc,
|
||||||
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
|
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
|
||||||
inputPaddingList.emplace_back(
|
inputPaddingList.emplace_back(
|
||||||
rewriter.create<Torch::ConstantIntOp>(
|
rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
|
loc, rewriter.getI64IntegerAttr(0)));
|
||||||
}
|
}
|
||||||
// The conv op itself will have no padding since the actual padding
|
// The conv op itself will have no padding since the actual padding
|
||||||
// is performed using the torch.pad preceding it.
|
// is performed using the torch.pad preceding it.
|
||||||
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
loc,
|
||||||
Torch::ListType::get(
|
Torch::ListType::get(
|
||||||
Torch::IntType::get(binder.op->getContext())),
|
Torch::IntType::get(binder.op->getContext())),
|
||||||
inputPaddingList);
|
inputPaddingList);
|
||||||
Value padsSizeList =
|
Value padsSizeList =
|
||||||
rewriter
|
rewriter
|
||||||
.create<Torch::PrimListConstructOp>(
|
.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
loc,
|
||||||
Torch::ListType::get(
|
Torch::ListType::get(
|
||||||
rewriter.getType<Torch::IntType>()),
|
rewriter.getType<Torch::IntType>()),
|
||||||
padsRearrange)
|
padsRearrange)
|
||||||
.getResult();
|
.getResult();
|
||||||
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
||||||
binder.getLoc(), rewriter.getStringAttr("constant"));
|
loc, rewriter.getStringAttr("constant"));
|
||||||
Value constantValue;
|
Value constantValue;
|
||||||
|
|
||||||
if (isa<IntegerType>(inputTensorType.getDtype()))
|
if (isa<IntegerType>(inputTensorType.getDtype()))
|
||||||
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
if (isa<FloatType>(inputTensorType.getDtype()))
|
if (isa<FloatType>(inputTensorType.getDtype()))
|
||||||
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
|
loc, rewriter.getF64FloatAttr(0.0f));
|
||||||
// Pad output shape must be computed explicitly from the pad values
|
// Pad output shape must be computed explicitly from the pad values
|
||||||
SmallVector<int64_t> newInputShape(inputTensorType.getSizes());
|
SmallVector<int64_t> newInputShape(inputTensorType.getSizes());
|
||||||
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
||||||
|
@ -1486,46 +1482,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
auto padTy = rewriter.getType<Torch::ValueTensorType>(
|
auto padTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
newInputShape, inputTensorType.getDtype());
|
newInputShape, inputTensorType.getDtype());
|
||||||
paddedInput = rewriter.create<Torch::AtenPadOp>(
|
paddedInput = rewriter.create<Torch::AtenPadOp>(
|
||||||
binder.getLoc(), padTy, input, padsSizeList, modeVal,
|
loc, padTy, input, padsSizeList, modeVal, constantValue);
|
||||||
constantValue);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int64_t i : dilations) {
|
for (int64_t i : dilations) {
|
||||||
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
loc, rewriter.getI64IntegerAttr(i)));
|
||||||
}
|
}
|
||||||
for (int64_t i : strides) {
|
for (int64_t i : strides) {
|
||||||
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
loc, rewriter.getI64IntegerAttr(i)));
|
||||||
}
|
}
|
||||||
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
cstOutputPadding = {cstZero, cstZero};
|
cstOutputPadding = {cstZero, cstZero};
|
||||||
|
|
||||||
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
|
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
loc,
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
cstDilations);
|
cstDilations);
|
||||||
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
|
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
loc,
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
cstStrides);
|
cstStrides);
|
||||||
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
|
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
loc,
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
cstOutputPadding);
|
cstOutputPadding);
|
||||||
Value transposed =
|
Value transposed = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
||||||
Value bias;
|
Value bias;
|
||||||
if (binder.op->getNumOperands() == 3) {
|
if (binder.op->getNumOperands() == 3) {
|
||||||
if (binder.tensorOperandAtIndex(bias, 2)) {
|
if (binder.tensorOperandAtIndex(bias, 2)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
bias = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
}
|
}
|
||||||
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
|
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(group));
|
loc, rewriter.getI64IntegerAttr(group));
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
|
||||||
binder.op, resultType, paddedInput, weight, bias, stridesList,
|
binder.op, resultType, paddedInput, weight, bias, stridesList,
|
||||||
|
|
Loading…
Reference in New Issue