[MLIR][ONNX] Add support for onnx.scan op (#3516)

This commit lowers onnx.scan op to torch.prim.Loop op and adds the
lowering in the onnx pipeline.

Signed-off-by: Gaurav Shukla <gaurav.shukla@amd.com>
pull/3603/head
Gaurav Shukla 2024-08-05 15:37:26 +05:30 committed by GitHub
parent 7030445c15
commit 839fe90f86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 151 additions and 0 deletions

View File

@ -4238,4 +4238,115 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
uniqueResults[1], uniqueResults[2]});
return success();
});
patterns.onOp(
"Scan", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
SmallVector<Value> operands;
int64_t numScanInputs;
if (binder.tensorOperandsList(operands) || operands.size() == 0 ||
binder.s64IntegerAttr(numScanInputs, "num_scan_inputs")) {
return rewriter.notifyMatchFailure(binder.op,
"Failed to get required inputs");
}
SmallVector<Type> resultTypes;
if (binder.tensorResultTypes(resultTypes)) {
return rewriter.notifyMatchFailure(binder.op,
"result type bind failure");
}
Region *loopBodyIn;
if (binder.getRegionAtIndex(loopBodyIn, 0)) {
return rewriter.notifyMatchFailure(binder.op,
"Failed getting LoopBody Region");
}
int64_t numInits = operands.size() - numScanInputs;
SmallVector<Value> initVals(operands.begin(),
operands.begin() + numInits);
SmallVector<Value> scanInputs(operands.begin() + numInits,
operands.end());
if (scanInputs.size() < 1) {
return rewriter.notifyMatchFailure(binder.op,
"Expects at least one scan input");
}
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
SmallVector<Type> scanOutTypes;
for (unsigned i = numInits; i < resultTypes.size(); i++) {
auto scanOutTy = cast<Torch::ValueTensorType>(resultTypes[i]);
// TODO: Handle dynamic result types.
if (!scanOutTy.hasSizes() || !scanOutTy.areAllSizesKnown()) {
return rewriter.notifyMatchFailure(
binder.op, "Expects result type to be static");
}
Value sizeList =
createConstantIntList(binder, rewriter, scanOutTy.getSizes());
initVals.push_back(Torch::createInitTensor(rewriter, loc, scanOutTy,
constZero, sizeList));
scanOutTypes.push_back(resultTypes[i]);
}
// Create torch.prim.Loop op.
Value maxTripCount = rewriter.create<Torch::AtenSizeIntOp>(
loc, scanInputs[0], constZero);
auto constBoolTrue = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(true));
auto primLoop = rewriter.create<Torch::PrimLoopOp>(
loc, resultTypes, maxTripCount, constBoolTrue, initVals);
rewriter.cloneRegionBefore(*loopBodyIn, primLoop.getRegion(),
primLoop.getRegion().begin());
// Insert index var as torch.int argument in the loop body, as
// the primLoopOp loopBody expects torch.int as first argument.
primLoop.getRegion().insertArgument(
0u, rewriter.getType<Torch::IntType>(), loc);
auto loopInd = primLoop.getRegion().getArgument(0);
// The block arguments of onnx.scan needs to be replaced with
// slice of scan inputs.
rewriter.setInsertionPointToStart(&primLoop.getRegion().front());
for (unsigned i = 0; i < numScanInputs; i++) {
auto loopBlockArg =
primLoop.getRegion().getArgument(numInits + 1 + i);
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
loc, loopBlockArg.getType(), scanInputs[i], constZero, loopInd);
loopBlockArg.replaceAllUsesWith(extract);
}
primLoop.getRegion().front().eraseArguments(numInits + 1,
/*count=*/numScanInputs);
// Collect the output slices to form scan outputs and replace the
// terminator.
SmallVector<Location> locs(scanOutTypes.size(), loc);
primLoop.getRegion().front().addArguments(scanOutTypes, locs);
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = primLoop.getRegion().front().getTerminator();
auto terminatorOperands = terminator->getOperands();
SmallVector<Value> resTerminatorOperands(
terminatorOperands.begin(), terminatorOperands.begin() + numInits);
SmallVector<Value> scanOutSlices(terminatorOperands.begin() + numInits,
terminatorOperands.end());
rewriter.setInsertionPoint(terminator);
for (unsigned i = 0; i < scanOutSlices.size(); i++) {
Value self = BlockArgument::Value(
primLoop.getRegion().getArgument(numInits + 1 + i));
FailureOr<Value> src = Torch::unsqueezeTensor(
rewriter, binder.op, scanOutSlices[i], constZero);
if (failed(src))
return failure();
Value scanOut = rewriter.create<Torch::AtenSliceScatterOp>(
loc, scanOutTypes[i], self, src.value(), constZero,
/*start=*/loopInd,
/*end=*/loopInd, constOne);
resTerminatorOperands.push_back(scanOut);
}
Value terminatorCond = constBoolTrue;
rewriter.replaceOpWithNewOp<Torch::PrimLoopConditionOp>(
terminator, terminatorCond, resTerminatorOperands);
rewriter.replaceOp(binder.op, primLoop);
return success();
});
}

