mirror of https://github.com/llvm/torch-mlir
add forward+backward test
parent
d0d70a4bff
commit
e70be44cfc
|
@ -3563,7 +3563,7 @@ public:
|
|||
Value noise = op.getNoise();
|
||||
Value lower = op.getLower();
|
||||
Value upper = op.getUpper();
|
||||
auto resType = cast<BaseTensorType>(self.getType());
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -3609,13 +3609,12 @@ public:
|
|||
rewriter.getI1Type());
|
||||
Value oneTensor =
|
||||
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
|
||||
Value not_positive = rewriter.create<AtenLeScalarOp>(
|
||||
Value not_positive = rewriter.create<AtenLtScalarOp>(
|
||||
loc, boolResType, self, constantZeroFloat);
|
||||
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
|
||||
noise, oneTensor);
|
||||
alpha, oneTensor);
|
||||
} else {
|
||||
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
|
||||
noise = alpha;
|
||||
}
|
||||
|
||||
Value negativeOutput =
|
||||
|
@ -3628,6 +3627,93 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// namespace {
|
||||
// class DecomposeAtenRreluWithNoiseOp
|
||||
// : public OpRewritePattern<AtenRreluWithNoiseOp> {
|
||||
// public:
|
||||
// using OpRewritePattern::OpRewritePattern;
|
||||
// LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
|
||||
// PatternRewriter &rewriter) const override {
|
||||
// Location loc = op.getLoc();
|
||||
// Value self = op.getSelf();
|
||||
// Value noise = op.getNoise();
|
||||
// Value lower = op.getLower();
|
||||
// Value upper = op.getUpper();
|
||||
// auto resType = cast<BaseTensorType>(op.getType());
|
||||
// if (!resType.hasDtype()) {
|
||||
// return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
// }
|
||||
|
||||
// bool training;
|
||||
// if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
|
||||
// return rewriter.notifyMatchFailure(op, "training should be a
|
||||
// constant");
|
||||
// }
|
||||
|
||||
// Value constantZeroFloat =
|
||||
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
// Value constantOneFloat =
|
||||
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
// Value constantTwoFloat =
|
||||
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));
|
||||
|
||||
// // Value alpha;
|
||||
// // if (training) {
|
||||
// // Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
// // Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
||||
// // loc, resType, self, constantZeroFloat, /*dtype=*/none,
|
||||
// // /*layout=*/none,
|
||||
// // /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
|
||||
// // alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
|
||||
// // /*from=*/lower, /*to=*/upper,
|
||||
// // /*generator=*/none);
|
||||
// // } else {
|
||||
// // Value half = rewriter.create<AtenAddOp>(loc,
|
||||
// constantTwoFloat.getType(),
|
||||
// // lower, upper);
|
||||
// // alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(),
|
||||
// half,
|
||||
// // constantTwoFloat);
|
||||
// // }
|
||||
|
||||
// Value zeroTensor =
|
||||
// createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
|
||||
// Value positiveOutput =
|
||||
// rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);
|
||||
|
||||
// Value scaledSelf;
|
||||
// if (training) {
|
||||
// scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self,
|
||||
// noise); auto boolResType =
|
||||
// resType.getWithSizesAndDtype(resType.getSizes(),
|
||||
// rewriter.getI1Type());
|
||||
// Value oneTensor =
|
||||
// createRank0Tensor(rewriter, loc, resType, constantOneFloat);
|
||||
// Value not_positive = rewriter.create<AtenLeScalarOp>(
|
||||
// loc, boolResType, self, constantZeroFloat);
|
||||
// noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
|
||||
// noise, oneTensor);
|
||||
// } else {
|
||||
// Value half = rewriter.create<AtenAddOp>(loc,
|
||||
// constantTwoFloat.getType(),
|
||||
// lower, upper);
|
||||
// Value alpha = rewriter.create<AtenDivOp>(loc,
|
||||
// constantTwoFloat.getType(), half,
|
||||
// constantTwoFloat);
|
||||
// scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self,
|
||||
// alpha);
|
||||
// }
|
||||
|
||||
// Value negativeOutput =
|
||||
// rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
|
||||
// Value rreluOutput = rewriter.create<AtenAddTensorOp>(
|
||||
// loc, resType, positiveOutput, negativeOutput, constantOneFloat);
|
||||
// rewriter.replaceOp(op, rreluOutput);
|
||||
// return success();
|
||||
// }
|
||||
// };
|
||||
// } // namespace
|
||||
|
||||
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
|
||||
namespace {
|
||||
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
|
||||
|
|
|
@ -448,3 +448,41 @@ class RreluWithNoiseBackwardEvalStaticModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule())
|
||||
def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
class RreluWithNoiseForwardBackwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input, noise):
|
||||
torch.ops.aten.rrelu_with_noise(
|
||||
input, noise, lower=0.4, upper=0.6, training=True
|
||||
)
|
||||
res = torch.ops.aten.rrelu_with_noise_backward(
|
||||
grad,
|
||||
input,
|
||||
noise,
|
||||
lower=0.4,
|
||||
upper=0.6,
|
||||
training=True,
|
||||
self_is_result=False,
|
||||
)
|
||||
return torch.mean(res), torch.std(res)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule())
|
||||
def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(256, 244),
|
||||
tu.rand(256, 244, low=-1.0, high=1.0),
|
||||
tu.rand(256, 244, low=0.4, high=0.6),
|
||||
)
|
||||
|
|
|
@ -1188,13 +1188,13 @@ class ElementwiseRreluWithNoiseTrainModule(torch.nn.Module):
|
|||
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
|
||||
)
|
||||
def forward(self, x, noise):
|
||||
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True)
|
||||
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True)
|
||||
return torch.mean(res), torch.std(res)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule())
|
||||
def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 1536), torch.zeros((1024, 1536)))
|
||||
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1206,16 +1206,16 @@ class ElementwiseRreluWithNoiseTrainStaticModule(torch.nn.Module):
|
|||
|
||||
@export
|
||||
@annotate_args(
|
||||
[None, ([1024, 1536], torch.float32, True), ([1024, 1536], torch.float32, True)]
|
||||
[None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)]
|
||||
)
|
||||
def forward(self, x, noise):
|
||||
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.1, 0.9, True)
|
||||
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True)
|
||||
return torch.mean(res), torch.std(res)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule())
|
||||
def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 1536), torch.zeros((1024, 1536)))
|
||||
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1236,7 +1236,7 @@ class ElementwiseRreluWithNoiseEvalModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule())
|
||||
def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3, low=-1, high=1), torch.zeros((5, 3)))
|
||||
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1255,7 +1255,7 @@ class ElementwiseRreluWithNoiseEvalStaticModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule())
|
||||
def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3, low=-1, high=1), torch.zeros((5, 3)))
|
||||
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
Loading…
Reference in New Issue