mirror of https://github.com/llvm/torch-mlir
[LINALG] Add the lowering of `aten.ne.Scalar` op
The lowering of `aten.ne.Scalar` op has been added to the linalg backend.pull/706/head snapshot-20220405.371
parent
5620fe030e
commit
fb8cb0c5f3
|
@ -154,4 +154,6 @@ TOSA_PASS_SET = {
|
||||||
"ReshapeCollapseModule_basic",
|
"ReshapeCollapseModule_basic",
|
||||||
"ElementwiseGeluModule_basic",
|
"ElementwiseGeluModule_basic",
|
||||||
"GeluBackwardModule_basic",
|
"GeluBackwardModule_basic",
|
||||||
|
"ElementwiseNeIntScalarModule_basic",
|
||||||
|
"ElementwiseNeFloatTensorModule_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -173,6 +173,22 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
|
||||||
b, loc, elementalType, lhs, rhs);
|
b, loc, elementalType, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
|
||||||
|
Value lhs, Value rhs) {
|
||||||
|
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
|
||||||
|
arith::CmpIPredicate::eq,
|
||||||
|
arith::CmpIPredicate::eq>(
|
||||||
|
b, loc, elementalType, lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createNotEqual(OpBuilder &b, Location loc, Type elementalType,
|
||||||
|
Value lhs, Value rhs) {
|
||||||
|
return createComparisonTemplate<arith::CmpFPredicate::UNE,
|
||||||
|
arith::CmpIPredicate::ne,
|
||||||
|
arith::CmpIPredicate::ne>(
|
||||||
|
b, loc, elementalType, lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
|
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
|
||||||
Value sigma) {
|
Value sigma) {
|
||||||
Type elementType = x.getType();
|
Type elementType = x.getType();
|
||||||
|
@ -607,9 +623,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
if (dtype.isa<mlir::FloatType>())
|
|
||||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
|
|
||||||
payloadArgs[0], otherPromoted);
|
|
||||||
if (dtype.isa<mlir::IntegerType>()) {
|
if (dtype.isa<mlir::IntegerType>()) {
|
||||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||||
// TODO: Promote tensor operand from integer to float.
|
// TODO: Promote tensor operand from integer to float.
|
||||||
|
@ -617,20 +630,33 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
"unimplemented: type promotion from tensor to scalar");
|
"unimplemented: type promotion from tensor to scalar");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
|
||||||
payloadArgs[0], otherPromoted);
|
|
||||||
}
|
}
|
||||||
eqScalar.emitError("unimplemented: dtype isn't supported");
|
return createEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
|
||||||
|
Type dtype = neScalar.self().getType().cast<BaseTensorType>().getDtype();
|
||||||
|
Value otherPromoted =
|
||||||
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
|
if (dtype.isa<mlir::IntegerType>()) {
|
||||||
|
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||||
|
// TODO: Promote tensor operand from integer to float.
|
||||||
|
neScalar.emitError(
|
||||||
|
"unimplemented: type promotion from tensor to scalar");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return createNotEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
|
||||||
|
}
|
||||||
|
|
||||||
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
||||||
Type dtype = ltScalar.self().getType().cast<BaseTensorType>().getDtype();
|
Type dtype = ltScalar.self().getType().cast<BaseTensorType>().getDtype();
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share a
|
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
|
||||||
// lot of code that can be refactored.
|
// a lot of code that can be refactored.
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (dtype.isa<mlir::FloatType>())
|
||||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||||
payloadArgs[0], otherPromoted);
|
payloadArgs[0], otherPromoted);
|
||||||
|
@ -657,8 +683,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code that
|
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
|
||||||
// can be refactored.
|
// that can be refactored.
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (dtype.isa<mlir::FloatType>())
|
||||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
|
||||||
payloadArgs[0], otherPromoted);
|
payloadArgs[0], otherPromoted);
|
||||||
|
@ -908,7 +934,8 @@ public:
|
||||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||||
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp>(op))
|
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||||
|
AtenNeScalarOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -1626,7 +1653,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||||
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
||||||
AtenSinOp, AtenCosOp>();
|
AtenSinOp, AtenCosOp, AtenNeScalarOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||||
|
|
|
@ -2367,7 +2367,7 @@ module {
|
||||||
torch.prim.If.yield %int0 : !torch.int
|
torch.prim.If.yield %int0 : !torch.int
|
||||||
} else {
|
} else {
|
||||||
%25 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int
|
%25 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
%26 = torch.aten.ge.int %12, %25 : !torch.int, !torch.int -> !torch.bool
|
%26 = torch.aten.gt.int %12, %25 : !torch.int, !torch.int -> !torch.bool
|
||||||
%27 = torch.prim.If %26 -> (!torch.int) {
|
%27 = torch.prim.If %26 -> (!torch.int) {
|
||||||
%28 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int
|
%28 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
torch.prim.If.yield %28 : !torch.int
|
torch.prim.If.yield %28 : !torch.int
|
||||||
|
|
|
@ -457,3 +457,41 @@ class ElementwiseEqIntTensorModule(torch.nn.Module):
|
||||||
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))
|
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseNeFloatScalarModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ne(x, 2.0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule())
|
||||||
|
def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseNeIntScalarModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ne(x, 3)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseNeIntScalarModule())
|
||||||
|
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(2, 4, (8, 5)))
|
||||||
|
|
Loading…
Reference in New Issue