[LINALG] Add the lowering of `aten.ne.Scalar` op

The lowering of `aten.ne.Scalar` op has been added to
the linalg backend.
pull/706/head snapshot-20220405.371
Prashant Kumar 2022-04-03 21:49:01 +05:30
parent 5620fe030e
commit fb8cb0c5f3
4 changed files with 81 additions and 14 deletions

View File

@ -154,4 +154,6 @@ TOSA_PASS_SET = {
"ReshapeCollapseModule_basic",
"ElementwiseGeluModule_basic",
"GeluBackwardModule_basic",
"ElementwiseNeIntScalarModule_basic",
"ElementwiseNeFloatTensorModule_basic",
}

View File

@ -173,6 +173,22 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
b, loc, elementalType, lhs, rhs);
}
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
arith::CmpIPredicate::eq,
arith::CmpIPredicate::eq>(
b, loc, elementalType, lhs, rhs);
}
static Value createNotEqual(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UNE,
arith::CmpIPredicate::ne,
arith::CmpIPredicate::ne>(
b, loc, elementalType, lhs, rhs);
}
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
Value sigma) {
Type elementType = x.getType();
@ -607,9 +623,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
payloadArgs[0], otherPromoted);
if (dtype.isa<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
@ -617,20 +630,33 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
payloadArgs[0], otherPromoted);
}
eqScalar.emitError("unimplemented: dtype isn't supported");
return createEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
}
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
Type dtype = neScalar.self().getType().cast<BaseTensorType>().getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
neScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
}
return createNotEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
}
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
Type dtype = ltScalar.self().getType().cast<BaseTensorType>().getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share a
// lot of code that can be refactored.
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
// a lot of code that can be refactored.
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], otherPromoted);
@ -657,8 +683,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code that
// can be refactored.
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
// that can be refactored.
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
payloadArgs[0], otherPromoted);
@ -908,7 +934,8 @@ public:
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp>(op))
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1626,7 +1653,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp>();
AtenSinOp, AtenCosOp, AtenNeScalarOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);

View File

@ -2367,7 +2367,7 @@ module {
torch.prim.If.yield %int0 : !torch.int
} else {
%25 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int
%26 = torch.aten.ge.int %12, %25 : !torch.int, !torch.int -> !torch.bool
%26 = torch.aten.gt.int %12, %25 : !torch.int, !torch.int -> !torch.bool
%27 = torch.prim.If %26 -> (!torch.int) {
%28 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %28 : !torch.int

View File

@ -457,3 +457,41 @@ class ElementwiseEqIntTensorModule(torch.nn.Module):
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))
# ==============================================================================
class ElementwiseNeFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ne(x, 2.0)
@register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule())
def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32))
# ==============================================================================
class ElementwiseNeIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.ne(x, 3)
@register_test_case(module_factory=lambda: ElementwiseNeIntScalarModule())
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (8, 5)))