add forward+backward test

pull/3645/head
Andrija Bosnjakovic 2024-09-20 17:14:37 +02:00
parent d0d70a4bff
commit e70be44cfc
3 changed files with 135 additions and 11 deletions

View File

@ -3563,7 +3563,7 @@ public:
Value noise = op.getNoise();
Value lower = op.getLower();
Value upper = op.getUpper();
auto resType = cast<BaseTensorType>(self.getType());
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
@ -3609,13 +3609,12 @@ public:
rewriter.getI1Type());
Value oneTensor =
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
Value not_positive = rewriter.create<AtenLeScalarOp>(
Value not_positive = rewriter.create<AtenLtScalarOp>(
loc, boolResType, self, constantZeroFloat);
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
noise, oneTensor);
alpha, oneTensor);
} else {
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
noise = alpha;
}
Value negativeOutput =
@ -3628,6 +3627,93 @@ public:
};
} // namespace
// namespace {
// class DecomposeAtenRreluWithNoiseOp
// : public OpRewritePattern<AtenRreluWithNoiseOp> {
// public:
// using OpRewritePattern::OpRewritePattern;
// LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
// PatternRewriter &rewriter) const override {
// Location loc = op.getLoc();
// Value self = op.getSelf();
// Value noise = op.getNoise();
// Value lower = op.getLower();
// Value upper = op.getUpper();
// auto resType = cast<BaseTensorType>(op.getType());
// if (!resType.hasDtype()) {
// return rewriter.notifyMatchFailure(op, "result should have dtype");
// }
// bool training;
// if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
// return rewriter.notifyMatchFailure(op, "training should be a
// constant");
// }
// Value constantZeroFloat =
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
// Value constantOneFloat =
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
// Value constantTwoFloat =
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));
// // Value alpha;
// // if (training) {
// // Value none = rewriter.create<ConstantNoneOp>(loc);
// // Value emptyTensor = rewriter.create<AtenFullLikeOp>(
// // loc, resType, self, constantZeroFloat, /*dtype=*/none,
// // /*layout=*/none,
// // /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
// // alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
// // /*from=*/lower, /*to=*/upper,
// // /*generator=*/none);
// // } else {
// // Value half = rewriter.create<AtenAddOp>(loc,
// constantTwoFloat.getType(),
// // lower, upper);
// // alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(),
// half,
// // constantTwoFloat);
// // }
// Value zeroTensor =
// createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
// Value positiveOutput =
// rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);
// Value scaledSelf;
// if (training) {
// scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self,
// noise); auto boolResType =
// resType.getWithSizesAndDtype(resType.getSizes(),
// rewriter.getI1Type());
// Value oneTensor =
// createRank0Tensor(rewriter, loc, resType, constantOneFloat);
// Value not_positive = rewriter.create<AtenLeScalarOp>(
// loc, boolResType, self, constantZeroFloat);
// noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
// noise, oneTensor);
// } else {
// Value half = rewriter.create<AtenAddOp>(loc,
// constantTwoFloat.getType(),
// lower, upper);
// Value alpha = rewriter.create<AtenDivOp>(loc,
// constantTwoFloat.getType(), half,
// constantTwoFloat);
// scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self,
// alpha);
// }
// Value negativeOutput =
// rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
// Value rreluOutput = rewriter.create<AtenAddTensorOp>(
// loc, resType, positiveOutput, negativeOutput, constantOneFloat);
// rewriter.replaceOp(op, rreluOutput);
// return success();
// }
// };
// } // namespace
// CELU(x)=max(0,x)+min(0,alpha(exp(x/alpha)1))
namespace {
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {

View File

@ -448,3 +448,41 @@ class RreluWithNoiseBackwardEvalStaticModule(torch.nn.Module):
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule())
def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
class RreluWithNoiseForwardBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, grad, input, noise):
torch.ops.aten.rrelu_with_noise(
input, noise, lower=0.4, upper=0.6, training=True
)
res = torch.ops.aten.rrelu_with_noise_backward(
grad,
input,
noise,
lower=0.4,
upper=0.6,
training=True,
self_is_result=False,
)
return torch.mean(res), torch.std(res)
@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule())
def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(256, 244),
tu.rand(256, 244, low=-1.0, high=1.0),
tu.rand(256, 244, low=0.4, high=0.6),
)

View File

@ -1188,13 +1188,13 @@ class ElementwiseRreluWithNoiseTrainModule(torch.nn.Module):
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
)
def forward(self, x, noise):
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True)
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True)
return torch.mean(res), torch.std(res)
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule())
def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1536), torch.zeros((1024, 1536)))
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))
# ==============================================================================
@ -1206,16 +1206,16 @@ class ElementwiseRreluWithNoiseTrainStaticModule(torch.nn.Module):
@export
@annotate_args(
[None, ([1024, 1536], torch.float32, True), ([1024, 1536], torch.float32, True)]
[None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)]
)
def forward(self, x, noise):
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.1, 0.9, True)
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True)
return torch.mean(res), torch.std(res)
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule())
def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1536), torch.zeros((1024, 1536)))
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))
# ==============================================================================
@ -1236,7 +1236,7 @@ class ElementwiseRreluWithNoiseEvalModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule())
def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1), torch.zeros((5, 3)))
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))
# ==============================================================================
@ -1255,7 +1255,7 @@ class ElementwiseRreluWithNoiseEvalStaticModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule())
def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1), torch.zeros((5, 3)))
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))
# ==============================================================================