mirror of https://github.com/llvm/torch-mlir
[torch] Add OnnxToTorch lowering for `onnx.HammingWindow` (#3283)
Adds OnnxToTorch lowering for the `onnx.HammingWindow` op.pull/3296/head
parent
e60160d793
commit
6f911ba3d7
|
@ -2300,15 +2300,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = binder.getLoc();
|
||||
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
|
||||
loc, rewriter.getF64FloatAttr(0.42));
|
||||
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
|
||||
loc, rewriter.getF64FloatAttr(-0.5));
|
||||
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
|
||||
loc, rewriter.getF64FloatAttr(0.08));
|
||||
|
||||
auto windowFunctionResult =
|
||||
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
|
||||
|
@ -2332,13 +2331,45 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = binder.getLoc();
|
||||
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.5));
|
||||
loc, rewriter.getF64FloatAttr(0.5));
|
||||
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
|
||||
loc, rewriter.getF64FloatAttr(-0.5));
|
||||
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));
|
||||
loc, rewriter.getF64FloatAttr(0.0));
|
||||
|
||||
auto windowFunctionResult =
|
||||
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
|
||||
output_datatype, periodic);
|
||||
|
||||
if (failed(windowFunctionResult))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
"HammingWindow", 17,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Value size;
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t periodic, output_datatype;
|
||||
if (binder.tensorOperand(size) ||
|
||||
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
|
||||
binder.s64IntegerAttr(periodic, "periodic", 1) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = binder.getLoc();
|
||||
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(0.543478));
|
||||
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(-0.456522));
|
||||
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(0.0));
|
||||
|
||||
auto windowFunctionResult =
|
||||
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
|
||||
|
|
|
@ -2153,3 +2153,83 @@ func.func @test_hannwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.v
|
|||
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_hammingwindow_symmetric
|
||||
func.func @test_hammingwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.434780e-01
|
||||
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -4.565220e-01
|
||||
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
|
||||
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
|
||||
|
||||
%0 = torch.operator "onnx.HammingWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_hammingwindow
|
||||
func.func @test_hammingwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.434780e-01
|
||||
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -4.565220e-01
|
||||
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
|
||||
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
|
||||
|
||||
%0 = torch.operator "onnx.HammingWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue