[MLIR][TORCH] Add E2E support for aten.bitwise_and.tensor op

This commit adds lowering of `aten.bitwise_and.tensor` op.

Signed-Off By: Vivek Khandelwal vivek@nod-labs.com
pull/398/head snapshot-20211202.120
Vivek Khandelwal 2021-11-23 21:54:47 +05:30
parent 46a0668b3b
commit 46a2189a41
5 changed files with 69 additions and 14 deletions

View File

@ -552,3 +552,24 @@ class ElementwiseDivScalarModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseDivScalarModule()) @register_test_case(module_factory=lambda: ElementwiseDivScalarModule())
def ElementwiseDivScalarModule_basic(module, tu: TestUtils): def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4)) module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseAndIntegerModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.int64, True),
])
def forward(self, x, y):
return torch.bitwise_and(x, y)
@register_test_case(module_factory=lambda: ElementwiseAndIntegerModule())
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32),
torch.randint(-10, 10, (3, 4)))

View File

@ -1494,6 +1494,21 @@ def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)"; let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)";
} }
def Torch_AtenBitwiseAndTensorOp : Torch_Op<"aten.bitwise_and.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
AllowsTypeRefinement AllowsTypeRefinement
]> { ]> {

View File

@ -1394,6 +1394,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::SqrtOp>(loc, payloadArgs[0]); return b.create<math::SqrtOp>(loc, payloadArgs[0]);
if (isa<AtenRsqrtOp>(op)) if (isa<AtenRsqrtOp>(op))
return b.create<math::RsqrtOp>(loc, payloadArgs[0]); return b.create<math::RsqrtOp>(loc, payloadArgs[0]);
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
if (bitwiseAndTensor.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
bitwiseAndTensor.emitError(
"Bitwise_And does not support floating point dtype");
return nullptr;
}
Type dtype = converter->convertType(bitwiseAndTensor.getType())
.cast<RankedTensorType>()
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
if (isa<AtenLog2Op>(op)) if (isa<AtenLog2Op>(op))
return b.create<math::Log2Op>(loc, payloadArgs[0]); return b.create<math::Log2Op>(loc, payloadArgs[0]);
if (isa<AtenAbsOp>(op)) if (isa<AtenAbsOp>(op))
@ -1895,13 +1911,14 @@ struct ConvertElementwiseOp : ConversionPattern {
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenExpOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp,
AtenMulScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenAbsOp, AtenReciprocalOp>(op)) AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -3193,12 +3210,12 @@ public:
target.addIllegalOp<AtenBatchNormOp>(); target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context); patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp< target.addIllegalOp<
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp,
AtenReciprocalOp>(); AtenReciprocalOp, AtenBitwiseAndTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>(); target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context); patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);

View File

@ -304,7 +304,7 @@ public:
return visitBinaryTensorScalarOp(op, operands); return visitBinaryTensorScalarOp(op, operands);
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, } else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp, AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
AtenMinimumOp, AtenMaximumOp>(op)) { AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
return visitBinaryBroadcastingOp(op, operands); return visitBinaryBroadcastingOp(op, operands);
} else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) { } else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) {
return visitAtenLerpTensorOp(lerpTensor, operands); return visitAtenLerpTensorOp(lerpTensor, operands);

View File

@ -470,6 +470,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::rsqrt : (Tensor) -> (Tensor)", "aten::rsqrt : (Tensor) -> (Tensor)",
"aten::abs : (Tensor) -> (Tensor)", "aten::abs : (Tensor) -> (Tensor)",
"aten::reciprocal : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)",
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
]: ]:
emit_with_mutating_variants(key) emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating # Elementwise tensor compute ops that don't have the standard mutating