Support onnx.If (#2825)

This is probably a decent PR for learning about blocks and regions.

If you're here to learn about that, consider also looking at
lib/Conversion/TorchToSCF/TorchToSCF.cpp

While this doesn't include an e2e test, it is tested downstream in
https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/onnx/operators/If/model.py

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
pull/3269/head
Xida Ren (Cedar) 2024-04-30 14:36:40 -04:00 committed by GitHub
parent 315dc6c3e3
commit 33eef15e42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 141 additions and 9 deletions

View File

@ -97,6 +97,31 @@ struct OpBinder {
return success(); return success();
} }
ParseResult tensorResultTypes(llvm::SmallVector<mlir::Type> &typeList) {
for (auto result : op->getResults()) {
auto t = toValidTensorType(result.getType());
if (!t)
return failure();
typeList.push_back(t);
}
return success();
}
// The importer imports Onnx.GraphProto attributes as regions attached to the
// op.
ParseResult getRegionAtIndex(mlir::Region *&region, int64_t idx) {
if (idx >= op->getNumRegions())
return failure();
region = &op->getRegion(idx);
if (region == nullptr) {
return failure();
}
return success();
}
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
int64_t idx) { int64_t idx) {
if (idx >= op->getNumResults()) if (idx >= op->getNumResults())

View File

@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
alignCorners); alignCorners);
return success(); return success();
}); });
patterns.onOp(
"If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value conditionTensor;
if (binder.tensorOperand(conditionTensor)) {
return rewriter.notifyMatchFailure(binder.op,
"condition bind failure");
}
auto conditionType =
conditionTensor.getType().cast<Torch::ValueTensorType>();
if (!conditionType || conditionType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(
binder.op, "condition must have one single element per "
"https://onnx.ai/onnx/operators/onnx__If.html");
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
conditionTensor);
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);
llvm::SmallVector<mlir::Type> resultTypes;
if (binder.tensorResultTypes(resultTypes)) {
return rewriter.notifyMatchFailure(binder.op,
"result type bind failure");
}
Region *thenRegion, *elseRegion;
if (binder.getRegionAtIndex(elseRegion, 0) ||
binder.getRegionAtIndex(thenRegion, 1)) {
return rewriter.notifyMatchFailure(binder.op, "region bind failure");
}
auto primIfOp = rewriter.create<Torch::PrimIfOp>(
binder.getLoc(), TypeRange(resultTypes), conditionBool);
auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
};
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
inlineIfCase(*elseRegion, primIfOp.getElseRegion());
auto replaceTerminator = [&](Region &region) {
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = region.front().getTerminator();
rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
terminator, terminator->getOperands());
};
replaceTerminator(primIfOp.getThenRegion());
replaceTerminator(primIfOp.getElseRegion());
rewriter.replaceOp(binder.op, primIfOp.getResults());
return success();
});
patterns.onOp("Less", 13, patterns.onOp("Less", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;

View File

@ -2562,16 +2562,12 @@ ONNX_XFAIL_SET = {
"_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic",
"_SoftmaxModule_basic", "_SoftmaxModule_basic",
# Failure - onnx_import
# Failure - onnx_lowering: onnx.AveragePool # Failure - onnx_lowering: onnx.AveragePool
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
# Failure - onnx_lowering: onnx.If # these diagonal modules are currently failing due to dynamic shape.
"DiagonalModule_basic", # We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead.
"DiagonalModule_nonsquare", # when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here.
"DiagonalModule_transposed",
"DiagonalModule_with_dims",
"DiagonalModule_with_dims_and_offset",
"DiagonalModule_with_negative_dims",
"DiagonalModule_with_offset",
"TileBigDimsSizeModule_basic", "TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic", "TileSmallDimsSizeModule_basic",
# Failure - onnx_lowering: onnx.MaxPool # Failure - onnx_lowering: onnx.MaxPool

View File

@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class DiagonalWithStaticShapeModule(torch.nn.Module):
"""
Diagonal with static shape. The other diagonal modules are failing in onnx
because DecomoposeAtenEyeMOp requires constants n, m, which are only constant
when the shape is static.
Please remove this module and associated test once the issue is fixed.
"""
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 9], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.diagonal(a)
@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule())
def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 9))
# ==============================================================================
class DiagonalTransposedModule(torch.nn.Module): class DiagonalTransposedModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -347,8 +347,14 @@ class NodeImporter:
continue continue
elif handler is False: elif handler is False:
# Active error. # Active error.
# try matching attribute type ID to name for a more descriptive error message
try:
attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type)
except ValueError:
attr_type_name = "UNKNOWN"
raise OnnxImportError( raise OnnxImportError(
f"ONNX importer does not support generic node attribute type {attr_type}. " f"ONNX importer does not support generic node attribute type {attr_type_name} "
f"with ID {attr_type}. "
f"This likely means that this is a special node which requires specific " f"This likely means that this is a special node which requires specific "
f"handling in the importer: {onnx_attr}" f"handling in the importer: {onnx_attr}"
) )

View File

@ -0,0 +1,20 @@
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s
// CHECK-LABEL: func.func @test_ifop_basic
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>)
// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32>
// CHECK-DAG: } else {
// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32>
func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> {
%1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
}, {
%1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
}
return %0 : !torch.vtensor<[1],f32>
}