mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and torch.aten.rrelu_with_noise_backward ops (fix) (#3748)
parent
ad9dfe974e
commit
54d9e24013
|
@ -309,6 +309,61 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$noise,
|
||||
AnyTorchScalarType:$lower,
|
||||
AnyTorchScalarType:$upper,
|
||||
Torch_BoolType:$training,
|
||||
AnyTorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self,
|
||||
Torch_NonValueTensorType:$noise,
|
||||
AnyTorchScalarType:$lower,
|
||||
AnyTorchScalarType:$upper,
|
||||
Torch_BoolType:$training,
|
||||
AnyTorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalNonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -16814,6 +16869,35 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$noise,
|
||||
AnyTorchScalarType:$lower,
|
||||
AnyTorchScalarType:$upper,
|
||||
Torch_BoolType:$training,
|
||||
Torch_BoolType:$self_is_result
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||
}
|
||||
void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 7, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -6683,6 +6683,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -7285,6 +7289,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -12055,6 +12063,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
@ -12247,6 +12263,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If.yield %7 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If.yield %7 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %5 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %6 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
|
|
|
@ -3489,6 +3489,59 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenRreluWithNoiseBackwardOp
|
||||
: public OpRewritePattern<AtenRreluWithNoiseBackwardOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value gradOutput = op.getGradOutput();
|
||||
Value self = op.getSelf();
|
||||
Value noise = op.getNoise();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
||||
bool training;
|
||||
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"training should be a bool constant");
|
||||
}
|
||||
|
||||
bool selfIsResult = false;
|
||||
if (!matchPattern(op.getSelfIsResult(),
|
||||
m_TorchConstantBool(&selfIsResult)) ||
|
||||
selfIsResult)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: self_is_result should be false");
|
||||
|
||||
double lower, upper;
|
||||
if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) ||
|
||||
!matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "lower and upper should be float constants");
|
||||
}
|
||||
|
||||
if (training && (upper - lower > 0.000001)) {
|
||||
Value rreluWithNoiseBackwardOutput =
|
||||
rewriter.create<AtenMulTensorOp>(loc, resType, gradOutput, noise);
|
||||
rewriter.replaceOp(op, rreluWithNoiseBackwardOutput);
|
||||
} else {
|
||||
double negative_slope = (upper + lower) / 2;
|
||||
Value cstNegativeSlope = rewriter.create<ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(negative_slope));
|
||||
rewriter.replaceOpWithNewOp<AtenLeakyReluBackwardOp>(
|
||||
op, resType, gradOutput, self, cstNegativeSlope,
|
||||
op.getSelfIsResult());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {
|
||||
public:
|
||||
|
@ -3588,6 +3641,82 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenRreluWithNoiseOp
|
||||
: public OpRewritePattern<AtenRreluWithNoiseOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
Value noise = op.getNoise();
|
||||
Value lower = op.getLower();
|
||||
Value upper = op.getUpper();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
||||
bool training;
|
||||
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
|
||||
return rewriter.notifyMatchFailure(op, "training should be a constant");
|
||||
}
|
||||
|
||||
Value constantZeroFloat =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
Value constantOneFloat =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
Value constantTwoFloat =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));
|
||||
|
||||
Value alpha;
|
||||
if (training) {
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
||||
loc, resType, self, constantZeroFloat, /*dtype=*/none,
|
||||
/*layout=*/none,
|
||||
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
|
||||
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
|
||||
/*from=*/lower, /*to=*/upper,
|
||||
/*generator=*/none);
|
||||
} else {
|
||||
Value half = rewriter.create<AtenAddOp>(loc, constantTwoFloat.getType(),
|
||||
lower, upper);
|
||||
alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(), half,
|
||||
constantTwoFloat);
|
||||
}
|
||||
|
||||
Value zeroTensor =
|
||||
createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
|
||||
Value positiveOutput =
|
||||
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);
|
||||
|
||||
Value scaledSelf;
|
||||
if (training) {
|
||||
scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self, alpha);
|
||||
auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(),
|
||||
rewriter.getI1Type());
|
||||
Value oneTensor =
|
||||
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
|
||||
Value not_positive = rewriter.create<AtenLtScalarOp>(
|
||||
loc, boolResType, self, constantZeroFloat);
|
||||
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
|
||||
alpha, oneTensor);
|
||||
} else {
|
||||
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
|
||||
}
|
||||
|
||||
Value negativeOutput =
|
||||
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
|
||||
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
|
||||
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
|
||||
rewriter.replaceOp(op, rreluOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
|
||||
namespace {
|
||||
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
|
||||
|
@ -9924,6 +10053,9 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
|
||||
|
|
|
@ -498,6 +498,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenPadOp>();
|
||||
target.addIllegalOp<AtenPreluOp>();
|
||||
target.addIllegalOp<AtenRreluOp>();
|
||||
target.addIllegalOp<AtenRreluWithNoiseOp>();
|
||||
target.addIllegalOp<AtenRreluWithNoiseBackwardOp>();
|
||||
target.addIllegalOp<AtenCeluOp>();
|
||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||
target.addIllegalOp<AtenToDeviceOp>();
|
||||
|
|
|
@ -1207,6 +1207,10 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwisePreluStaticModule_basic",
|
||||
"ElementwiseReciprocalModule_basic",
|
||||
"ElementwiseReluModule_basic",
|
||||
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||
"RreluWithNoiseBackwardEvalStaticModule_basic",
|
||||
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||
|
@ -2106,6 +2110,7 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseReciprocalModule_basic",
|
||||
"ElementwiseRelu6Module_basic",
|
||||
"ElementwiseReluModule_basic",
|
||||
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
|
||||
|
@ -2238,6 +2243,10 @@ TOSA_PASS_SET = {
|
|||
"ReduceSumFloatModule_basic",
|
||||
"ReduceSumSignedIntModule_basic",
|
||||
"ReduceSumUnsignedIntModule_basic",
|
||||
"RreluWithNoiseBackwardEvalModule_basic",
|
||||
"RreluWithNoiseBackwardEvalStaticModule_basic",
|
||||
"RreluWithNoiseBackwardTrainModule_basic",
|
||||
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||
"RepeatModule_basic",
|
||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
|
@ -2436,6 +2445,10 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
"ViewSizeFromOtherTensor_basic",
|
||||
"RenormModuleFloat32NegativeDim_basic",
|
||||
"RenormModuleFloat32_basic",
|
||||
"RreluWithNoiseBackwardEvalModule_basic",
|
||||
"RreluWithNoiseBackwardEvalStaticModule_basic",
|
||||
"RreluWithNoiseBackwardTrainModule_basic",
|
||||
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||
}
|
||||
) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
|
@ -2854,6 +2867,10 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseRemainderTensorModule_Int_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseRreluWithNoiseEvalModule_basic",
|
||||
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||
"ElementwiseSgnModule_basic",
|
||||
"EmptyStridedModule_basic",
|
||||
"EmptyStridedSizeIntStrideModule_basic",
|
||||
|
@ -3002,6 +3019,11 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceL1NormComplexModule_basic",
|
||||
"ReduceL2NormComplexModule_basic",
|
||||
"ReduceL3NormKeepDimComplexModule_basic",
|
||||
"RreluWithNoiseBackwardEvalModule_basic",
|
||||
"RreluWithNoiseBackwardEvalStaticModule_basic",
|
||||
"RreluWithNoiseBackwardTrainModule_basic",
|
||||
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||
"RreluWithNoiseForwardBackwardModule_basic",
|
||||
"ReshapeAliasCollapseModule_basic",
|
||||
"ReshapeAliasExpandModule_basic",
|
||||
"ReshapeExpandModule_basic",
|
||||
|
|
|
@ -298,6 +298,9 @@ def aten〇gelu_backward〡shape(grad_output: List[int], self: List[int], approx
|
|||
def aten〇leaky_relu_backward〡shape(grad_output: List[int], self: List[int], negative_slope: float, self_is_result: bool) -> List[int]:
|
||||
return upstream_shape_functions.unary(grad_output)
|
||||
|
||||
def aten〇rrelu_with_noise_backward〡shape(grad_output: List[int], self: List[int], noise: List[int], lower: float, upper: float, training: bool, self_is_result: bool) -> List[int]:
|
||||
return upstream_shape_functions.unary(grad_output)
|
||||
|
||||
def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], min_val: float, max_val: float) -> List[int]:
|
||||
return upstream_shape_functions.unary(grad_output)
|
||||
|
||||
|
@ -634,6 +637,9 @@ def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]:
|
|||
def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇selu〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -3126,6 +3132,15 @@ def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int],
|
|||
promoted_dtype = promote_dtypes(ranks, dtypes)
|
||||
return promoted_dtype
|
||||
|
||||
@check_dtype_function([Invocation(TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), 0.1, 0.9, False, False) for dtype in _SORTED_TORCH_TYPES])
|
||||
def aten〇rrelu_with_noise_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex], upper: Union[int, float, complex], training: bool, self_is_result: bool) -> int:
|
||||
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [grad_output_rank, self_rank]
|
||||
dtypes = [grad_output_dtype, self_dtype]
|
||||
promoted_dtype = promote_dtypes(ranks, dtypes)
|
||||
return promoted_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
@ -3293,6 +3308,15 @@ def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, flo
|
|||
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()}))
|
||||
def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
noise_rank, noise_dtype = noise_rank_dtype
|
||||
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
|
||||
assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype)
|
||||
assert self_rank == noise_rank
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
|
||||
def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -302,6 +302,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::relu6 : (Tensor) -> (Tensor)",
|
||||
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)",
|
||||
"aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)",
|
||||
"aten::celu : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::selu : (Tensor) -> (Tensor)",
|
||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||
|
@ -1171,6 +1172,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)"
|
||||
)
|
||||
|
||||
# quantized ops
|
||||
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)")
|
||||
|
|
|
@ -322,3 +322,164 @@ class LeakyReluBackwardStaticModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule())
|
||||
def LeakyReluBackwardStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RreluWithNoiseBackwardTrainModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input, noise):
|
||||
return torch.ops.aten.rrelu_with_noise_backward(
|
||||
grad,
|
||||
input,
|
||||
noise,
|
||||
lower=0.1,
|
||||
upper=0.9,
|
||||
training=True,
|
||||
self_is_result=False,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainModule())
|
||||
def RreluWithNoiseBackwardTrainModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
class RreluWithNoiseBackwardTrainStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 4, 5], torch.float32, True),
|
||||
([3, 4, 5], torch.float32, True),
|
||||
([3, 4, 5], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input, noise):
|
||||
return torch.ops.aten.rrelu_with_noise_backward(
|
||||
grad,
|
||||
input,
|
||||
noise,
|
||||
lower=0.1,
|
||||
upper=0.9,
|
||||
training=True,
|
||||
self_is_result=False,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainStaticModule())
|
||||
def RreluWithNoiseBackwardTrainStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RreluWithNoiseBackwardEvalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input, noise):
|
||||
return torch.ops.aten.rrelu_with_noise_backward(
|
||||
grad,
|
||||
input,
|
||||
noise,
|
||||
lower=0.1,
|
||||
upper=0.9,
|
||||
training=False,
|
||||
self_is_result=False,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalModule())
|
||||
def RreluWithNoiseBackwardEvalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
class RreluWithNoiseBackwardEvalStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 4, 5], torch.float32, True),
|
||||
([3, 4, 5], torch.float32, True),
|
||||
([3, 4, 5], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input, noise):
|
||||
return torch.ops.aten.rrelu_with_noise_backward(
|
||||
grad,
|
||||
input,
|
||||
noise,
|
||||
lower=0.1,
|
||||
upper=0.9,
|
||||
training=False,
|
||||
self_is_result=False,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule())
|
||||
def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
class RreluWithNoiseForwardBackwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input, noise):
|
||||
res = torch.ops.aten.rrelu_with_noise_backward(
|
||||
grad,
|
||||
input,
|
||||
noise,
|
||||
lower=0.4,
|
||||
upper=0.6,
|
||||
training=True,
|
||||
self_is_result=False,
|
||||
)
|
||||
return torch.mean(res), torch.std(res)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule())
|
||||
def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils):
|
||||
grad = tu.rand(256, 244)
|
||||
input = tu.rand(256, 244, low=-1.0, high=1.0)
|
||||
noise = tu.rand(256, 244)
|
||||
torch.ops.aten.rrelu_with_noise(input, noise, lower=0.4, upper=0.6, training=True)
|
||||
module.forward(grad, input, noise)
|
||||
|
|
|
@ -1179,6 +1179,88 @@ def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRreluWithNoiseTrainModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
|
||||
)
|
||||
def forward(self, x, noise):
|
||||
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True)
|
||||
return torch.mean(res), torch.std(res)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule())
|
||||
def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRreluWithNoiseTrainStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)]
|
||||
)
|
||||
def forward(self, x, noise):
|
||||
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True)
|
||||
return torch.mean(res), torch.std(res)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule())
|
||||
def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRreluWithNoiseEvalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
|
||||
)
|
||||
def forward(self, x, noise):
|
||||
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False)
|
||||
return torch.mean(res), torch.std(res)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule())
|
||||
def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRreluWithNoiseEvalStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([5, 3], torch.float32, True), ([5, 3], torch.float32, True)])
|
||||
def forward(self, x, noise):
|
||||
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False)
|
||||
return torch.mean(res), torch.std(res)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule())
|
||||
def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCeluStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue