Cast static/dynamic shape for onnx.If branches to match result type (#3828)

pull/3848/head
jinchen 2024-11-01 12:10:59 -07:00 committed by GitHub
parent 3cfb7c8df6
commit 39d69db5ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 5 deletions

View File

@ -211,15 +211,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
inlineIfCase(*elseRegion, primIfOp.getElseRegion());
auto replaceTerminator = [&](Region &region) {
auto replaceTerminator = [&](Region &region) -> LogicalResult {
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = region.front().getTerminator();
rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
terminator, terminator->getOperands());
// cast result shape if there is static/dynamic difference
llvm::SmallVector<Value> terOperands = terminator->getOperands();
if (terOperands.size() != resultTypes.size())
return failure();
for (size_t i = 0; i < terOperands.size(); i++) {
mlir::Type terType = terOperands[i].getType();
int64_t terOpRank =
dyn_cast<Torch::ValueTensorType>(terType).getSizes().size();
int64_t resRank = dyn_cast<Torch::ValueTensorType>(resultTypes[i])
.getSizes()
.size();
if (terOpRank != resRank)
return failure();
if (terType != resultTypes[i]) {
Value cast = rewriter.create<Torch::TensorStaticInfoCastOp>(
binder.getLoc(), resultTypes[i], terOperands[i]);
terOperands[i] = cast;
}
}
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(terminator,
terOperands);
return success();
};
replaceTerminator(primIfOp.getThenRegion());
replaceTerminator(primIfOp.getElseRegion());
if (failed(replaceTerminator(primIfOp.getThenRegion())) ||
failed(replaceTerminator(primIfOp.getElseRegion())))
return rewriter.notifyMatchFailure(binder.op,
"terminator replace failure");
rewriter.replaceOp(binder.op, primIfOp.getResults());
return success();

View File

@ -18,3 +18,24 @@ func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<
}
return %0 : !torch.vtensor<[1],f32>
}
// -----
// CHECK-LABEL: func.func @test_ifop_cast_shape
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[?],si64>)
// CHECK-DAG: %[[CAST:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64>
// CHECK-DAG: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[?],si64>
// CHECK-DAG: } else {
// CHECK-DAG: %[[SQUEEZE:.*]] = torch.prims.squeeze %arg1, %{{.*}} : !torch.vtensor<[?,1],si64>, !torch.list<int> -> !torch.vtensor<[?],si64>
// CHECK-DAG: torch.prim.If.yield %[[SQUEEZE]] : !torch.vtensor<[?],si64>
func.func @test_ifop_cast_shape(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[?,1],si64>) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[?],si64> {
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%2 = torch.operator "onnx.Squeeze"(%arg1, %1) : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64>
torch.operator_terminator %2 : !torch.vtensor<[?],si64>
}, {
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<0xsi64>} : () -> !torch.vtensor<[0],si64>
torch.operator_terminator %1 : !torch.vtensor<[0],si64>
}
return %0 : !torch.vtensor<[?],si64>
}