mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and torch.aten.rrelu_with_noise_backward ops (fix) (#3748)
parent
ad9dfe974e
commit
54d9e24013
|
@ -309,6 +309,61 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$noise,
|
||||||
|
AnyTorchScalarType:$lower,
|
||||||
|
AnyTorchScalarType:$upper,
|
||||||
|
Torch_BoolType:$training,
|
||||||
|
AnyTorchOptionalGeneratorType:$generator
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||||
|
}
|
||||||
|
void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 6, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_NonValueTensorType:$self,
|
||||||
|
Torch_NonValueTensorType:$noise,
|
||||||
|
AnyTorchScalarType:$lower,
|
||||||
|
AnyTorchScalarType:$upper,
|
||||||
|
Torch_BoolType:$training,
|
||||||
|
AnyTorchOptionalGeneratorType:$generator
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalNonValueTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||||
|
}
|
||||||
|
void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 6, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
|
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -16814,6 +16869,35 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$grad_output,
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$noise,
|
||||||
|
AnyTorchScalarType:$lower,
|
||||||
|
AnyTorchScalarType:$upper,
|
||||||
|
Torch_BoolType:$training,
|
||||||
|
Torch_BoolType:$self_is_result
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||||
|
}
|
||||||
|
void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 7, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
|
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -6683,6 +6683,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -7285,6 +7289,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -12055,6 +12063,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
" return %4 : !torch.int\n"
|
" return %4 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||||
|
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" return %4 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
|
@ -12247,6 +12263,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %true = torch.constant.bool true\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
|
||||||
|
" torch.prim.If.yield %true : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %7 : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If %3 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
|
||||||
|
" torch.prim.If.yield %true : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %7 : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If %5 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %6 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
|
|
@ -3489,6 +3489,59 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenRreluWithNoiseBackwardOp
|
||||||
|
: public OpRewritePattern<AtenRreluWithNoiseBackwardOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value gradOutput = op.getGradOutput();
|
||||||
|
Value self = op.getSelf();
|
||||||
|
Value noise = op.getNoise();
|
||||||
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool training;
|
||||||
|
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"training should be a bool constant");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool selfIsResult = false;
|
||||||
|
if (!matchPattern(op.getSelfIsResult(),
|
||||||
|
m_TorchConstantBool(&selfIsResult)) ||
|
||||||
|
selfIsResult)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: self_is_result should be false");
|
||||||
|
|
||||||
|
double lower, upper;
|
||||||
|
if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) ||
|
||||||
|
!matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "lower and upper should be float constants");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (training && (upper - lower > 0.000001)) {
|
||||||
|
Value rreluWithNoiseBackwardOutput =
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, resType, gradOutput, noise);
|
||||||
|
rewriter.replaceOp(op, rreluWithNoiseBackwardOutput);
|
||||||
|
} else {
|
||||||
|
double negative_slope = (upper + lower) / 2;
|
||||||
|
Value cstNegativeSlope = rewriter.create<ConstantFloatOp>(
|
||||||
|
loc, rewriter.getF64FloatAttr(negative_slope));
|
||||||
|
rewriter.replaceOpWithNewOp<AtenLeakyReluBackwardOp>(
|
||||||
|
op, resType, gradOutput, self, cstNegativeSlope,
|
||||||
|
op.getSelfIsResult());
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {
|
class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -3588,6 +3641,82 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenRreluWithNoiseOp
|
||||||
|
: public OpRewritePattern<AtenRreluWithNoiseOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value self = op.getSelf();
|
||||||
|
Value noise = op.getNoise();
|
||||||
|
Value lower = op.getLower();
|
||||||
|
Value upper = op.getUpper();
|
||||||
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool training;
|
||||||
|
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "training should be a constant");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value constantZeroFloat =
|
||||||
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||||
|
Value constantOneFloat =
|
||||||
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||||
|
Value constantTwoFloat =
|
||||||
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));
|
||||||
|
|
||||||
|
Value alpha;
|
||||||
|
if (training) {
|
||||||
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
||||||
|
loc, resType, self, constantZeroFloat, /*dtype=*/none,
|
||||||
|
/*layout=*/none,
|
||||||
|
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
|
||||||
|
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
|
||||||
|
/*from=*/lower, /*to=*/upper,
|
||||||
|
/*generator=*/none);
|
||||||
|
} else {
|
||||||
|
Value half = rewriter.create<AtenAddOp>(loc, constantTwoFloat.getType(),
|
||||||
|
lower, upper);
|
||||||
|
alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(), half,
|
||||||
|
constantTwoFloat);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value zeroTensor =
|
||||||
|
createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
|
||||||
|
Value positiveOutput =
|
||||||
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);
|
||||||
|
|
||||||
|
Value scaledSelf;
|
||||||
|
if (training) {
|
||||||
|
scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self, alpha);
|
||||||
|
auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(),
|
||||||
|
rewriter.getI1Type());
|
||||||
|
Value oneTensor =
|
||||||
|
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
|
||||||
|
Value not_positive = rewriter.create<AtenLtScalarOp>(
|
||||||
|
loc, boolResType, self, constantZeroFloat);
|
||||||
|
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
|
||||||
|
alpha, oneTensor);
|
||||||
|
} else {
|
||||||
|
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value negativeOutput =
|
||||||
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
|
||||||
|
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
|
||||||
|
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
|
||||||
|
rewriter.replaceOp(op, rreluOutput);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
|
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
|
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
|
||||||
|
@ -9924,6 +10053,9 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
|
||||||
|
patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
|
||||||
|
|
|
@ -498,6 +498,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenPadOp>();
|
target.addIllegalOp<AtenPadOp>();
|
||||||
target.addIllegalOp<AtenPreluOp>();
|
target.addIllegalOp<AtenPreluOp>();
|
||||||
target.addIllegalOp<AtenRreluOp>();
|
target.addIllegalOp<AtenRreluOp>();
|
||||||
|
target.addIllegalOp<AtenRreluWithNoiseOp>();
|
||||||
|
target.addIllegalOp<AtenRreluWithNoiseBackwardOp>();
|
||||||
target.addIllegalOp<AtenCeluOp>();
|
target.addIllegalOp<AtenCeluOp>();
|
||||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||||
target.addIllegalOp<AtenToDeviceOp>();
|
target.addIllegalOp<AtenToDeviceOp>();
|
||||||
|
|
|
@ -1207,6 +1207,10 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwisePreluStaticModule_basic",
|
"ElementwisePreluStaticModule_basic",
|
||||||
"ElementwiseReciprocalModule_basic",
|
"ElementwiseReciprocalModule_basic",
|
||||||
"ElementwiseReluModule_basic",
|
"ElementwiseReluModule_basic",
|
||||||
|
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||||
|
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||||
|
"RreluWithNoiseBackwardEvalStaticModule_basic",
|
||||||
|
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||||
"ElementwiseRemainderTensorModule_Float_basic",
|
"ElementwiseRemainderTensorModule_Float_basic",
|
||||||
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||||
|
@ -2106,6 +2110,7 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseReciprocalModule_basic",
|
"ElementwiseReciprocalModule_basic",
|
||||||
"ElementwiseRelu6Module_basic",
|
"ElementwiseRelu6Module_basic",
|
||||||
"ElementwiseReluModule_basic",
|
"ElementwiseReluModule_basic",
|
||||||
|
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||||
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
|
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
|
||||||
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
|
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
|
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
|
||||||
|
@ -2238,6 +2243,10 @@ TOSA_PASS_SET = {
|
||||||
"ReduceSumFloatModule_basic",
|
"ReduceSumFloatModule_basic",
|
||||||
"ReduceSumSignedIntModule_basic",
|
"ReduceSumSignedIntModule_basic",
|
||||||
"ReduceSumUnsignedIntModule_basic",
|
"ReduceSumUnsignedIntModule_basic",
|
||||||
|
"RreluWithNoiseBackwardEvalModule_basic",
|
||||||
|
"RreluWithNoiseBackwardEvalStaticModule_basic",
|
||||||
|
"RreluWithNoiseBackwardTrainModule_basic",
|
||||||
|
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||||
"RepeatModule_basic",
|
"RepeatModule_basic",
|
||||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||||
"ResNet18StaticModule_basic",
|
"ResNet18StaticModule_basic",
|
||||||
|
@ -2436,6 +2445,10 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
"RenormModuleFloat32NegativeDim_basic",
|
"RenormModuleFloat32NegativeDim_basic",
|
||||||
"RenormModuleFloat32_basic",
|
"RenormModuleFloat32_basic",
|
||||||
|
"RreluWithNoiseBackwardEvalModule_basic",
|
||||||
|
"RreluWithNoiseBackwardEvalStaticModule_basic",
|
||||||
|
"RreluWithNoiseBackwardTrainModule_basic",
|
||||||
|
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||||
}
|
}
|
||||||
) - {
|
) - {
|
||||||
### Test failing in make_fx_tosa but not in tosa
|
### Test failing in make_fx_tosa but not in tosa
|
||||||
|
@ -2854,6 +2867,10 @@ ONNX_XFAIL_SET = {
|
||||||
"ElementwiseRemainderTensorModule_Int_basic",
|
"ElementwiseRemainderTensorModule_Int_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRreluWithNoiseEvalModule_basic",
|
||||||
|
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||||
|
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||||
|
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||||
"ElementwiseSgnModule_basic",
|
"ElementwiseSgnModule_basic",
|
||||||
"EmptyStridedModule_basic",
|
"EmptyStridedModule_basic",
|
||||||
"EmptyStridedSizeIntStrideModule_basic",
|
"EmptyStridedSizeIntStrideModule_basic",
|
||||||
|
@ -3002,6 +3019,11 @@ ONNX_XFAIL_SET = {
|
||||||
"ReduceL1NormComplexModule_basic",
|
"ReduceL1NormComplexModule_basic",
|
||||||
"ReduceL2NormComplexModule_basic",
|
"ReduceL2NormComplexModule_basic",
|
||||||
"ReduceL3NormKeepDimComplexModule_basic",
|
"ReduceL3NormKeepDimComplexModule_basic",
|
||||||
|
"RreluWithNoiseBackwardEvalModule_basic",
|
||||||
|
"RreluWithNoiseBackwardEvalStaticModule_basic",
|
||||||
|
"RreluWithNoiseBackwardTrainModule_basic",
|
||||||
|
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||||
|
"RreluWithNoiseForwardBackwardModule_basic",
|
||||||
"ReshapeAliasCollapseModule_basic",
|
"ReshapeAliasCollapseModule_basic",
|
||||||
"ReshapeAliasExpandModule_basic",
|
"ReshapeAliasExpandModule_basic",
|
||||||
"ReshapeExpandModule_basic",
|
"ReshapeExpandModule_basic",
|
||||||
|
|
|
@ -298,6 +298,9 @@ def aten〇gelu_backward〡shape(grad_output: List[int], self: List[int], approx
|
||||||
def aten〇leaky_relu_backward〡shape(grad_output: List[int], self: List[int], negative_slope: float, self_is_result: bool) -> List[int]:
|
def aten〇leaky_relu_backward〡shape(grad_output: List[int], self: List[int], negative_slope: float, self_is_result: bool) -> List[int]:
|
||||||
return upstream_shape_functions.unary(grad_output)
|
return upstream_shape_functions.unary(grad_output)
|
||||||
|
|
||||||
|
def aten〇rrelu_with_noise_backward〡shape(grad_output: List[int], self: List[int], noise: List[int], lower: float, upper: float, training: bool, self_is_result: bool) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(grad_output)
|
||||||
|
|
||||||
def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], min_val: float, max_val: float) -> List[int]:
|
def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], min_val: float, max_val: float) -> List[int]:
|
||||||
return upstream_shape_functions.unary(grad_output)
|
return upstream_shape_functions.unary(grad_output)
|
||||||
|
|
||||||
|
@ -634,6 +637,9 @@ def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]:
|
||||||
def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]:
|
def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇selu〡shape(self: List[int]) -> List[int]:
|
def aten〇selu〡shape(self: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
@ -3126,6 +3132,15 @@ def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int],
|
||||||
promoted_dtype = promote_dtypes(ranks, dtypes)
|
promoted_dtype = promote_dtypes(ranks, dtypes)
|
||||||
return promoted_dtype
|
return promoted_dtype
|
||||||
|
|
||||||
|
@check_dtype_function([Invocation(TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), 0.1, 0.9, False, False) for dtype in _SORTED_TORCH_TYPES])
|
||||||
|
def aten〇rrelu_with_noise_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex], upper: Union[int, float, complex], training: bool, self_is_result: bool) -> int:
|
||||||
|
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
ranks: List[Optional[int]] = [grad_output_rank, self_rank]
|
||||||
|
dtypes = [grad_output_dtype, self_dtype]
|
||||||
|
promoted_dtype = promote_dtypes(ranks, dtypes)
|
||||||
|
return promoted_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
@ -3293,6 +3308,15 @@ def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, flo
|
||||||
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
|
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()}))
|
||||||
|
def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
noise_rank, noise_dtype = noise_rank_dtype
|
||||||
|
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
|
||||||
|
assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype)
|
||||||
|
assert self_rank == noise_rank
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
|
||||||
def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
|
|
@ -302,6 +302,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::relu6 : (Tensor) -> (Tensor)",
|
"aten::relu6 : (Tensor) -> (Tensor)",
|
||||||
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)",
|
"aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)",
|
||||||
|
"aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)",
|
||||||
"aten::celu : (Tensor, Scalar) -> (Tensor)",
|
"aten::celu : (Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::selu : (Tensor) -> (Tensor)",
|
"aten::selu : (Tensor) -> (Tensor)",
|
||||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||||
|
@ -1171,6 +1172,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)"
|
"aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)"
|
||||||
)
|
)
|
||||||
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
||||||
|
emit(
|
||||||
|
"aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
|
|
||||||
# quantized ops
|
# quantized ops
|
||||||
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)")
|
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)")
|
||||||
|
|
|
@ -322,3 +322,164 @@ class LeakyReluBackwardStaticModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule())
|
@register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule())
|
||||||
def LeakyReluBackwardStaticModule_basic(module, tu: TestUtils):
|
def LeakyReluBackwardStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class RreluWithNoiseBackwardTrainModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, grad, input, noise):
|
||||||
|
return torch.ops.aten.rrelu_with_noise_backward(
|
||||||
|
grad,
|
||||||
|
input,
|
||||||
|
noise,
|
||||||
|
lower=0.1,
|
||||||
|
upper=0.9,
|
||||||
|
training=True,
|
||||||
|
self_is_result=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainModule())
|
||||||
|
def RreluWithNoiseBackwardTrainModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class RreluWithNoiseBackwardTrainStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([3, 4, 5], torch.float32, True),
|
||||||
|
([3, 4, 5], torch.float32, True),
|
||||||
|
([3, 4, 5], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, grad, input, noise):
|
||||||
|
return torch.ops.aten.rrelu_with_noise_backward(
|
||||||
|
grad,
|
||||||
|
input,
|
||||||
|
noise,
|
||||||
|
lower=0.1,
|
||||||
|
upper=0.9,
|
||||||
|
training=True,
|
||||||
|
self_is_result=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainStaticModule())
|
||||||
|
def RreluWithNoiseBackwardTrainStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class RreluWithNoiseBackwardEvalModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, grad, input, noise):
|
||||||
|
return torch.ops.aten.rrelu_with_noise_backward(
|
||||||
|
grad,
|
||||||
|
input,
|
||||||
|
noise,
|
||||||
|
lower=0.1,
|
||||||
|
upper=0.9,
|
||||||
|
training=False,
|
||||||
|
self_is_result=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalModule())
|
||||||
|
def RreluWithNoiseBackwardEvalModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class RreluWithNoiseBackwardEvalStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([3, 4, 5], torch.float32, True),
|
||||||
|
([3, 4, 5], torch.float32, True),
|
||||||
|
([3, 4, 5], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, grad, input, noise):
|
||||||
|
return torch.ops.aten.rrelu_with_noise_backward(
|
||||||
|
grad,
|
||||||
|
input,
|
||||||
|
noise,
|
||||||
|
lower=0.1,
|
||||||
|
upper=0.9,
|
||||||
|
training=False,
|
||||||
|
self_is_result=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule())
|
||||||
|
def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class RreluWithNoiseForwardBackwardModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, grad, input, noise):
|
||||||
|
res = torch.ops.aten.rrelu_with_noise_backward(
|
||||||
|
grad,
|
||||||
|
input,
|
||||||
|
noise,
|
||||||
|
lower=0.4,
|
||||||
|
upper=0.6,
|
||||||
|
training=True,
|
||||||
|
self_is_result=False,
|
||||||
|
)
|
||||||
|
return torch.mean(res), torch.std(res)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule())
|
||||||
|
def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils):
|
||||||
|
grad = tu.rand(256, 244)
|
||||||
|
input = tu.rand(256, 244, low=-1.0, high=1.0)
|
||||||
|
noise = tu.rand(256, 244)
|
||||||
|
torch.ops.aten.rrelu_with_noise(input, noise, lower=0.4, upper=0.6, training=True)
|
||||||
|
module.forward(grad, input, noise)
|
||||||
|
|
|
@ -1179,6 +1179,88 @@ def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRreluWithNoiseTrainModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
|
||||||
|
)
|
||||||
|
def forward(self, x, noise):
|
||||||
|
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True)
|
||||||
|
return torch.mean(res), torch.std(res)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule())
|
||||||
|
def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRreluWithNoiseTrainStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)]
|
||||||
|
)
|
||||||
|
def forward(self, x, noise):
|
||||||
|
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True)
|
||||||
|
return torch.mean(res), torch.std(res)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule())
|
||||||
|
def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRreluWithNoiseEvalModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
|
||||||
|
)
|
||||||
|
def forward(self, x, noise):
|
||||||
|
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False)
|
||||||
|
return torch.mean(res), torch.std(res)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule())
|
||||||
|
def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRreluWithNoiseEvalStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([5, 3], torch.float32, True), ([5, 3], torch.float32, True)])
|
||||||
|
def forward(self, x, noise):
|
||||||
|
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False)
|
||||||
|
return torch.mean(res), torch.std(res)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule())
|
||||||
|
def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseCeluStaticModule(torch.nn.Module):
|
class ElementwiseCeluStaticModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue