[TORCH][MLIR] Add E2E support for `aten.ceil` op

This commit adds lowering of `aten.ceil` op as a part of element-wise
ops lowering.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/472/head snapshot-20211211.138
Gaurav Shukla 2021-12-11 22:50:24 +05:30 committed by Gaurav Shukla
parent 03b6edce68
commit a778f990e9
5 changed files with 52 additions and 4 deletions

View File

@ -521,6 +521,22 @@ class ElementwiseFloorModule(torch.nn.Module):
def ElementwiseFloorModule_basic(module, tu: TestUtils): def ElementwiseFloorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4)) module.forward(tu.rand(3, 4))
class ElementwiseCeilModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ceil(a)
@register_test_case(module_factory=lambda: ElementwiseCeilModule())
def ElementwiseCeilModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class ElementwisePowModule(torch.nn.Module): class ElementwisePowModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -298,6 +298,34 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
} }
def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::ceil : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}
def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::ceil_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}
def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics

View File

@ -1512,6 +1512,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::ExpOp>(loc, payloadArgs[0]); return b.create<math::ExpOp>(loc, payloadArgs[0]);
if (isa<AtenFloorOp>(op)) if (isa<AtenFloorOp>(op))
return b.create<math::FloorOp>(loc, payloadArgs[0]); return b.create<math::FloorOp>(loc, payloadArgs[0]);
if (isa<AtenCeilOp>(op))
return b.create<math::CeilOp>(loc, payloadArgs[0]);
if (isa<AtenLogOp>(op)) if (isa<AtenLogOp>(op))
return b.create<math::LogOp>(loc, payloadArgs[0]); return b.create<math::LogOp>(loc, payloadArgs[0]);
if (isa<AtenSqrtOp>(op)) if (isa<AtenSqrtOp>(op))
@ -2067,7 +2069,8 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp>(op)) AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp,
AtenCeilOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -3635,8 +3638,8 @@ public:
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenWhereSelfOp>(); AtenWhereSelfOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>(); target.addIllegalOp<AtenSqueezeOp>();

View File

@ -242,7 +242,7 @@ public:
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp,
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp,
AtenAddIntOp, AtenAbsOp, AtenReciprocalOp>(op)) { AtenAddIntOp, AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }

View File

@ -447,6 +447,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::cos : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)",
"aten::floor : (Tensor) -> (Tensor)", "aten::floor : (Tensor) -> (Tensor)",
"aten::ceil : (Tensor) -> (Tensor)",
"aten::bitwise_not : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)",
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",