[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and torch.aten.rrelu_with_noise_backward ops (fix) (#3748)

pull/3835/head
Andrija Bosnjakovic 2024-10-25 18:01:05 +02:00 committed by GitHub
parent ad9dfe974e
commit 54d9e24013
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 568 additions and 0 deletions

View File

@ -309,6 +309,61 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
}];
}
def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
AllowsTypeRefinement,
HasValueSemantics,
@ -16814,6 +16869,35 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
}];
}
def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
Torch_BoolType:$self_is_result
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -6683,6 +6683,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -7285,6 +7289,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -12055,6 +12063,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
@ -12247,6 +12263,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %true = torch.constant.bool true\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"

View File

@ -3489,6 +3489,59 @@ public:
};
} // namespace
namespace {
class DecomposeAtenRreluWithNoiseBackwardOp
: public OpRewritePattern<AtenRreluWithNoiseBackwardOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value gradOutput = op.getGradOutput();
Value self = op.getSelf();
Value noise = op.getNoise();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
bool training;
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
return rewriter.notifyMatchFailure(op,
"training should be a bool constant");
}
bool selfIsResult = false;
if (!matchPattern(op.getSelfIsResult(),
m_TorchConstantBool(&selfIsResult)) ||
selfIsResult)
return rewriter.notifyMatchFailure(
op, "unimplemented: self_is_result should be false");
double lower, upper;
if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) ||
!matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) {
return rewriter.notifyMatchFailure(
op, "lower and upper should be float constants");
}
if (training && (upper - lower > 0.000001)) {
Value rreluWithNoiseBackwardOutput =
rewriter.create<AtenMulTensorOp>(loc, resType, gradOutput, noise);
rewriter.replaceOp(op, rreluWithNoiseBackwardOutput);
} else {
double negative_slope = (upper + lower) / 2;
Value cstNegativeSlope = rewriter.create<ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(negative_slope));
rewriter.replaceOpWithNewOp<AtenLeakyReluBackwardOp>(
op, resType, gradOutput, self, cstNegativeSlope,
op.getSelfIsResult());
}
return success();
}
};
} // namespace
namespace {
class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {
public:
@ -3588,6 +3641,82 @@ public:
};
} // namespace
namespace {
class DecomposeAtenRreluWithNoiseOp
: public OpRewritePattern<AtenRreluWithNoiseOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value noise = op.getNoise();
Value lower = op.getLower();
Value upper = op.getUpper();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
bool training;
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
return rewriter.notifyMatchFailure(op, "training should be a constant");
}
Value constantZeroFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value constantOneFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value constantTwoFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));
Value alpha;
if (training) {
Value none = rewriter.create<ConstantNoneOp>(loc);
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
loc, resType, self, constantZeroFloat, /*dtype=*/none,
/*layout=*/none,
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
/*from=*/lower, /*to=*/upper,
/*generator=*/none);
} else {
Value half = rewriter.create<AtenAddOp>(loc, constantTwoFloat.getType(),
lower, upper);
alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(), half,
constantTwoFloat);
}
Value zeroTensor =
createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
Value positiveOutput =
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);
Value scaledSelf;
if (training) {
scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self, alpha);
auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(),
rewriter.getI1Type());
Value oneTensor =
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
Value not_positive = rewriter.create<AtenLtScalarOp>(
loc, boolResType, self, constantZeroFloat);
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
alpha, oneTensor);
} else {
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
}
Value negativeOutput =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
rewriter.replaceOp(op, rreluOutput);
return success();
}
};
} // namespace
// CELU(x)=max(0,x)+min(0,alpha(exp(x/alpha)1))
namespace {
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
@ -9924,6 +10053,9 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);

View File

@ -498,6 +498,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenPadOp>();
target.addIllegalOp<AtenPreluOp>();
target.addIllegalOp<AtenRreluOp>();
target.addIllegalOp<AtenRreluWithNoiseOp>();
target.addIllegalOp<AtenRreluWithNoiseBackwardOp>();
target.addIllegalOp<AtenCeluOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>();

View File

@ -1207,6 +1207,10 @@ STABLEHLO_PASS_SET = {
"ElementwisePreluStaticModule_basic",
"ElementwiseReciprocalModule_basic",
"ElementwiseReluModule_basic",
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"RreluWithNoiseBackwardEvalStaticModule_basic",
"RreluWithNoiseBackwardTrainStaticModule_basic",
"ElementwiseRemainderTensorModule_Float_basic",
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_Float_basic",
@ -2106,6 +2110,7 @@ TOSA_PASS_SET = {
"ElementwiseReciprocalModule_basic",
"ElementwiseRelu6Module_basic",
"ElementwiseReluModule_basic",
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
@ -2238,6 +2243,10 @@ TOSA_PASS_SET = {
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RreluWithNoiseBackwardEvalModule_basic",
"RreluWithNoiseBackwardEvalStaticModule_basic",
"RreluWithNoiseBackwardTrainModule_basic",
"RreluWithNoiseBackwardTrainStaticModule_basic",
"RepeatModule_basic",
"RepeatInterleaveSelfIntNoDimModule_basic",
"ResNet18StaticModule_basic",
@ -2436,6 +2445,10 @@ MAKE_FX_TOSA_PASS_SET = (
"ViewSizeFromOtherTensor_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
"RreluWithNoiseBackwardEvalModule_basic",
"RreluWithNoiseBackwardEvalStaticModule_basic",
"RreluWithNoiseBackwardTrainModule_basic",
"RreluWithNoiseBackwardTrainStaticModule_basic",
}
) - {
### Test failing in make_fx_tosa but not in tosa
@ -2854,6 +2867,10 @@ ONNX_XFAIL_SET = {
"ElementwiseRemainderTensorModule_Int_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
"ElementwiseRreluWithNoiseEvalModule_basic",
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"ElementwiseSgnModule_basic",
"EmptyStridedModule_basic",
"EmptyStridedSizeIntStrideModule_basic",
@ -3002,6 +3019,11 @@ ONNX_XFAIL_SET = {
"ReduceL1NormComplexModule_basic",
"ReduceL2NormComplexModule_basic",
"ReduceL3NormKeepDimComplexModule_basic",
"RreluWithNoiseBackwardEvalModule_basic",
"RreluWithNoiseBackwardEvalStaticModule_basic",
"RreluWithNoiseBackwardTrainModule_basic",
"RreluWithNoiseBackwardTrainStaticModule_basic",
"RreluWithNoiseForwardBackwardModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",

View File

@ -298,6 +298,9 @@ def atengelu_backward〡shape(grad_output: List[int], self: List[int], approx
def atenleaky_relu_backward〡shape(grad_output: List[int], self: List[int], negative_slope: float, self_is_result: bool) -> List[int]:
return upstream_shape_functions.unary(grad_output)
def atenrrelu_with_noise_backward〡shape(grad_output: List[int], self: List[int], noise: List[int], lower: float, upper: float, training: bool, self_is_result: bool) -> List[int]:
return upstream_shape_functions.unary(grad_output)
def atenhardtanh_backward〡shape(grad_output: List[int], self: List[int], min_val: float, max_val: float) -> List[int]:
return upstream_shape_functions.unary(grad_output)
@ -634,6 +637,9 @@ def atencelu〡shape(self: List[int], alpha: float = 1.) -> List[int]:
def atenrrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]:
return upstream_shape_functions.unary(self)
def atenrrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]:
return upstream_shape_functions.unary(self)
def atenselu〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -3126,6 +3132,15 @@ def atenleaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int],
promoted_dtype = promote_dtypes(ranks, dtypes)
return promoted_dtype
@check_dtype_function([Invocation(TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), 0.1, 0.9, False, False) for dtype in _SORTED_TORCH_TYPES])
def atenrrelu_with_noise_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex], upper: Union[int, float, complex], training: bool, self_is_result: bool) -> int:
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [grad_output_rank, self_rank]
dtypes = [grad_output_dtype, self_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
return promoted_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenlift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
@ -3293,6 +3308,15 @@ def atenrrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, flo
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()}))
def atenrrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
noise_rank, noise_dtype = noise_rank_dtype
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype)
assert self_rank == noise_rank
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
def atenrelu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -302,6 +302,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::relu6 : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)",
"aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)",
"aten::celu : (Tensor, Scalar) -> (Tensor)",
"aten::selu : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
@ -1171,6 +1172,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)"
)
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
emit(
"aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)"
)
# quantized ops
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)")

View File

@ -322,3 +322,164 @@ class LeakyReluBackwardStaticModule(torch.nn.Module):
@register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule())
def LeakyReluBackwardStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5))
# ==============================================================================
class RreluWithNoiseBackwardTrainModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, grad, input, noise):
return torch.ops.aten.rrelu_with_noise_backward(
grad,
input,
noise,
lower=0.1,
upper=0.9,
training=True,
self_is_result=False,
)
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainModule())
def RreluWithNoiseBackwardTrainModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
class RreluWithNoiseBackwardTrainStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3, 4, 5], torch.float32, True),
([3, 4, 5], torch.float32, True),
([3, 4, 5], torch.float32, True),
]
)
def forward(self, grad, input, noise):
return torch.ops.aten.rrelu_with_noise_backward(
grad,
input,
noise,
lower=0.1,
upper=0.9,
training=True,
self_is_result=False,
)
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainStaticModule())
def RreluWithNoiseBackwardTrainStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
# ==============================================================================
class RreluWithNoiseBackwardEvalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, grad, input, noise):
return torch.ops.aten.rrelu_with_noise_backward(
grad,
input,
noise,
lower=0.1,
upper=0.9,
training=False,
self_is_result=False,
)
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalModule())
def RreluWithNoiseBackwardEvalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
class RreluWithNoiseBackwardEvalStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3, 4, 5], torch.float32, True),
([3, 4, 5], torch.float32, True),
([3, 4, 5], torch.float32, True),
]
)
def forward(self, grad, input, noise):
return torch.ops.aten.rrelu_with_noise_backward(
grad,
input,
noise,
lower=0.1,
upper=0.9,
training=False,
self_is_result=False,
)
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule())
def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
class RreluWithNoiseForwardBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, grad, input, noise):
res = torch.ops.aten.rrelu_with_noise_backward(
grad,
input,
noise,
lower=0.4,
upper=0.6,
training=True,
self_is_result=False,
)
return torch.mean(res), torch.std(res)
@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule())
def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils):
grad = tu.rand(256, 244)
input = tu.rand(256, 244, low=-1.0, high=1.0)
noise = tu.rand(256, 244)
torch.ops.aten.rrelu_with_noise(input, noise, lower=0.4, upper=0.6, training=True)
module.forward(grad, input, noise)

View File

@ -1179,6 +1179,88 @@ def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseRreluWithNoiseTrainModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
)
def forward(self, x, noise):
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True)
return torch.mean(res), torch.std(res)
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule())
def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))
# ==============================================================================
class ElementwiseRreluWithNoiseTrainStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)]
)
def forward(self, x, noise):
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True)
return torch.mean(res), torch.std(res)
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule())
def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))
# ==============================================================================
class ElementwiseRreluWithNoiseEvalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
)
def forward(self, x, noise):
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False)
return torch.mean(res), torch.std(res)
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule())
def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))
# ==============================================================================
class ElementwiseRreluWithNoiseEvalStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([5, 3], torch.float32, True), ([5, 3], torch.float32, True)])
def forward(self, x, noise):
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False)
return torch.mean(res), torch.std(res)
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule())
def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))
# ==============================================================================
class ElementwiseCeluStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()