e2e support aten.linalg_norm to aten.linalg_vector_norm (#2953)

Add e2d support for `aten.linalg_norm` by decompose it to
`aten.linalg_vector_norm`.

Lowering to `aten.linalg_matrix_norm` is still unsupported.

To Test: 

`python -m e2e_testing.main -v`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
pull/2983/head
Ze Zhang 2024-03-05 16:31:01 -08:00 committed by GitHub
parent bc0527676b
commit aa7c9a9653
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 185 additions and 0 deletions

View File

@ -7938,6 +7938,33 @@ def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [
}];
}
def Torch_AtenLinalgNormOp : Torch_Op<"aten.linalg_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalScalarType:$ord,
AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$keepdim,
AnyTorchOptionalIntType:$dtype
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLinalgNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenLinalgNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -9336,6 +9336,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linalg_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg4 : !torch.optional<int> to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.frobenius_norm.dim\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n"
@ -12058,6 +12063,60 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<number>, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.__isnot__ %arg4, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" %5 = torch.prim.unchecked_cast %arg4 : !torch.optional<int> -> !torch.int\n"
" %6 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n"
" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %9 = torch.prim.If %8 -> (!torch.int) {\n"
" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %10 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %11 = torch.prim.TupleConstruct %0#0, %5 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" %12 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%11, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" torch.prim.If.yield %12 : !torch.int\n"
" } else {\n"
" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n"
" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n"
" torch.prim.If %11 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" torch.prim.If.yield %9 : !torch.int\n"
" } else {\n"
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %int5 = torch.constant.int 5\n"

View File

@ -6971,6 +6971,37 @@ public:
};
} // namespace
namespace {
// Decompose AtenLinalgNormOp to AtenLinalgVectorNormOp only
class DecomposeAtenLinalgNormOp : public OpRewritePattern<AtenLinalgNormOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLinalgNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
SmallVector<Value> dimList;
if (!getListConstructElements(op.getDim(), dimList)) {
return rewriter.notifyMatchFailure(
op, "dim should comes from a PrimListConstructOp");
}
if (dimList.size() != 1) {
return rewriter.notifyMatchFailure(
op, "Unimplemented: only dim size of 1 is supported");
}
// default ord value is 2 for vector_norm
auto ord = op.getOrd();
if (ord.getType().isa<Torch::NoneType>()) {
ord = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
}
rewriter.replaceOpWithNewOp<Torch::AtenLinalgVectorNormOp>(
op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(),
op.getDtype());
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -7177,6 +7208,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
// More specific conv ops
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTbcOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConv1dOp>(patterns);

View File

@ -520,6 +520,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenTileOp>();
target.addIllegalOp<AtenReshapeAsOp>();
target.addIllegalOp<AtenTriuOp>();
target.addIllegalOp<AtenLinalgNormOp>();
for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));

View File

@ -1112,6 +1112,7 @@ TOSA_PASS_SET = {
"LiftFreshCopyModule_basic",
"LinalgVectorNormKeepDimModule_basic",
"LinalgVectorNormModule_basic",
"LinalgNormKeepDimModule_basic",
"MaskedFillScalarDefaultModule_basic",
"MaskedFillScalarIntValueModule_basic",
"MaskedFillScalarIntValueStaticModule_basic",
@ -1885,6 +1886,8 @@ ONNX_XFAIL_SET = {
"ScatterReduceIntSumModuleIncludeSelf",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"LinalgNormKeepDimModule_basic",
"LinalgNormModule_basic",
# Failure - onnx_lowering: onnx.AveragePool
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",

View File

@ -1722,6 +1722,9 @@ def atennonzero_static〡shape(self: List[int], size: int, fill_value: int =
def atenlinalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
def atenlinalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
def atenfrobenius_normdim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
@ -3938,6 +3941,29 @@ def atenlinalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni
return dtype
return atenstd〡dtype(self_rank_dtype)
@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1,
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) +
_check_tensors_with_the_same_dtype(
num_of_tensors=1,
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.complex64, torch.complex128}, dtype=torch.float64) +
_check_tensors_with_the_same_dtype(
num_of_tensors=1,
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) +
[ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)])
def atenlinalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[Union[int, float, complex]] = None, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)
if dtype is not None:
assert not is_integer_dtype(dtype)
if is_complex_dtype(self_dtype):
assert is_complex_dtype(dtype)
return atenstd〡dtype((self_rank, dtype))
assert not is_complex_dtype(dtype)
return dtype
return atenstd〡dtype(self_rank_dtype)
@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1,

View File

@ -542,6 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)")
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")

View File

@ -1228,6 +1228,42 @@ def LinalgVectorNormKeepDimModule_basic(module, tu: TestUtils):
# ==============================================================================
class LinalgNormModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=False)
@register_test_case(module_factory=lambda: LinalgNormModule())
def LinalgNormModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class LinalgNormKeepDimModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=True)
@register_test_case(module_factory=lambda: LinalgNormKeepDimModule())
def LinalgNormKeepDimModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class MseLossNoReductionModule(torch.nn.Module):
def __init__(self):
super().__init__()