mirror of https://github.com/llvm/torch-mlir
Support onnx.If (#2825)
This is probably a decent PR for learning about blocks and regions. If you're here to learn about that, consider also looking at lib/Conversion/TorchToSCF/TorchToSCF.cpp While this doesn't include an e2e test, it is tested downstream in https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/onnx/operators/If/model.py --------- Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>pull/3269/head
parent
315dc6c3e3
commit
33eef15e42
|
@ -97,6 +97,31 @@ struct OpBinder {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParseResult tensorResultTypes(llvm::SmallVector<mlir::Type> &typeList) {
|
||||||
|
for (auto result : op->getResults()) {
|
||||||
|
auto t = toValidTensorType(result.getType());
|
||||||
|
if (!t)
|
||||||
|
return failure();
|
||||||
|
typeList.push_back(t);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The importer imports Onnx.GraphProto attributes as regions attached to the
|
||||||
|
// op.
|
||||||
|
ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) {
|
||||||
|
if (idx >= op->getNumRegions())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
region = &op->getRegion(idx);
|
||||||
|
|
||||||
|
if (region == nullptr) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
|
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
|
||||||
int64_t idx) {
|
int64_t idx) {
|
||||||
if (idx >= op->getNumResults())
|
if (idx >= op->getNumResults())
|
||||||
|
|
|
@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
alignCorners);
|
alignCorners);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Value conditionTensor;
|
||||||
|
if (binder.tensorOperand(conditionTensor)) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"condition bind failure");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto conditionType =
|
||||||
|
conditionTensor.getType().cast<Torch::ValueTensorType>();
|
||||||
|
if (!conditionType || conditionType.getSizes().size() != 1)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "condition must have one single element per "
|
||||||
|
"https://onnx.ai/onnx/operators/onnx__If.html");
|
||||||
|
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
conditionTensor);
|
||||||
|
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Type> resultTypes;
|
||||||
|
if (binder.tensorResultTypes(resultTypes)) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"result type bind failure");
|
||||||
|
}
|
||||||
|
|
||||||
|
Region *thenRegion, *elseRegion;
|
||||||
|
if (binder.getRegionAtIndex(elseRegion, 0) ||
|
||||||
|
binder.getRegionAtIndex(thenRegion, 1)) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op, "region bind failure");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto primIfOp = rewriter.create<Torch::PrimIfOp>(
|
||||||
|
binder.getLoc(), TypeRange(resultTypes), conditionBool);
|
||||||
|
|
||||||
|
auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
|
||||||
|
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
|
||||||
|
};
|
||||||
|
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
|
||||||
|
inlineIfCase(*elseRegion, primIfOp.getElseRegion());
|
||||||
|
|
||||||
|
auto replaceTerminator = [&](Region ®ion) {
|
||||||
|
PatternRewriter::InsertionGuard guard(rewriter);
|
||||||
|
Operation *terminator = region.front().getTerminator();
|
||||||
|
rewriter.setInsertionPoint(terminator);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
|
||||||
|
terminator, terminator->getOperands());
|
||||||
|
};
|
||||||
|
replaceTerminator(primIfOp.getThenRegion());
|
||||||
|
replaceTerminator(primIfOp.getElseRegion());
|
||||||
|
|
||||||
|
rewriter.replaceOp(binder.op, primIfOp.getResults());
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp("Less", 13,
|
patterns.onOp("Less", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
|
|
@ -2562,16 +2562,12 @@ ONNX_XFAIL_SET = {
|
||||||
"_ConvolutionDeprecated2DCudnnModule_basic",
|
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||||
"_SoftmaxModule_basic",
|
"_SoftmaxModule_basic",
|
||||||
|
# Failure - onnx_import
|
||||||
# Failure - onnx_lowering: onnx.AveragePool
|
# Failure - onnx_lowering: onnx.AveragePool
|
||||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||||
# Failure - onnx_lowering: onnx.If
|
# these diagonal modules are currently failing due to dynamic shape.
|
||||||
"DiagonalModule_basic",
|
# We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead.
|
||||||
"DiagonalModule_nonsquare",
|
# when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here.
|
||||||
"DiagonalModule_transposed",
|
|
||||||
"DiagonalModule_with_dims",
|
|
||||||
"DiagonalModule_with_dims_and_offset",
|
|
||||||
"DiagonalModule_with_negative_dims",
|
|
||||||
"DiagonalModule_with_offset",
|
|
||||||
"TileBigDimsSizeModule_basic",
|
"TileBigDimsSizeModule_basic",
|
||||||
"TileSmallDimsSizeModule_basic",
|
"TileSmallDimsSizeModule_basic",
|
||||||
# Failure - onnx_lowering: onnx.MaxPool
|
# Failure - onnx_lowering: onnx.MaxPool
|
||||||
|
|
|
@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalWithStaticShapeModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Diagonal with static shape. The other diagonal modules are failing in onnx
|
||||||
|
because DecomoposeAtenEyeMOp requires constants n, m, which are only constant
|
||||||
|
when the shape is static.
|
||||||
|
|
||||||
|
Please remove this module and associated test once the issue is fixed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([5, 9], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.diagonal(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule())
|
||||||
|
def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 9))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class DiagonalTransposedModule(torch.nn.Module):
|
class DiagonalTransposedModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -347,8 +347,14 @@ class NodeImporter:
|
||||||
continue
|
continue
|
||||||
elif handler is False:
|
elif handler is False:
|
||||||
# Active error.
|
# Active error.
|
||||||
|
# try matching attribute type ID to name for a more descriptive error message
|
||||||
|
try:
|
||||||
|
attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type)
|
||||||
|
except ValueError:
|
||||||
|
attr_type_name = "UNKNOWN"
|
||||||
raise OnnxImportError(
|
raise OnnxImportError(
|
||||||
f"ONNX importer does not support generic node attribute type {attr_type}. "
|
f"ONNX importer does not support generic node attribute type {attr_type_name} "
|
||||||
|
f"with ID {attr_type}. "
|
||||||
f"This likely means that this is a special node which requires specific "
|
f"This likely means that this is a special node which requires specific "
|
||||||
f"handling in the importer: {onnx_attr}"
|
f"handling in the importer: {onnx_attr}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_ifop_basic
|
||||||
|
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>)
|
||||||
|
// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
|
// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32>
|
||||||
|
// CHECK-DAG: } else {
|
||||||
|
// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
|
// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32>
|
||||||
|
func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> {
|
||||||
|
%1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
|
||||||
|
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
|
||||||
|
}, {
|
||||||
|
%1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
|
||||||
|
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
|
||||||
|
}
|
||||||
|
return %0 : !torch.vtensor<[1],f32>
|
||||||
|
}
|
Loading…
Reference in New Issue