mirror of https://github.com/llvm/torch-mlir
e2e support aten.linalg_norm to aten.linalg_vector_norm (#2953)
Add e2d support for `aten.linalg_norm` by decompose it to `aten.linalg_vector_norm`. Lowering to `aten.linalg_matrix_norm` is still unsupported. To Test: `python -m e2e_testing.main -v` --------- Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>pull/2983/head
parent
bc0527676b
commit
aa7c9a9653
|
@ -7938,6 +7938,33 @@ def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLinalgNormOp : Torch_Op<"aten.linalg_norm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalScalarType:$ord,
|
||||
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||
Torch_BoolType:$keepdim,
|
||||
AnyTorchOptionalIntType:$dtype
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLinalgNormOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenLinalgNormOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -9336,6 +9336,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.linalg_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" %0 = torch.derefine %arg4 : !torch.optional<int> to !torch.any\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.frobenius_norm.dim\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n"
|
||||
|
@ -12058,6 +12063,60 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<number>, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.aten.__isnot__ %arg4, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" %5 = torch.prim.unchecked_cast %arg4 : !torch.optional<int> -> !torch.int\n"
|
||||
" %6 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n"
|
||||
" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %7 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %9 = torch.prim.If %8 -> (!torch.int) {\n"
|
||||
" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If %10 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %11 = torch.prim.TupleConstruct %0#0, %5 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" %12 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%11, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
|
||||
" torch.prim.If.yield %12 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n"
|
||||
" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %11 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %5 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %9 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
|
||||
" torch.prim.If.yield %5 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
|
|
|
@ -6971,6 +6971,37 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose AtenLinalgNormOp to AtenLinalgVectorNormOp only
|
||||
class DecomposeAtenLinalgNormOp : public OpRewritePattern<AtenLinalgNormOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenLinalgNormOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value> dimList;
|
||||
if (!getListConstructElements(op.getDim(), dimList)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "dim should comes from a PrimListConstructOp");
|
||||
}
|
||||
if (dimList.size() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: only dim size of 1 is supported");
|
||||
}
|
||||
|
||||
// default ord value is 2 for vector_norm
|
||||
auto ord = op.getOrd();
|
||||
if (ord.getType().isa<Torch::NoneType>()) {
|
||||
ord = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenLinalgVectorNormOp>(
|
||||
op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(),
|
||||
op.getDtype());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -7177,6 +7208,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
|
||||
// More specific conv ops
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTbcOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConv1dOp>(patterns);
|
||||
|
|
|
@ -520,6 +520,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenTileOp>();
|
||||
target.addIllegalOp<AtenReshapeAsOp>();
|
||||
target.addIllegalOp<AtenTriuOp>();
|
||||
target.addIllegalOp<AtenLinalgNormOp>();
|
||||
for (auto &opName : backendLegalOpsSet) {
|
||||
target.addLegalOp(
|
||||
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
||||
|
|
|
@ -1112,6 +1112,7 @@ TOSA_PASS_SET = {
|
|||
"LiftFreshCopyModule_basic",
|
||||
"LinalgVectorNormKeepDimModule_basic",
|
||||
"LinalgVectorNormModule_basic",
|
||||
"LinalgNormKeepDimModule_basic",
|
||||
"MaskedFillScalarDefaultModule_basic",
|
||||
"MaskedFillScalarIntValueModule_basic",
|
||||
"MaskedFillScalarIntValueStaticModule_basic",
|
||||
|
@ -1885,6 +1886,8 @@ ONNX_XFAIL_SET = {
|
|||
"ScatterReduceIntSumModuleIncludeSelf",
|
||||
"TileBigDimsSizeModule_basic",
|
||||
"TileSmallDimsSizeModule_basic",
|
||||
"LinalgNormKeepDimModule_basic",
|
||||
"LinalgNormModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.AveragePool
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
|
|
|
@ -1722,6 +1722,9 @@ def aten〇nonzero_static〡shape(self: List[int], size: int, fill_value: int =
|
|||
def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||
|
||||
def aten〇linalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||
|
||||
def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
|
||||
|
||||
|
@ -3938,6 +3941,29 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni
|
|||
return dtype
|
||||
return aten〇std〡dtype(self_rank_dtype)
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(
|
||||
num_of_tensors=1,
|
||||
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) +
|
||||
_check_tensors_with_the_same_dtype(
|
||||
num_of_tensors=1,
|
||||
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.complex64, torch.complex128}, dtype=torch.float64) +
|
||||
_check_tensors_with_the_same_dtype(
|
||||
num_of_tensors=1,
|
||||
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) +
|
||||
[ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)])
|
||||
def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[Union[int, float, complex]] = None, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert not is_integer_dtype(self_dtype)
|
||||
if dtype is not None:
|
||||
assert not is_integer_dtype(dtype)
|
||||
if is_complex_dtype(self_dtype):
|
||||
assert is_complex_dtype(dtype)
|
||||
return aten〇std〡dtype((self_rank, dtype))
|
||||
assert not is_complex_dtype(dtype)
|
||||
return dtype
|
||||
return aten〇std〡dtype(self_rank_dtype)
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(
|
||||
num_of_tensors=1,
|
||||
|
|
|
@ -542,6 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
||||
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
||||
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
||||
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
|
||||
emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)")
|
||||
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
|
||||
|
|
|
@ -1228,6 +1228,42 @@ def LinalgVectorNormKeepDimModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class LinalgNormModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: LinalgNormModule())
|
||||
def LinalgNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class LinalgNormKeepDimModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: LinalgNormKeepDimModule())
|
||||
def LinalgNormKeepDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MseLossNoReductionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue