From 5a627c46b76f8cdc737aef3bda1b910836e33d88 Mon Sep 17 00:00:00 2001 From: Phaneesh Barwaria Date: Fri, 28 Jun 2024 20:08:43 +0530 Subject: [PATCH] onnx.DFT basic support (#3463) - adds support for DFT v20 on the FFT and IFFT path - adds required skeleton code for IFFT ops to be recognised in TMlir --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 ++++++ .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 91 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 47 ++++++++++ .../build_tools/abstract_interp_lib_gen.py | 20 ++++ .../build_tools/torch_ods_gen.py | 1 + .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 48 ++++++++++ 6 files changed, 233 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index be5bc56d7..ae5f56aea 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12418,6 +12418,32 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [ }]; } +def Torch_AtenFftIfftOp : Torch_Op<"aten.fft_ifft", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$n, + Torch_IntType:$dim, + AnyTorchOptionalStringType:$norm + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFftIfftOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenFftIfftOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenFmodTensorOp : Torch_Op<"aten.fmod.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 446298e89..a5cdc1020 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -2728,4 +2728,95 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); + + patterns.onOp( + "DFT", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value inTensor, dftLength, axis; + Torch::ValueTensorType resultType; + int64_t inverse, onesided; + if (binder.tensorOperandAtIndex(inTensor, 0) || + binder.s64IntegerAttr(inverse, "inverse", 0) || + binder.s64IntegerAttr(onesided, "onesided", 0) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, "Input Tensor / attrs / resultType bind failed"); + if (!binder.tensorOperandAtIndex(dftLength, 1)) { + // Convert to int and pass as n + dftLength = rewriter.create( + binder.getLoc(), rewriter.getType(), dftLength); + } else { + // Default for torch is None + dftLength = rewriter.create(binder.getLoc()); + } + // Default is same for onnx and torch + if (!binder.tensorOperandAtIndex(axis, 2)) { + // convert to int and pass to dims + axis = rewriter.create( + binder.getLoc(), rewriter.getType(), axis); + } else { + // Default in torch is -1 and onnx is -2 (since -1 is for real / img) + axis = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(-2)); + } + + if (onesided == 1) + return rewriter.notifyMatchFailure(binder.op, + "Unsupported option : onesided"); + // norm default string attr + Value norm = rewriter.create( + binder.getLoc(), rewriter.getStringAttr(Twine("backward"))); + // Convert from [....., 2] complex number repr for fft consumption. + Torch::ValueTensorType inType = + binder.toValidTensorType(inTensor.getType()); + int64_t lastIndex = inType.getSizes().back(); + if (lastIndex != 1 && lastIndex != 2) + return rewriter.notifyMatchFailure( + binder.op, + "Expected input tensor to have dims [..., 1] or [..., 2]"); + + // concat with zeros to make it [..., 2] + Value inForComplexVal = inTensor; + ArrayRef inForComplexSizes = inType.getSizes().drop_back(); + if (lastIndex == 1) { + Value constZeroVal = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0)); + Value constOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value constZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value padSizeList = + rewriter + .create( + binder.getLoc(), + Torch::ListType::get(rewriter.getType()), + SmallVector({constZero, constOne})) + .getResult(); + Value modeVal = rewriter.create( + binder.getLoc(), rewriter.getStringAttr("constant")); + SmallVector resSize(inForComplexSizes); + resSize.push_back(2); + inForComplexVal = rewriter.create( + binder.getLoc(), + inType.getWithSizesAndDtype(resSize, inType.getOptionalDtype()), + inTensor, padSizeList, modeVal, constZeroVal); + } + Type inComplexTensorType = Torch::ValueTensorType::get( + binder.op->getContext(), inForComplexSizes, + mlir::ComplexType::get(inType.getDtype())); + Value inComplexTensor = rewriter.create( + binder.getLoc(), inComplexTensorType, inForComplexVal); + Value ftOp; + if (inverse == 0) { + ftOp = rewriter.create( + binder.getLoc(), inComplexTensorType, inComplexTensor, + /*n = */ dftLength, /*dim = */ axis, /*norm = */ norm); + } else { + ftOp = rewriter.create( + binder.getLoc(), inComplexTensorType, inComplexTensor, + /*n = */ dftLength, /*dim = */ axis, /*norm = */ norm); + } + rewriter.replaceOpWithNewOp(binder.op, + resultType, ftOp); + return success(); + }); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 6974636c0..b05e1051c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10369,6 +10369,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %14 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fft_ifft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" " %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" @@ -11984,6 +11987,50 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_ifft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int10 = torch.constant.int 10\n" +" %int7 = torch.constant.int 7\n" +" %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" +" %int8 = torch.constant.int 8\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" } else {\n" +" %4 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %6 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %8 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %11 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 0b356cc34..b3d7ec5a9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2038,6 +2038,9 @@ def aten〇stft〡shape(self: List[int], n_fft: int, hop_length: Optional[int] = return out +def aten〇fft_ifft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: + return self + class DummyClassType: def __init__(self): pass @@ -3406,6 +3409,23 @@ def aten〇stft〡dtype(self_rank_dtype: Tuple[int, int], n_fft: int, hop_length assert False, "Unsupported dtype" +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bfloat16})) +def aten〇fft_ifft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if is_complex_dtype(self_dtype): + return self_dtype + elif self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_integer_dtype(self_dtype): + return torch.complex64 + else: + assert False, "Unsupported dtype" + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 90d3e1054..fe700d292 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -910,6 +910,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)" ) emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") + emit("aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit( "aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)" diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 4b03fccee..cf92c04d8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -2480,3 +2480,51 @@ func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch %0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,9,4],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> return %0 : !torch.vtensor<[1,1,5,5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_dft_fft +func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>, + // CHECK-SAME: %[[AXIS:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[SIG_LEN:.*]] = torch.constant.none + // CHECK: %[[DIM:.*]] = torch.aten.item %[[AXIS]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NORM:.*]] = torch.constant.str "backward" + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + // CHECK: %[[ZERO:.*]] = torch.constant.int 0 + // CHECK: %[[PAD_DIM_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[MODE:.*]] = torch.constant.str "constant" + // CHECK: %[[INPUT_PADDED:.*]] = torch.aten.pad %[[INPUT_SIGNAL]], %[[PAD_DIM_LIST]], %[[MODE]], %[[FILL_VAL]] : !torch.vtensor<[10,10,1],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[10,10,2],f32> + // CHECK: %[[INPUT_T_CMPLX:.*]] = torch.aten.view_as_complex %[[INPUT_PADDED]] : !torch.vtensor<[10,10,2],f32> -> !torch.vtensor<[10,10],complex> + // CHECK: %[[FFT_CMPLX:.*]] = torch.aten.fft_fft %[[INPUT_T_CMPLX]], %[[SIG_LEN]], %[[DIM]], %[[NORM]] : !torch.vtensor<[10,10],complex>, !torch.none, !torch.int, !torch.str -> !torch.vtensor<[10,10],complex> + // CHECK: %[[FFT_RES_REAL:.*]] = torch.aten.view_as_real %[[FFT_CMPLX]] : !torch.vtensor<[10,10],complex> -> !torch.vtensor<[10,10,2],f32> + // CHECK: return %[[FFT_RES_REAL]] : !torch.vtensor<[10,10,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> + return %0 : !torch.vtensor<[10,10,2],f32> +} + +// CHECK-LABEL: func.func @test_dft_inverse_real +func.func @test_dft_inverse_real(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>, + // CHECK-SAME: %[[AXIS:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[SIG_LEN:.*]] = torch.constant.none + // CHECK: %[[DIM:.*]] = torch.aten.item %[[AXIS]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NORM:.*]] = torch.constant.str "backward" + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + // CHECK: %[[ZERO:.*]] = torch.constant.int 0 + // CHECK: %[[PAD_DIM_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[MODE:.*]] = torch.constant.str "constant" + // CHECK: %[[INPUT_PADDED:.*]] = torch.aten.pad %[[INPUT_SIGNAL]], %[[PAD_DIM_LIST]], %[[MODE]], %[[FILL_VAL]] : !torch.vtensor<[10,10,1],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[10,10,2],f32> + // CHECK: %[[INPUT_T_CMPLX:.*]] = torch.aten.view_as_complex %[[INPUT_PADDED]] : !torch.vtensor<[10,10,2],f32> -> !torch.vtensor<[10,10],complex> + // CHECK: %[[IFFT_CMPLX:.*]] = torch.aten.fft_ifft %[[INPUT_T_CMPLX]], %[[SIG_LEN]], %[[DIM]], %[[NORM]] : !torch.vtensor<[10,10],complex>, !torch.none, !torch.int, !torch.str -> !torch.vtensor<[10,10],complex> + // CHECK: %[[IFFT_RES_REAL:.*]] = torch.aten.view_as_real %[[IFFT_CMPLX]] : !torch.vtensor<[10,10],complex> -> !torch.vtensor<[10,10,2],f32> + // CHECK: return %[[IFFT_RES_REAL]] : !torch.vtensor<[10,10,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) {torch.onnx.inverse = 1 : si64} : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> + return %0 : !torch.vtensor<[10,10,2],f32> +}