Add aten::gelu lowering (#374)

* Print more exception info on error during test execution

* Fix formatting

* Add aten::gelu lowering

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
pull/376/head snapshot-20211026.45
Boian Petkantchin 2021-10-25 16:16:01 -07:00 committed by GitHub
parent a6943ef90c
commit e276dbbaa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 97 additions and 23 deletions

View File

@ -198,6 +198,28 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseGeluModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.gelu = torch.nn.GELU()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return self.gelu(x)
@register_test_case(module_factory=lambda: ElementwiseGeluModule())
def ElementwiseGeluModule_basic(module, tu: TestUtils):
module.forward(2*tu.rand(5, 3) - 0.5)
# ==============================================================================
class ElementwiseSigmoidModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -762,6 +762,20 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [
let assemblyFormat = "$self `,` $mask `,` $value attr-dict `:` type($self) `,` type($mask) `,` type($value) `->` type($result)";
}
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::gelu : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -77,7 +77,8 @@ static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
assert(dim.getType().isa<IntegerType>() &&
"dim arg of toPositiveDim must be integer type");
Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank);
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
Value cst0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
Value predDimGEZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
Value dimInt = b.create<SelectOp>(loc, predDimGEZero, dim, dimAddInputRank);
@ -89,7 +90,8 @@ static void assertIsValidDim(OpBuilder &b, Location loc, Value dim,
Value inputRank) {
assert(dim.getType().isa<IntegerType>() &&
"dim arg of assertIsValidDim must be integer type");
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
Value cst0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
Value predGEZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
b.create<AssertOp>(loc, predGEZero,
@ -270,6 +272,29 @@ static bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
return true;
}
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
Value sigma) {
Type elementType = x.getType();
Value xMinusMean = b.create<arith::SubFOp>(loc, x, mean);
Value two = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 2));
Value sqrt2 = b.create<math::SqrtOp>(loc, two);
Value erfArg = b.create<arith::DivFOp>(loc, xMinusMean, sqrt2);
Value erf = b.create<math::ErfOp>(loc, erfArg);
Value one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
Value erfPlus1 = b.create<arith::AddFOp>(loc, one, erf);
Value oneHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
Value normalCdf = b.create<arith::MulFOp>(loc, oneHalf, erfPlus1);
return normalCdf;
}
static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
Type elementType = x.getType();
Value zero = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0));
Value one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
return buildNormalCdf(b, loc, x, zero, one);
}
namespace {
class ConvertAtenAdaptiveAvgPool2dOp
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
@ -1117,6 +1142,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
payloadArgs[0], constZero);
return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
}
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
if (!gelu.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
gelu.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdf);
}
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
AtenAddTensorOp::Adaptor adaptor(operands);
if (add.alpha().getType().isa<Torch::FloatType>()) {
@ -1396,9 +1432,9 @@ struct ConvertElementwiseOp : ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenAddTensorOp, AtenMulTensorOp,
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
AtenExpOp>(op))
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -2322,7 +2358,7 @@ public:
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenAddTensorOp,
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);

View File

@ -190,14 +190,15 @@ public:
visitOperation(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenAddScalarOp,
AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp, AtenFmodScalarOp,
AtenFloorDivideScalarOp, AtenEqScalarOp, AtenGeScalarOp,
AtenGtScalarOp, AtenNeScalarOp, AtenBitwiseNotOp, AtenToDtypeOp,
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp>(op)) {
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp,
AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp,
AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenEqScalarOp,
AtenGeScalarOp, AtenGtScalarOp, AtenNeScalarOp, AtenBitwiseNotOp,
AtenToDtypeOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp,
AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp,
AtenLayerNormOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

View File

@ -16,6 +16,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Approximation.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
@ -179,11 +180,13 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
auto *context = &getContext();
RewritePatternSet patterns(context);
populateExpandTanhPattern(patterns);
patterns.add<math::ErfPolynomialApproximation>(patterns.getContext());
ConversionTarget target(*context);
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<math::MathDialect>();
target.addLegalDialect<arith::ArithmeticDialect>();
target.addIllegalOp<math::TanhOp>();
target.addIllegalOp<math::ErfOp>();
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
return signalPassFailure();
}

View File

@ -464,6 +464,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
]:
emit_with_mutating_variants(key)
emit("aten::gelu : (Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")

View File

@ -25,6 +25,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Di
import io
import pickle
import traceback
import torch
@ -298,16 +299,10 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
golden_trace = generate_golden_trace(test)
compiled = config.compile(test.program_factory())
except Exception as e:
# Useful for debugging:
# ```
# raise
# ```
# This will give the full traceback rather than giving just
# the stringified exception in the report.
# TODO: Capture the traceback and make it available in the report.
results.append(
TestResult(unique_name=test.unique_name,
compilation_error=str(e),
compilation_error="".join(traceback.format_exception(
type(e), e, e.__traceback__)),
runtime_error=None,
trace=None,
golden_trace=None))
@ -319,7 +314,8 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
results.append(
TestResult(unique_name=test.unique_name,
compilation_error=None,
runtime_error=str(e),
runtime_error="".join(traceback.format_exception(
type(e), e, e.__traceback__)),
trace=None,
golden_trace=None))
continue