mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add support for onnx.scan op (#3516)
This commit lowers onnx.scan op to torch.prim.Loop op and adds the lowering in the onnx pipeline. Signed-off-by: Gaurav Shukla <gaurav.shukla@amd.com>pull/3603/head
parent
7030445c15
commit
839fe90f86
|
@ -4238,4 +4238,115 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
uniqueResults[1], uniqueResults[2]});
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Scan", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Location loc = binder.getLoc();
|
||||
SmallVector<Value> operands;
|
||||
int64_t numScanInputs;
|
||||
if (binder.tensorOperandsList(operands) || operands.size() == 0 ||
|
||||
binder.s64IntegerAttr(numScanInputs, "num_scan_inputs")) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Failed to get required inputs");
|
||||
}
|
||||
SmallVector<Type> resultTypes;
|
||||
if (binder.tensorResultTypes(resultTypes)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"result type bind failure");
|
||||
}
|
||||
Region *loopBodyIn;
|
||||
if (binder.getRegionAtIndex(loopBodyIn, 0)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Failed getting LoopBody Region");
|
||||
}
|
||||
|
||||
int64_t numInits = operands.size() - numScanInputs;
|
||||
SmallVector<Value> initVals(operands.begin(),
|
||||
operands.begin() + numInits);
|
||||
SmallVector<Value> scanInputs(operands.begin() + numInits,
|
||||
operands.end());
|
||||
if (scanInputs.size() < 1) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Expects at least one scan input");
|
||||
}
|
||||
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
SmallVector<Type> scanOutTypes;
|
||||
for (unsigned i = numInits; i < resultTypes.size(); i++) {
|
||||
auto scanOutTy = cast<Torch::ValueTensorType>(resultTypes[i]);
|
||||
// TODO: Handle dynamic result types.
|
||||
if (!scanOutTy.hasSizes() || !scanOutTy.areAllSizesKnown()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expects result type to be static");
|
||||
}
|
||||
Value sizeList =
|
||||
createConstantIntList(binder, rewriter, scanOutTy.getSizes());
|
||||
initVals.push_back(Torch::createInitTensor(rewriter, loc, scanOutTy,
|
||||
constZero, sizeList));
|
||||
scanOutTypes.push_back(resultTypes[i]);
|
||||
}
|
||||
// Create torch.prim.Loop op.
|
||||
Value maxTripCount = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
loc, scanInputs[0], constZero);
|
||||
auto constBoolTrue = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(true));
|
||||
auto primLoop = rewriter.create<Torch::PrimLoopOp>(
|
||||
loc, resultTypes, maxTripCount, constBoolTrue, initVals);
|
||||
rewriter.cloneRegionBefore(*loopBodyIn, primLoop.getRegion(),
|
||||
primLoop.getRegion().begin());
|
||||
|
||||
// Insert index var as torch.int argument in the loop body, as
|
||||
// the primLoopOp loopBody expects torch.int as first argument.
|
||||
primLoop.getRegion().insertArgument(
|
||||
0u, rewriter.getType<Torch::IntType>(), loc);
|
||||
auto loopInd = primLoop.getRegion().getArgument(0);
|
||||
|
||||
// The block arguments of onnx.scan needs to be replaced with
|
||||
// slice of scan inputs.
|
||||
rewriter.setInsertionPointToStart(&primLoop.getRegion().front());
|
||||
for (unsigned i = 0; i < numScanInputs; i++) {
|
||||
auto loopBlockArg =
|
||||
primLoop.getRegion().getArgument(numInits + 1 + i);
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
loc, loopBlockArg.getType(), scanInputs[i], constZero, loopInd);
|
||||
loopBlockArg.replaceAllUsesWith(extract);
|
||||
}
|
||||
primLoop.getRegion().front().eraseArguments(numInits + 1,
|
||||
/*count=*/numScanInputs);
|
||||
|
||||
// Collect the output slices to form scan outputs and replace the
|
||||
// terminator.
|
||||
SmallVector<Location> locs(scanOutTypes.size(), loc);
|
||||
primLoop.getRegion().front().addArguments(scanOutTypes, locs);
|
||||
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
Operation *terminator = primLoop.getRegion().front().getTerminator();
|
||||
auto terminatorOperands = terminator->getOperands();
|
||||
SmallVector<Value> resTerminatorOperands(
|
||||
terminatorOperands.begin(), terminatorOperands.begin() + numInits);
|
||||
SmallVector<Value> scanOutSlices(terminatorOperands.begin() + numInits,
|
||||
terminatorOperands.end());
|
||||
rewriter.setInsertionPoint(terminator);
|
||||
for (unsigned i = 0; i < scanOutSlices.size(); i++) {
|
||||
Value self = BlockArgument::Value(
|
||||
primLoop.getRegion().getArgument(numInits + 1 + i));
|
||||
FailureOr<Value> src = Torch::unsqueezeTensor(
|
||||
rewriter, binder.op, scanOutSlices[i], constZero);
|
||||
if (failed(src))
|
||||
return failure();
|
||||
Value scanOut = rewriter.create<Torch::AtenSliceScatterOp>(
|
||||
loc, scanOutTypes[i], self, src.value(), constZero,
|
||||
/*start=*/loopInd,
|
||||
/*end=*/loopInd, constOne);
|
||||
resTerminatorOperands.push_back(scanOut);
|
||||
}
|
||||
|
||||
Value terminatorCond = constBoolTrue;
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimLoopConditionOp>(
|
||||
terminator, terminatorCond, resTerminatorOperands);
|
||||
rewriter.replaceOp(binder.op, primLoop);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -3318,3 +3318,43 @@ func.func @test_unique_sorted_with_negative_axis(%arg0: !torch.vtensor<[3,3],f32
|
|||
%0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>)
|
||||
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_scan_sum(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
func.func @test_scan_sum(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_6:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_8:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_9:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_10:.*]] = torch.aten.full %[[VAL_7]], %[[VAL_3]], %[[VAL_9]], %[[VAL_8]], %[[VAL_8]], %[[VAL_8]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK: %[[VAL_11:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_3]] : !torch.vtensor<[3,2],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_12:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_13:.*]]:2 = torch.prim.Loop %[[VAL_11]], %[[VAL_12]], init(%[[VAL_0]], %[[VAL_10]]) {
|
||||
// CHECK: ^bb0(%[[VAL_14:.*]]: !torch.int, %[[VAL_15:.*]]: !torch.vtensor<[2],f32>, %[[VAL_16:.*]]: !torch.vtensor<[3,2],f32>):
|
||||
// CHECK: %[[VAL_17:.*]] = torch.aten.select.int %[[VAL_1]], %[[VAL_3]], %[[VAL_14]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32>
|
||||
// CHECK: %[[VAL_18:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_19:.*]] = torch.aten.add.Tensor %[[VAL_15]], %[[VAL_17]], %[[VAL_18]] : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32>
|
||||
// CHECK: %[[VAL_20:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_21:.*]] = torch.aten.clone %[[VAL_19]], %[[VAL_20]] : !torch.vtensor<[2],f32>, !torch.none -> !torch.vtensor<[2],f32>
|
||||
// CHECK: %[[VAL_22:.*]] = torch.aten.unsqueeze %[[VAL_21]], %[[VAL_3]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32>
|
||||
// CHECK: %[[VAL_23:.*]] = torch.aten.slice_scatter %[[VAL_16]], %[[VAL_22]], %[[VAL_3]], %[[VAL_14]], %[[VAL_14]], %[[VAL_4]] : !torch.vtensor<[3,2],f32>, !torch.vtensor<[1,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK: torch.prim.Loop.condition %[[VAL_12]], iter(%[[VAL_19]], %[[VAL_23]] : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>)
|
||||
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>)
|
||||
// CHECK: return %[[VAL_24:.*]]#0, %[[VAL_24]]#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>
|
||||
// CHECK: }
|
||||
%none = torch.constant.none
|
||||
%0:2 = torch.operator "onnx.Scan"(%arg0, %arg1) {torch.onnx.num_scan_inputs = 1 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) {
|
||||
^bb0(%arg2: !torch.vtensor<[2],f32>, %arg3: !torch.vtensor<[2],f32>):
|
||||
%1 = torch.operator "onnx.Add"(%arg2, %arg3) : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32>
|
||||
%2 = torch.operator "onnx.Identity"(%1) : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32>
|
||||
torch.operator_terminator %1, %2 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>
|
||||
}
|
||||
return %0#0, %0#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue