mirror of https://github.com/llvm/torch-mlir
[ONNX] Basic Support for DeformConv (#3469)
This adds a torchvision op to torch-mlir and a path from onnx.DeformConv to torchvision.deform_conv2d. I'm not implementing the torch->linalg lowering for the torchvision op yet, but posting this PR to get feedback on some of the choices being made here and to flesh out the onnx frontend a bit.pull/3495/head
parent
e346c911f7
commit
368fabf0c1
|
@ -16660,6 +16660,42 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_TorchvisionDeformConv2dOp : Torch_Op<"torchvision.deform_conv2d", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `torchvision::deform_conv2d : (Tensor, Tensor, Tensor, Tensor, Tensor, int, int, int, int, int, int, int, int, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$input,
|
||||||
|
AnyTorchTensorType:$weight,
|
||||||
|
AnyTorchTensorType:$offset,
|
||||||
|
AnyTorchTensorType:$mask,
|
||||||
|
AnyTorchTensorType:$bias,
|
||||||
|
Torch_IntType:$stride_h,
|
||||||
|
Torch_IntType:$stride_w,
|
||||||
|
Torch_IntType:$pad_h,
|
||||||
|
Torch_IntType:$pad_w,
|
||||||
|
Torch_IntType:$dilation_h,
|
||||||
|
Torch_IntType:$dilation_w,
|
||||||
|
Torch_IntType:$groups,
|
||||||
|
Torch_IntType:$offset_groups,
|
||||||
|
Torch_BoolType:$use_mask
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult TorchvisionDeformConv2dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 14, 1);
|
||||||
|
}
|
||||||
|
void TorchvisionDeformConv2dOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 14, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [
|
def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -1837,6 +1837,141 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.op, resultType, transposedInput, reshapeSizesList);
|
binder.op, resultType, transposedInput, reshapeSizesList);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"DeformConv", 19,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
auto loc = binder.getLoc();
|
||||||
|
|
||||||
|
// get operands
|
||||||
|
llvm::SmallVector<Value> operands;
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
if (binder.tensorOperandsList(operands) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
if (operands.size() < 3 || operands.size() > 5)
|
||||||
|
return failure();
|
||||||
|
auto inputType =
|
||||||
|
dyn_cast<Torch::ValueTensorType>(operands[0].getType());
|
||||||
|
if (!inputType || !inputType.hasSizes() ||
|
||||||
|
inputType.getSizes().size() != 4)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Unsupported: DeformConv with input rank != 4");
|
||||||
|
unsigned rank = inputType.getSizes().size();
|
||||||
|
auto weightType =
|
||||||
|
dyn_cast<Torch::ValueTensorType>(operands[1].getType());
|
||||||
|
if (!weightType || !weightType.hasSizes())
|
||||||
|
return failure();
|
||||||
|
auto offsetType =
|
||||||
|
dyn_cast<Torch::ValueTensorType>(operands[2].getType());
|
||||||
|
if (!offsetType || !offsetType.hasSizes())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// get attributes
|
||||||
|
SmallVector<int64_t> dilations, kernelShape, pads, strides;
|
||||||
|
SmallVector<int64_t> defaultDilations(rank - 2, 0);
|
||||||
|
SmallVector<int64_t> defaultPads(2 * (rank - 2), 0);
|
||||||
|
SmallVector<int64_t> defaultStrides(rank - 2, 1);
|
||||||
|
int64_t group, offsetGroup;
|
||||||
|
if (binder.s64IntegerArrayAttr(dilations, "dilations",
|
||||||
|
defaultDilations) ||
|
||||||
|
binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}) ||
|
||||||
|
binder.s64IntegerArrayAttr(pads, "pads", defaultPads) ||
|
||||||
|
binder.s64IntegerArrayAttr(strides, "strides", defaultStrides) ||
|
||||||
|
binder.s64IntegerAttr(group, "group", 1) ||
|
||||||
|
binder.s64IntegerAttr(offsetGroup, "offset_group", 1))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < rank - 2; i++) {
|
||||||
|
if (pads[i] != pads[rank + i - 2])
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unsupported: asymmetric padding");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identify and assign names to operands
|
||||||
|
Value input, weight, offset, bias, mask;
|
||||||
|
bool useMask = false;
|
||||||
|
input = operands[0];
|
||||||
|
weight = operands[1];
|
||||||
|
offset = operands[2];
|
||||||
|
if (operands.size() == 4) {
|
||||||
|
auto unknownOpdRank = Torch::getTensorRank(operands[3]);
|
||||||
|
if (!unknownOpdRank)
|
||||||
|
return failure();
|
||||||
|
if (*unknownOpdRank == 1)
|
||||||
|
bias = operands[3];
|
||||||
|
else if (*unknownOpdRank == rank) {
|
||||||
|
mask = operands[3];
|
||||||
|
useMask = true;
|
||||||
|
} else
|
||||||
|
llvm_unreachable("onnx.DeformConv: optional 4th operand of "
|
||||||
|
"unexpected rank encountered");
|
||||||
|
}
|
||||||
|
if (operands.size() == 5) {
|
||||||
|
bias = operands[3];
|
||||||
|
mask = operands[4];
|
||||||
|
useMask = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign default operand values if necessary
|
||||||
|
ArrayRef<int64_t> weightSizes = weightType.getSizes();
|
||||||
|
ArrayRef<int64_t> offsetSizes = offsetType.getSizes();
|
||||||
|
if (!bias) {
|
||||||
|
int64_t outputChannels = weightSizes[0];
|
||||||
|
SmallVector<int64_t> biasShape(1, outputChannels);
|
||||||
|
Value biasShapeList = mlir::torch::onnx_c::createConstantIntList(
|
||||||
|
binder, rewriter, biasShape);
|
||||||
|
Value cstZero = Torch::getConstantWithGivenDtypeAndValue(
|
||||||
|
rewriter, loc, 0.0f, inputType.getDtype());
|
||||||
|
bias =
|
||||||
|
Torch::createInitTensor(rewriter, loc,
|
||||||
|
rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
biasShape, inputType.getDtype()),
|
||||||
|
cstZero, biasShapeList);
|
||||||
|
}
|
||||||
|
if (!mask) {
|
||||||
|
int64_t batchSize = inputType.getSizes()[0];
|
||||||
|
int64_t kernelHeight = weightSizes[2];
|
||||||
|
int64_t kernelWidth = weightSizes[3];
|
||||||
|
int64_t outputHeight = offsetSizes[2];
|
||||||
|
int64_t outputWidth = offsetSizes[3];
|
||||||
|
int64_t maskDimOne = offsetGroup * kernelHeight * kernelWidth;
|
||||||
|
SmallVector<int64_t> maskShape(
|
||||||
|
{batchSize, maskDimOne, outputHeight, outputWidth});
|
||||||
|
Value cstOne = Torch::getConstantWithGivenDtypeAndValue(
|
||||||
|
rewriter, loc, 1.0f, inputType.getDtype());
|
||||||
|
Value maskShapeList = mlir::torch::onnx_c::createConstantIntList(
|
||||||
|
binder, rewriter, maskShape);
|
||||||
|
mask =
|
||||||
|
Torch::createInitTensor(rewriter, loc,
|
||||||
|
rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
maskShape, inputType.getDtype()),
|
||||||
|
cstOne, maskShapeList);
|
||||||
|
}
|
||||||
|
|
||||||
|
// get attributes as constant values
|
||||||
|
SmallVector<Value> dilationValues, padValues, strideValues;
|
||||||
|
for (auto i : dilations)
|
||||||
|
dilationValues.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(i)));
|
||||||
|
for (auto i : pads)
|
||||||
|
padValues.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(i)));
|
||||||
|
for (auto i : strides)
|
||||||
|
strideValues.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(i)));
|
||||||
|
Value groupValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(group));
|
||||||
|
Value offsetGroupValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(offsetGroup));
|
||||||
|
Value useMaskValue = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
|
loc, rewriter.getBoolAttr(useMask));
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::TorchvisionDeformConv2dOp>(
|
||||||
|
binder.op, resultType, input, weight, offset, mask, bias,
|
||||||
|
strideValues[0], strideValues[1], padValues[0], padValues[1],
|
||||||
|
dilationValues[0], dilationValues[1], groupValue, offsetGroupValue,
|
||||||
|
useMaskValue);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"DequantizeLinear", 1,
|
"DequantizeLinear", 1,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
|
|
@ -9492,6 +9492,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
" %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.torchvision.deform_conv2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %int2 = torch.constant.int 2\n"
|
||||||
|
" %int3 = torch.constant.int 3\n"
|
||||||
|
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %1 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %2 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %3 = torch.aten.__getitem__.t %arg2, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" return %4 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.torchvision.deform_conv2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.tuple<int, int>, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.bool) -> !torch.int {\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
|
|
@ -29,6 +29,9 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||||
"AtenIntMM_basic",
|
"AtenIntMM_basic",
|
||||||
|
# unimplemented lowering torch -> linalg for torchvision.deform_conv2d
|
||||||
|
# this is added to check the torch.onnx.export -> import_onnx -> torch path
|
||||||
|
"DeformConv2D_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
LINALG_CRASHING_SET = {
|
LINALG_CRASHING_SET = {
|
||||||
|
@ -383,6 +386,7 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"ConvolutionBackwardModule2DStrided_basic",
|
"ConvolutionBackwardModule2DStrided_basic",
|
||||||
"ConvolutionBackwardModule2D_basic",
|
"ConvolutionBackwardModule2D_basic",
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
|
"DeformConv2D_basic",
|
||||||
"DivFloatModule_basic",
|
"DivFloatModule_basic",
|
||||||
"DivIntModule_basic",
|
"DivIntModule_basic",
|
||||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||||
|
@ -554,6 +558,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ConvolutionBackwardModule2DStrided_basic",
|
"ConvolutionBackwardModule2DStrided_basic",
|
||||||
"ConvolutionBackwardModule2D_basic",
|
"ConvolutionBackwardModule2D_basic",
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
|
"DeformConv2D_basic",
|
||||||
"DiagonalModule_basic",
|
"DiagonalModule_basic",
|
||||||
"DiagonalModule_nonsquare",
|
"DiagonalModule_nonsquare",
|
||||||
"DiagonalModule_transposed",
|
"DiagonalModule_transposed",
|
||||||
|
@ -2357,19 +2362,12 @@ ONNX_XFAIL_SET = {
|
||||||
"DivIntModule_basic",
|
"DivIntModule_basic",
|
||||||
"ElementwiseAcoshIntModule_basic",
|
"ElementwiseAcoshIntModule_basic",
|
||||||
"ElementwiseAcoshModule_basic",
|
"ElementwiseAcoshModule_basic",
|
||||||
"ElementwiseAndScalarModule_basic",
|
|
||||||
"ElementwiseAndScalarStaticShapeModule_basic",
|
|
||||||
"ElementwiseAsinhIntModule_basic",
|
"ElementwiseAsinhIntModule_basic",
|
||||||
"ElementwiseAsinhModule_basic",
|
"ElementwiseAsinhModule_basic",
|
||||||
"ElementwiseAtanhIntModule_basic",
|
"ElementwiseAtanhIntModule_basic",
|
||||||
"ElementwiseAtanhModule_basic",
|
"ElementwiseAtanhModule_basic",
|
||||||
"ElementwiseAtenIsneginfOpModule_basic",
|
"ElementwiseAtenIsneginfOpModule_basic",
|
||||||
"ElementwiseAtenIsposinfOpModule_basic",
|
"ElementwiseAtenIsposinfOpModule_basic",
|
||||||
"ElementwiseBitwiseAndModule_basic",
|
|
||||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
|
||||||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
|
||||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
|
||||||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
|
||||||
"ElementwiseBitwiseNotInt32Module_basic",
|
"ElementwiseBitwiseNotInt32Module_basic",
|
||||||
"ElementwiseBitwiseNotInt64Module_basic",
|
"ElementwiseBitwiseNotInt64Module_basic",
|
||||||
"ElementwiseBitwiseOrModule_basic",
|
"ElementwiseBitwiseOrModule_basic",
|
||||||
|
@ -2710,6 +2708,8 @@ ONNX_XFAIL_SET = {
|
||||||
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
||||||
# RuntimeError: unsupported input type: Device
|
# RuntimeError: unsupported input type: Device
|
||||||
"PrimsIotaModule_basic",
|
"PrimsIotaModule_basic",
|
||||||
|
# unimplemented torchvision.deform_conv2d torch->linalg
|
||||||
|
"DeformConv2D_basic",
|
||||||
# Error: 'aten::renorm' to ONNX opset version 17 is not supported.
|
# Error: 'aten::renorm' to ONNX opset version 17 is not supported.
|
||||||
"RenormModuleFloat16_basic",
|
"RenormModuleFloat16_basic",
|
||||||
"RenormModuleFloat32NegativeDim_basic",
|
"RenormModuleFloat32NegativeDim_basic",
|
||||||
|
@ -2759,6 +2759,14 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
||||||
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||||
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
||||||
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
||||||
|
# bitwise and support has been added in torch nightly
|
||||||
|
"ElementwiseAndScalarModule_basic",
|
||||||
|
"ElementwiseAndScalarStaticShapeModule_basic",
|
||||||
|
"ElementwiseBitwiseAndModule_basic",
|
||||||
|
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||||
|
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||||
|
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||||
|
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
||||||
|
@ -2930,6 +2938,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
"CumsumStaticModule_basic",
|
"CumsumStaticModule_basic",
|
||||||
"CumsumStaticNegativeDimModule_basic",
|
"CumsumStaticNegativeDimModule_basic",
|
||||||
|
"DeformConv2D_basic",
|
||||||
"DiagonalModule_basic",
|
"DiagonalModule_basic",
|
||||||
"DiagonalModule_nonsquare",
|
"DiagonalModule_nonsquare",
|
||||||
"DiagonalModule_transposed",
|
"DiagonalModule_transposed",
|
||||||
|
@ -3724,6 +3733,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
"CumsumStaticModule_basic",
|
"CumsumStaticModule_basic",
|
||||||
"CumsumStaticNegativeDimModule_basic",
|
"CumsumStaticNegativeDimModule_basic",
|
||||||
|
"DeformConv2D_basic",
|
||||||
"DiagonalModule_basic",
|
"DiagonalModule_basic",
|
||||||
"DiagonalModule_nonsquare",
|
"DiagonalModule_nonsquare",
|
||||||
"DiagonalModule_transposed",
|
"DiagonalModule_transposed",
|
||||||
|
|
|
@ -8,7 +8,6 @@ import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
|
||||||
from torch import device
|
from torch import device
|
||||||
import torch.jit._shape_functions as upstream_shape_functions
|
import torch.jit._shape_functions as upstream_shape_functions
|
||||||
|
|
||||||
|
@ -1639,6 +1638,12 @@ def aten〇view_as_real〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
assert False, "Unsupported dtype"
|
assert False, "Unsupported dtype"
|
||||||
|
|
||||||
|
|
||||||
|
def torchvision〇deform_conv2d〡shape(input: List[int], weight: List[int], offset: List[int], mask: List[int], bias: List[int], stride_h: int, stride_w: int, pad_h: int, pad_w: int, dilation_h: int, dilation_w: int, groups: int, offset_groups: int, use_mask: bool) -> List[int]:
|
||||||
|
return [input[0], weight[0], offset[2], offset[3]]
|
||||||
|
|
||||||
|
def torchvision〇deform_conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], offset_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], bias_rank_dtype: Tuple[int, int], stride_h: int, stride_w: int, pad_h: int, pad_w: int, dilation_h: int, dilation_w: int, groups: int, offset_groups: int, use_mask: bool) -> int:
|
||||||
|
return input_rank_dtype[1]
|
||||||
|
|
||||||
def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]:
|
def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]:
|
||||||
return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
||||||
|
|
||||||
|
@ -5117,6 +5122,9 @@ def _maybe_import_op_extensions(args: argparse.Namespace):
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
_maybe_import_op_extensions(args)
|
_maybe_import_op_extensions(args)
|
||||||
|
# importing torchvision will register torchvision ops with the JITOperatorRegistry
|
||||||
|
import torchvision
|
||||||
|
|
||||||
asm = generate_library(globals())
|
asm = generate_library(globals())
|
||||||
# We're about to put quotes around the string, so escape the `"` characters.
|
# We're about to put quotes around the string, so escape the `"` characters.
|
||||||
asm = asm.replace("\"", "\\\"")
|
asm = asm.replace("\"", "\\\"")
|
||||||
|
|
|
@ -1155,6 +1155,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
traits=["HasValueSemantics"],
|
traits=["HasValueSemantics"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# `torchvision::` namespace.
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
emit(
|
||||||
|
"torchvision::deform_conv2d : (Tensor, Tensor, Tensor, Tensor, Tensor, int, int, int, int, int, int, int, int, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
emit(
|
emit(
|
||||||
"torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)"
|
"torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
@ -1180,6 +1187,7 @@ def _maybe_import_op_extensions(args: argparse.Namespace):
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
_maybe_import_op_extensions(args)
|
_maybe_import_op_extensions(args)
|
||||||
|
# importing torchvision will register torchvision ops with the JITOperatorRegistry
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
registry = Registry.load()
|
registry = Registry.load()
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any
|
||||||
import io
|
import io
|
||||||
import onnx
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
|
from torch.onnx._constants import ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET as max_opset_ver
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
|
||||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
@ -78,7 +79,12 @@ def convert_onnx(model, inputs):
|
||||||
|
|
||||||
examples = tuple(examples)
|
examples = tuple(examples)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors
|
model,
|
||||||
|
examples,
|
||||||
|
buffer,
|
||||||
|
input_names=input_names,
|
||||||
|
dynamic_axes=dynamic_tensors,
|
||||||
|
opset_version=max_opset_ver,
|
||||||
)
|
)
|
||||||
buffer = buffer.getvalue()
|
buffer = buffer.getvalue()
|
||||||
return import_onnx(buffer)
|
return import_onnx(buffer)
|
||||||
|
|
|
@ -1256,3 +1256,90 @@ def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
|
||||||
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
|
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
|
||||||
torch.rand(Cout),
|
torch.rand(Cout),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# torchvision.deform_conv2d
|
||||||
|
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
# This section defines a torch->onnx path for this torchvision op so we can test the onnx paths e2e.
|
||||||
|
|
||||||
|
# Create symbolic function
|
||||||
|
from torch.onnx.symbolic_helper import parse_args, _get_tensor_sizes
|
||||||
|
|
||||||
|
|
||||||
|
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "b")
|
||||||
|
def symbolic_deform_conv2d_forward(
|
||||||
|
g,
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
offset,
|
||||||
|
mask,
|
||||||
|
bias,
|
||||||
|
stride_h,
|
||||||
|
stride_w,
|
||||||
|
pad_h,
|
||||||
|
pad_w,
|
||||||
|
dilation_h,
|
||||||
|
dilation_w,
|
||||||
|
groups,
|
||||||
|
offset_groups,
|
||||||
|
use_mask,
|
||||||
|
):
|
||||||
|
args = [input, weight, offset, bias]
|
||||||
|
if use_mask:
|
||||||
|
args.append(mask)
|
||||||
|
weight_size = _get_tensor_sizes(weight)
|
||||||
|
kwargs = {
|
||||||
|
"dilations_i": [dilation_h, dilation_w],
|
||||||
|
"group_i": groups,
|
||||||
|
"kernel_shape_i": weight_size[2:],
|
||||||
|
"offset_group_i": offset_groups,
|
||||||
|
# NB: ONNX supports asymmetric padding, whereas PyTorch supports only
|
||||||
|
# symmetric padding
|
||||||
|
"pads_i": [pad_h, pad_w, pad_h, pad_w],
|
||||||
|
"strides_i": [stride_h, stride_w],
|
||||||
|
}
|
||||||
|
return g.op("DeformConv", *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Register symbolic function
|
||||||
|
from torch.onnx import register_custom_op_symbolic
|
||||||
|
|
||||||
|
register_custom_op_symbolic(
|
||||||
|
"torchvision::deform_conv2d", symbolic_deform_conv2d_forward, 19
|
||||||
|
)
|
||||||
|
|
||||||
|
N = 1
|
||||||
|
Cin = 1
|
||||||
|
Hin = 7
|
||||||
|
Win = 6
|
||||||
|
Cout = 1
|
||||||
|
Hker = 2
|
||||||
|
Wker = 2
|
||||||
|
offset_groups = 1
|
||||||
|
Hout = 6
|
||||||
|
Wout = 5
|
||||||
|
offset_dim1 = 2 * offset_groups * Hker * Wker
|
||||||
|
|
||||||
|
|
||||||
|
class DeformableConvModule(torch.nn.Module):
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([N, Cin, Hin, Win], torch.float32, True),
|
||||||
|
([N, offset_dim1, Hout, Wout], torch.float32, True),
|
||||||
|
([Cout, Cin, Hker, Wker], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, input, offset, weight):
|
||||||
|
return torchvision.ops.deform_conv2d(input, offset, weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: DeformableConvModule())
|
||||||
|
def DeformConv2D_basic(module, tu: TestUtils):
|
||||||
|
input = tu.rand(N, Cin, Hin, Win)
|
||||||
|
offset = tu.rand(N, offset_dim1, Hout, Wout)
|
||||||
|
weight = tu.rand(Cout, Cin, Hker, Wker)
|
||||||
|
module.forward(input, offset, weight)
|
||||||
|
|
|
@ -735,6 +735,19 @@ func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_deform_conv
|
||||||
|
func.func @test_deform_conv(%arg0: !torch.vtensor<[1,1,7,6],f32>, %arg1: !torch.vtensor<[1,8,6,5],f32>, %arg2: !torch.vtensor<[1,1,2,2],f32>, %arg3: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,6,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
|
||||||
|
// CHECK: %[[cstOne:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[mask:.*]] = torch.aten.full %[[sizeList:.*]], %[[cstOne]]
|
||||||
|
// CHECK-SAME: -> !torch.vtensor<[1,4,6,5],f32>
|
||||||
|
// CHECK: torch.torchvision.deform_conv2d %arg0, %arg2, %arg1, %[[mask]], %arg3
|
||||||
|
// CHECK-SAME: : !torch.vtensor<[1,1,7,6],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,8,6,5],f32>, !torch.vtensor<[1,4,6,5],f32>, !torch.vtensor<[1],f32>, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[1,1,6,5],f32>
|
||||||
|
%1 = torch.operator "onnx.DeformConv"(%arg0, %arg2, %arg1, %arg3) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.offset_group = 1 : si64, torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,7,6],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,8,6,5],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,6,5],f32>
|
||||||
|
return %1 : !torch.vtensor<[1,1,6,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_dequantizelinear_si8
|
// CHECK-LABEL: @test_dequantizelinear_si8
|
||||||
func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
|
func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
%0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32>
|
%0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32>
|
||||||
|
|
Loading…
Reference in New Issue