From c5a1da1910f8e1a5dac748eb2806833bd4f1b0c2 Mon Sep 17 00:00:00 2001 From: ptrifunovic98 <156185835+ptrifunovic98@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:46:56 +0100 Subject: [PATCH] Implement lowering of torch.aten.norm.Scalar (#2899) Closes [nod-ai/SHARK-Turbine#365](https://github.com/nod-ai/SHARK-Turbine/issues/365) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++ lib/Conversion/TorchToLinalg/Reduction.cpp | 53 ++++++++++++++++--- lib/Dialect/Torch/IR/TorchOps.cpp | 36 +++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 32 +++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 18 +++++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reduction.py | 19 +++++++ 8 files changed, 177 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index cc8be7c69..dc1203de9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6325,6 +6325,31 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ }]; } +def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$p + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNormScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenNormScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index a21615ad8..e05076499 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -275,7 +275,8 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, elementType.getIntOrFloatBitWidth()))); } - if (isa(op) || isa(op)) + if (isa(op) || isa(op) || + isa(op)) return b.create(loc, b.getZeroAttr(elementType)); if (isa(op)) { @@ -341,6 +342,26 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, if (intType.isSigned()) return b.create(loc, self, result); } + } else if (isa(op)) { + // This creates payload for only the first of the two linalg.generic ops. + // TODO: Short-circuit operations if `p` is zero or one. + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + + // TODO: Fix this part to support complex elements. + if (elem.getType().isa()) { + op->emitError("lowering of complex input type for torch.aten.norm.Scalar " + "is currently unimplemented"); + return nullptr; + } + + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + + auto abs = b.create(loc, self); + AtenNormScalarOp::Adaptor adaptor(operands); + Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType); + auto pow = b.create(loc, abs, p); + return b.create(loc, pow, result); } else if (isa(op)) { // This creates payload for only the first of the two linalg.generic ops. // TODO: Short-circuit operations if `ord` is zero or one. @@ -433,7 +454,7 @@ private: ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; - if (isa(op)) { + if (isa(op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); @@ -484,10 +505,12 @@ private: return err ? Value{} : powOp; } - FailureOr createSecondReductionForVectorNormOp( - Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp, - Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo, - ConversionPatternRewriter &rewriter) const { + template + FailureOr + createSecondReductionForNormOp(Location loc, Type elemType, TOp op, + Value ordOp, Value firstReduction, + const torch_to_linalg::ReductionOpInfo &opInfo, + ConversionPatternRewriter &rewriter) const { // Cast `ord` to float so that we can readily pass it math.powf. Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType); @@ -544,13 +567,15 @@ private: LogicalResult validateReductionElementType(Operation *op, Type elemType, ConversionPatternRewriter &rewriter) const { - if ((isa(op) || isa(op)) && + if ((isa(op) || isa(op) || + isa(op)) && !elemType.isa()) return rewriter.notifyMatchFailure( op, "only float types are valid for vector norm ops"); if (isa(op) && elemType.isa() && elemType.getIntOrFloatBitWidth() == 8) return rewriter.notifyMatchFailure(op, "uint8 is not supported"); + // No checks for all other reduction operations return success(); } @@ -587,11 +612,22 @@ public: return rewriter.notifyMatchFailure( op, "failed to create linalg.generic operation for reduction"); + // If this is aten.norm.Scalar op, then we need to generate another + // linalg.generic op that references the first linalg.generic op. + if (isa(op)) { + AtenNormScalarOp::Adaptor adaptor(operands); + FailureOr secondReduceOp = createSecondReductionForNormOp( + loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter); + if (failed(secondReduceOp)) + return secondReduceOp; + reduceOp = *secondReduceOp; + } + // If this is aten.linalg_vector_norm op, then we need to generate another // linalg.generic op that references the first linalg.generic op. if (auto normOp = dyn_cast(op)) { AtenLinalgVectorNormOp::Adaptor adaptor(operands); - FailureOr secondReduceOp = createSecondReductionForVectorNormOp( + FailureOr secondReduceOp = createSecondReductionForNormOp( loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter); if (failed(secondReduceOp)) return secondReduceOp; @@ -627,6 +663,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index da6f71015..ef3098eb1 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3767,6 +3767,42 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenNormScalarOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenNormScalarOp::verify() { + + // Verificaion of input type for torch.aten.norm.Scalar. + // Per PyTorch docs, only float and complex types are valid for norm + // operation. + + auto inTensor = getSelf().getType().cast(); + + // If no dtype is specified, it will default to a float one. + if (!inTensor.hasDtype()) { + return success(); + } + + auto inTensorDtype = inTensor.getDtype(); + + // Check if dtype is one of those supported by norm operation. + // ComplexType will match any torch complex types, but each float must be + // checked individually. + if (!inTensorDtype.isa()) { + return emitOpError( + "expected a float or complex type for input tensor, but got ") + << inTensorDtype; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// AtenPermuteOp +//===----------------------------------------------------------------------===// + LogicalResult AtenPermuteOp::verify() { // Verification of the permute op for input & output dimensions with diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bfc2fc6a1..a8327b0e0 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9339,6 +9339,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.optional>\n" +" %1 = torch.derefine %none : !torch.none to !torch.any\n" +" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %false, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %0 = torch.derefine %arg2 : !torch.list to !torch.optional>\n" @@ -12038,6 +12046,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %int5 = torch.constant.int 5\n" +" %int8 = torch.constant.int 8\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int8 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" +" } else {\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e749b5834..70f26fe42 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1667,6 +1667,7 @@ ONNX_XFAIL_SET = { "NllLossModule_ignore_index_out_of_bounds_basic", "NllLossModule_mean_basic", "NllLossModule_sum_basic", + "NormScalarModule_basic", "NormScalarOptDimKeepDimModule_basic", "NormScalarOptDimModule_basic", "NormalFunctionalModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 403d124ad..99f4f2200 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1722,6 +1722,9 @@ def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Opti def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) +def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, None, False, None) + def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) @@ -3924,6 +3927,21 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni return dtype return aten〇std〡dtype(self_rank_dtype) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64})) +def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex] = 2) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + # The following check is added because aten〇std〡dtype + # does not handle complex32 transformation to float, + # so it is done manually (torch.half == torch.float16). + # Should possibly be added to aten〇std〡dtype. + if self_dtype == torch.complex32: + return torch.half + return aten〇std〡dtype(self_rank_dtype) + @check_dtype_function([Invocation(0.0), Invocation(0.0, dtype=torch.int32), Invocation(0.0, dtype=torch.float16), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 51c196421..cc41a99be 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -449,6 +449,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) + emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True) emit( "aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 2c61524bd..d0d6c2ea2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1100,6 +1100,25 @@ def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils): # ============================================================================== +class NormScalarModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.p = 3.0 + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.norm(a, self.p) + +@register_test_case(module_factory=lambda: NormScalarModule()) +def NormScalarModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + class NormScalarOptDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__()