mirror of https://github.com/llvm/torch-mlir
[LINALG] Add the lowering of `aten.ne.Scalar` op
The lowering of `aten.ne.Scalar` op has been added to the linalg backend.pull/706/head snapshot-20220405.371
parent
5620fe030e
commit
fb8cb0c5f3
|
@ -154,4 +154,6 @@ TOSA_PASS_SET = {
|
|||
"ReshapeCollapseModule_basic",
|
||||
"ElementwiseGeluModule_basic",
|
||||
"GeluBackwardModule_basic",
|
||||
"ElementwiseNeIntScalarModule_basic",
|
||||
"ElementwiseNeFloatTensorModule_basic",
|
||||
}
|
||||
|
|
|
@ -173,6 +173,22 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
|
|||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
|
||||
Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
|
||||
arith::CmpIPredicate::eq,
|
||||
arith::CmpIPredicate::eq>(
|
||||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createNotEqual(OpBuilder &b, Location loc, Type elementalType,
|
||||
Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::UNE,
|
||||
arith::CmpIPredicate::ne,
|
||||
arith::CmpIPredicate::ne>(
|
||||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
|
||||
Value sigma) {
|
||||
Type elementType = x.getType();
|
||||
|
@ -607,9 +623,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
if (dtype.isa<mlir::FloatType>())
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (dtype.isa<mlir::IntegerType>()) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor operand from integer to float.
|
||||
|
@ -617,20 +630,33 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
"unimplemented: type promotion from tensor to scalar");
|
||||
return nullptr;
|
||||
}
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
payloadArgs[0], otherPromoted);
|
||||
}
|
||||
eqScalar.emitError("unimplemented: dtype isn't supported");
|
||||
return createEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
|
||||
}
|
||||
|
||||
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
|
||||
Type dtype = neScalar.self().getType().cast<BaseTensorType>().getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
if (dtype.isa<mlir::IntegerType>()) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor operand from integer to float.
|
||||
neScalar.emitError(
|
||||
"unimplemented: type promotion from tensor to scalar");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return createNotEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
|
||||
}
|
||||
|
||||
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
||||
Type dtype = ltScalar.self().getType().cast<BaseTensorType>().getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share a
|
||||
// lot of code that can be refactored.
|
||||
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
|
||||
// a lot of code that can be refactored.
|
||||
if (dtype.isa<mlir::FloatType>())
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
payloadArgs[0], otherPromoted);
|
||||
|
@ -657,8 +683,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code that
|
||||
// can be refactored.
|
||||
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
|
||||
// that can be refactored.
|
||||
if (dtype.isa<mlir::FloatType>())
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
|
||||
payloadArgs[0], otherPromoted);
|
||||
|
@ -908,7 +934,8 @@ public:
|
|||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp>(op))
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1626,7 +1653,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
||||
AtenSinOp, AtenCosOp>();
|
||||
AtenSinOp, AtenCosOp, AtenNeScalarOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||
|
|
|
@ -2367,7 +2367,7 @@ module {
|
|||
torch.prim.If.yield %int0 : !torch.int
|
||||
} else {
|
||||
%25 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%26 = torch.aten.ge.int %12, %25 : !torch.int, !torch.int -> !torch.bool
|
||||
%26 = torch.aten.gt.int %12, %25 : !torch.int, !torch.int -> !torch.bool
|
||||
%27 = torch.prim.If %26 -> (!torch.int) {
|
||||
%28 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %28 : !torch.int
|
||||
|
|
|
@ -457,3 +457,41 @@ class ElementwiseEqIntTensorModule(torch.nn.Module):
|
|||
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseNeFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ne(x, 2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule())
|
||||
def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseNeIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ne(x, 3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseNeIntScalarModule())
|
||||
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (8, 5)))
|
||||
|
|
Loading…
Reference in New Issue