[MLIR][TORCH] Add OnnxToTorch support for BlackmanWindow function (#3181)

Implements OnnxToTorch lowering for the BlackmanWindow Function.
pull/3269/head
Vinayak Dev 2024-04-30 21:51:27 +05:30 committed by GitHub
parent f32ada993d
commit 05f8b69bf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 207 additions and 8 deletions

View File

@ -38,6 +38,13 @@ Value createConstantIntList(OpBinder binder,
Type getQTorchTypeFromTorchIntType(Type ty);
template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Value &ofItem) {
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
rewriter.getType<T>(), ofItem);
}
LogicalResult OnnxLstmExpander(OpBinder binder,
ConversionPatternRewriter &rewriter);

View File

@ -2240,4 +2240,126 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone);
return success();
});
patterns.onOp(
"BlackmanWindow", 17,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value size;
Torch::ValueTensorType resultType;
int64_t periodic, output_datatype;
if (binder.tensorOperand(size) ||
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
binder.s64IntegerAttr(periodic, "periodic", 1) ||
binder.tensorResultType(resultType)) {
return failure();
}
double isPeriodicFp = static_cast<double>(periodic);
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
Value zero = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0));
Value one = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(1.0));
Value two = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(2.0));
constexpr double pi = llvm::numbers::pi;
Value tau = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value float32Type = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(/*float32Type*/ 6));
// Create an f32 ValueTensorType with thse same size as size, the
// operand
auto shapeOfOperand = size.getType()
.dyn_cast<Torch::ValueTensorType>()
.getOptionalSizes();
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
shapeOfOperand, rewriter.getF32Type());
Value periodicSizeFloat = rewriter.create<Torch::AtenToDtypeOp>(
binder.getLoc(), f32ResultType, size, float32Type, cstFalse,
cstFalse, noneVal);
Value symmetricSizeFloat = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat,
one, one);
Value isPeriodic = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(isPeriodicFp));
Value isSymmetricFloat = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(1.0 - isPeriodicFp));
Value periodicComponent = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat,
isPeriodic);
Value symmetricComponent = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), symmetricSizeFloat.getType(), symmetricSizeFloat,
isSymmetricFloat);
Value sizeFloat = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), symmetricComponent.getType(), symmetricComponent,
periodicComponent, one);
// Here, size can be used in the place of periodicSizeFloat, as the
// latter is just a float representation of the former.
Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);
Value rangeArr = rewriter.create<Torch::AtenArangeStartStepOp>(
binder.getLoc(), resultType, zero, scalarLimit, one, noneVal,
noneVal, noneVal, noneVal);
Value rangeTimesTau = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, rangeArr, tau);
Value rangeAngular = rewriter.create<Torch::AtenDivTensorOp>(
binder.getLoc(), resultType, rangeTimesTau, sizeFloat);
Value twoRangeAngular = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, rangeAngular, two);
Value cosRangeAngular = rewriter.create<Torch::AtenCosOp>(
binder.getLoc(), resultType, rangeAngular);
Value cosTwoRangeAngular = rewriter.create<Torch::AtenCosOp>(
binder.getLoc(), resultType, twoRangeAngular);
Value a1Component = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, cosRangeAngular, a1);
Value a2Component = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, cosTwoRangeAngular, a2);
// AtenSubScalarOp actually requires a tensor operand as the LHS, that
// is, operand #1. Therefore, to avoid errors, the onnx implementation
// has been modified. a1 has been changed to negative half, and the
// AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add
// operation is commutative.
Value subA1Component = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, a1Component, a0, one);
Value result = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, subA1Component, a2Component, one);
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(output_datatype);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value outputDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch.value()));
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, result, outputDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/noneVal);
return success();
});
}

View File

@ -31,15 +31,7 @@ using namespace mlir::torch::onnx_c;
// thing here, so we simplify.
// utilities
// Templatized function to get an item op of a type
namespace {
template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Value &ofItem) {
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
rewriter.getType<T>(), ofItem);
}
// In case the ReduceSum Op was not the first operation performed on the data,
// we provide the original operand through storeResult, which will be modified
// if the result will be passed onto another operation, and will be used for

View File

@ -2035,3 +2035,81 @@ func.func @test_eyelike_dynamic(%arg0: !torch.vtensor<[3,?],f32>) -> !torch.vten
%0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32>
return %0 : !torch.vtensor<[3,?],f32>
}
// -----
// CHECK-LABEL: func.func @test_blackmanwindow_symmetric
func.func @test_blackmanwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02
// CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}
// -----
// CHECK-LABEL: func.func @test_blackmanwindow
func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02
// CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}