diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index d4ace352a..919146c6a 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -38,6 +38,13 @@ Value createConstantIntList(OpBinder binder, Type getQTorchTypeFromTorchIntType(Type ty); +template +Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, + Value &ofItem) { + return rewriter.create(binder.getLoc(), + rewriter.getType(), ofItem); +} + LogicalResult OnnxLstmExpander(OpBinder binder, ConversionPatternRewriter &rewriter); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 716ea3d6e..bd5c57fac 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -2240,4 +2240,126 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone); return success(); }); + patterns.onOp( + "BlackmanWindow", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value size; + Torch::ValueTensorType resultType; + int64_t periodic, output_datatype; + if (binder.tensorOperand(size) || + binder.s64IntegerAttr(output_datatype, "output_datatype", 1) || + binder.s64IntegerAttr(periodic, "periodic", 1) || + binder.tensorResultType(resultType)) { + return failure(); + } + double isPeriodicFp = static_cast(periodic); + Value a0 = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), 0.42)); + Value a1 = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), -0.5)); + Value a2 = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), 0.08)); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(1.0)); + Value two = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(2.0)); + + constexpr double pi = llvm::numbers::pi; + Value tau = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi)); + + Value noneVal = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value float32Type = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + + // Create an f32 ValueTensorType with thse same size as size, the + // operand + auto shapeOfOperand = size.getType() + .dyn_cast() + .getOptionalSizes(); + auto f32ResultType = rewriter.getType( + shapeOfOperand, rewriter.getF32Type()); + Value periodicSizeFloat = rewriter.create( + binder.getLoc(), f32ResultType, size, float32Type, cstFalse, + cstFalse, noneVal); + Value symmetricSizeFloat = rewriter.create( + binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat, + one, one); + + Value isPeriodic = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(isPeriodicFp)); + Value isSymmetricFloat = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(1.0 - isPeriodicFp)); + + Value periodicComponent = rewriter.create( + binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat, + isPeriodic); + Value symmetricComponent = rewriter.create( + binder.getLoc(), symmetricSizeFloat.getType(), symmetricSizeFloat, + isSymmetricFloat); + Value sizeFloat = rewriter.create( + binder.getLoc(), symmetricComponent.getType(), symmetricComponent, + periodicComponent, one); + + // Here, size can be used in the place of periodicSizeFloat, as the + // latter is just a float representation of the former. + Value scalarLimit = getItemOp(binder, rewriter, size); + + Value rangeArr = rewriter.create( + binder.getLoc(), resultType, zero, scalarLimit, one, noneVal, + noneVal, noneVal, noneVal); + + Value rangeTimesTau = rewriter.create( + binder.getLoc(), resultType, rangeArr, tau); + Value rangeAngular = rewriter.create( + binder.getLoc(), resultType, rangeTimesTau, sizeFloat); + Value twoRangeAngular = rewriter.create( + binder.getLoc(), resultType, rangeAngular, two); + + Value cosRangeAngular = rewriter.create( + binder.getLoc(), resultType, rangeAngular); + Value cosTwoRangeAngular = rewriter.create( + binder.getLoc(), resultType, twoRangeAngular); + + Value a1Component = rewriter.create( + binder.getLoc(), resultType, cosRangeAngular, a1); + Value a2Component = rewriter.create( + binder.getLoc(), resultType, cosTwoRangeAngular, a2); + + // AtenSubScalarOp actually requires a tensor operand as the LHS, that + // is, operand #1. Therefore, to avoid errors, the onnx implementation + // has been modified. a1 has been changed to negative half, and the + // AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add + // operation is commutative. + Value subA1Component = rewriter.create( + binder.getLoc(), resultType, a1Component, a0, one); + Value result = rewriter.create( + binder.getLoc(), resultType, subA1Component, a2Component, one); + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(output_datatype); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value outputDtype = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dtypeIntTorch.value())); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, result, outputDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/noneVal); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 197d9c536..5f9da3faa 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -31,15 +31,7 @@ using namespace mlir::torch::onnx_c; // thing here, so we simplify. // utilities -// Templatized function to get an item op of a type namespace { -template -Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, - Value &ofItem) { - return rewriter.create(binder.getLoc(), - rewriter.getType(), ofItem); -} - // In case the ReduceSum Op was not the first operation performed on the data, // we provide the original operand through storeResult, which will be modified // if the result will be passed onto another operation, and will be used for diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index eb2cde696..a068acbf2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -2035,3 +2035,81 @@ func.func @test_eyelike_dynamic(%arg0: !torch.vtensor<[3,?],f32>) -> !torch.vten %0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32> return %0 : !torch.vtensor<[3,?],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_blackmanwindow_symmetric +func.func @test_blackmanwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02 + // CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST]] : !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_blackmanwindow +func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02 + // CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST]] : !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +}