diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 36b2243af..206d70ffb 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -309,6 +309,61 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [ }]; } +def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ AllowsTypeRefinement, HasValueSemantics, @@ -16814,6 +16869,35 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } +def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + Torch_BoolType:$self_is_result + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index f2963f7c8..46cb3e6b7 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6683,6 +6683,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7285,6 +7289,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12055,6 +12063,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -12247,6 +12263,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9b24d0e95..1fefb59a4 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3489,6 +3489,59 @@ public: }; } // namespace +namespace { +class DecomposeAtenRreluWithNoiseBackwardOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value gradOutput = op.getGradOutput(); + Value self = op.getSelf(); + Value noise = op.getNoise(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, + "training should be a bool constant"); + } + + bool selfIsResult = false; + if (!matchPattern(op.getSelfIsResult(), + m_TorchConstantBool(&selfIsResult)) || + selfIsResult) + return rewriter.notifyMatchFailure( + op, "unimplemented: self_is_result should be false"); + + double lower, upper; + if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) || + !matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) { + return rewriter.notifyMatchFailure( + op, "lower and upper should be float constants"); + } + + if (training && (upper - lower > 0.000001)) { + Value rreluWithNoiseBackwardOutput = + rewriter.create(loc, resType, gradOutput, noise); + rewriter.replaceOp(op, rreluWithNoiseBackwardOutput); + } else { + double negative_slope = (upper + lower) / 2; + Value cstNegativeSlope = rewriter.create( + loc, rewriter.getF64FloatAttr(negative_slope)); + rewriter.replaceOpWithNewOp( + op, resType, gradOutput, self, cstNegativeSlope, + op.getSelfIsResult()); + } + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenPreluOp : public OpRewritePattern { public: @@ -3588,6 +3641,82 @@ public: }; } // namespace +namespace { +class DecomposeAtenRreluWithNoiseOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value noise = op.getNoise(); + Value lower = op.getLower(); + Value upper = op.getUpper(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, "training should be a constant"); + } + + Value constantZeroFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value constantOneFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value constantTwoFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + + Value alpha; + if (training) { + Value none = rewriter.create(loc); + Value emptyTensor = rewriter.create( + loc, resType, self, constantZeroFloat, /*dtype=*/none, + /*layout=*/none, + /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); + alpha = rewriter.create(loc, resType, emptyTensor, + /*from=*/lower, /*to=*/upper, + /*generator=*/none); + } else { + Value half = rewriter.create(loc, constantTwoFloat.getType(), + lower, upper); + alpha = rewriter.create(loc, constantTwoFloat.getType(), half, + constantTwoFloat); + } + + Value zeroTensor = + createRank0Tensor(rewriter, loc, resType, constantZeroFloat); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, self); + + Value scaledSelf; + if (training) { + scaledSelf = rewriter.create(loc, resType, self, alpha); + auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(), + rewriter.getI1Type()); + Value oneTensor = + createRank0Tensor(rewriter, loc, resType, constantOneFloat); + Value not_positive = rewriter.create( + loc, boolResType, self, constantZeroFloat); + noise = rewriter.create(loc, resType, not_positive, + alpha, oneTensor); + } else { + scaledSelf = rewriter.create(loc, resType, self, alpha); + } + + Value negativeOutput = + rewriter.create(loc, resType, zeroTensor, scaledSelf); + Value rreluOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOneFloat); + rewriter.replaceOp(op, rreluOutput); + return success(); + } +}; +} // namespace + // CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) namespace { class DecomposeAtenCeluOp : public OpRewritePattern { @@ -9924,6 +10053,9 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index ebc43faa5..feb63db0b 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -498,6 +498,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 553a27924..e370a1d8b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1207,6 +1207,10 @@ STABLEHLO_PASS_SET = { "ElementwisePreluStaticModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseReluModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", "ElementwiseRemainderTensorModule_Float_basic", "ElementwiseRemainderTensorModule_Float_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", @@ -2106,6 +2110,7 @@ TOSA_PASS_SET = { "ElementwiseReciprocalModule_basic", "ElementwiseRelu6Module_basic", "ElementwiseReluModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic", @@ -2238,6 +2243,10 @@ TOSA_PASS_SET = { "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", "RepeatModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", "ResNet18StaticModule_basic", @@ -2436,6 +2445,10 @@ MAKE_FX_TOSA_PASS_SET = ( "ViewSizeFromOtherTensor_basic", "RenormModuleFloat32NegativeDim_basic", "RenormModuleFloat32_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", } ) - { ### Test failing in make_fx_tosa but not in tosa @@ -2854,6 +2867,10 @@ ONNX_XFAIL_SET = { "ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", + "ElementwiseRreluWithNoiseEvalModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseSgnModule_basic", "EmptyStridedModule_basic", "EmptyStridedSizeIntStrideModule_basic", @@ -3002,6 +3019,11 @@ ONNX_XFAIL_SET = { "ReduceL1NormComplexModule_basic", "ReduceL2NormComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", + "RreluWithNoiseForwardBackwardModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index d632e9815..1cb9678ec 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -298,6 +298,9 @@ def aten〇gelu_backward〡shape(grad_output: List[int], self: List[int], approx def aten〇leaky_relu_backward〡shape(grad_output: List[int], self: List[int], negative_slope: float, self_is_result: bool) -> List[int]: return upstream_shape_functions.unary(grad_output) +def aten〇rrelu_with_noise_backward〡shape(grad_output: List[int], self: List[int], noise: List[int], lower: float, upper: float, training: bool, self_is_result: bool) -> List[int]: + return upstream_shape_functions.unary(grad_output) + def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], min_val: float, max_val: float) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -634,6 +637,9 @@ def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]: def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -3126,6 +3132,15 @@ def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], promoted_dtype = promote_dtypes(ranks, dtypes) return promoted_dtype +@check_dtype_function([Invocation(TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), 0.1, 0.9, False, False) for dtype in _SORTED_TORCH_TYPES]) +def aten〇rrelu_with_noise_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex], upper: Union[int, float, complex], training: bool, self_is_result: bool) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3293,6 +3308,15 @@ def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, flo assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()})) +def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + noise_rank, noise_dtype = noise_rank_dtype + assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) + assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype) + assert self_rank == noise_rank + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 31984d727..17f7faa10 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -302,6 +302,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", + "aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", "aten::celu : (Tensor, Scalar) -> (Tensor)", "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", @@ -1171,6 +1172,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)" ) emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") + emit( + "aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)" + ) # quantized ops emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index e209d15b2..5e6e09390 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -322,3 +322,164 @@ class LeakyReluBackwardStaticModule(torch.nn.Module): @register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule()) def LeakyReluBackwardStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class RreluWithNoiseBackwardTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=True, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainModule()) +def RreluWithNoiseBackwardTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseBackwardTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=True, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainStaticModule()) +def RreluWithNoiseBackwardTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class RreluWithNoiseBackwardEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=False, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalModule()) +def RreluWithNoiseBackwardEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseBackwardEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=False, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule()) +def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseForwardBackwardModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + res = torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.4, + upper=0.6, + training=True, + self_is_result=False, + ) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule()) +def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils): + grad = tu.rand(256, 244) + input = tu.rand(256, 244, low=-1.0, high=1.0) + noise = tu.rand(256, 244) + torch.ops.aten.rrelu_with_noise(input, noise, lower=0.4, upper=0.6, training=True) + module.forward(grad, input, noise) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index ed5254353..a62b901a9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1179,6 +1179,88 @@ def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRreluWithNoiseTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] + ) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule()) +def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)] + ) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule()) +def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] + ) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule()) +def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([5, 3], torch.float32, True), ([5, 3], torch.float32, True)]) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule()) +def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) + + +# ============================================================================== + + class ElementwiseCeluStaticModule(torch.nn.Module): def __init__(self): super().__init__()