mirror of https://github.com/llvm/torch-mlir
implement lowering of torch.aten._linalg_slogdet (#3524)
parent
c7d972ed58
commit
2cdf3deae3
|
@ -8737,6 +8737,30 @@ def Torch_Aten_LinalgDetOp : Torch_Op<"aten._linalg_det", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLinalgSlogdetOp : Torch_Op<"aten.linalg_slogdet", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::linalg_slogdet : (Tensor) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$A
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$sign,
|
||||
AnyTorchOptionalTensorType:$logabsdet
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLinalgSlogdetOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 2);
|
||||
}
|
||||
void AtenLinalgSlogdetOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 2);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -6555,6 +6555,54 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
|
||||
" return %3 : !torch.tuple<int, int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.linalg_slogdet\"(%arg0: !torch.list<int>) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %int-2 = torch.constant.int -2\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %9 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %10 = torch.aten.eq.int %9, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %10 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %5 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %7 = torch.aten.eq.int %6, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %8 = torch.prim.If %7 -> (!torch.tuple<list<int>, list<int>>) {\n"
|
||||
" %9 = torch.aten.slice.t %arg0, %none, %int1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" %10 = torch.aten.slice.t %arg0, %none, %int1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" %11 = torch.prim.TupleConstruct %9, %10 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" torch.prim.If.yield %11 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" } else {\n"
|
||||
" %9 = torch.derefine %arg0 : !torch.list<int> to !torch.any\n"
|
||||
" %10 = func.call @__torch__.torch.jit._shape_functions.zero_dim_tensor(%9) : (!torch.any) -> !torch.list<int>\n"
|
||||
" %11 = torch.prim.TupleConstruct %10, %10 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" torch.prim.If.yield %11 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" return %8 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -11770,6 +11818,61 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" return %arg3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_slogdet\"(%arg0: !torch.tuple<int, int>) -> !torch.tuple<int, int> {\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
" %int10 = torch.constant.int 10\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int9 = torch.constant.int 9\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int8 = torch.constant.int 8\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" %11 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = torch.aten.eq.int %0#1, %int8 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" %8 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %9 = torch.prim.If %8 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int7 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %7 : !torch.int\n"
|
||||
" }\n"
|
||||
" %10 = torch.prim.TupleConstruct %0#1, %9 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %10 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.square\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
|
|
|
@ -2904,6 +2904,35 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// decompose aten.linalg_slogdet into: aten.sgn, aten.log, aten.abs
|
||||
// aten.linalg_det
|
||||
namespace {
|
||||
|
||||
class DecomposeAtenLinalgSlogdetOp
|
||||
: public OpRewritePattern<AtenLinalgSlogdetOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenLinalgSlogdetOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Value, 2> results = op.getResults();
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getA();
|
||||
Value determinant = rewriter.create<Torch::AtenLinalgDetOp>(
|
||||
loc, results[0].getType(), input);
|
||||
Value sign =
|
||||
rewriter.create<AtenSgnOp>(loc, determinant.getType(), determinant);
|
||||
Value abs_det =
|
||||
rewriter.create<AtenAbsOp>(loc, determinant.getType(), determinant);
|
||||
Value ln_abs_det =
|
||||
rewriter.create<AtenLogOp>(loc, abs_det.getType(), abs_det);
|
||||
rewriter.replaceAllUsesWith(results[0], sign);
|
||||
rewriter.replaceAllUsesWith(results[1], ln_abs_det);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
class DecomposeAten_LinalgDetOp : public OpRewritePattern<Aten_LinalgDetOp> {
|
||||
|
@ -9274,6 +9303,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenTrilIndicesOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_LinalgDetOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgSlogdetOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns);
|
||||
// More specific conv ops
|
||||
|
|
|
@ -406,6 +406,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenRenormOp>();
|
||||
target.addIllegalOp<AtenLinalgCrossOp>();
|
||||
target.addIllegalOp<Aten_LinalgDetOp>();
|
||||
target.addIllegalOp<AtenLinalgSlogdetOp>();
|
||||
target.addIllegalOp<AtenPixelShuffleOp>();
|
||||
target.addIllegalOp<AtenTOp>();
|
||||
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
||||
|
|
|
@ -451,6 +451,9 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"SignAndLogarithmOfDeterminantModule_F32",
|
||||
"SignAndLogarithmOfDeterminantBatchedModule_F32",
|
||||
"SignAndLogarithmOfDeterminantDynamicModule_F32",
|
||||
"SortIntListReverse_basic",
|
||||
"SortIntList_basic",
|
||||
"SplitDimDynamicModule_basic",
|
||||
|
@ -2563,6 +2566,9 @@ ONNX_XFAIL_SET = {
|
|||
"ScatterReduceIntSumModule",
|
||||
"SelectScattertModule_basic",
|
||||
"SelectScattertStaticModule_basic",
|
||||
"SignAndLogarithmOfDeterminantModule_F32",
|
||||
"SignAndLogarithmOfDeterminantBatchedModule_F32",
|
||||
"SignAndLogarithmOfDeterminantDynamicModule_F32",
|
||||
"SliceEndSleStartModule_basic",
|
||||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
"SliceScatterModule_basic",
|
||||
|
@ -3429,6 +3435,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ScatterValueIntModule_basic",
|
||||
"SelectScattertModule_basic",
|
||||
"SelectScattertStaticModule_basic",
|
||||
"SignAndLogarithmOfDeterminantModule_F32",
|
||||
"SignAndLogarithmOfDeterminantBatchedModule_F32",
|
||||
"SignAndLogarithmOfDeterminantDynamicModule_F32",
|
||||
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
||||
"SliceCopyNegative_Module_basic",
|
||||
"SliceCopyNonZeroDim_Module_basic",
|
||||
|
@ -4341,6 +4350,9 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SelectScattertModule_basic",
|
||||
"SelectScattertStaticModule_basic",
|
||||
"SignAndLogarithmOfDeterminantModule_F32",
|
||||
"SignAndLogarithmOfDeterminantBatchedModule_F32",
|
||||
"SignAndLogarithmOfDeterminantDynamicModule_F32",
|
||||
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
||||
"SliceCopyNegative_Module_basic",
|
||||
"SliceCopyNonZeroDim_Module_basic",
|
||||
|
|
|
@ -242,6 +242,14 @@ def aten〇_linalg_det〡shape(A: List[int]) -> Tuple[List[int], List[int], List
|
|||
def aten〇_linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int, int]:
|
||||
return (A_rank_dtype[1], A_rank_dtype[1], A_rank_dtype[1])
|
||||
|
||||
def aten〇linalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]:
|
||||
assert len(A) == 2 or len(A) == 3
|
||||
assert A[-1] == A[-2]
|
||||
if len(A) == 3:
|
||||
return A[:1], A[:1]
|
||||
shape = upstream_shape_functions.zero_dim_tensor(A)
|
||||
return shape, shape
|
||||
|
||||
def aten〇detach〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -3224,6 +3232,18 @@ def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0
|
|||
def aten〇_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int:
|
||||
return input_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4,4),], error_types={*all_integer_dtypes(), torch.float16, torch.bfloat16}))
|
||||
def aten〇linalg_slogdet〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int]:
|
||||
self_rank, self_dtype = A_rank_dtype
|
||||
assert not is_integer_dtype(self_dtype)
|
||||
assert self_dtype != torch.float16 and self_dtype != torch.bfloat16
|
||||
det_type = self_dtype
|
||||
if self_dtype == torch.complex32 or self_dtype == torch.complex64:
|
||||
det_type = torch.float32
|
||||
if self_dtype == torch.complex128:
|
||||
det_type = torch.float64
|
||||
return self_dtype, det_type
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇square〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -709,6 +709,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)")
|
||||
emit("aten::linalg_det : (Tensor) -> (Tensor)")
|
||||
emit("aten::_linalg_det : (Tensor) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::linalg_slogdet : (Tensor) -> (Tensor, Tensor)")
|
||||
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
|
||||
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
||||
|
|
|
@ -49,3 +49,45 @@ class DeterminantDynamicModule(torch.nn.Module):
|
|||
def DeterminantDynamicModule_F32(module, tu: TestUtils):
|
||||
A = tu.rand(3, 4, 4).to(dtype=torch.float32)
|
||||
module.forward(A)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SignAndLogarithmOfDeterminantModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([None, [(4, 4), torch.float32, True]])
|
||||
def forward(self, A):
|
||||
return torch.linalg.slogdet(A)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SignAndLogarithmOfDeterminantModule())
|
||||
def SignAndLogarithmOfDeterminantModule_F32(module, tu: TestUtils):
|
||||
A = tu.rand(4, 4).to(dtype=torch.float32)
|
||||
module.forward(A)
|
||||
|
||||
|
||||
class SignAndLogarithmOfDeterminantBatchedModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([None, [(3, 4, 4), torch.float32, True]])
|
||||
def forward(self, A):
|
||||
return torch.linalg.slogdet(A)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SignAndLogarithmOfDeterminantBatchedModule())
|
||||
def SignAndLogarithmOfDeterminantBatchedModule_F32(module, tu: TestUtils):
|
||||
A = tu.rand(3, 4, 4).to(dtype=torch.float32)
|
||||
module.forward(A)
|
||||
|
||||
|
||||
class SignAndLogarithmOfDeterminantDynamicModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([None, [(-1, -1, -1), torch.float32, True]])
|
||||
def forward(self, A):
|
||||
return torch.linalg.slogdet(A)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SignAndLogarithmOfDeterminantBatchedModule())
|
||||
def SignAndLogarithmOfDeterminantDynamicModule_F32(module, tu: TestUtils):
|
||||
A = tu.rand(3, 4, 4).to(dtype=torch.float32)
|
||||
module.forward(A)
|
||||
|
|
Loading…
Reference in New Issue