Added GeluBackward: MHLO support (#1725)

pull/1743/head
pranavmulticore 2022-12-21 17:39:43 +05:30 committed by GitHub
parent 1d695239ff
commit 0f6008c802
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 0 deletions

View File

@ -176,6 +176,7 @@ MHLO_PASS_SET = {
"GatherModule_basic",
"Gather2DInputModdule_basic",
"GatherRandomIndexModule_basic",
"GeluBackwardModule_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"HardsigmoidModule_basic",

View File

@ -1246,6 +1246,57 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
AtenGeluBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value input = adaptor.getSelf();
auto outType = this->getTypeConverter()
->convertType(op.getType())
.cast<TensorType>();
if (!outType) {
return op.emitError("only tensor type is supported");
}
// TODO: Handle approximate.
std::string approximate;
if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate)) ||
approximate != "none") {
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
}
// Create constant value
Value kAlpha =
chlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input);
Value cstAlpha0 =
chlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input);
Value half = chlo::getConstantLike(rewriter, loc, .5, input);
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
// Compute
Value kBeta0 = rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
Value kBeta = rewriter.create<mhlo::MulOp>(loc, outType, kBeta0, half);
Value erfArg =
rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, adaptor.getSelf());
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
Value erfAdd = rewriter.create<mhlo::AddOp>(loc, outType, erf, one);
Value cdf = rewriter.create<mhlo::MulOp>(loc, outType, erfAdd, half);
Value inputSquared = rewriter.create<mhlo::MulOp>(
loc, outType, adaptor.getSelf(), adaptor.getSelf());
Value negHalfInputSquared =
rewriter.create<mhlo::MulOp>(loc, outType, inputSquared, negHalf);
Value expRes =
rewriter.create<mhlo::ExpOp>(loc, outType, negHalfInputSquared);
Value pdf = rewriter.create<mhlo::MulOp>(loc, outType, kBeta, expRes);
Value pdfTimesInput =
rewriter.create<mhlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
Value pdfTimesInputAddCdf =
rewriter.create<mhlo::AddOp>(loc, outType, pdfTimesInput, cdf);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, adaptor.getGradOutput(),
pdfTimesInputAddCdf);
return success();
}
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
@ -1327,6 +1378,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenErfOp);
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenClampOp);