mirror of https://github.com/llvm/torch-mlir
Add support for the onnx.SequenceLength op. (#3362)
parent
2937753070
commit
513d89c16d
|
@ -97,6 +97,19 @@ struct OpBinder {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Operand matches of different arities.
|
||||||
|
ParseResult tensorListOperand(Value &value0) {
|
||||||
|
if (op->getNumOperands() != 1)
|
||||||
|
return failure();
|
||||||
|
value0 = op->getOperand(0);
|
||||||
|
auto tt = dyn_cast<Torch::ListType>(value0.getType());
|
||||||
|
if (!tt)
|
||||||
|
return failure();
|
||||||
|
if (!toValidTensorType(tt.getContainedType()))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
ParseResult tensorListResultType(Torch::ListType &type0) {
|
ParseResult tensorListResultType(Torch::ListType &type0) {
|
||||||
if (op->getNumResults() != 1)
|
if (op->getNumResults() != 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -532,6 +532,30 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.op, resultType, operands);
|
binder.op, resultType, operands);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"SequenceLength", 11,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
// onnx.SequenceLength takes a sequence(list) of tensors, and returns
|
||||||
|
// a zero rank tensor with the length.
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value x;
|
||||||
|
if (binder.tensorListOperand(x) || binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value cstFalse =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
|
||||||
|
Value len = rewriter.create<Torch::AtenLenTOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), x);
|
||||||
|
|
||||||
|
// AtenLenTOp returns a torch.int, so we have to
|
||||||
|
// put that in a tensor.
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenTensorIntOp>(
|
||||||
|
binder.op, resultType, len, none, none, cstFalse);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
|
|
@ -2099,6 +2099,21 @@ module {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_sequence_length
|
||||||
|
module {
|
||||||
|
func.func @test_sequence_length(%arg0: !torch.list<vtensor<[?,?,?],f32>>) -> !torch.vtensor<[],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[LEN:.+]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[?,?,?],f32>> -> !torch.int
|
||||||
|
// CHECK: %[[LEN_AS_TEN:.+]] = torch.aten.tensor.int %[[LEN]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: return %[[LEN_AS_TEN]] : !torch.vtensor<[],si64>
|
||||||
|
%0 = torch.operator "onnx.SequenceLength"(%arg0) : (!torch.list<vtensor<[?,?,?],f32>>) -> !torch.vtensor<[],si64>
|
||||||
|
return %0 : !torch.vtensor<[],si64>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_sce_mean_3d
|
// CHECK-LABEL: func.func @test_sce_mean_3d
|
||||||
func.func @test_sce_mean_3d(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_sce_mean_3d(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
|
Loading…
Reference in New Issue