mirror of https://github.com/llvm/torch-mlir
onnx.DFT basic support (#3463)
- adds support for DFT v20 on the FFT and IFFT path - adds required skeleton code for IFFT ops to be recognised in TMlirpull/3519/head
parent
7e6d76e997
commit
5a627c46b7
|
@ -12418,6 +12418,32 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFftIfftOp : Torch_Op<"aten.fft_ifft", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalIntType:$n,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchOptionalStringType:$norm
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenFftIfftOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenFftIfftOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFmodTensorOp : Torch_Op<"aten.fmod.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -2728,4 +2728,95 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
|
||||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
"DFT", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Value inTensor, dftLength, axis;
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t inverse, onesided;
|
||||
if (binder.tensorOperandAtIndex(inTensor, 0) ||
|
||||
binder.s64IntegerAttr(inverse, "inverse", 0) ||
|
||||
binder.s64IntegerAttr(onesided, "onesided", 0) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Input Tensor / attrs / resultType bind failed");
|
||||
if (!binder.tensorOperandAtIndex(dftLength, 1)) {
|
||||
// Convert to int and pass as n
|
||||
dftLength = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), dftLength);
|
||||
} else {
|
||||
// Default for torch is None
|
||||
dftLength = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
}
|
||||
// Default is same for onnx and torch
|
||||
if (!binder.tensorOperandAtIndex(axis, 2)) {
|
||||
// convert to int and pass to dims
|
||||
axis = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), axis);
|
||||
} else {
|
||||
// Default in torch is -1 and onnx is -2 (since -1 is for real / img)
|
||||
axis = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(-2));
|
||||
}
|
||||
|
||||
if (onesided == 1)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unsupported option : onesided");
|
||||
// norm default string attr
|
||||
Value norm = rewriter.create<Torch::ConstantStrOp>(
|
||||
binder.getLoc(), rewriter.getStringAttr(Twine("backward")));
|
||||
// Convert from [....., 2] complex number repr for fft consumption.
|
||||
Torch::ValueTensorType inType =
|
||||
binder.toValidTensorType(inTensor.getType());
|
||||
int64_t lastIndex = inType.getSizes().back();
|
||||
if (lastIndex != 1 && lastIndex != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"Expected input tensor to have dims [..., 1] or [..., 2]");
|
||||
|
||||
// concat with zeros to make it [..., 2]
|
||||
Value inForComplexVal = inTensor;
|
||||
ArrayRef<int64_t> inForComplexSizes = inType.getSizes().drop_back();
|
||||
if (lastIndex == 1) {
|
||||
Value constZeroVal = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getF64FloatAttr(0));
|
||||
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
Value padSizeList =
|
||||
rewriter
|
||||
.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
||||
SmallVector<Value>({constZero, constOne}))
|
||||
.getResult();
|
||||
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
||||
binder.getLoc(), rewriter.getStringAttr("constant"));
|
||||
SmallVector<int64_t> resSize(inForComplexSizes);
|
||||
resSize.push_back(2);
|
||||
inForComplexVal = rewriter.create<Torch::AtenPadOp>(
|
||||
binder.getLoc(),
|
||||
inType.getWithSizesAndDtype(resSize, inType.getOptionalDtype()),
|
||||
inTensor, padSizeList, modeVal, constZeroVal);
|
||||
}
|
||||
Type inComplexTensorType = Torch::ValueTensorType::get(
|
||||
binder.op->getContext(), inForComplexSizes,
|
||||
mlir::ComplexType::get(inType.getDtype()));
|
||||
Value inComplexTensor = rewriter.create<Torch::AtenViewAsComplexOp>(
|
||||
binder.getLoc(), inComplexTensorType, inForComplexVal);
|
||||
Value ftOp;
|
||||
if (inverse == 0) {
|
||||
ftOp = rewriter.create<Torch::AtenFftFftOp>(
|
||||
binder.getLoc(), inComplexTensorType, inComplexTensor,
|
||||
/*n = */ dftLength, /*dim = */ axis, /*norm = */ norm);
|
||||
} else {
|
||||
ftOp = rewriter.create<Torch::AtenFftIfftOp>(
|
||||
binder.getLoc(), inComplexTensorType, inComplexTensor,
|
||||
/*n = */ dftLength, /*dim = */ axis, /*norm = */ norm);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenViewAsRealOp>(binder.op,
|
||||
resultType, ftOp);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -10369,6 +10369,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %14 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.fft_ifft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
|
||||
|
@ -11984,6 +11987,50 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fft_ifft\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
|
||||
" %int10 = torch.constant.int 10\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
" %int9 = torch.constant.int 9\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int8 = torch.constant.int 8\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %0 = torch.prim.Uninitialized : !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %1#1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %4 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %5 = torch.prim.If %4 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int8 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %6 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int9 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %8 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %9 = torch.prim.If %8 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int10 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %11 = torch.prim.If %10 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int9 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %11 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %9 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %7 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %5 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -2038,6 +2038,9 @@ def aten〇stft〡shape(self: List[int], n_fft: int, hop_length: Optional[int] =
|
|||
|
||||
return out
|
||||
|
||||
def aten〇fft_ifft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]:
|
||||
return self
|
||||
|
||||
class DummyClassType:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -3406,6 +3409,23 @@ def aten〇stft〡dtype(self_rank_dtype: Tuple[int, int], n_fft: int, hop_length
|
|||
|
||||
assert False, "Unsupported dtype"
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bfloat16}))
|
||||
def aten〇fft_ifft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if is_complex_dtype(self_dtype):
|
||||
return self_dtype
|
||||
elif self_dtype == torch.float16:
|
||||
return torch.complex32
|
||||
elif self_dtype == torch.float32:
|
||||
return torch.complex64
|
||||
elif self_dtype == torch.float64:
|
||||
return torch.complex128
|
||||
elif is_integer_dtype(self_dtype):
|
||||
return torch.complex64
|
||||
else:
|
||||
assert False, "Unsupported dtype"
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
|
||||
|
|
|
@ -910,6 +910,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)"
|
||||
)
|
||||
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
|
||||
emit("aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)")
|
||||
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
emit(
|
||||
"aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)"
|
||||
|
|
|
@ -2480,3 +2480,51 @@ func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch
|
|||
%0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,9,4],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32>
|
||||
return %0 : !torch.vtensor<[1,1,5,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_dft_fft
|
||||
func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>,
|
||||
// CHECK-SAME: %[[AXIS:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[SIG_LEN:.*]] = torch.constant.none
|
||||
// CHECK: %[[DIM:.*]] = torch.aten.item %[[AXIS]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[NORM:.*]] = torch.constant.str "backward"
|
||||
// CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[ONE:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[PAD_DIM_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[MODE:.*]] = torch.constant.str "constant"
|
||||
// CHECK: %[[INPUT_PADDED:.*]] = torch.aten.pad %[[INPUT_SIGNAL]], %[[PAD_DIM_LIST]], %[[MODE]], %[[FILL_VAL]] : !torch.vtensor<[10,10,1],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[10,10,2],f32>
|
||||
// CHECK: %[[INPUT_T_CMPLX:.*]] = torch.aten.view_as_complex %[[INPUT_PADDED]] : !torch.vtensor<[10,10,2],f32> -> !torch.vtensor<[10,10],complex<f32>>
|
||||
// CHECK: %[[FFT_CMPLX:.*]] = torch.aten.fft_fft %[[INPUT_T_CMPLX]], %[[SIG_LEN]], %[[DIM]], %[[NORM]] : !torch.vtensor<[10,10],complex<f32>>, !torch.none, !torch.int, !torch.str -> !torch.vtensor<[10,10],complex<f32>>
|
||||
// CHECK: %[[FFT_RES_REAL:.*]] = torch.aten.view_as_real %[[FFT_CMPLX]] : !torch.vtensor<[10,10],complex<f32>> -> !torch.vtensor<[10,10,2],f32>
|
||||
// CHECK: return %[[FFT_RES_REAL]] : !torch.vtensor<[10,10,2],f32>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32>
|
||||
return %0 : !torch.vtensor<[10,10,2],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_dft_inverse_real
|
||||
func.func @test_dft_inverse_real(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>,
|
||||
// CHECK-SAME: %[[AXIS:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[SIG_LEN:.*]] = torch.constant.none
|
||||
// CHECK: %[[DIM:.*]] = torch.aten.item %[[AXIS]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[NORM:.*]] = torch.constant.str "backward"
|
||||
// CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[ONE:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[PAD_DIM_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[MODE:.*]] = torch.constant.str "constant"
|
||||
// CHECK: %[[INPUT_PADDED:.*]] = torch.aten.pad %[[INPUT_SIGNAL]], %[[PAD_DIM_LIST]], %[[MODE]], %[[FILL_VAL]] : !torch.vtensor<[10,10,1],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[10,10,2],f32>
|
||||
// CHECK: %[[INPUT_T_CMPLX:.*]] = torch.aten.view_as_complex %[[INPUT_PADDED]] : !torch.vtensor<[10,10,2],f32> -> !torch.vtensor<[10,10],complex<f32>>
|
||||
// CHECK: %[[IFFT_CMPLX:.*]] = torch.aten.fft_ifft %[[INPUT_T_CMPLX]], %[[SIG_LEN]], %[[DIM]], %[[NORM]] : !torch.vtensor<[10,10],complex<f32>>, !torch.none, !torch.int, !torch.str -> !torch.vtensor<[10,10],complex<f32>>
|
||||
// CHECK: %[[IFFT_RES_REAL:.*]] = torch.aten.view_as_real %[[IFFT_CMPLX]] : !torch.vtensor<[10,10],complex<f32>> -> !torch.vtensor<[10,10,2],f32>
|
||||
// CHECK: return %[[IFFT_RES_REAL]] : !torch.vtensor<[10,10,2],f32>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) {torch.onnx.inverse = 1 : si64} : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32>
|
||||
return %0 : !torch.vtensor<[10,10,2],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue