diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 964b045a9..945b28984 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8737,6 +8737,30 @@ def Torch_Aten_LinalgDetOp : Torch_Op<"aten._linalg_det", [ }]; } +def Torch_AtenLinalgSlogdetOp : Torch_Op<"aten.linalg_slogdet", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_slogdet : (Tensor) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$A + ); + let results = (outs + AnyTorchOptionalTensorType:$sign, + AnyTorchOptionalTensorType:$logabsdet + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgSlogdetOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 2); + } + void AtenLinalgSlogdetOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 2); + } + }]; +} + def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 90afe5ee3..96e6b4bd3 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6555,6 +6555,54 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %3 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linalg_slogdet\"(%arg0: !torch.list) -> !torch.tuple, list> {\n" +" %int-2 = torch.constant.int -2\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %9 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %10 = torch.aten.eq.int %9, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.tuple, list>) {\n" +" %9 = torch.aten.slice.t %arg0, %none, %int1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %10 = torch.aten.slice.t %arg0, %none, %int1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %11 = torch.prim.TupleConstruct %9, %10 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" torch.prim.If.yield %11 : !torch.tuple, list>\n" +" } else {\n" +" %9 = torch.derefine %arg0 : !torch.list to !torch.any\n" +" %10 = func.call @__torch__.torch.jit._shape_functions.zero_dim_tensor(%9) : (!torch.any) -> !torch.list\n" +" %11 = torch.prim.TupleConstruct %10, %10 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" torch.prim.If.yield %11 : !torch.tuple, list>\n" +" }\n" +" return %8 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11770,6 +11818,61 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_slogdet\"(%arg0: !torch.tuple) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %true = torch.constant.bool true\n" +" %int8 = torch.constant.int 8\n" +" %false = torch.constant.bool false\n" +" %int15 = torch.constant.int 15\n" +" %int5 = torch.constant.int 5\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.eq.int %0#1, %int8 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" %8 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" %10 = torch.prim.TupleConstruct %0#1, %9 : !torch.int, !torch.int -> !torch.tuple\n" +" return %10 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.square\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index df044f52f..46b218535 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2904,6 +2904,35 @@ public: }; } // namespace +// decompose aten.linalg_slogdet into: aten.sgn, aten.log, aten.abs +// aten.linalg_det +namespace { + +class DecomposeAtenLinalgSlogdetOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinalgSlogdetOp op, + PatternRewriter &rewriter) const override { + SmallVector results = op.getResults(); + Location loc = op.getLoc(); + Value input = op.getA(); + Value determinant = rewriter.create( + loc, results[0].getType(), input); + Value sign = + rewriter.create(loc, determinant.getType(), determinant); + Value abs_det = + rewriter.create(loc, determinant.getType(), determinant); + Value ln_abs_det = + rewriter.create(loc, abs_det.getType(), abs_det); + rewriter.replaceAllUsesWith(results[0], sign); + rewriter.replaceAllUsesWith(results[1], ln_abs_det); + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + namespace { class DecomposeAten_LinalgDetOp : public OpRewritePattern { @@ -9274,6 +9303,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns); // More specific conv ops diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 3adb96d1f..31ad13158 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -406,6 +406,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1b173d3ec..3e9bd5913 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -451,6 +451,9 @@ FX_IMPORTER_XFAIL_SET = { "RsubInt0d_NumToTensor_Module_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", + "SignAndLogarithmOfDeterminantModule_F32", + "SignAndLogarithmOfDeterminantBatchedModule_F32", + "SignAndLogarithmOfDeterminantDynamicModule_F32", "SortIntListReverse_basic", "SortIntList_basic", "SplitDimDynamicModule_basic", @@ -2563,6 +2566,9 @@ ONNX_XFAIL_SET = { "ScatterReduceIntSumModule", "SelectScattertModule_basic", "SelectScattertStaticModule_basic", + "SignAndLogarithmOfDeterminantModule_F32", + "SignAndLogarithmOfDeterminantBatchedModule_F32", + "SignAndLogarithmOfDeterminantDynamicModule_F32", "SliceEndSleStartModule_basic", "SliceOutOfUpperBoundIndexModule_basic", "SliceScatterModule_basic", @@ -3429,6 +3435,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ScatterValueIntModule_basic", "SelectScattertModule_basic", "SelectScattertStaticModule_basic", + "SignAndLogarithmOfDeterminantModule_F32", + "SignAndLogarithmOfDeterminantBatchedModule_F32", + "SignAndLogarithmOfDeterminantDynamicModule_F32", "SliceCopyEndGreaterThanDimSize_Module_basic", "SliceCopyNegative_Module_basic", "SliceCopyNonZeroDim_Module_basic", @@ -4341,6 +4350,9 @@ ONNX_TOSA_XFAIL_SET = { "SelectIntNegativeDimAndIndexStaticModule_basic", "SelectScattertModule_basic", "SelectScattertStaticModule_basic", + "SignAndLogarithmOfDeterminantModule_F32", + "SignAndLogarithmOfDeterminantBatchedModule_F32", + "SignAndLogarithmOfDeterminantDynamicModule_F32", "SliceCopyEndGreaterThanDimSize_Module_basic", "SliceCopyNegative_Module_basic", "SliceCopyNonZeroDim_Module_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 316851313..fa4ee0a37 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -242,6 +242,14 @@ def aten〇_linalg_det〡shape(A: List[int]) -> Tuple[List[int], List[int], List def aten〇_linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int, int]: return (A_rank_dtype[1], A_rank_dtype[1], A_rank_dtype[1]) +def aten〇linalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]: + assert len(A) == 2 or len(A) == 3 + assert A[-1] == A[-2] + if len(A) == 3: + return A[:1], A[:1] + shape = upstream_shape_functions.zero_dim_tensor(A) + return shape, shape + def aten〇detach〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -3224,6 +3232,18 @@ def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0 def aten〇_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: return input_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4,4),], error_types={*all_integer_dtypes(), torch.float16, torch.bfloat16})) +def aten〇linalg_slogdet〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int]: + self_rank, self_dtype = A_rank_dtype + assert not is_integer_dtype(self_dtype) + assert self_dtype != torch.float16 and self_dtype != torch.bfloat16 + det_type = self_dtype + if self_dtype == torch.complex32 or self_dtype == torch.complex64: + det_type = torch.float32 + if self_dtype == torch.complex128: + det_type = torch.float64 + return self_dtype, det_type + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇square〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 62ef59d50..4c1767754 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -709,6 +709,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)") emit("aten::linalg_det : (Tensor) -> (Tensor)") emit("aten::_linalg_det : (Tensor) -> (Tensor, Tensor, Tensor)") + emit("aten::linalg_slogdet : (Tensor) -> (Tensor, Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py index 0bb620591..9b7610033 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py @@ -49,3 +49,45 @@ class DeterminantDynamicModule(torch.nn.Module): def DeterminantDynamicModule_F32(module, tu: TestUtils): A = tu.rand(3, 4, 4).to(dtype=torch.float32) module.forward(A) + + +# ============================================================================== + + +class SignAndLogarithmOfDeterminantModule(torch.nn.Module): + @export + @annotate_args([None, [(4, 4), torch.float32, True]]) + def forward(self, A): + return torch.linalg.slogdet(A) + + +@register_test_case(module_factory=lambda: SignAndLogarithmOfDeterminantModule()) +def SignAndLogarithmOfDeterminantModule_F32(module, tu: TestUtils): + A = tu.rand(4, 4).to(dtype=torch.float32) + module.forward(A) + + +class SignAndLogarithmOfDeterminantBatchedModule(torch.nn.Module): + @export + @annotate_args([None, [(3, 4, 4), torch.float32, True]]) + def forward(self, A): + return torch.linalg.slogdet(A) + + +@register_test_case(module_factory=lambda: SignAndLogarithmOfDeterminantBatchedModule()) +def SignAndLogarithmOfDeterminantBatchedModule_F32(module, tu: TestUtils): + A = tu.rand(3, 4, 4).to(dtype=torch.float32) + module.forward(A) + + +class SignAndLogarithmOfDeterminantDynamicModule(torch.nn.Module): + @export + @annotate_args([None, [(-1, -1, -1), torch.float32, True]]) + def forward(self, A): + return torch.linalg.slogdet(A) + + +@register_test_case(module_factory=lambda: SignAndLogarithmOfDeterminantBatchedModule()) +def SignAndLogarithmOfDeterminantDynamicModule_F32(module, tu: TestUtils): + A = tu.rand(3, 4, 4).to(dtype=torch.float32) + module.forward(A)