mirror of https://github.com/llvm/torch-mlir
Support onnx.If (#2825)
This is probably a decent PR for learning about blocks and regions. If you're here to learn about that, consider also looking at lib/Conversion/TorchToSCF/TorchToSCF.cpp While this doesn't include an e2e test, it is tested downstream in https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/onnx/operators/If/model.py --------- Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>pull/3269/head
parent
315dc6c3e3
commit
33eef15e42
|
@ -97,6 +97,31 @@ struct OpBinder {
|
|||
return success();
|
||||
}
|
||||
|
||||
ParseResult tensorResultTypes(llvm::SmallVector<mlir::Type> &typeList) {
|
||||
for (auto result : op->getResults()) {
|
||||
auto t = toValidTensorType(result.getType());
|
||||
if (!t)
|
||||
return failure();
|
||||
typeList.push_back(t);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// The importer imports Onnx.GraphProto attributes as regions attached to the
|
||||
// op.
|
||||
ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) {
|
||||
if (idx >= op->getNumRegions())
|
||||
return failure();
|
||||
|
||||
region = &op->getRegion(idx);
|
||||
|
||||
if (region == nullptr) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
|
||||
int64_t idx) {
|
||||
if (idx >= op->getNumResults())
|
||||
|
|
|
@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
alignCorners);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Value conditionTensor;
|
||||
if (binder.tensorOperand(conditionTensor)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"condition bind failure");
|
||||
}
|
||||
|
||||
auto conditionType =
|
||||
conditionTensor.getType().cast<Torch::ValueTensorType>();
|
||||
if (!conditionType || conditionType.getSizes().size() != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "condition must have one single element per "
|
||||
"https://onnx.ai/onnx/operators/onnx__If.html");
|
||||
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
conditionTensor);
|
||||
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);
|
||||
|
||||
llvm::SmallVector<mlir::Type> resultTypes;
|
||||
if (binder.tensorResultTypes(resultTypes)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"result type bind failure");
|
||||
}
|
||||
|
||||
Region *thenRegion, *elseRegion;
|
||||
if (binder.getRegionAtIndex(elseRegion, 0) ||
|
||||
binder.getRegionAtIndex(thenRegion, 1)) {
|
||||
return rewriter.notifyMatchFailure(binder.op, "region bind failure");
|
||||
}
|
||||
|
||||
auto primIfOp = rewriter.create<Torch::PrimIfOp>(
|
||||
binder.getLoc(), TypeRange(resultTypes), conditionBool);
|
||||
|
||||
auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
|
||||
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
|
||||
};
|
||||
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
|
||||
inlineIfCase(*elseRegion, primIfOp.getElseRegion());
|
||||
|
||||
auto replaceTerminator = [&](Region ®ion) {
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
Operation *terminator = region.front().getTerminator();
|
||||
rewriter.setInsertionPoint(terminator);
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
|
||||
terminator, terminator->getOperands());
|
||||
};
|
||||
replaceTerminator(primIfOp.getThenRegion());
|
||||
replaceTerminator(primIfOp.getElseRegion());
|
||||
|
||||
rewriter.replaceOp(binder.op, primIfOp.getResults());
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Less", 13,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -2562,16 +2562,12 @@ ONNX_XFAIL_SET = {
|
|||
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||
"_SoftmaxModule_basic",
|
||||
# Failure - onnx_import
|
||||
# Failure - onnx_lowering: onnx.AveragePool
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
# Failure - onnx_lowering: onnx.If
|
||||
"DiagonalModule_basic",
|
||||
"DiagonalModule_nonsquare",
|
||||
"DiagonalModule_transposed",
|
||||
"DiagonalModule_with_dims",
|
||||
"DiagonalModule_with_dims_and_offset",
|
||||
"DiagonalModule_with_negative_dims",
|
||||
"DiagonalModule_with_offset",
|
||||
# these diagonal modules are currently failing due to dynamic shape.
|
||||
# We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead.
|
||||
# when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here.
|
||||
"TileBigDimsSizeModule_basic",
|
||||
"TileSmallDimsSizeModule_basic",
|
||||
# Failure - onnx_lowering: onnx.MaxPool
|
||||
|
|
|
@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class DiagonalWithStaticShapeModule(torch.nn.Module):
|
||||
"""
|
||||
Diagonal with static shape. The other diagonal modules are failing in onnx
|
||||
because DecomoposeAtenEyeMOp requires constants n, m, which are only constant
|
||||
when the shape is static.
|
||||
|
||||
Please remove this module and associated test once the issue is fixed.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([5, 9], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule())
|
||||
def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 9))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DiagonalTransposedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -347,8 +347,14 @@ class NodeImporter:
|
|||
continue
|
||||
elif handler is False:
|
||||
# Active error.
|
||||
# try matching attribute type ID to name for a more descriptive error message
|
||||
try:
|
||||
attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type)
|
||||
except ValueError:
|
||||
attr_type_name = "UNKNOWN"
|
||||
raise OnnxImportError(
|
||||
f"ONNX importer does not support generic node attribute type {attr_type}. "
|
||||
f"ONNX importer does not support generic node attribute type {attr_type_name} "
|
||||
f"with ID {attr_type}. "
|
||||
f"This likely means that this is a special node which requires specific "
|
||||
f"handling in the importer: {onnx_attr}"
|
||||
)
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @test_ifop_basic
|
||||
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>)
|
||||
// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32>
|
||||
// CHECK-DAG: } else {
|
||||
// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32>
|
||||
func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> {
|
||||
%1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
|
||||
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
|
||||
}, {
|
||||
%1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
|
||||
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
Loading…
Reference in New Issue