mirror of https://github.com/llvm/torch-mlir
Add support for the onnx.SequenceLength op. (#3362)
parent
2937753070
commit
513d89c16d
|
@ -97,6 +97,19 @@ struct OpBinder {
|
|||
return success();
|
||||
}
|
||||
|
||||
// Operand matches of different arities.
|
||||
ParseResult tensorListOperand(Value &value0) {
|
||||
if (op->getNumOperands() != 1)
|
||||
return failure();
|
||||
value0 = op->getOperand(0);
|
||||
auto tt = dyn_cast<Torch::ListType>(value0.getType());
|
||||
if (!tt)
|
||||
return failure();
|
||||
if (!toValidTensorType(tt.getContainedType()))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult tensorListResultType(Torch::ListType &type0) {
|
||||
if (op->getNumResults() != 1)
|
||||
return failure();
|
||||
|
|
|
@ -532,6 +532,30 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, resultType, operands);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"SequenceLength", 11,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// onnx.SequenceLength takes a sequence(list) of tensors, and returns
|
||||
// a zero rank tensor with the length.
|
||||
Torch::ValueTensorType resultType;
|
||||
Value x;
|
||||
if (binder.tensorListOperand(x) || binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
Value cstFalse =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
|
||||
Value len = rewriter.create<Torch::AtenLenTOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), x);
|
||||
|
||||
// AtenLenTOp returns a torch.int, so we have to
|
||||
// put that in a tensor.
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenTensorIntOp>(
|
||||
binder.op, resultType, len, none, none, cstFalse);
|
||||
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -2099,6 +2099,21 @@ module {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sequence_length
|
||||
module {
|
||||
func.func @test_sequence_length(%arg0: !torch.list<vtensor<[?,?,?],f32>>) -> !torch.vtensor<[],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[LEN:.+]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[?,?,?],f32>> -> !torch.int
|
||||
// CHECK: %[[LEN_AS_TEN:.+]] = torch.aten.tensor.int %[[LEN]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[LEN_AS_TEN]] : !torch.vtensor<[],si64>
|
||||
%0 = torch.operator "onnx.SequenceLength"(%arg0) : (!torch.list<vtensor<[?,?,?],f32>>) -> !torch.vtensor<[],si64>
|
||||
return %0 : !torch.vtensor<[],si64>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sce_mean_3d
|
||||
func.func @test_sce_mean_3d(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
|
|
Loading…
Reference in New Issue