[TORCH][MLIR] Add E2E support for `aten.eq` and `aten.lt` ops

- Added E2E support for `aten.eq.Tensor` and `aten.lt.Tensor` ops. Both
  the operands are expected to be of the same type, i.e., type promotion
  is not addressed as a part of this commit.
- Added E2E support for `aten.eq.Scalar` and `aten.lt.Scalar` ops.
  Tensor operand type to Scalar operand type promotion has not been
  handled in this commit.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/477/head
Gaurav Shukla 2021-12-15 19:45:10 +05:30 committed by Gaurav Shukla
parent 0cd95b5c68
commit eddc09aa55
5 changed files with 368 additions and 9 deletions

View File

@ -419,6 +419,194 @@ def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseLtFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.lt(x, 0.6)
@register_test_case(module_factory=lambda: ElementwiseLtFloatScalarModule())
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ElementwiseLtIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.lt(x, 0)
@register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule())
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3,4)))
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.lt(x, 2)
@register_test_case(module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3,4)).to(torch.int32))
class ElementwiseLtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.lt(x, y)
@register_test_case(module_factory=lambda: ElementwiseLtFloatTensorModule())
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))
class ElementwiseLtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.lt(x, y)
@register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule())
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5,)))
# ==============================================================================
class ElementwiseEqFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.eq(x, 6.0)
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]])
.to(torch.float32))
class ElementwiseEqIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.eq(x, 2)
@register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule())
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (5,8)))
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.eq(x, 2)
@register_test_case(module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (5,8)).to(torch.int32))
class ElementwiseEqFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.eq(x, y)
@register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule())
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]])
.to(torch.float32),
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
class ElementwiseEqIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.eq(x, y)
@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule())
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5,)))
# ==============================================================================
class ElementwiseClampModule(torch.nn.Module):
def __init__(self):

View File

@ -570,6 +570,36 @@ def Torch_AtenGt_TensorOp : Torch_Op<"aten.gt_.Tensor", [
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenLtTensorOp : Torch_Op<"aten.lt.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenLt_TensorOp : Torch_Op<"aten.lt_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::lt_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenNeTensorOp : Torch_Op<"aten.ne.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
@ -844,6 +874,36 @@ def Torch_AtenGe_ScalarOp : Torch_Op<"aten.ge_.Scalar", [
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenLtScalarOp : Torch_Op<"aten.lt.Scalar", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenLt_ScalarOp : Torch_Op<"aten.lt_.Scalar", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::lt_.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenFmodScalarOp : Torch_Op<"aten.fmod.Scalar", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -1691,8 +1691,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype)
if (lhsDtype != rhsDtype) {
gtTensor.emitError("unimplemented: different lhs and rhs dtype");
return nullptr;
}
Type elementalType =
gtTensor.self().getType().cast<BaseTensorType>().getDtype();
@ -1709,6 +1711,61 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
payloadArgs[0], payloadArgs[1]);
}
gtTensor.emitError("unimplemented: dtype isn't supported.");
return nullptr;
}
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
AtenEqTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
eqTensor.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
eqTensor.self().getType().cast<BaseTensorType>().getDtype();
if (elementalType.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
payloadArgs[0], payloadArgs[1]);
if (elementalType.isa<mlir::IntegerType>()) {
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
payloadArgs[0], payloadArgs[1]);
}
eqTensor.emitError("unimplemented: dtype isn't supported.");
return nullptr;
}
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
AtenLtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
ltTensor.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
ltTensor.self().getType().cast<BaseTensorType>().getDtype();
if (elementalType.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], payloadArgs[1]);
if (IntegerType intType = elementalType.dyn_cast<mlir::IntegerType>()) {
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
payloadArgs[0], payloadArgs[1]);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
payloadArgs[0], payloadArgs[1]);
}
ltTensor.emitError("unimplemented: dtype isn't supported.");
return nullptr;
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands);
@ -1764,6 +1821,56 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr;
}
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
Type dtype = eqScalar.self().getType().cast<BaseTensorType>().getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
payloadArgs[0], otherPromoted);
if (dtype.isa<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
eqScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
payloadArgs[0], otherPromoted);
}
eqScalar.emitError("unimplemented: dtype isn't supported");
return nullptr;
}
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
Type dtype = ltScalar.self().getType().cast<BaseTensorType>().getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share a
// lot of code that can be refactored.
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
ltScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
payloadArgs[0], otherPromoted);
}
ltScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
}
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
Type dtype = converter->convertType(whereSelf.getType())
.cast<RankedTensorType>()
@ -2130,8 +2237,9 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp>(op))
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -3788,7 +3896,8 @@ public:
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenWhereSelfOp, AtenGtTensorOp>();
AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);

View File

@ -247,8 +247,8 @@ public:
}
// These comparison ops return a tensor with 1-bit integer dtype.
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenNeScalarOp>(
op)) {
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
AtenNeScalarOp>(op)) {
auto operand = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
@ -317,10 +317,10 @@ public:
op)) {
return visitBinaryTensorScalarOp(op, operands);
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
AtenDivTensorOp, Aten__And__TensorOp, AtenMinimumOp,
AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
return visitBinaryBroadcastingOp(op, operands);
} else if (isa<AtenGtTensorOp>(op)) {
} else if (isa<AtenEqTensorOp, AtenGtTensorOp, AtenLtTensorOp>(op)) {
return visitBinaryBroadcastingComparisonOp(op, operands);
} else if (auto whereSelf = llvm::dyn_cast<AtenWhereSelfOp>(op)) {
return visitAtenWhereSelfOp(whereSelf, operands);

View File

@ -461,6 +461,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
@ -470,6 +471,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",