mirror of https://github.com/llvm/torch-mlir
add onnx loop support (#3408)
- Adds limited support for lowering onnx.Loop to primLoopOp - lower in the pipeline`torch-to-scf` there is a check to see if loop is for like. A primLoopOp is for like when the input condition is a `trueBoolConstant`. To adapt the onnx to torch lowering to take advantage of it, the implementation checks for specific op patterns in the loodBody region and decides if loop is for like and uses the right input condition op. - to adapt the onnxLoopBody to torchLoopBody, we need to adapt the input block arguments and set the correct output condition variable in the loop body. - scanOutput variables are currently not supported.pull/3508/head
parent
6678e1a256
commit
39d1332008
|
@ -209,6 +209,16 @@ struct OpBinder {
|
|||
return success();
|
||||
}
|
||||
|
||||
ParseResult tensorOperandTypes(llvm::SmallVector<mlir::Type> &typeList) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
auto t = toValidTensorType(operand.getType());
|
||||
if (!t)
|
||||
return failure();
|
||||
typeList.push_back(t);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// The importer imports Onnx.GraphProto attributes as regions attached to the
|
||||
// op.
|
||||
ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) {
|
||||
|
|
|
@ -259,6 +259,159 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, resultType, operand);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Loop", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// Get all operands (maxTripCount, cond, ....inits....)
|
||||
llvm::SmallVector<Value> operands;
|
||||
if (binder.tensorOperandsList(operands) || operands.size() == 0 ||
|
||||
binder.getNumOperands() < 2) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Failed to get required operands");
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Type> operandTypeVec;
|
||||
if (binder.tensorOperandTypes(operandTypeVec) ||
|
||||
operandTypeVec.size() == 0) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Failed to get operandTypes");
|
||||
}
|
||||
|
||||
Region *loopBodyIn;
|
||||
if (binder.getRegionAtIndex(loopBodyIn, 0)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Failed getting LoopBody Region");
|
||||
}
|
||||
|
||||
// MaxTripCount - tensor int64 scalar (or empty)
|
||||
Value maxTripCountTensor = operands[0];
|
||||
auto maxTripCountInt = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
maxTripCountTensor);
|
||||
|
||||
// Condition - tensor bool scalar (or empty)
|
||||
Value conditionTensor = operands[1];
|
||||
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
conditionTensor);
|
||||
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);
|
||||
// To be used for "for like" loop case
|
||||
auto constBoolTrue = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(true));
|
||||
|
||||
// Others (if present) - variadic (can be tensors and scalar values)
|
||||
if (binder.getNumOperands() > 2) {
|
||||
operandTypeVec.erase(operandTypeVec.begin(),
|
||||
operandTypeVec.begin() + 2);
|
||||
operands.erase(operands.begin(), operands.begin() + 2);
|
||||
}
|
||||
|
||||
auto getOpName = [](Operation *op) -> std::string {
|
||||
std::string name = op->getName().getStringRef().str();
|
||||
if (name != "torch.operator")
|
||||
return name;
|
||||
// for unconverted onnx ops
|
||||
return mlir::dyn_cast<StringAttr>(op->getAttr("name"))
|
||||
.getValue()
|
||||
.str();
|
||||
};
|
||||
|
||||
// PrimLoop Op expectes inputCondition to be boolConstantTrue
|
||||
// to decide if the loopOp is `forlike`. Use loopIsForLike to
|
||||
// ensure appropriate inputCondition is set
|
||||
// Case 1 : loopCondInp -> identity -> terminator(loopCondOut)
|
||||
bool loopIsForLike = false;
|
||||
auto case1ForLike = [&getOpName](Region *loopBody) -> bool {
|
||||
Value onnxLoopBodyCondIn = loopBody->front().getArgument(1);
|
||||
if (!onnxLoopBodyCondIn.hasOneUse())
|
||||
return false;
|
||||
Operation *inpCondUser = *onnxLoopBodyCondIn.getUsers().begin();
|
||||
if (getOpName(inpCondUser) != "onnx.Identity") {
|
||||
return false;
|
||||
}
|
||||
if (!inpCondUser->hasOneUse() ||
|
||||
getOpName(*(inpCondUser->getUsers().begin())) !=
|
||||
"torch.operator_terminator")
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
loopIsForLike = case1ForLike(loopBodyIn);
|
||||
|
||||
Value loopInitCondition =
|
||||
loopIsForLike ? constBoolTrue : conditionBool.getResult();
|
||||
auto loc = binder.getLoc();
|
||||
mlir::ImplicitLocOpBuilder b(loc, rewriter);
|
||||
auto loop = b.create<Torch::PrimLoopOp>(
|
||||
TypeRange(operandTypeVec), maxTripCountInt, loopInitCondition,
|
||||
ValueRange(operands));
|
||||
|
||||
rewriter.cloneRegionBefore(*loopBodyIn, loop.getRegion(),
|
||||
loop.getRegion().begin());
|
||||
|
||||
// primLoopOp loopBody expects torch.int as first arg
|
||||
// insert torch.int arg in loop body, convert to tensor,
|
||||
// replace all uses of old arg, delete old arg.
|
||||
auto loopVarArg = loop.getRegion().front().getArgument(0);
|
||||
// insert new Arg
|
||||
loop.getRegion().front().insertArgument(
|
||||
0U, rewriter.getType<Torch::IntType>(), binder.getLoc());
|
||||
auto newLoopVarArg = loop.getRegion().front().getArgument(0);
|
||||
|
||||
// convert int arg to tensor of original Type
|
||||
rewriter.setInsertionPointToStart(&loop.getRegion().front());
|
||||
Value loopVarVal = BlockArgument::Value(loopVarArg);
|
||||
auto newTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
||||
loop.getRegion().op_begin()->getLoc(), loopVarVal.getType(),
|
||||
newLoopVarArg);
|
||||
|
||||
loopVarArg.replaceAllUsesWith(newTensor);
|
||||
loop.getRegion().eraseArgument(1);
|
||||
|
||||
// primLoopOp loopBody has no condition arg
|
||||
auto condArg = loop.getRegion().front().getArgument(1);
|
||||
if (!condArg.use_empty())
|
||||
condArg.replaceAllUsesWith(conditionTensor);
|
||||
|
||||
// replace terminator
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
Operation *terminator = loop.getRegion().front().getTerminator();
|
||||
rewriter.setInsertionPoint(terminator);
|
||||
|
||||
// results - n loop carried dependencies and k scan outputs
|
||||
// Fail when there are scanOutputs in onnxLoop (K>0);
|
||||
// unsupported for now
|
||||
if (terminator->getNumOperands() !=
|
||||
loop.getRegion().getNumArguments() - 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "scanOutputs in loop body unsupported");
|
||||
}
|
||||
|
||||
// Get remaining operands from onnxLoopBody's terminator Op
|
||||
// these are all the loop carried dependencies in the loop body
|
||||
auto terminatorOperands = terminator->getOperands();
|
||||
llvm::SmallVector<Value> remTerminatorOperands(
|
||||
terminatorOperands.begin() + 1, terminatorOperands.end());
|
||||
Value terminatorCond;
|
||||
if (loopIsForLike) {
|
||||
terminatorCond = constBoolTrue;
|
||||
} else {
|
||||
// Only use when loop is not forlike
|
||||
Value terminatorCondTensor = terminatorOperands[0];
|
||||
auto terminatorCondInt = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
terminatorCondTensor);
|
||||
auto terminatorCondBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
|
||||
terminatorCondInt);
|
||||
terminatorCond = terminatorCondBool.getResult();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimLoopConditionOp>(
|
||||
terminator, terminatorCond, remTerminatorOperands);
|
||||
|
||||
loop.getRegion().eraseArgument(1);
|
||||
rewriter.replaceOp(binder.op, loop);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("LSTM", 1, onnx_c::OnnxLstmExpander);
|
||||
patterns.onOp(
|
||||
"LogSoftmax", 13,
|
||||
|
@ -2197,7 +2350,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Identity", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
"Identity", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value tensor;
|
||||
if (binder.tensorOperand(tensor) ||
|
||||
|
|
|
@ -1652,3 +1652,65 @@ func.func @test_optional_has_element_list_tensor_input(%arg0: !torch.list<vtenso
|
|||
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.list<vtensor<[4],f32>>) -> !torch.vtensor<[],i1>
|
||||
return %0 : !torch.vtensor<[],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_loop_forlike
|
||||
func.func @test_loop_forlike(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],i1>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "loop_example", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[MAX_TRIP_COUNT_INP:.*]]: !torch.vtensor<[],si64>,
|
||||
// CHECK-SAME: %[[CONDITION_INP:.*]]: !torch.vtensor<[],i1>,
|
||||
// CHECK-SAME: %[[LCD_1:.*]]: !torch.vtensor<[1],f32>
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[MAX_TRIP_COUNT_INT:.*]] = torch.aten.item %[[MAX_TRIP_COUNT_INP]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[CONDITION_INT:.*]] = torch.aten.item %[[CONDITION_INP]] : !torch.vtensor<[],i1> -> !torch.int
|
||||
// CHECK: %[[CONDITION_BOOL:.*]] = torch.aten.Bool.int %[[CONDITION_INT]] : !torch.int -> !torch.bool
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[MAX_TRIP_COUNT_INT]], %[[TRUE]], init(%[[LCD_1]]) {
|
||||
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[LCD_1_BODY:.*]]: !torch.vtensor<[1],f32>):
|
||||
// CHECK: %[[ITER_NUM_T:.*]] = torch.prim.NumToTensor.Scalar %[[ITER_NUM]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[CLONE_INP_COND:.*]] = torch.aten.clone %[[CONDITION_INP]], %[[NONE_1]] : !torch.vtensor<[],i1>, !torch.none -> !torch.vtensor<[],i1>
|
||||
// CHECK: %[[CONST_ARR:.*]] = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>) : !torch.vtensor<[5],f32>
|
||||
// CHECK: %[[ONE_T:.*]] = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ONE_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ADD_ONE_T:.*]] = torch.aten.add.Tensor %[[ITER_NUM_T]], %[[ONE_T]], %[[ONE_0]] : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ZERO_T:.*]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ZERO_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ITER_NUM_RT:.*]] = torch.aten.unsqueeze %[[ITER_NUM_T]], %[[ZERO_0]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ZERO_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ADD_ONE_RT:.*]] = torch.aten.unsqueeze %[[ADD_ONE_T]], %[[ZERO_1]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[ONE_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ONE_SIZE_LIST:.*]] = torch.prim.ListConstruct %[[ONE_1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[ONES_T:.*]] = torch.aten.ones %[[ONE_SIZE_LIST]], %[[NONE_2]], %[[NONE_2]], %[[NONE_2]], %[[NONE_2]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ZERO_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ZERO_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ZERO_T_1:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITER_NUM_INDEXED:.*]] = torch.aten.index_select %[[ITER_NUM_RT]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITER_NUM_INT:.*]] = torch.aten.item %[[ITER_NUM_INDEXED]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[INC_INDEXED:.*]] = torch.aten.index_select %[[ADD_ONE_RT]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INC_INT:.*]] = torch.aten.item %[[INC_INDEXED]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[SLICE_INDEX_T:.*]] = torch.aten.index_select %[[ONES_T]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INDEX_INT:.*]] = torch.aten.item %[[SLICE_INDEX_T]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[INPUT_SLICE:.*]] = torch.aten.slice.Tensor %[[CONST_ARR]], %[[ZERO_3]], %[[ITER_NUM_INT]], %[[INC_INT]], %[[INDEX_INT]] : !torch.vtensor<[5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[ONE_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INTERM_RES:.*]] = torch.aten.add.Tensor %[[LCD_1_BODY]], %[[INPUT_SLICE]], %[[ONE_2]] : !torch.vtensor<[1],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[INTERM_RES]] : !torch.vtensor<[1],f32>)
|
||||
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return %[[LOOP]] : !torch.vtensor<[1],f32>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.Loop"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],i1>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> {
|
||||
^bb0(%arg3: !torch.vtensor<[],si64>, %arg4: !torch.vtensor<[],i1>, %arg5: !torch.vtensor<[1],f32>):
|
||||
%1 = torch.operator "onnx.Identity"(%arg4) : (!torch.vtensor<[],i1>) -> !torch.vtensor<[],i1>
|
||||
%2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>} : () -> !torch.vtensor<[5],f32>
|
||||
%3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||
%4 = torch.operator "onnx.Add"(%arg3, %3) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64>
|
||||
%5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||
%6 = torch.operator "onnx.Unsqueeze"(%arg3, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1],si64>
|
||||
%7 = torch.operator "onnx.Unsqueeze"(%4, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1],si64>
|
||||
%8 = torch.operator "onnx.Slice"(%2, %6, %7) : (!torch.vtensor<[5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],f32>
|
||||
%9 = torch.operator "onnx.Add"(%arg5, %8) : (!torch.vtensor<[1],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32>
|
||||
torch.operator_terminator %1, %9 : !torch.vtensor<[],i1>, !torch.vtensor<[1],f32>
|
||||
}
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue