mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.bitwise_and.tensor op
This commit adds lowering of `aten.bitwise_and.tensor` op. Signed-Off By: Vivek Khandelwal vivek@nod-labs.compull/398/head snapshot-20211202.120
parent
46a0668b3b
commit
46a2189a41
|
@ -552,3 +552,24 @@ class ElementwiseDivScalarModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ElementwiseDivScalarModule())
|
@register_test_case(module_factory=lambda: ElementwiseDivScalarModule())
|
||||||
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
class ElementwiseAndIntegerModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.bitwise_and(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAndIntegerModule())
|
||||||
|
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32),
|
||||||
|
torch.randint(-10, 10, (3, 4)))
|
||||||
|
|
|
@ -1494,6 +1494,21 @@ def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
|
||||||
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)";
|
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenBitwiseAndTensorOp : Torch_Op<"aten.bitwise_and.Tensor", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$other
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
||||||
AllowsTypeRefinement
|
AllowsTypeRefinement
|
||||||
]> {
|
]> {
|
||||||
|
|
|
@ -1394,6 +1394,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<math::SqrtOp>(loc, payloadArgs[0]);
|
return b.create<math::SqrtOp>(loc, payloadArgs[0]);
|
||||||
if (isa<AtenRsqrtOp>(op))
|
if (isa<AtenRsqrtOp>(op))
|
||||||
return b.create<math::RsqrtOp>(loc, payloadArgs[0]);
|
return b.create<math::RsqrtOp>(loc, payloadArgs[0]);
|
||||||
|
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
|
||||||
|
if (bitwiseAndTensor.getType()
|
||||||
|
.cast<ValueTensorType>()
|
||||||
|
.getDtype()
|
||||||
|
.isa<mlir::FloatType>()) {
|
||||||
|
bitwiseAndTensor.emitError(
|
||||||
|
"Bitwise_And does not support floating point dtype");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Type dtype = converter->convertType(bitwiseAndTensor.getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
||||||
|
}
|
||||||
if (isa<AtenLog2Op>(op))
|
if (isa<AtenLog2Op>(op))
|
||||||
return b.create<math::Log2Op>(loc, payloadArgs[0]);
|
return b.create<math::Log2Op>(loc, payloadArgs[0]);
|
||||||
if (isa<AtenAbsOp>(op))
|
if (isa<AtenAbsOp>(op))
|
||||||
|
@ -1895,13 +1911,14 @@ struct ConvertElementwiseOp : ConversionPattern {
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
|
||||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
|
||||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
|
||||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
|
AtenExpOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp,
|
||||||
AtenMulScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
|
||||||
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp,
|
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
|
||||||
AtenAbsOp, AtenReciprocalOp>(op))
|
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
|
||||||
|
AtenBitwiseAndTensorOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -3193,12 +3210,12 @@ public:
|
||||||
target.addIllegalOp<AtenBatchNormOp>();
|
target.addIllegalOp<AtenBatchNormOp>();
|
||||||
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||||
target.addIllegalOp<
|
target.addIllegalOp<
|
||||||
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
|
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
|
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||||
AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
|
||||||
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
|
||||||
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp,
|
AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp,
|
||||||
AtenReciprocalOp>();
|
AtenReciprocalOp, AtenBitwiseAndTensorOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenSqueezeOp>();
|
target.addIllegalOp<AtenSqueezeOp>();
|
||||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||||
|
|
|
@ -304,7 +304,7 @@ public:
|
||||||
return visitBinaryTensorScalarOp(op, operands);
|
return visitBinaryTensorScalarOp(op, operands);
|
||||||
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
||||||
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
|
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
|
||||||
AtenMinimumOp, AtenMaximumOp>(op)) {
|
AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
|
||||||
return visitBinaryBroadcastingOp(op, operands);
|
return visitBinaryBroadcastingOp(op, operands);
|
||||||
} else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) {
|
} else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) {
|
||||||
return visitAtenLerpTensorOp(lerpTensor, operands);
|
return visitAtenLerpTensorOp(lerpTensor, operands);
|
||||||
|
|
|
@ -470,6 +470,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
"aten::rsqrt : (Tensor) -> (Tensor)",
|
"aten::rsqrt : (Tensor) -> (Tensor)",
|
||||||
"aten::abs : (Tensor) -> (Tensor)",
|
"aten::abs : (Tensor) -> (Tensor)",
|
||||||
"aten::reciprocal : (Tensor) -> (Tensor)",
|
"aten::reciprocal : (Tensor) -> (Tensor)",
|
||||||
|
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
|
|
||||||
]:
|
]:
|
||||||
emit_with_mutating_variants(key)
|
emit_with_mutating_variants(key)
|
||||||
# Elementwise tensor compute ops that don't have the standard mutating
|
# Elementwise tensor compute ops that don't have the standard mutating
|
||||||
|
|
Loading…
Reference in New Issue