[LINALG] Add the lowering of `aten.ne.Scalar` op

The lowering of `aten.ne.Scalar` op has been added to
the linalg backend.
pull/706/head snapshot-20220405.371
Prashant Kumar 2022-04-03 21:49:01 +05:30
parent 5620fe030e
commit fb8cb0c5f3
4 changed files with 81 additions and 14 deletions

View File

@ -154,4 +154,6 @@ TOSA_PASS_SET = {
"ReshapeCollapseModule_basic", "ReshapeCollapseModule_basic",
"ElementwiseGeluModule_basic", "ElementwiseGeluModule_basic",
"GeluBackwardModule_basic", "GeluBackwardModule_basic",
"ElementwiseNeIntScalarModule_basic",
"ElementwiseNeFloatTensorModule_basic",
} }

View File

@ -173,6 +173,22 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
b, loc, elementalType, lhs, rhs); b, loc, elementalType, lhs, rhs);
} }
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
arith::CmpIPredicate::eq,
arith::CmpIPredicate::eq>(
b, loc, elementalType, lhs, rhs);
}
static Value createNotEqual(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UNE,
arith::CmpIPredicate::ne,
arith::CmpIPredicate::ne>(
b, loc, elementalType, lhs, rhs);
}
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean, static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
Value sigma) { Value sigma) {
Type elementType = x.getType(); Type elementType = x.getType();
@ -607,9 +623,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
payloadArgs[0], otherPromoted);
if (dtype.isa<mlir::IntegerType>()) { if (dtype.isa<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) { if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float. // TODO: Promote tensor operand from integer to float.
@ -617,11 +630,24 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
"unimplemented: type promotion from tensor to scalar"); "unimplemented: type promotion from tensor to scalar");
return nullptr; return nullptr;
} }
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
payloadArgs[0], otherPromoted);
} }
eqScalar.emitError("unimplemented: dtype isn't supported"); return createEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
return nullptr; }
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
Type dtype = neScalar.self().getType().cast<BaseTensorType>().getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
neScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
}
return createNotEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
} }
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) { if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
@ -629,8 +655,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share a // TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
// lot of code that can be refactored. // a lot of code that can be refactored.
if (dtype.isa<mlir::FloatType>()) if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], otherPromoted); payloadArgs[0], otherPromoted);
@ -657,8 +683,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code that // TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
// can be refactored. // that can be refactored.
if (dtype.isa<mlir::FloatType>()) if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
payloadArgs[0], otherPromoted); payloadArgs[0], otherPromoted);
@ -908,7 +934,8 @@ public:
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp>(op)) AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1626,7 +1653,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp>(); AtenSinOp, AtenCosOp, AtenNeScalarOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>(); target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context); patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);

View File

@ -2367,7 +2367,7 @@ module {
torch.prim.If.yield %int0 : !torch.int torch.prim.If.yield %int0 : !torch.int
} else { } else {
%25 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int %25 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int
%26 = torch.aten.ge.int %12, %25 : !torch.int, !torch.int -> !torch.bool %26 = torch.aten.gt.int %12, %25 : !torch.int, !torch.int -> !torch.bool
%27 = torch.prim.If %26 -> (!torch.int) { %27 = torch.prim.If %26 -> (!torch.int) {
%28 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int %28 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %28 : !torch.int torch.prim.If.yield %28 : !torch.int

View File

@ -457,3 +457,41 @@ class ElementwiseEqIntTensorModule(torch.nn.Module):
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils): def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, ))) module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))
# ==============================================================================
class ElementwiseNeFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ne(x, 2.0)
@register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule())
def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32))
# ==============================================================================
class ElementwiseNeIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.ne(x, 3)
@register_test_case(module_factory=lambda: ElementwiseNeIntScalarModule())
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (8, 5)))