View File

@ -3318,3 +3318,43 @@ func.func @test_unique_sorted_with_negative_axis(%arg0: !torch.vtensor<[3,3],f32
%0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>)
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
}
// -----
// CHECK-LABEL: func.func @test_scan_sum(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_scan_sum(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
// CHECK: %[[VAL_6:.*]] = torch.constant.int 2
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_8:.*]] = torch.constant.none
// CHECK: %[[VAL_9:.*]] = torch.constant.int 6
// CHECK: %[[VAL_10:.*]] = torch.aten.full %[[VAL_7]], %[[VAL_3]], %[[VAL_9]], %[[VAL_8]], %[[VAL_8]], %[[VAL_8]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: %[[VAL_11:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_3]] : !torch.vtensor<[3,2],f32>, !torch.int -> !torch.int
// CHECK: %[[VAL_12:.*]] = torch.constant.bool true
// CHECK: %[[VAL_13:.*]]:2 = torch.prim.Loop %[[VAL_11]], %[[VAL_12]], init(%[[VAL_0]], %[[VAL_10]]) {
// CHECK: ^bb0(%[[VAL_14:.*]]: !torch.int, %[[VAL_15:.*]]: !torch.vtensor<[2],f32>, %[[VAL_16:.*]]: !torch.vtensor<[3,2],f32>):
// CHECK: %[[VAL_17:.*]] = torch.aten.select.int %[[VAL_1]], %[[VAL_3]], %[[VAL_14]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32>
// CHECK: %[[VAL_18:.*]] = torch.constant.int 1
// CHECK: %[[VAL_19:.*]] = torch.aten.add.Tensor %[[VAL_15]], %[[VAL_17]], %[[VAL_18]] : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32>
// CHECK: %[[VAL_20:.*]] = torch.constant.none
// CHECK: %[[VAL_21:.*]] = torch.aten.clone %[[VAL_19]], %[[VAL_20]] : !torch.vtensor<[2],f32>, !torch.none -> !torch.vtensor<[2],f32>
// CHECK: %[[VAL_22:.*]] = torch.aten.unsqueeze %[[VAL_21]], %[[VAL_3]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32>
// CHECK: %[[VAL_23:.*]] = torch.aten.slice_scatter %[[VAL_16]], %[[VAL_22]], %[[VAL_3]], %[[VAL_14]], %[[VAL_14]], %[[VAL_4]] : !torch.vtensor<[3,2],f32>, !torch.vtensor<[1,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,2],f32>
// CHECK: torch.prim.Loop.condition %[[VAL_12]], iter(%[[VAL_19]], %[[VAL_23]] : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>)
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>)
// CHECK: return %[[VAL_24:.*]]#0, %[[VAL_24]]#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>
// CHECK: }
%none = torch.constant.none
%0:2 = torch.operator "onnx.Scan"(%arg0, %arg1) {torch.onnx.num_scan_inputs = 1 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) {
^bb0(%arg2: !torch.vtensor<[2],f32>, %arg3: !torch.vtensor<[2],f32>):
%1 = torch.operator "onnx.Add"(%arg2, %arg3) : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32>
%2 = torch.operator "onnx.Identity"(%1) : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32>
torch.operator_terminator %1, %2 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>
}
return %0#0, %0#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>
}