From 676fa8cc09771cab3f0844577304bcb3a5e90377 Mon Sep 17 00:00:00 2001 From: Branko Trifkovic <88882867+BaneTrifa@users.noreply.github.com> Date: Mon, 17 Jun 2024 19:40:57 +0200 Subject: [PATCH] Implement lowering of torch.aten.renorm (#3388) Closes [nod-ai/SHARK-Turbine/issues/689](https://github.com/nod-ai/SHARK-Turbine/issues/689) --------- Co-authored-by: Branko Trifkovic --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 27 ++++ lib/Dialect/Torch/IR/TorchOps.cpp | 74 +++++++++ .../Transforms/AbstractInterpLibrary.cpp | 17 +++ .../Torch/Transforms/DecomposeComplexOps.cpp | 140 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 12 ++ .../build_tools/abstract_interp_lib_gen.py | 17 +++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/norm_like.py | 93 ++++++++++++ 9 files changed, 382 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 90e497117..550f9c47c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6657,6 +6657,33 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ }]; } +def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$p, + Torch_IntType:$dim, + AnyTorchScalarType:$maxnorm + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRenormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenRenormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 500a861da..b0bb55511 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4655,6 +4655,80 @@ LogicalResult AtenNormScalarOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenRenormOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenRenormOp::verify() { + + auto selfType = cast(getSelf().getType()); + + if (!selfType.hasDtype() || !selfType.hasSizes()) + return success(); + + auto inShape = selfType.getSizes(); + int64_t selfRank = inShape.size(); + auto selfDtype = selfType.getDtype(); + + if (!isa(selfDtype)) + return emitOpError( + "expected a float or complex type for input tensor, but got ") + << selfDtype; + + // According to the Pytoch documentation tensor need to be at least rank 2 + if (selfRank <= 1) + return emitOpError("renorm: input needs at least 2 dimensions, got ") + << selfRank << " dimensions"; + + // Check if argument p is valid + auto pType = getP().getType(); + + if (isa(pType)) + return emitOpError("renorm: p must be real-valued"); + + // The argument 'p' can be either an integer or a floating-point number, + // so we need to consider both options and check if 'p' is within the correct + // range + int64_t pInt = 1; + double_t pDouble = 1; + if (!matchPattern(getP(), m_TorchConstantInt(&pInt)) && + !matchPattern(getP(), m_TorchConstantFloat(&pDouble))) + return success(); + + if (pInt <= 0 || pDouble <= 0) + return emitOpError("renorm: non-positive norm not supported"); + + // Check if argument maxnorm is valid + auto maxnormType = getMaxnorm().getType(); + if (isa(maxnormType)) + return emitOpError("renorm: maxnorm must be real-valued"); + + // The argument 'maxnorm' can be either an integer or a floating-point number, + // so we need to consider both options and check if 'maxnorm' is within the + // correct range + int64_t maxnormInt = 0; + double_t maxnormDouble = 0; + if (!matchPattern(getMaxnorm(), m_TorchConstantInt(&maxnormInt)) && + !matchPattern(getMaxnorm(), m_TorchConstantFloat(&maxnormDouble))) + return success(); + + if (maxnormInt < 0 || maxnormDouble < 0) + return emitOpError("renorm: expected maxnorm to be >= 0"); + + // Get the dimension + int64_t dim; + if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) + return success(); + + // check if is dim is in the correct range + if (dim >= selfRank || dim < -selfRank) + return emitOpError("Dimension out of range (expected to be in range of [") + << -selfRank << ", " << selfRank - 1 << "], but got " << dim; + + return success(); +} + //===----------------------------------------------------------------------===// // AtenPermuteOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index c587fd9f9..71767fe14 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10119,6 +10119,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.renorm\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %false = torch.constant.bool false\n" " %none = torch.constant.none\n" @@ -13162,6 +13165,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %int5 = torch.constant.int 5\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index bc3ba0c07..0c1584160 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2069,6 +2069,145 @@ public: }; } // namespace +// https://github.com/pytorch/pytorch/blob/9dec41b684a4284c4e052e295314c23f0f942fec/torch/_refs/__init__.py#L3229 +// Decompose aten.renorm into: linalg_vector_norm +namespace { +class DecomposeAtenRenormOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRenormOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value dim = op.getDim(); + Value p = op.getP(); + Value maxnorm = op.getMaxnorm(); + + // Prepare all necessary variables + auto ndim = getTensorRank(self); + auto resType = cast(self.getType()); + + if (!resType.hasDtype() || !resType.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "result should have dtype and sizes"); + } + + Type dtype = resType.getDtype(); + if (isa(dtype)) { + return rewriter.notifyMatchFailure( + op, "lowering of aten.renorm for complex inputs dtype is " + "currently unimplemented"); + } + + SmallVector inputSize(resType.getSizes()); + + // Convert dim from Value to int + int64_t dimInt; + if (!matchPattern(dim, m_TorchConstantInt(&dimInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: dim not constant int"); + + // Define all constants + Value cstTrue = rewriter.create(loc, true); + Value cstZero = rewriter.create(loc, 0); + Value cstOne = rewriter.create(loc, 1); + Value cstNone = rewriter.create(loc); + + // Arragne reduce_dims tensor (vector), [0, 1, ... , dim-1, dim+1, ... , + // ndim-1] + llvm::SmallVector reduceDimsVector; + for (u_int64_t i = 0; i < ndim; i++) { + if (i == (u_int64_t)dimInt) + continue; + + Value constI = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + + reduceDimsVector.push_back(constI); + } + + Value reduceDimsList = rewriter.create( + loc, + rewriter.getType(rewriter.getType()), + reduceDimsVector); + + // Make output shape for linalg.vector_norm operation + SmallVector inputSizeValue; + for (u_int64_t i = 0; i < inputSize.size(); i++) { + if (i != (u_int64_t)dimInt) + inputSize[i] = 1; + + inputSizeValue.push_back( + rewriter.create(loc, inputSize[i])); + } + + // Prepare arguments for linalg.vector_norm + Value dtypeValue; + Type vectorNormOutType; + + if (isa(dtype)) { + dtype = cast(rewriter.getF32Type()); + dtypeValue = getDtypeIntValueForType(rewriter, loc, dtype); + vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype); + } else { + dtypeValue = cstNone; + vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype); + } + + auto norm = rewriter.create( + loc, vectorNormOutType, self, p, reduceDimsList, cstTrue, dtypeValue); + + // Define epsiolon constant 10^-7 + mlir::FloatType f64Type = rewriter.getF64Type(); + Value epsValue = rewriter.create( + loc, rewriter.getFloatAttr(f64Type, 1e-7)); + + Value normPlusEps = rewriter.create( + loc, vectorNormOutType, norm, epsValue, cstOne); + + Value maxnormTensorValue = rewriter.create( + loc, normPlusEps.getType(), normPlusEps, maxnorm, cstNone, cstNone, + cstNone, cstNone, cstNone); + + // Divide maxnorm and normPlusEps + auto divideMaxnormAndNorm = rewriter.create( + loc, vectorNormOutType, maxnormTensorValue, normPlusEps); + + // Next few lines corespond to this pythorch code: norm_factor = + // torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0) + auto boolTensorType = rewriter.getType( + cast(vectorNormOutType).getOptionalSizes(), + rewriter.getI1Type()); + + Value greaterThanMaxnorm = + rewriter.create(loc, boolTensorType, norm, maxnorm); + + Value cstOnetensor = rewriter.create( + loc, normPlusEps.getType(), normPlusEps, cstOne, cstNone, cstNone, + cstNone, cstNone, cstNone); + + auto normFactor = rewriter.create( + loc, vectorNormOutType, greaterThanMaxnorm, divideMaxnormAndNorm, + cstOnetensor); + + // Converte norm_factor to input dtype + Value normFactorFinal = rewriter.create( + loc, resType.getWithSizesAndDtype(inputSize, resType.getDtype()), + normFactor, getDtypeIntValueForType(rewriter, loc, resType.getDtype())); + + // Multiply input tensor with norm factor + auto output = rewriter.create(loc, self.getType(), self, + normFactorFinal); + + rewriter.replaceOpWithNewOp(op, self.getType(), output, + /*memory_format*/ cstZero); + + return success(); + } +}; +} // namespace + // Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select, // aten.add.Tensor and aten.mull.Tensor. See // https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70. @@ -8081,6 +8220,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index fb5dd7ea8..fc56700f2 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -402,6 +402,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 058ada5b4..bdb726052 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1473,6 +1473,9 @@ STABLEHLO_PASS_SET = { "ElementwiseLogSigmoidModule_basic", "ElementwiseHardshrinkStaticModule_basic", "ElementwiseSoftshrinkStaticModule_basic", + "RenormModuleFloat16_basic", + "RenormModuleFloat32NegativeDim_basic", + "RenormModuleFloat32_basic", } STABLEHLO_CRASHING_SET = set() @@ -1949,6 +1952,8 @@ TOSA_PASS_SET = { "LinspaceOneSizeModule_basic", "LinspaceTwoSizeModule_basic", "TorchPrimLoopForLikeTensorArgModule_basic", + "RenormModuleFloat32NegativeDim_basic", + "RenormModuleFloat32_basic", } MAKE_FX_TOSA_PASS_SET = ( @@ -1982,6 +1987,8 @@ MAKE_FX_TOSA_PASS_SET = ( "ViewSizeDimLedByCollapsedOnesModule_basic", "ViewSizeFromOtherTensor_basic", "ScaledDotProductAttentionDifferentModule_basic", + "RenormModuleFloat32NegativeDim_basic", + "RenormModuleFloat32_basic", } ) - { ### Test failing in make_fx_tosa but not in tosa @@ -2695,6 +2702,11 @@ ONNX_XFAIL_SET = { "IndexPutHackedTwin3DIntNonAccumulateModule_basic", # RuntimeError: unsupported input type: Device "PrimsIotaModule_basic", + # Error: 'aten::renorm' to ONNX opset version 17 is not supported. + "RenormModuleFloat16_basic", + "RenormModuleFloat32NegativeDim_basic", + "RenormModuleFloat32_basic", + "RenormModuleFloat32DynamicDims_basic", # Failure - unknown "BernoulliModule_basic", "Conv_Transpose1dModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index da2681e76..d8e5f51d7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1998,6 +1998,9 @@ def aten〇linalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) +def aten〇renorm〡shape(self: List[int], p: float, dim: int, maxnorm: float) -> List[int]: + return self + def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, None, False, None) @@ -4416,6 +4419,20 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U return dtype return aten〇std〡dtype(self_rank_dtype) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(3,3)], + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, + p=1, + dim=0, + maxnorm=5) +) +def aten〇renorm〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex], dim: int, maxnorm: Union[int, float, complex]) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 5a0632bed..ade2e2b22 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -587,6 +587,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) + emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True) emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True) emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)") emit( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index f4c9e39d1..69926259d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -633,3 +633,96 @@ class AtenInstanceNormModule(torch.nn.Module): @register_test_case(module_factory=lambda: AtenInstanceNormModule()) def AtenInstanceNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2)) + + +# ============================================================================== +class RenormModuleFloat32(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = 2 + self.dim = 1 + self.maxnorm = 10 + + @export + @annotate_args( + [ + None, + ([3, 3], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm) + + +@register_test_case(module_factory=lambda: RenormModuleFloat32()) +def RenormModuleFloat32_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3)) + + +class RenormModuleFloat16(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = 2.1 + self.dim = 1 + self.maxnorm = 10 + + @export + @annotate_args( + [ + None, + ([3, 4, 5], torch.float16, True), + ] + ) + def forward(self, x): + return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm) + + +@register_test_case(module_factory=lambda: RenormModuleFloat16()) +def RenormModuleFloat16_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float16)) + + +class RenormModuleFloat32NegativeDim(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = 2.3 + self.dim = -1 + self.maxnorm = 5.2 + + @export + @annotate_args( + [ + None, + ([1, 4, 5, 2], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm) + + +@register_test_case(module_factory=lambda: RenormModuleFloat32NegativeDim()) +def RenormModuleFloat32NegativeDim_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 5, 2).to(torch.float32)) + + +class RenormModuleFloat32DynamicDims(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = 2 + self.dim = 1 + self.maxnorm = 10 + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm) + + +@register_test_case(module_factory=lambda: RenormModuleFloat32DynamicDims()) +def RenormModuleFloat32DynamicDims_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 3))