Add `aten.gt.Tensor` op

`aten.gt.Tensor` op has been added in torch dialect and the
lowering of the op has been done to the linalg dialect.

Signed-off-by: Prashant Kumar <prashant@nod-labs.com>
pull/474/head snapshot-20211212.140
Prashant Kumar 2021-12-10 14:10:36 +05:30
parent a778f990e9
commit 528354de84
3 changed files with 83 additions and 2 deletions

View File

@ -343,6 +343,42 @@ class ElementwiseGtScalarModule(torch.nn.Module):
def ElementwiseGtScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ElementwiseGtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.gt(x, y)
@register_test_case(module_factory=lambda: ElementwiseGtFloatTensorModule())
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))
class ElementwiseGtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.gt(x, y)
@register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule())
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5,)))
# ==============================================================================

View File

@ -1669,6 +1669,32 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::MulIOp>(loc, lhs, rhs);
}
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
AtenGtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype)
gtTensor.emitError("unimplemented: different lhs and rhs dtype");
Type elementalType =
gtTensor.self().getType().cast<BaseTensorType>().getDtype();
if (elementalType.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], payloadArgs[1]);
if (IntegerType intType = elementalType.dyn_cast<mlir::IntegerType>()) {
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
payloadArgs[0], payloadArgs[1]);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
payloadArgs[0], payloadArgs[1]);
}
gtTensor.emitError("unimplemented: dtype isn't supported.");
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(div.getType())
@ -2070,7 +2096,7 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp,
AtenCeilOp>(op))
AtenCeilOp, AtenGtTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -3640,7 +3666,7 @@ public:
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenWhereSelfOp>();
AtenWhereSelfOp, AtenGtTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);

View File

@ -320,6 +320,8 @@ public:
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
return visitBinaryBroadcastingOp(op, operands);
} else if (isa<AtenGtTensorOp>(op)) {
return visitBinaryBroadcastingComparisonOp(op, operands);
} else if (auto whereSelf = llvm::dyn_cast<AtenWhereSelfOp>(op)) {
return visitAtenWhereSelfOp(whereSelf, operands);
} else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) {
@ -505,6 +507,8 @@ private:
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitBinaryBroadcastingOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitBinaryBroadcastingComparisonOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenWhereSelfOp(AtenWhereSelfOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
@ -884,6 +888,21 @@ ChangeResult TypeAnalyzer::visitBinaryBroadcastingOp(
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitBinaryBroadcastingComparisonOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto lhs = operands[0]->getValue();
auto rhs = operands[1]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext());
if (lhs.hasSizes && rhs.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()),
kUnknownSize);
}
knowledge.dtype = IntegerType::get(op->getContext(), 1);
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenWhereSelfOp(
AtenWhereSelfOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto condition = operands[0]->getValue();