mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.erf op.
Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>pull/652/head snapshot-20220309.314
parent
1a2a9e066f
commit
3d9ba5e525
|
@ -545,6 +545,26 @@ def ElementwiseLogModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseErfModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.erf(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseErfModule())
|
||||
def ElementwiseErfModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSqrtModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -246,6 +246,34 @@ def Torch_AtenHardswish_Op : Torch_Op<"aten.hardswish_", [
|
|||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenErfOp : Torch_Op<"aten.erf", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::erf : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenErf_Op : Torch_Op<"aten.erf_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::erf_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenSiluOp : Torch_Op<"aten.silu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -1622,6 +1622,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<math::CeilOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenLogOp>(op))
|
||||
return b.create<math::LogOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenErfOp>(op))
|
||||
return b.create<math::ErfOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenSqrtOp>(op))
|
||||
return b.create<math::SqrtOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenRsqrtOp>(op))
|
||||
|
@ -2487,12 +2489,12 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
|
||||
AtenExpOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp,
|
||||
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
|
||||
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
|
||||
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
||||
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp,
|
||||
AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp,
|
||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
|
@ -4518,12 +4520,13 @@ public:
|
|||
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
|
||||
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
|
||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>();
|
||||
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp,
|
||||
AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op,
|
||||
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenCloneOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeOp>();
|
||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||
|
|
|
@ -232,7 +232,7 @@ public:
|
|||
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
|
||||
PseudoAtenBernoulliFloatOp, PseudoAtenBernoulliTensorOp,
|
||||
PseudoAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
||||
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp>(op)) {
|
||||
AtenHardswishOp, AtenErfOp, AtenSiluOp, AtenHardtanhOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -453,6 +453,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::hardsigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::hardswish : (Tensor) -> (Tensor)",
|
||||
"aten::erf : (Tensor) -> (Tensor)",
|
||||
"aten::silu : (Tensor) -> (Tensor)",
|
||||
"aten::sin : (Tensor) -> (Tensor)",
|
||||
"aten::exp : (Tensor) -> (Tensor)",
|
||||
|
|
Loading…
Reference in New Issue