mirror of https://github.com/llvm/torch-mlir
Cast static/dynamic shape for onnx.If branches to match result type (#3828)
parent
3cfb7c8df6
commit
39d69db5ca
|
@ -211,15 +211,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
|
||||
inlineIfCase(*elseRegion, primIfOp.getElseRegion());
|
||||
|
||||
auto replaceTerminator = [&](Region ®ion) {
|
||||
auto replaceTerminator = [&](Region ®ion) -> LogicalResult {
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
Operation *terminator = region.front().getTerminator();
|
||||
rewriter.setInsertionPoint(terminator);
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
|
||||
terminator, terminator->getOperands());
|
||||
|
||||
// cast result shape if there is static/dynamic difference
|
||||
llvm::SmallVector<Value> terOperands = terminator->getOperands();
|
||||
if (terOperands.size() != resultTypes.size())
|
||||
return failure();
|
||||
for (size_t i = 0; i < terOperands.size(); i++) {
|
||||
mlir::Type terType = terOperands[i].getType();
|
||||
int64_t terOpRank =
|
||||
dyn_cast<Torch::ValueTensorType>(terType).getSizes().size();
|
||||
int64_t resRank = dyn_cast<Torch::ValueTensorType>(resultTypes[i])
|
||||
.getSizes()
|
||||
.size();
|
||||
if (terOpRank != resRank)
|
||||
return failure();
|
||||
if (terType != resultTypes[i]) {
|
||||
Value cast = rewriter.create<Torch::TensorStaticInfoCastOp>(
|
||||
binder.getLoc(), resultTypes[i], terOperands[i]);
|
||||
terOperands[i] = cast;
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(terminator,
|
||||
terOperands);
|
||||
return success();
|
||||
};
|
||||
replaceTerminator(primIfOp.getThenRegion());
|
||||
replaceTerminator(primIfOp.getElseRegion());
|
||||
if (failed(replaceTerminator(primIfOp.getThenRegion())) ||
|
||||
failed(replaceTerminator(primIfOp.getElseRegion())))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"terminator replace failure");
|
||||
|
||||
rewriter.replaceOp(binder.op, primIfOp.getResults());
|
||||
return success();
|
||||
|
|
|
@ -18,3 +18,24 @@ func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<
|
|||
}
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_ifop_cast_shape
|
||||
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[?],si64>)
|
||||
// CHECK-DAG: %[[CAST:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64>
|
||||
// CHECK-DAG: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[?],si64>
|
||||
// CHECK-DAG: } else {
|
||||
// CHECK-DAG: %[[SQUEEZE:.*]] = torch.prims.squeeze %arg1, %{{.*}} : !torch.vtensor<[?,1],si64>, !torch.list<int> -> !torch.vtensor<[?],si64>
|
||||
// CHECK-DAG: torch.prim.If.yield %[[SQUEEZE]] : !torch.vtensor<[?],si64>
|
||||
func.func @test_ifop_cast_shape(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[?,1],si64>) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[?],si64> {
|
||||
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
|
||||
%2 = torch.operator "onnx.Squeeze"(%arg1, %1) : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64>
|
||||
torch.operator_terminator %2 : !torch.vtensor<[?],si64>
|
||||
}, {
|
||||
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<0xsi64>} : () -> !torch.vtensor<[0],si64>
|
||||
torch.operator_terminator %1 : !torch.vtensor<[0],si64>
|
||||
}
|
||||
return %0 : !torch.vtensor<[?],si64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue