[MLIR][TORCH] Add E2E support for aten.bitwise_not op

This commit adds lowering of `aten.bitwise_not` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/1354/head
Vivek Khandelwal 2022-09-07 12:28:42 +05:30
parent 7dfadc2498
commit e35741fb1d
5 changed files with 73 additions and 2 deletions

View File

@ -254,6 +254,8 @@ TOSA_PASS_SET = {
"ElementwiseMulScalarModule_float", "ElementwiseMulScalarModule_float",
"ElementwiseCeilModule_basic", "ElementwiseCeilModule_basic",
"ElementwiseReciprocalModule_basic", "ElementwiseReciprocalModule_basic",
"ElementwiseNotIntegerModule_basic",
"ElementwiseNotInt32Module_basic",
"TypePromotionAlphaWiderModule_basic", "TypePromotionAlphaWiderModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic",
"BatchNorm1DModule_basic", "BatchNorm1DModule_basic",

View File

@ -23,6 +23,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/APSInt.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -948,6 +949,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::SelectOp>(loc, pred, scalar, zero); return b.create<arith::SelectOp>(loc, pred, scalar, zero);
} }
if (auto bitwiseNot = dyn_cast<AtenBitwiseNotOp>(op)) {
Type elementType = converter->convertType(bitwiseNot.getType())
.cast<RankedTensorType>()
.getElementType();
if (elementType.isa<mlir::FloatType>()) {
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
return nullptr;
}
Value allOnesVal = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(
elementType,
APSInt::getAllOnesValue(elementType.getIntOrFloatBitWidth())));
return b.create<arith::XOrIOp>(loc, payloadArgs[0], allOnesVal);
}
op->emitError("unimplemented lowering in " op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp"); "createLinalgPayloadCalculationForElementwiseOp");
return nullptr; return nullptr;
@ -995,7 +1012,8 @@ public:
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp>(op)) AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
AtenBitwiseNotOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1470,7 +1488,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
AtenRemainderScalarOp>(); AtenRemainderScalarOp, AtenBitwiseNotOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>(); target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context); patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -6369,6 +6369,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int> %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>
} }
func.func @"__torch_mlir_shape_fn.aten.bitwise_not"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.logical_or"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> { func.func @"__torch_mlir_shape_fn.aten.logical_or"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int> %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>

View File

@ -860,6 +860,9 @@ def atenmaximum(self: List[int], other: List[int]) -> List[int]:
def atenbitwise_andTensor(self: List[int], other: List[int]) -> List[int]: def atenbitwise_andTensor(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other) return upstream_shape_functions.broadcast(self, other)
def atenbitwise_not(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenlogical_or(self: List[int], other: List[int]) -> List[int]: def atenlogical_or(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other) return upstream_shape_functions.broadcast(self, other)

View File

@ -1531,6 +1531,50 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseNotIntegerModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.bitwise_not(x)
@register_test_case(module_factory=lambda: ElementwiseNotIntegerModule())
def ElementwiseNotIntegerModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10))
# ==============================================================================
class ElementwiseNotInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.bitwise_not(x)
@register_test_case(module_factory=lambda: ElementwiseNotInt32Module())
def ElementwiseNotInt32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).to(torch.int32))
# ==============================================================================
class ElementwiseSubScalarIntModule(torch.nn.Module): class ElementwiseSubScalarIntModule(torch.nn.Module):
def __init__(self): def __init__(self):