mirror of https://github.com/llvm/torch-mlir
[LINALG] Add lowering for aten::round op.
-- Added the lowering for aten::round op. -- Added the folding for integer cases.pull/1488/head
parent
8f76c74be9
commit
3a2cd23380
|
@ -279,6 +279,7 @@ MHLO_PASS_SET = {
|
|||
"Permute0RankModule_basic",
|
||||
"UnsafeViewCollapseModule_basic",
|
||||
"UnsafeViewDynamicExpandModule_basic",
|
||||
"AtenRoundIntModule_basic",
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
|
@ -470,6 +471,7 @@ TOSA_PASS_SET = {
|
|||
"ToDtypeBoolLayoutNoneStaticModule_basic",
|
||||
"ToCopyBoolDTypeStaticModule_basic",
|
||||
"HardTanhIntModule_basic",
|
||||
"AtenRoundIntModule_basic"
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
|
|
|
@ -3255,6 +3255,52 @@ def Torch_AtenTriu_Op : Torch_Op<"aten.triu_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRoundOp : Torch_Op<"aten.round", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::round : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRoundOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenRoundOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenRound_Op : Torch_Op<"aten.round_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::round_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRound_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenRound_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -252,6 +252,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
payloadArgs[0], constZero);
|
||||
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], constZero);
|
||||
}
|
||||
if (auto round = dyn_cast<AtenRoundOp>(op)) {
|
||||
if (!round.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
round.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
return b.create<math::RoundOp>(loc, payloadArgs[0]);
|
||||
}
|
||||
if (auto lrelu = dyn_cast<AtenLeakyReluOp>(op)) {
|
||||
if (!lrelu.getType()
|
||||
.cast<ValueTensorType>()
|
||||
|
@ -1029,7 +1039,7 @@ public:
|
|||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
||||
AtenBitwiseNotOp>(op))
|
||||
AtenBitwiseNotOp, AtenRoundOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1504,7 +1514,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
||||
AtenRemainderScalarOp, AtenBitwiseNotOp>();
|
||||
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -673,6 +673,18 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenRoundOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto selfType = self().getType().dyn_cast<BaseTensorType>()) {
|
||||
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
|
||||
return self();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenTypeAsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -701,8 +701,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp,
|
||||
AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp,
|
||||
AtenIndexTensorHackedTwinOp, AtenUpsampleNearest2dVecOp, AtenMishOp>(
|
||||
op)) {
|
||||
AtenIndexTensorHackedTwinOp, AtenUpsampleNearest2dVecOp, AtenMishOp,
|
||||
AtenRoundOp>(op)) {
|
||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
}
|
||||
|
||||
|
|
|
@ -5582,6 +5582,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.round\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._softmax\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
|
|
@ -403,6 +403,9 @@ def aten〇relu(self: List[int]) -> List[int]:
|
|||
def aten〇relu6(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇round(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇_softmax(self: List[int], dim: int, half_to_float: bool) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
|
|
@ -327,6 +327,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
|
||||
|
||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants(
|
||||
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
||||
emit_with_mutating_variants(
|
||||
|
|
|
@ -2273,3 +2273,42 @@ class AtenTriuWithNegDiagonalModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: AtenTriuWithNegDiagonalModule())
|
||||
def AtenTriuWithNegDiagonalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 5, 9))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenRoundFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.round(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenRoundFloatModule())
|
||||
def AtenRoundFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 5, low = -3.0, high = 3.0))
|
||||
|
||||
class AtenRoundIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.round(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenRoundIntModule())
|
||||
def AtenRoundIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(5, 5, low = -10))
|
||||
|
|
Loading…
Reference in New Issue