onnx.DFT basic support (#3463)

- adds support for DFT v20 on the FFT and IFFT path
- adds required skeleton code for IFFT ops to be recognised in TMlir
pull/3519/head
Phaneesh Barwaria 2024-06-28 20:08:43 +05:30 committed by GitHub
parent 7e6d76e997
commit 5a627c46b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 233 additions and 0 deletions

View File

@ -12418,6 +12418,32 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [
}];
}
def Torch_AtenFftIfftOp : Torch_Op<"aten.fft_ifft", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalIntType:$n,
Torch_IntType:$dim,
AnyTorchOptionalStringType:$norm
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFftIfftOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenFftIfftOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenFmodTensorOp : Torch_Op<"aten.fmod.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -2728,4 +2728,95 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
});
patterns.onOp(
"DFT", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value inTensor, dftLength, axis;
Torch::ValueTensorType resultType;
int64_t inverse, onesided;
if (binder.tensorOperandAtIndex(inTensor, 0) ||
binder.s64IntegerAttr(inverse, "inverse", 0) ||
binder.s64IntegerAttr(onesided, "onesided", 0) ||
binder.tensorResultType(resultType))
return rewriter.notifyMatchFailure(
binder.op, "Input Tensor / attrs / resultType bind failed");
if (!binder.tensorOperandAtIndex(dftLength, 1)) {
// Convert to int and pass as n
dftLength = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), dftLength);
} else {
// Default for torch is None
dftLength = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
}
// Default is same for onnx and torch
if (!binder.tensorOperandAtIndex(axis, 2)) {
// convert to int and pass to dims
axis = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), axis);
} else {
// Default in torch is -1 and onnx is -2 (since -1 is for real / img)
axis = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(-2));
}
if (onesided == 1)
return rewriter.notifyMatchFailure(binder.op,
"Unsupported option : onesided");
// norm default string attr
Value norm = rewriter.create<Torch::ConstantStrOp>(
binder.getLoc(), rewriter.getStringAttr(Twine("backward")));
// Convert from [....., 2] complex number repr for fft consumption.
Torch::ValueTensorType inType =
binder.toValidTensorType(inTensor.getType());
int64_t lastIndex = inType.getSizes().back();
if (lastIndex != 1 && lastIndex != 2)
return rewriter.notifyMatchFailure(
binder.op,
"Expected input tensor to have dims [..., 1] or [..., 2]");
// concat with zeros to make it [..., 2]
Value inForComplexVal = inTensor;
ArrayRef<int64_t> inForComplexSizes = inType.getSizes().drop_back();
if (lastIndex == 1) {
Value constZeroVal = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value constZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value padSizeList =
rewriter
.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
SmallVector<Value>({constZero, constOne}))
.getResult();
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
binder.getLoc(), rewriter.getStringAttr("constant"));
SmallVector<int64_t> resSize(inForComplexSizes);
resSize.push_back(2);
inForComplexVal = rewriter.create<Torch::AtenPadOp>(
binder.getLoc(),
inType.getWithSizesAndDtype(resSize, inType.getOptionalDtype()),
inTensor, padSizeList, modeVal, constZeroVal);
}
Type inComplexTensorType = Torch::ValueTensorType::get(
binder.op->getContext(), inForComplexSizes,
mlir::ComplexType::get(inType.getDtype()));
Value inComplexTensor = rewriter.create<Torch::AtenViewAsComplexOp>(
binder.getLoc(), inComplexTensorType, inForComplexVal);
Value ftOp;
if (inverse == 0) {
ftOp = rewriter.create<Torch::AtenFftFftOp>(
binder.getLoc(), inComplexTensorType, inComplexTensor,
/*n = */ dftLength, /*dim = */ axis, /*norm = */ norm);
} else {
ftOp = rewriter.create<Torch::AtenFftIfftOp>(
binder.getLoc(), inComplexTensorType, inComplexTensor,
/*n = */ dftLength, /*dim = */ axis, /*norm = */ norm);
}
rewriter.replaceOpWithNewOp<Torch::AtenViewAsRealOp>(binder.op,
resultType, ftOp);
return success();
});
}

View File

