mirror of https://github.com/llvm/torch-mlir
Add aten::gelu lowering (#374)
* Print more exception info on error during test execution * Fix formatting * Add aten::gelu lowering Co-authored-by: Boian Petkantchin <boian@nod-labs.com>pull/376/head snapshot-20211026.45
parent
a6943ef90c
commit
e276dbbaa6
|
@ -198,6 +198,28 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeluModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gelu = torch.nn.GELU()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.gelu(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGeluModule())
|
||||
def ElementwiseGeluModule_basic(module, tu: TestUtils):
|
||||
module.forward(2*tu.rand(5, 3) - 0.5)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSigmoidModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -762,6 +762,20 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [
|
|||
let assemblyFormat = "$self `,` $mask `,` $value attr-dict `:` type($self) `,` type($mask) `,` type($value) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::gelu : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -77,7 +77,8 @@ static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
|||
assert(dim.getType().isa<IntegerType>() &&
|
||||
"dim arg of toPositiveDim must be integer type");
|
||||
Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank);
|
||||
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
||||
Value cst0 =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
||||
Value predDimGEZero =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
|
||||
Value dimInt = b.create<SelectOp>(loc, predDimGEZero, dim, dimAddInputRank);
|
||||
|
@ -89,7 +90,8 @@ static void assertIsValidDim(OpBuilder &b, Location loc, Value dim,
|
|||
Value inputRank) {
|
||||
assert(dim.getType().isa<IntegerType>() &&
|
||||
"dim arg of assertIsValidDim must be integer type");
|
||||
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
||||
Value cst0 =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
||||
Value predGEZero =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
|
||||
b.create<AssertOp>(loc, predGEZero,
|
||||
|
@ -270,6 +272,29 @@ static bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
|
|||
return true;
|
||||
}
|
||||
|
||||
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
|
||||
Value sigma) {
|
||||
Type elementType = x.getType();
|
||||
Value xMinusMean = b.create<arith::SubFOp>(loc, x, mean);
|
||||
Value two = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 2));
|
||||
Value sqrt2 = b.create<math::SqrtOp>(loc, two);
|
||||
Value erfArg = b.create<arith::DivFOp>(loc, xMinusMean, sqrt2);
|
||||
Value erf = b.create<math::ErfOp>(loc, erfArg);
|
||||
Value one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
Value erfPlus1 = b.create<arith::AddFOp>(loc, one, erf);
|
||||
Value oneHalf =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
|
||||
Value normalCdf = b.create<arith::MulFOp>(loc, oneHalf, erfPlus1);
|
||||
return normalCdf;
|
||||
}
|
||||
|
||||
static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
|
||||
Type elementType = x.getType();
|
||||
Value zero = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
Value one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
return buildNormalCdf(b, loc, x, zero, one);
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAdaptiveAvgPool2dOp
|
||||
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
|
||||
|
@ -1117,6 +1142,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
payloadArgs[0], constZero);
|
||||
return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
|
||||
}
|
||||
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
|
||||
if (!gelu.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
gelu.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]);
|
||||
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdf);
|
||||
}
|
||||
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
||||
AtenAddTensorOp::Adaptor adaptor(operands);
|
||||
if (add.alpha().getType().isa<Torch::FloatType>()) {
|
||||
|
@ -1396,9 +1432,9 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenAddTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
|
||||
AtenExpOp>(op))
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -2322,7 +2358,7 @@ public:
|
|||
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBatchNormOp>();
|
||||
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenAddTensorOp,
|
||||
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
|
|
|
@ -190,14 +190,15 @@ public:
|
|||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
|
||||
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
|
||||
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenAddScalarOp,
|
||||
AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp, AtenFmodScalarOp,
|
||||
AtenFloorDivideScalarOp, AtenEqScalarOp, AtenGeScalarOp,
|
||||
AtenGtScalarOp, AtenNeScalarOp, AtenBitwiseNotOp, AtenToDtypeOp,
|
||||
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
|
||||
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
||||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
||||
AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp>(op)) {
|
||||
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp,
|
||||
AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp,
|
||||
AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenEqScalarOp,
|
||||
AtenGeScalarOp, AtenGtScalarOp, AtenNeScalarOp, AtenBitwiseNotOp,
|
||||
AtenToDtypeOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
|
||||
DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
|
||||
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp,
|
||||
AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp,
|
||||
AtenLayerNormOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Approximation.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -179,11 +180,13 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
|||
auto *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
populateExpandTanhPattern(patterns);
|
||||
patterns.add<math::ErfPolynomialApproximation>(patterns.getContext());
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<math::MathDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
target.addIllegalOp<math::TanhOp>();
|
||||
target.addIllegalOp<math::ErfOp>();
|
||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -464,6 +464,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
]:
|
||||
emit_with_mutating_variants(key)
|
||||
|
||||
emit("aten::gelu : (Tensor) -> (Tensor)")
|
||||
|
||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Di
|
|||
|
||||
import io
|
||||
import pickle
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -298,16 +299,10 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
|||
golden_trace = generate_golden_trace(test)
|
||||
compiled = config.compile(test.program_factory())
|
||||
except Exception as e:
|
||||
# Useful for debugging:
|
||||
# ```
|
||||
# raise
|
||||
# ```
|
||||
# This will give the full traceback rather than giving just
|
||||
# the stringified exception in the report.
|
||||
# TODO: Capture the traceback and make it available in the report.
|
||||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
compilation_error=str(e),
|
||||
compilation_error="".join(traceback.format_exception(
|
||||
type(e), e, e.__traceback__)),
|
||||
runtime_error=None,
|
||||
trace=None,
|
||||
golden_trace=None))
|
||||
|
@ -319,7 +314,8 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
|||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
compilation_error=None,
|
||||
runtime_error=str(e),
|
||||
runtime_error="".join(traceback.format_exception(
|
||||
type(e), e, e.__traceback__)),
|
||||
trace=None,
|
||||
golden_trace=None))
|
||||
continue
|
||||
|
|
Loading…
Reference in New Issue