diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d36c453d5..03cf60589 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -4238,4 +4238,115 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( uniqueResults[1], uniqueResults[2]}); return success(); }); + patterns.onOp( + "Scan", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + SmallVector operands; + int64_t numScanInputs; + if (binder.tensorOperandsList(operands) || operands.size() == 0 || + binder.s64IntegerAttr(numScanInputs, "num_scan_inputs")) { + return rewriter.notifyMatchFailure(binder.op, + "Failed to get required inputs"); + } + SmallVector resultTypes; + if (binder.tensorResultTypes(resultTypes)) { + return rewriter.notifyMatchFailure(binder.op, + "result type bind failure"); + } + Region *loopBodyIn; + if (binder.getRegionAtIndex(loopBodyIn, 0)) { + return rewriter.notifyMatchFailure(binder.op, + "Failed getting LoopBody Region"); + } + + int64_t numInits = operands.size() - numScanInputs; + SmallVector initVals(operands.begin(), + operands.begin() + numInits); + SmallVector scanInputs(operands.begin() + numInits, + operands.end()); + if (scanInputs.size() < 1) { + return rewriter.notifyMatchFailure(binder.op, + "Expects at least one scan input"); + } + + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + SmallVector scanOutTypes; + for (unsigned i = numInits; i < resultTypes.size(); i++) { + auto scanOutTy = cast(resultTypes[i]); + // TODO: Handle dynamic result types. + if (!scanOutTy.hasSizes() || !scanOutTy.areAllSizesKnown()) { + return rewriter.notifyMatchFailure( + binder.op, "Expects result type to be static"); + } + Value sizeList = + createConstantIntList(binder, rewriter, scanOutTy.getSizes()); + initVals.push_back(Torch::createInitTensor(rewriter, loc, scanOutTy, + constZero, sizeList)); + scanOutTypes.push_back(resultTypes[i]); + } + // Create torch.prim.Loop op. + Value maxTripCount = rewriter.create( + loc, scanInputs[0], constZero); + auto constBoolTrue = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(true)); + auto primLoop = rewriter.create( + loc, resultTypes, maxTripCount, constBoolTrue, initVals); + rewriter.cloneRegionBefore(*loopBodyIn, primLoop.getRegion(), + primLoop.getRegion().begin()); + + // Insert index var as torch.int argument in the loop body, as + // the primLoopOp loopBody expects torch.int as first argument. + primLoop.getRegion().insertArgument( + 0u, rewriter.getType(), loc); + auto loopInd = primLoop.getRegion().getArgument(0); + + // The block arguments of onnx.scan needs to be replaced with + // slice of scan inputs. + rewriter.setInsertionPointToStart(&primLoop.getRegion().front()); + for (unsigned i = 0; i < numScanInputs; i++) { + auto loopBlockArg = + primLoop.getRegion().getArgument(numInits + 1 + i); + Value extract = rewriter.create( + loc, loopBlockArg.getType(), scanInputs[i], constZero, loopInd); + loopBlockArg.replaceAllUsesWith(extract); + } + primLoop.getRegion().front().eraseArguments(numInits + 1, + /*count=*/numScanInputs); + + // Collect the output slices to form scan outputs and replace the + // terminator. + SmallVector locs(scanOutTypes.size(), loc); + primLoop.getRegion().front().addArguments(scanOutTypes, locs); + + PatternRewriter::InsertionGuard guard(rewriter); + Operation *terminator = primLoop.getRegion().front().getTerminator(); + auto terminatorOperands = terminator->getOperands(); + SmallVector resTerminatorOperands( + terminatorOperands.begin(), terminatorOperands.begin() + numInits); + SmallVector scanOutSlices(terminatorOperands.begin() + numInits, + terminatorOperands.end()); + rewriter.setInsertionPoint(terminator); + for (unsigned i = 0; i < scanOutSlices.size(); i++) { + Value self = BlockArgument::Value( + primLoop.getRegion().getArgument(numInits + 1 + i)); + FailureOr src = Torch::unsqueezeTensor( + rewriter, binder.op, scanOutSlices[i], constZero); + if (failed(src)) + return failure(); + Value scanOut = rewriter.create( + loc, scanOutTypes[i], self, src.value(), constZero, + /*start=*/loopInd, + /*end=*/loopInd, constOne); + resTerminatorOperands.push_back(scanOut); + } + + Value terminatorCond = constBoolTrue; + rewriter.replaceOpWithNewOp( + terminator, terminatorCond, resTerminatorOperands); + rewriter.replaceOp(binder.op, primLoop); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index bed62329a..41e4391a8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -3318,3 +3318,43 @@ func.func @test_unique_sorted_with_negative_axis(%arg0: !torch.vtensor<[3,3],f32 %0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>) return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64> } + +// ----- + +// CHECK-LABEL: func.func @test_scan_sum( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_scan_sum(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_6:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_8:.*]] = torch.constant.none + // CHECK: %[[VAL_9:.*]] = torch.constant.int 6 + // CHECK: %[[VAL_10:.*]] = torch.aten.full %[[VAL_7]], %[[VAL_3]], %[[VAL_9]], %[[VAL_8]], %[[VAL_8]], %[[VAL_8]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[VAL_11:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_3]] : !torch.vtensor<[3,2],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_12:.*]] = torch.constant.bool true + // CHECK: %[[VAL_13:.*]]:2 = torch.prim.Loop %[[VAL_11]], %[[VAL_12]], init(%[[VAL_0]], %[[VAL_10]]) { + // CHECK: ^bb0(%[[VAL_14:.*]]: !torch.int, %[[VAL_15:.*]]: !torch.vtensor<[2],f32>, %[[VAL_16:.*]]: !torch.vtensor<[3,2],f32>): + // CHECK: %[[VAL_17:.*]] = torch.aten.select.int %[[VAL_1]], %[[VAL_3]], %[[VAL_14]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32> + // CHECK: %[[VAL_18:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_19:.*]] = torch.aten.add.Tensor %[[VAL_15]], %[[VAL_17]], %[[VAL_18]] : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32> + // CHECK: %[[VAL_20:.*]] = torch.constant.none + // CHECK: %[[VAL_21:.*]] = torch.aten.clone %[[VAL_19]], %[[VAL_20]] : !torch.vtensor<[2],f32>, !torch.none -> !torch.vtensor<[2],f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.unsqueeze %[[VAL_21]], %[[VAL_3]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32> + // CHECK: %[[VAL_23:.*]] = torch.aten.slice_scatter %[[VAL_16]], %[[VAL_22]], %[[VAL_3]], %[[VAL_14]], %[[VAL_14]], %[[VAL_4]] : !torch.vtensor<[3,2],f32>, !torch.vtensor<[1,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,2],f32> + // CHECK: torch.prim.Loop.condition %[[VAL_12]], iter(%[[VAL_19]], %[[VAL_23]] : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) + // CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) + // CHECK: return %[[VAL_24:.*]]#0, %[[VAL_24]]#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32> + // CHECK: } + %none = torch.constant.none + %0:2 = torch.operator "onnx.Scan"(%arg0, %arg1) {torch.onnx.num_scan_inputs = 1 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) { + ^bb0(%arg2: !torch.vtensor<[2],f32>, %arg3: !torch.vtensor<[2],f32>): + %1 = torch.operator "onnx.Add"(%arg2, %arg3) : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> + %2 = torch.operator "onnx.Identity"(%1) : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> + torch.operator_terminator %1, %2 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32> + } + return %0#0, %0#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32> +}