From 528354de84c5892d539fce17c74b2c928a645991 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Fri, 10 Dec 2021 14:10:36 +0530 Subject: [PATCH] Add `aten.gt.Tensor` op `aten.gt.Tensor` op has been added in torch dialect and the lowering of the op has been done to the linalg dialect. Signed-off-by: Prashant Kumar --- e2e_testing/torchscript/elementwise.py | 36 +++++++++++++++++++ .../TorchToLinalg/TorchToLinalg.cpp | 30 ++++++++++++++-- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 19 ++++++++++ 3 files changed, 83 insertions(+), 2 deletions(-) diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index f297d3fdc..ab8e3a70f 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -343,6 +343,42 @@ class ElementwiseGtScalarModule(torch.nn.Module): def ElementwiseGtScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) +class ElementwiseGtFloatTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.gt(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseGtFloatTensorModule()) +def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5), tu.rand(5)) + +class ElementwiseGtIntTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.gt(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule()) +def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5,))) + # ============================================================================== diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 807e0b77a..f6a46a32a 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1669,6 +1669,32 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, lhs, rhs); } } + if (auto gtTensor = dyn_cast(op)) { + AtenGtTensorOp::Adaptor adaptor(operands); + Type lhsDtype = payloadArgs[0].getType(); + Type rhsDtype = payloadArgs[1].getType(); + + // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs + // to be handled. + if (lhsDtype != rhsDtype) + gtTensor.emitError("unimplemented: different lhs and rhs dtype"); + + Type elementalType = + gtTensor.self().getType().cast().getDtype(); + + if (elementalType.isa()) + return b.create(loc, arith::CmpFPredicate::UGT, + payloadArgs[0], payloadArgs[1]); + if (IntegerType intType = elementalType.dyn_cast()) { + if (intType.isUnsigned()) + return b.create(loc, arith::CmpIPredicate::ugt, + payloadArgs[0], payloadArgs[1]); + if (intType.isSigned()) + return b.create(loc, arith::CmpIPredicate::sgt, + payloadArgs[0], payloadArgs[1]); + } + gtTensor.emitError("unimplemented: dtype isn't supported."); + } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(div.getType()) @@ -2070,7 +2096,7 @@ struct ConvertElementwiseOp : ConversionPattern { AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp, - AtenCeilOp>(op)) + AtenCeilOp, AtenGtTensorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -3640,7 +3666,7 @@ public: AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, - AtenWhereSelfOp>(); + AtenWhereSelfOp, AtenGtTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 4563631e2..a0142683f 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -320,6 +320,8 @@ public: AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp, AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) { return visitBinaryBroadcastingOp(op, operands); + } else if (isa(op)) { + return visitBinaryBroadcastingComparisonOp(op, operands); } else if (auto whereSelf = llvm::dyn_cast(op)) { return visitAtenWhereSelfOp(whereSelf, operands); } else if (auto lerpTensor = llvm::dyn_cast(op)) { @@ -505,6 +507,8 @@ private: Operation *op, ArrayRef *> operands); ChangeResult visitBinaryBroadcastingOp( Operation *op, ArrayRef *> operands); + ChangeResult visitBinaryBroadcastingComparisonOp( + Operation *op, ArrayRef *> operands); ChangeResult visitAtenWhereSelfOp(AtenWhereSelfOp op, ArrayRef *> operands); @@ -884,6 +888,21 @@ ChangeResult TypeAnalyzer::visitBinaryBroadcastingOp( return getLatticeElement(op->getResult(0)).join(knowledge); } +ChangeResult TypeAnalyzer::visitBinaryBroadcastingComparisonOp( + Operation *op, ArrayRef *> operands) { + auto lhs = operands[0]->getValue(); + auto rhs = operands[1]->getValue(); + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(getContext()); + if (lhs.hasSizes && rhs.hasSizes) { + knowledge.hasSizes = true; + knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()), + kUnknownSize); + } + knowledge.dtype = IntegerType::get(op->getContext(), 1); + return getLatticeElement(op->getResult(0)).join(knowledge); +} + ChangeResult TypeAnalyzer::visitAtenWhereSelfOp( AtenWhereSelfOp op, ArrayRef *> operands) { auto condition = operands[0]->getValue();