From 1778314620b796de7a7aba61f00396cecbd29a0b Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 3 Jan 2024 12:52:59 -0500 Subject: [PATCH] add basic cumsum. this doesn't support the exclusive and reverse attrs (#2717) fixes #2711 --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 56 +++++++++++++++++++ .../unsupported_fb_opt_ops.mlir | 9 +++ 2 files changed, 65 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 3d2cf8aae..86f23bee1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -836,6 +836,62 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + Value operand; + Value axisTensor; + if (binder.tensorOperands(operand, axisTensor) || + binder.tensorResultType(resultType)) + return failure(); + + int64_t exclusive; + int64_t reverse; + // if bind succeeds and either is set, fail because not implemented + if (binder.s64IntegerAttr(exclusive, "exclusive", 0)) + if (exclusive != 0) + return rewriter.notifyMatchFailure( + binder.op, "unsupported onnx.CumSum conversion: exclusive"); + if (binder.s64IntegerAttr(reverse, "reverse", 0)) + if (reverse != 0) + return rewriter.notifyMatchFailure( + binder.op, "unsupported onnx.CumSum conversion: reverse"); + + // deal with neg axis: if (axis < 0) axis += rank + int64_t rank = + cast(operand.getType()).getSizes().size(); + Value rankVal = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + rank)); + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + Value axisScalar = rewriter.create( + binder.getLoc(), rewriter.getType(), axisTensor); + Value isNegative = + rewriter.create(binder.getLoc(), axisScalar, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, rankVal); + Value dim = rewriter.create( + binder.getLoc(), axisScalar, finalOffset); + + Torch::BaseTensorType resultTensorType = resultType.cast(); + if (!resultTensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + binder.op, "expected result type to have a dtype"); + } + // resultTensorType.print(llvm::outs()); + Value resultDType = + Torch::getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, dim, resultDType); + return success(); + }); patterns.onOp("Div", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir index 3ed9f1c6e..6659935ff 100644 --- a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir +++ b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir @@ -36,4 +36,13 @@ func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>) // The ReduceMean operation as provided. %211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> return %211 : !torch.vtensor<[1,64,1],f32> +} + +// ----- +// Fixed. +func.func @cumsum_operation(%arg0: !torch.vtensor<[2,3],f64>, + %arg1: !torch.vtensor<[],si32>) + -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %212 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> + return %212 : !torch.vtensor<[2,3],f64> } \ No newline at end of file