@ -10369,6 +10369,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %14 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fft_ifft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
@ -11984,6 +11987,50 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fft_ifft\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
" %int10 = torch.constant.int 10\n"
" %int7 = torch.constant.int 7\n"
" %int9 = torch.constant.int 9\n"
" %int6 = torch.constant.int 6\n"
" %int8 = torch.constant.int 8\n"
" %int5 = torch.constant.int 5\n"
" %0 = torch.prim.Uninitialized : !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
" torch.prim.If.yield %1#1 : !torch.int\n"
" } else {\n"
" %4 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.int) {\n"
" torch.prim.If.yield %int8 : !torch.int\n"
" } else {\n"
" %6 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
" torch.prim.If.yield %int9 : !torch.int\n"
" } else {\n"
" %8 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" %9 = torch.prim.If %8 -> (!torch.int) {\n"
" torch.prim.If.yield %int10 : !torch.int\n"
" } else {\n"
" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %11 = torch.prim.If %10 -> (!torch.int) {\n"
" torch.prim.If.yield %int9 : !torch.int\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %0 : !torch.int\n"
" }\n"
" torch.prim.If.yield %11 : !torch.int\n"
" }\n"
" torch.prim.If.yield %9 : !torch.int\n"
" }\n"
" torch.prim.If.yield %7 : !torch.int\n"
" }\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -2038,6 +2038,9 @@ def atenstft〡shape(self: List[int], n_fft: int, hop_length: Optional[int] =
return out
def atenfft_ifft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]:
return self
class DummyClassType:
def __init__(self):
pass
@ -3406,6 +3409,23 @@ def atenstft〡dtype(self_rank_dtype: Tuple[int, int], n_fft: int, hop_length
assert False, "Unsupported dtype"
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bfloat16}))
def atenfft_ifft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int:
self_rank, self_dtype = self_rank_dtype
if is_complex_dtype(self_dtype):
return self_dtype
elif self_dtype == torch.float16:
return torch.complex32
elif self_dtype == torch.float32:
return torch.complex64
elif self_dtype == torch.float64:
return torch.complex128
elif is_integer_dtype(self_dtype):
return torch.complex64
else:
assert False, "Unsupported dtype"
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))

View File

@ -910,6 +910,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)"
)
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
emit("aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)")
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
emit(
"aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)"

View File

@ -2480,3 +2480,51 @@ func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch
%0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,9,4],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32>
return %0 : !torch.vtensor<[1,1,5,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_dft_fft
func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
// CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>,
// CHECK-SAME: %[[AXIS:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[SIG_LEN:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.aten.item %[[AXIS]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[NORM:.*]] = torch.constant.str "backward"
// CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[ONE:.*]] = torch.constant.int 1
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
// CHECK: %[[PAD_DIM_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MODE:.*]] = torch.constant.str "constant"
// CHECK: %[[INPUT_PADDED:.*]] = torch.aten.pad %[[INPUT_SIGNAL]], %[[PAD_DIM_LIST]], %[[MODE]], %[[FILL_VAL]] : !torch.vtensor<[10,10,1],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[10,10,2],f32>
// CHECK: %[[INPUT_T_CMPLX:.*]] = torch.aten.view_as_complex %[[INPUT_PADDED]] : !torch.vtensor<[10,10,2],f32> -> !torch.vtensor<[10,10],complex<f32>>
// CHECK: %[[FFT_CMPLX:.*]] = torch.aten.fft_fft %[[INPUT_T_CMPLX]], %[[SIG_LEN]], %[[DIM]], %[[NORM]] : !torch.vtensor<[10,10],complex<f32>>, !torch.none, !torch.int, !torch.str -> !torch.vtensor<[10,10],complex<f32>>
// CHECK: %[[FFT_RES_REAL:.*]] = torch.aten.view_as_real %[[FFT_CMPLX]] : !torch.vtensor<[10,10],complex<f32>> -> !torch.vtensor<[10,10,2],f32>
// CHECK: return %[[FFT_RES_REAL]] : !torch.vtensor<[10,10,2],f32>
%none = torch.constant.none
%0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32>
return %0 : !torch.vtensor<[10,10,2],f32>
}
// CHECK-LABEL: func.func @test_dft_inverse_real
func.func @test_dft_inverse_real(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
// CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>,
// CHECK-SAME: %[[AXIS:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[SIG_LEN:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.aten.item %[[AXIS]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[NORM:.*]] = torch.constant.str "backward"
// CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[ONE:.*]] = torch.constant.int 1
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
// CHECK: %[[PAD_DIM_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MODE:.*]] = torch.constant.str "constant"
// CHECK: %[[INPUT_PADDED:.*]] = torch.aten.pad %[[INPUT_SIGNAL]], %[[PAD_DIM_LIST]], %[[MODE]], %[[FILL_VAL]] : !torch.vtensor<[10,10,1],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[10,10,2],f32>
// CHECK: %[[INPUT_T_CMPLX:.*]] = torch.aten.view_as_complex %[[INPUT_PADDED]] : !torch.vtensor<[10,10,2],f32> -> !torch.vtensor<[10,10],complex<f32>>
// CHECK: %[[IFFT_CMPLX:.*]] = torch.aten.fft_ifft %[[INPUT_T_CMPLX]], %[[SIG_LEN]], %[[DIM]], %[[NORM]] : !torch.vtensor<[10,10],complex<f32>>, !torch.none, !torch.int, !torch.str -> !torch.vtensor<[10,10],complex<f32>>
// CHECK: %[[IFFT_RES_REAL:.*]] = torch.aten.view_as_real %[[IFFT_CMPLX]] : !torch.vtensor<[10,10],complex<f32>> -> !torch.vtensor<[10,10,2],f32>
// CHECK: return %[[IFFT_RES_REAL]] : !torch.vtensor<[10,10,2],f32>
%none = torch.constant.none
%0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) {torch.onnx.inverse = 1 : si64} : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32>
return %0 : !torch.vtensor<[10,10,2],f32>
}