Add lowering of aten.floor op

pull/405/head snapshot-20211106.68
Wang Kangyu 2021-11-06 22:19:01 +08:00 committed by Yi Zhang
parent 5ff823ace9
commit b33543af85
5 changed files with 50 additions and 3 deletions

View File

@ -378,3 +378,19 @@ class ElementwiseSqrtModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseSqrtModule())
def ElementwiseSqrtModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class ElementwiseFloorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.floor(a)
@register_test_case(module_factory=lambda: ElementwiseFloorModule())
def ElementwiseFloorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

View File

@ -240,6 +240,34 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}
def Torch_AtenFloorOp : Torch_Op<"aten.floor", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}
def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}
def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -1278,6 +1278,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::TanhOp>(loc, payloadArgs[0]);
if (isa<AtenExpOp>(op))
return b.create<math::ExpOp>(loc, payloadArgs[0]);
if (isa<AtenFloorOp>(op))
return b.create<math::FloorOp>(loc, payloadArgs[0]);
if (isa<AtenLogOp>(op))
return b.create<math::LogOp>(loc, payloadArgs[0]);
if (isa<AtenSqrtOp>(op))
@ -1666,7 +1668,7 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenSqrtOp>(op))
AtenSqrtOp, AtenFloorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -2803,7 +2805,7 @@ public:
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenSqrtOp>();
AtenSqrtOp, AtenFloorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenUnsqueezeOp>();
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);

View File

@ -231,7 +231,7 @@ public:
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp,
AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenSqrtOp>(op)) {
AtenSqrtOp, AtenFloorOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

View File

@ -445,6 +445,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::exp : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)",
"aten::floor : (Tensor) -> (Tensor)",
"aten::bitwise_not : (Tensor) -> (Tensor)",
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",