mirror of https://github.com/llvm/torch-mlir
Added GeluBackward: MHLO support (#1725)
parent
1d695239ff
commit
0f6008c802
|
@ -176,6 +176,7 @@ MHLO_PASS_SET = {
|
|||
"GatherModule_basic",
|
||||
"Gather2DInputModdule_basic",
|
||||
"GatherRandomIndexModule_basic",
|
||||
"GeluBackwardModule_basic",
|
||||
"HardTanhIntModule_basic",
|
||||
"HardTanhModule_basic",
|
||||
"HardsigmoidModule_basic",
|
||||
|
|
|
@ -1246,6 +1246,57 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||
AtenGeluBackwardOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
auto outType = this->getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.cast<TensorType>();
|
||||
if (!outType) {
|
||||
return op.emitError("only tensor type is supported");
|
||||
}
|
||||
// TODO: Handle approximate.
|
||||
std::string approximate;
|
||||
if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate)) ||
|
||||
approximate != "none") {
|
||||
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
|
||||
}
|
||||
// Create constant value
|
||||
Value kAlpha =
|
||||
chlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input);
|
||||
Value cstAlpha0 =
|
||||
chlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input);
|
||||
Value half = chlo::getConstantLike(rewriter, loc, .5, input);
|
||||
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
|
||||
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
|
||||
|
||||
// Compute
|
||||
Value kBeta0 = rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
|
||||
Value kBeta = rewriter.create<mhlo::MulOp>(loc, outType, kBeta0, half);
|
||||
Value erfArg =
|
||||
rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, adaptor.getSelf());
|
||||
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
|
||||
Value erfAdd = rewriter.create<mhlo::AddOp>(loc, outType, erf, one);
|
||||
Value cdf = rewriter.create<mhlo::MulOp>(loc, outType, erfAdd, half);
|
||||
Value inputSquared = rewriter.create<mhlo::MulOp>(
|
||||
loc, outType, adaptor.getSelf(), adaptor.getSelf());
|
||||
Value negHalfInputSquared =
|
||||
rewriter.create<mhlo::MulOp>(loc, outType, inputSquared, negHalf);
|
||||
Value expRes =
|
||||
rewriter.create<mhlo::ExpOp>(loc, outType, negHalfInputSquared);
|
||||
Value pdf = rewriter.create<mhlo::MulOp>(loc, outType, kBeta, expRes);
|
||||
Value pdfTimesInput =
|
||||
rewriter.create<mhlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
|
||||
Value pdfTimesInputAddCdf =
|
||||
rewriter.create<mhlo::AddOp>(loc, outType, pdfTimesInput, cdf);
|
||||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, adaptor.getGradOutput(),
|
||||
pdfTimesInputAddCdf);
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
|
@ -1327,6 +1378,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenErfOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenCatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenClampOp);
|
||||
|
|
Loading…
Reference in New Issue