add onnx loop support (#3408)

- Adds limited support for lowering onnx.Loop to primLoopOp
- lower in the pipeline`torch-to-scf` there is a check to see if loop is
for like. A primLoopOp is for like when the input condition is a
`trueBoolConstant`. To adapt the onnx to torch lowering to take
advantage of it, the implementation checks for specific op patterns in
the loodBody region and decides if loop is for like and uses the right
input condition op.
- to adapt the onnxLoopBody to torchLoopBody, we need to adapt the input
block arguments and set the correct output condition variable in the
loop body.
- scanOutput variables are currently not supported.
pull/3508/head
Phaneesh Barwaria 2024-06-27 17:08:44 +05:30 committed by GitHub
parent 6678e1a256
commit 39d1332008
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 226 additions and 1 deletions

View File

@ -209,6 +209,16 @@ struct OpBinder {
return success();
}
ParseResult tensorOperandTypes(llvm::SmallVector<mlir::Type> &typeList) {
for (auto operand : op->getOperands()) {
auto t = toValidTensorType(operand.getType());
if (!t)
return failure();
typeList.push_back(t);
}
return success();
}
// The importer imports Onnx.GraphProto attributes as regions attached to the
// op.
ParseResult getRegionAtIndex(mlir::Region *&region, int64_t idx) {

View File

@ -259,6 +259,159 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"Loop", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// Get all operands (maxTripCount, cond, ....inits....)
llvm::SmallVector<Value> operands;
if (binder.tensorOperandsList(operands) || operands.size() == 0 ||
binder.getNumOperands() < 2) {
return rewriter.notifyMatchFailure(binder.op,
"Failed to get required operands");
}
llvm::SmallVector<mlir::Type> operandTypeVec;
if (binder.tensorOperandTypes(operandTypeVec) ||
operandTypeVec.size() == 0) {
return rewriter.notifyMatchFailure(binder.op,
"Failed to get operandTypes");
}
Region *loopBodyIn;
if (binder.getRegionAtIndex(loopBodyIn, 0)) {
return rewriter.notifyMatchFailure(binder.op,
"Failed getting LoopBody Region");
}
// MaxTripCount - tensor int64 scalar (or empty)
Value maxTripCountTensor = operands[0];
auto maxTripCountInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
maxTripCountTensor);
// Condition - tensor bool scalar (or empty)
Value conditionTensor = operands[1];
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
conditionTensor);
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);
// To be used for "for like" loop case
auto constBoolTrue = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(true));
// Others (if present) - variadic (can be tensors and scalar values)
if (binder.getNumOperands() > 2) {
operandTypeVec.erase(operandTypeVec.begin(),
operandTypeVec.begin() + 2);
operands.erase(operands.begin(), operands.begin() + 2);
}
auto getOpName = [](Operation *op) -> std::string {
std::string name = op->getName().getStringRef().str();
if (name != "torch.operator")
return name;
// for unconverted onnx ops
return mlir::dyn_cast<StringAttr>(op->getAttr("name"))
.getValue()
.str();
};
// PrimLoop Op expectes inputCondition to be boolConstantTrue
// to decide if the loopOp is `forlike`. Use loopIsForLike to
// ensure appropriate inputCondition is set
// Case 1 : loopCondInp -> identity -> terminator(loopCondOut)
bool loopIsForLike = false;
auto case1ForLike = [&getOpName](Region *loopBody) -> bool {
Value onnxLoopBodyCondIn = loopBody->front().getArgument(1);
if (!onnxLoopBodyCondIn.hasOneUse())
return false;
Operation *inpCondUser = *onnxLoopBodyCondIn.getUsers().begin();
if (getOpName(inpCondUser) != "onnx.Identity") {
return false;
}
if (!inpCondUser->hasOneUse() ||
getOpName(*(inpCondUser->getUsers().begin())) !=
"torch.operator_terminator")
return false;
return true;
};
loopIsForLike = case1ForLike(loopBodyIn);
Value loopInitCondition =
loopIsForLike ? constBoolTrue : conditionBool.getResult();
auto loc = binder.getLoc();
mlir::ImplicitLocOpBuilder b(loc, rewriter);
auto loop = b.create<Torch::PrimLoopOp>(
TypeRange(operandTypeVec), maxTripCountInt, loopInitCondition,
ValueRange(operands));
rewriter.cloneRegionBefore(*loopBodyIn, loop.getRegion(),
loop.getRegion().begin());
// primLoopOp loopBody expects torch.int as first arg
// insert torch.int arg in loop body, convert to tensor,
// replace all uses of old arg, delete old arg.
auto loopVarArg = loop.getRegion().front().getArgument(0);
// insert new Arg
loop.getRegion().front().insertArgument(
0U, rewriter.getType<Torch::IntType>(), binder.getLoc());
auto newLoopVarArg = loop.getRegion().front().getArgument(0);
// convert int arg to tensor of original Type
rewriter.setInsertionPointToStart(&loop.getRegion().front());
Value loopVarVal = BlockArgument::Value(loopVarArg);
auto newTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loop.getRegion().op_begin()->getLoc(), loopVarVal.getType(),
newLoopVarArg);
loopVarArg.replaceAllUsesWith(newTensor);
loop.getRegion().eraseArgument(1);
// primLoopOp loopBody has no condition arg
auto condArg = loop.getRegion().front().getArgument(1);
if (!condArg.use_empty())
condArg.replaceAllUsesWith(conditionTensor);
// replace terminator
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = loop.getRegion().front().getTerminator();
rewriter.setInsertionPoint(terminator);
// results - n loop carried dependencies and k scan outputs
// Fail when there are scanOutputs in onnxLoop (K>0);
// unsupported for now
if (terminator->getNumOperands() !=
loop.getRegion().getNumArguments() - 1) {
return rewriter.notifyMatchFailure(
binder.op, "scanOutputs in loop body unsupported");
}
// Get remaining operands from onnxLoopBody's terminator Op
// these are all the loop carried dependencies in the loop body
auto terminatorOperands = terminator->getOperands();
llvm::SmallVector<Value> remTerminatorOperands(
terminatorOperands.begin() + 1, terminatorOperands.end());
Value terminatorCond;
if (loopIsForLike) {
terminatorCond = constBoolTrue;
} else {
// Only use when loop is not forlike
Value terminatorCondTensor = terminatorOperands[0];
auto terminatorCondInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
terminatorCondTensor);
auto terminatorCondBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
terminatorCondInt);
terminatorCond = terminatorCondBool.getResult();
}
rewriter.replaceOpWithNewOp<Torch::PrimLoopConditionOp>(
terminator, terminatorCond, remTerminatorOperands);
loop.getRegion().eraseArgument(1);
rewriter.replaceOp(binder.op, loop);
return success();
});
patterns.onOp("LSTM", 1, onnx_c::OnnxLstmExpander);
patterns.onOp(
"LogSoftmax", 13,
@ -2197,7 +2350,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return success();
});
patterns.onOp(
"Identity", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
"Identity", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value tensor;
if (binder.tensorOperand(tensor) ||

View File

@ -1652,3 +1652,65 @@ func.func @test_optional_has_element_list_tensor_input(%arg0: !torch.list<vtenso
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.list<vtensor<[4],f32>>) -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}
// -----
// CHECK-LABEL: func.func @test_loop_forlike
func.func @test_loop_forlike(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],i1>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "loop_example", torch.onnx_meta.producer_version = ""} {
// CHECK-SAME: %[[MAX_TRIP_COUNT_INP:.*]]: !torch.vtensor<[],si64>,
// CHECK-SAME: %[[CONDITION_INP:.*]]: !torch.vtensor<[],i1>,
// CHECK-SAME: %[[LCD_1:.*]]: !torch.vtensor<[1],f32>
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[MAX_TRIP_COUNT_INT:.*]] = torch.aten.item %[[MAX_TRIP_COUNT_INP]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[CONDITION_INT:.*]] = torch.aten.item %[[CONDITION_INP]] : !torch.vtensor<[],i1> -> !torch.int
// CHECK: %[[CONDITION_BOOL:.*]] = torch.aten.Bool.int %[[CONDITION_INT]] : !torch.int -> !torch.bool
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[MAX_TRIP_COUNT_INT]], %[[TRUE]], init(%[[LCD_1]]) {
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[LCD_1_BODY:.*]]: !torch.vtensor<[1],f32>):
// CHECK: %[[ITER_NUM_T:.*]] = torch.prim.NumToTensor.Scalar %[[ITER_NUM]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[CLONE_INP_COND:.*]] = torch.aten.clone %[[CONDITION_INP]], %[[NONE_1]] : !torch.vtensor<[],i1>, !torch.none -> !torch.vtensor<[],i1>
// CHECK: %[[CONST_ARR:.*]] = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>) : !torch.vtensor<[5],f32>
// CHECK: %[[ONE_T:.*]] = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[ONE_0:.*]] = torch.constant.int 1
// CHECK: %[[ADD_ONE_T:.*]] = torch.aten.add.Tensor %[[ITER_NUM_T]], %[[ONE_T]], %[[ONE_0]] : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ZERO_T:.*]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[ZERO_0:.*]] = torch.constant.int 0
// CHECK: %[[ITER_NUM_RT:.*]] = torch.aten.unsqueeze %[[ITER_NUM_T]], %[[ZERO_0]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ZERO_1:.*]] = torch.constant.int 0
// CHECK: %[[ADD_ONE_RT:.*]] = torch.aten.unsqueeze %[[ADD_ONE_T]], %[[ZERO_1]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[ONE_1:.*]] = torch.constant.int 1
// CHECK: %[[ONE_SIZE_LIST:.*]] = torch.prim.ListConstruct %[[ONE_1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[ONES_T:.*]] = torch.aten.ones %[[ONE_SIZE_LIST]], %[[NONE_2]], %[[NONE_2]], %[[NONE_2]], %[[NONE_2]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
// CHECK: %[[ZERO_2:.*]] = torch.constant.int 0
// CHECK: %[[ZERO_3:.*]] = torch.constant.int 0
// CHECK: %[[ZERO_T_1:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITER_NUM_INDEXED:.*]] = torch.aten.index_select %[[ITER_NUM_RT]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITER_NUM_INT:.*]] = torch.aten.item %[[ITER_NUM_INDEXED]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[INC_INDEXED:.*]] = torch.aten.index_select %[[ADD_ONE_RT]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[INC_INT:.*]] = torch.aten.item %[[INC_INDEXED]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[SLICE_INDEX_T:.*]] = torch.aten.index_select %[[ONES_T]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[INDEX_INT:.*]] = torch.aten.item %[[SLICE_INDEX_T]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[INPUT_SLICE:.*]] = torch.aten.slice.Tensor %[[CONST_ARR]], %[[ZERO_3]], %[[ITER_NUM_INT]], %[[INC_INT]], %[[INDEX_INT]] : !torch.vtensor<[5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[ONE_2:.*]] = torch.constant.int 1
// CHECK: %[[INTERM_RES:.*]] = torch.aten.add.Tensor %[[LCD_1_BODY]], %[[INPUT_SLICE]], %[[ONE_2]] : !torch.vtensor<[1],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[INTERM_RES]] : !torch.vtensor<[1],f32>)
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
// CHECK: return %[[LOOP]] : !torch.vtensor<[1],f32>
%none = torch.constant.none
%0 = torch.operator "onnx.Loop"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],i1>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> {
^bb0(%arg3: !torch.vtensor<[],si64>, %arg4: !torch.vtensor<[],i1>, %arg5: !torch.vtensor<[1],f32>):
%1 = torch.operator "onnx.Identity"(%arg4) : (!torch.vtensor<[],i1>) -> !torch.vtensor<[],i1>
%2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>} : () -> !torch.vtensor<[5],f32>
%3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%4 = torch.operator "onnx.Add"(%arg3, %3) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64>
%5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%6 = torch.operator "onnx.Unsqueeze"(%arg3, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1],si64>
%7 = torch.operator "onnx.Unsqueeze"(%4, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1],si64>
%8 = torch.operator "onnx.Slice"(%2, %6, %7) : (!torch.vtensor<[5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],f32>
%9 = torch.operator "onnx.Add"(%arg5, %8) : (!torch.vtensor<[1],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1, %9 : !torch.vtensor<[],i1>, !torch.vtensor<[1],f32>
}
return %0 : !torch.vtensor<[1],f32>
}