mirror of https://github.com/llvm/torch-mlir
[torch] Add OnnxToTorch lowering for `onnx.HannWindow` (#3276)
Adds OnnxToTorch lowering for the `onnx.HannWindow` op. Also factors out common implementation between the window functions.pull/3285/head
parent
a46fe2c9db
commit
67d6a665a4
|
@ -35,6 +35,108 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
LogicalResult windowFunctionImpl(OpBinder binder,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
Value size, Value a0, Value a1, Value a2,
|
||||||
|
Torch::ValueTensorType resultType,
|
||||||
|
int64_t output_datatype, int64_t periodic) {
|
||||||
|
|
||||||
|
Location loc = binder.getLoc();
|
||||||
|
ImplicitLocOpBuilder b(loc, rewriter);
|
||||||
|
|
||||||
|
double isPeriodicFp = static_cast<double>(periodic);
|
||||||
|
|
||||||
|
Value zero = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(0.0));
|
||||||
|
Value one = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(1.0));
|
||||||
|
Value two = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(2.0));
|
||||||
|
|
||||||
|
constexpr double pi = llvm::numbers::pi;
|
||||||
|
Value tau = b.create<Torch::ConstantFloatOp>(
|
||||||
|
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
|
||||||
|
|
||||||
|
Value noneVal = b.create<Torch::ConstantNoneOp>();
|
||||||
|
Value cstFalse = b.create<Torch::ConstantBoolOp>(false);
|
||||||
|
Value float32Type = b.create<Torch::ConstantIntOp>(
|
||||||
|
rewriter.getI64IntegerAttr(/*float32Type*/ 6));
|
||||||
|
|
||||||
|
// Create an f32 ValueTensorType with thse same size as size, the
|
||||||
|
// operand
|
||||||
|
auto shapeOfOperand =
|
||||||
|
size.getType().dyn_cast<Torch::ValueTensorType>().getOptionalSizes();
|
||||||
|
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
shapeOfOperand, rewriter.getF32Type());
|
||||||
|
Value periodicSizeFloat = b.create<Torch::AtenToDtypeOp>(
|
||||||
|
f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal);
|
||||||
|
Value symmetricSizeFloat = b.create<Torch::AtenSubScalarOp>(
|
||||||
|
periodicSizeFloat.getType(), periodicSizeFloat, one, one);
|
||||||
|
|
||||||
|
Value isPeriodic =
|
||||||
|
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(isPeriodicFp));
|
||||||
|
Value isSymmetricFloat = b.create<Torch::ConstantFloatOp>(
|
||||||
|
rewriter.getF64FloatAttr(1.0 - isPeriodicFp));
|
||||||
|
|
||||||
|
Value periodicComponent = b.create<Torch::AtenMulScalarOp>(
|
||||||
|
periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic);
|
||||||
|
Value symmetricComponent = b.create<Torch::AtenMulScalarOp>(
|
||||||
|
symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat);
|
||||||
|
Value sizeFloat = b.create<Torch::AtenAddTensorOp>(
|
||||||
|
symmetricComponent.getType(), symmetricComponent, periodicComponent, one);
|
||||||
|
|
||||||
|
// Here, size can be used in the place of periodicSizeFloat, as the
|
||||||
|
// latter is just a float representation of the former.
|
||||||
|
Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);
|
||||||
|
|
||||||
|
Value rangeArr = b.create<Torch::AtenArangeStartStepOp>(
|
||||||
|
resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, noneVal);
|
||||||
|
|
||||||
|
Value rangeTimesTau =
|
||||||
|
b.create<Torch::AtenMulScalarOp>(resultType, rangeArr, tau);
|
||||||
|
Value rangeAngular =
|
||||||
|
b.create<Torch::AtenDivTensorOp>(resultType, rangeTimesTau, sizeFloat);
|
||||||
|
Value twoRangeAngular =
|
||||||
|
b.create<Torch::AtenMulScalarOp>(resultType, rangeAngular, two);
|
||||||
|
|
||||||
|
Value cosRangeAngular = b.create<Torch::AtenCosOp>(resultType, rangeAngular);
|
||||||
|
Value cosTwoRangeAngular =
|
||||||
|
b.create<Torch::AtenCosOp>(resultType, twoRangeAngular);
|
||||||
|
|
||||||
|
Value a1Component =
|
||||||
|
b.create<Torch::AtenMulScalarOp>(resultType, cosRangeAngular, a1);
|
||||||
|
Value a2Component =
|
||||||
|
b.create<Torch::AtenMulScalarOp>(resultType, cosTwoRangeAngular, a2);
|
||||||
|
|
||||||
|
// AtenSubScalarOp actually requires a tensor operand as the LHS, that
|
||||||
|
// is, operand #1. Therefore, to avoid errors, the onnx implementation
|
||||||
|
// has been modified. a1 has been changed to negative half, and the
|
||||||
|
// AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add
|
||||||
|
// operation is commutative.
|
||||||
|
Value subA1Component =
|
||||||
|
b.create<Torch::AtenAddScalarOp>(resultType, a1Component, a0, one);
|
||||||
|
Value result = b.create<Torch::AtenAddTensorOp>(resultType, subA1Component,
|
||||||
|
a2Component, one);
|
||||||
|
|
||||||
|
std::optional<int64_t> dtypeIntTorch =
|
||||||
|
onnxDtypeIntToTorchDtypeInt(output_datatype);
|
||||||
|
if (!dtypeIntTorch.has_value()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unimplemented support for the given dtype conversion");
|
||||||
|
}
|
||||||
|
Value outputDtype = b.create<Torch::ConstantIntOp>(
|
||||||
|
rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
dtypeIntTorch.value()));
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
||||||
|
binder.op, resultType, result, outputDtype,
|
||||||
|
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||||
|
/*memory_format=*/noneVal);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Simple rewrites for the default domain.
|
// Simple rewrites for the default domain.
|
||||||
// See: https://onnx.ai/onnx/operators/
|
// See: https://onnx.ai/onnx/operators/
|
||||||
// For operators that are effectively version invariant, we register with
|
// For operators that are effectively version invariant, we register with
|
||||||
|
@ -2252,7 +2354,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.tensorResultType(resultType)) {
|
binder.tensorResultType(resultType)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
double isPeriodicFp = static_cast<double>(periodic);
|
|
||||||
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
|
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
binder.getLoc(),
|
binder.getLoc(),
|
||||||
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
|
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
|
||||||
|
@ -2262,104 +2363,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
|
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
binder.getLoc(),
|
binder.getLoc(),
|
||||||
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
|
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
|
||||||
Value zero = rewriter.create<Torch::ConstantFloatOp>(
|
|
||||||
binder.getLoc(), rewriter.getF64FloatAttr(0.0));
|
|
||||||
Value one = rewriter.create<Torch::ConstantFloatOp>(
|
|
||||||
binder.getLoc(), rewriter.getF64FloatAttr(1.0));
|
|
||||||
Value two = rewriter.create<Torch::ConstantFloatOp>(
|
|
||||||
binder.getLoc(), rewriter.getF64FloatAttr(2.0));
|
|
||||||
|
|
||||||
constexpr double pi = llvm::numbers::pi;
|
auto windowFunctionResult =
|
||||||
Value tau = rewriter.create<Torch::ConstantFloatOp>(
|
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
|
||||||
binder.getLoc(),
|
output_datatype, periodic);
|
||||||
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
|
|
||||||
|
|
||||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
if (failed(windowFunctionResult))
|
||||||
Value cstFalse =
|
return failure();
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
||||||
Value float32Type = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(/*float32Type*/ 6));
|
|
||||||
|
|
||||||
// Create an f32 ValueTensorType with thse same size as size, the
|
return success();
|
||||||
// operand
|
});
|
||||||
auto shapeOfOperand = size.getType()
|
|
||||||
.dyn_cast<Torch::ValueTensorType>()
|
|
||||||
.getOptionalSizes();
|
|
||||||
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
|
|
||||||
shapeOfOperand, rewriter.getF32Type());
|
|
||||||
Value periodicSizeFloat = rewriter.create<Torch::AtenToDtypeOp>(
|
|
||||||
binder.getLoc(), f32ResultType, size, float32Type, cstFalse,
|
|
||||||
cstFalse, noneVal);
|
|
||||||
Value symmetricSizeFloat = rewriter.create<Torch::AtenSubScalarOp>(
|
|
||||||
binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat,
|
|
||||||
one, one);
|
|
||||||
|
|
||||||
Value isPeriodic = rewriter.create<Torch::ConstantFloatOp>(
|
patterns.onOp(
|
||||||
binder.getLoc(), rewriter.getF64FloatAttr(isPeriodicFp));
|
"HannWindow", 17,
|
||||||
Value isSymmetricFloat = rewriter.create<Torch::ConstantFloatOp>(
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
binder.getLoc(), rewriter.getF64FloatAttr(1.0 - isPeriodicFp));
|
Value size;
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
Value periodicComponent = rewriter.create<Torch::AtenMulScalarOp>(
|
int64_t periodic, output_datatype;
|
||||||
binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat,
|
if (binder.tensorOperand(size) ||
|
||||||
isPeriodic);
|
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
|
||||||
Value symmetricComponent = rewriter.create<Torch::AtenMulScalarOp>(
|
binder.s64IntegerAttr(periodic, "periodic", 1) ||
|
||||||
binder.getLoc(), symmetricSizeFloat.getType(), symmetricSizeFloat,
|
binder.tensorResultType(resultType)) {
|
||||||
isSymmetricFloat);
|
return failure();
|
||||||
Value sizeFloat = rewriter.create<Torch::AtenAddTensorOp>(
|
|
||||||
binder.getLoc(), symmetricComponent.getType(), symmetricComponent,
|
|
||||||
periodicComponent, one);
|
|
||||||
|
|
||||||
// Here, size can be used in the place of periodicSizeFloat, as the
|
|
||||||
// latter is just a float representation of the former.
|
|
||||||
Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);
|
|
||||||
|
|
||||||
Value rangeArr = rewriter.create<Torch::AtenArangeStartStepOp>(
|
|
||||||
binder.getLoc(), resultType, zero, scalarLimit, one, noneVal,
|
|
||||||
noneVal, noneVal, noneVal);
|
|
||||||
|
|
||||||
Value rangeTimesTau = rewriter.create<Torch::AtenMulScalarOp>(
|
|
||||||
binder.getLoc(), resultType, rangeArr, tau);
|
|
||||||
Value rangeAngular = rewriter.create<Torch::AtenDivTensorOp>(
|
|
||||||
binder.getLoc(), resultType, rangeTimesTau, sizeFloat);
|
|
||||||
Value twoRangeAngular = rewriter.create<Torch::AtenMulScalarOp>(
|
|
||||||
binder.getLoc(), resultType, rangeAngular, two);
|
|
||||||
|
|
||||||
Value cosRangeAngular = rewriter.create<Torch::AtenCosOp>(
|
|
||||||
binder.getLoc(), resultType, rangeAngular);
|
|
||||||
Value cosTwoRangeAngular = rewriter.create<Torch::AtenCosOp>(
|
|
||||||
binder.getLoc(), resultType, twoRangeAngular);
|
|
||||||
|
|
||||||
Value a1Component = rewriter.create<Torch::AtenMulScalarOp>(
|
|
||||||
binder.getLoc(), resultType, cosRangeAngular, a1);
|
|
||||||
Value a2Component = rewriter.create<Torch::AtenMulScalarOp>(
|
|
||||||
binder.getLoc(), resultType, cosTwoRangeAngular, a2);
|
|
||||||
|
|
||||||
// AtenSubScalarOp actually requires a tensor operand as the LHS, that
|
|
||||||
// is, operand #1. Therefore, to avoid errors, the onnx implementation
|
|
||||||
// has been modified. a1 has been changed to negative half, and the
|
|
||||||
// AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add
|
|
||||||
// operation is commutative.
|
|
||||||
Value subA1Component = rewriter.create<Torch::AtenAddScalarOp>(
|
|
||||||
binder.getLoc(), resultType, a1Component, a0, one);
|
|
||||||
Value result = rewriter.create<Torch::AtenAddTensorOp>(
|
|
||||||
binder.getLoc(), resultType, subA1Component, a2Component, one);
|
|
||||||
|
|
||||||
std::optional<int64_t> dtypeIntTorch =
|
|
||||||
onnxDtypeIntToTorchDtypeInt(output_datatype);
|
|
||||||
if (!dtypeIntTorch.has_value()) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
binder.op,
|
|
||||||
"unimplemented support for the given dtype conversion");
|
|
||||||
}
|
}
|
||||||
Value outputDtype = rewriter.create<Torch::ConstantIntOp>(
|
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.5));
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
dtypeIntTorch.value()));
|
binder.getLoc(),
|
||||||
|
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
|
||||||
|
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));
|
||||||
|
|
||||||
|
auto windowFunctionResult =
|
||||||
|
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
|
||||||
|
output_datatype, periodic);
|
||||||
|
|
||||||
|
if (failed(windowFunctionResult))
|
||||||
|
return failure();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
|
||||||
binder.op, resultType, result, outputDtype,
|
|
||||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
|
||||||
/*memory_format=*/noneVal);
|
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -2113,3 +2113,82 @@ func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor
|
||||||
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||||
return %0 : !torch.vtensor<[10],f32>
|
return %0 : !torch.vtensor<[10],f32>
|
||||||
}
|
}
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_hannwindow
|
||||||
|
func.func @test_hannwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
|
||||||
|
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
|
||||||
|
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
|
||||||
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
|
||||||
|
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
|
||||||
|
|
||||||
|
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||||
|
return %0 : !torch.vtensor<[10],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_hannwindow_symmetric
|
||||||
|
func.func @test_hannwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
|
||||||
|
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
|
||||||
|
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
|
||||||
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
|
||||||
|
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
|
||||||
|
|
||||||
|
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||||
|
return %0 : !torch.vtensor<[10],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue