lowerd Sqrt to linalg

reused clang-format, as changes got deleted
pull/400/head
nodlabs 2021-11-04 16:53:06 -07:00 committed by Yi Zhang
parent 2ce47dc8e4
commit 5ff823ace9
5 changed files with 41 additions and 3 deletions

View File

@ -360,3 +360,21 @@ class ElementwiseLogModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseLogModule()) @register_test_case(module_factory=lambda: ElementwiseLogModule())
def ElementwiseLogModule_basic(module, tu: TestUtils): def ElementwiseLogModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4)) module.forward(tu.rand(3, 4))
class ElementwiseSqrtModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.sqrt(a)
@register_test_case(module_factory=lambda: ElementwiseSqrtModule())
def ElementwiseSqrtModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

View File

@ -2700,3 +2700,17 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
} }
def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::sqrt : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$a
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$a `,` attr-dict `:` type($a) `->` type($result)";
}

View File

@ -1280,6 +1280,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::ExpOp>(loc, payloadArgs[0]); return b.create<math::ExpOp>(loc, payloadArgs[0]);
if (isa<AtenLogOp>(op)) if (isa<AtenLogOp>(op))
return b.create<math::LogOp>(loc, payloadArgs[0]); return b.create<math::LogOp>(loc, payloadArgs[0]);
if (isa<AtenSqrtOp>(op))
return b.create<math::SqrtOp>(loc, payloadArgs[0]);
if (isa<AtenSigmoidOp>(op)) { if (isa<AtenSigmoidOp>(op)) {
Type elementType = payloadArgs[0].getType(); Type elementType = payloadArgs[0].getType();
auto one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1)); auto one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
@ -1663,7 +1665,8 @@ struct ConvertElementwiseOp : ConversionPattern {
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp, if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp, AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp>(op)) AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenSqrtOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -2799,7 +2802,8 @@ public:
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp, target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp>(); AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenSqrtOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenUnsqueezeOp>(); target.addIllegalOp<AtenUnsqueezeOp>();
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context); patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);

View File

@ -230,7 +230,8 @@ public:
DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp,
AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp, AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp>(op)) { AtenLayerNormOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenSqrtOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }

View File

@ -508,6 +508,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)") emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)") emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)")
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::sqrt : (Tensor) -> (Tensor)")
# Misc tensor ops. # Misc tensor ops.
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)") emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")