mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.bitwise_not op
This commit adds lowering of `aten.bitwise_not` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1354/head
parent
7dfadc2498
commit
e35741fb1d
|
@ -254,6 +254,8 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseMulScalarModule_float",
|
"ElementwiseMulScalarModule_float",
|
||||||
"ElementwiseCeilModule_basic",
|
"ElementwiseCeilModule_basic",
|
||||||
"ElementwiseReciprocalModule_basic",
|
"ElementwiseReciprocalModule_basic",
|
||||||
|
"ElementwiseNotIntegerModule_basic",
|
||||||
|
"ElementwiseNotInt32Module_basic",
|
||||||
"TypePromotionAlphaWiderModule_basic",
|
"TypePromotionAlphaWiderModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||||
"BatchNorm1DModule_basic",
|
"BatchNorm1DModule_basic",
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
#include "llvm/ADT/APSInt.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -948,6 +949,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::SelectOp>(loc, pred, scalar, zero);
|
return b.create<arith::SelectOp>(loc, pred, scalar, zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto bitwiseNot = dyn_cast<AtenBitwiseNotOp>(op)) {
|
||||||
|
Type elementType = converter->convertType(bitwiseNot.getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
if (elementType.isa<mlir::FloatType>()) {
|
||||||
|
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value allOnesVal = b.create<arith::ConstantOp>(
|
||||||
|
loc, b.getIntegerAttr(
|
||||||
|
elementType,
|
||||||
|
APSInt::getAllOnesValue(elementType.getIntOrFloatBitWidth())));
|
||||||
|
return b.create<arith::XOrIOp>(loc, payloadArgs[0], allOnesVal);
|
||||||
|
}
|
||||||
|
|
||||||
op->emitError("unimplemented lowering in "
|
op->emitError("unimplemented lowering in "
|
||||||
"createLinalgPayloadCalculationForElementwiseOp");
|
"createLinalgPayloadCalculationForElementwiseOp");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -995,7 +1012,8 @@ public:
|
||||||
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp>(op))
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
||||||
|
AtenBitwiseNotOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -1470,7 +1488,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
||||||
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
||||||
AtenRemainderScalarOp>();
|
AtenRemainderScalarOp, AtenBitwiseNotOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||||
|
|
|
@ -6369,6 +6369,10 @@ module {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func.func @"__torch_mlir_shape_fn.aten.bitwise_not"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.logical_or"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.logical_or"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
|
|
|
@ -860,6 +860,9 @@ def aten〇maximum(self: List[int], other: List[int]) -> List[int]:
|
||||||
def aten〇bitwise_and〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
def aten〇bitwise_and〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.broadcast(self, other)
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
|
def aten〇bitwise_not(self: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇logical_or(self: List[int], other: List[int]) -> List[int]:
|
def aten〇logical_or(self: List[int], other: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.broadcast(self, other)
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
|
|
|
@ -1531,6 +1531,50 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseNotIntegerModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.bitwise_not(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseNotIntegerModule())
|
||||||
|
def ElementwiseNotIntegerModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-10, high=10))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseNotInt32Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.bitwise_not(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseNotInt32Module())
|
||||||
|
def ElementwiseNotInt32Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-10, high=10).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseSubScalarIntModule(torch.nn.Module):
|
class ElementwiseSubScalarIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue