mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for int8 dtype for sub, add, and bitwise_and op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/2502/head
parent
32d9b20bde
commit
ca6ce8974f
|
@ -291,6 +291,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
|
||||
# Lowering not present for this case
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
|
||||
# torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method add of type object at 0x7f4f8b05a720>(*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8!
|
||||
"ElementwiseAddScalarInt8Module_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
||||
|
@ -1261,6 +1264,8 @@ TOSA_PASS_SET = {
|
|||
"SoftmaxIntNegDimModule_basic",
|
||||
"_LogSoftmaxModule_basic",
|
||||
"_SoftmaxModule_basic",
|
||||
"ElementwiseAddScalarInt8Module_basic",
|
||||
"ElementwiseSubTensorInt8Module_basic",
|
||||
}
|
||||
|
||||
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||
|
@ -1421,4 +1426,5 @@ LTC_XFAIL_SET = {
|
|||
"EmptyStridedModule_basic",
|
||||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
}
|
||||
|
|
|
@ -309,8 +309,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
"bitwise_and.Scalar does not support non-integer input dtype.");
|
||||
return nullptr;
|
||||
}
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
Type resultElementType =
|
||||
bitwiseAndScalar.getType().cast<BaseTensorType>().getDtype();
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
return b.create<arith::AndIOp>(loc, self, other);
|
||||
}
|
||||
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
|
||||
|
@ -542,9 +548,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Type dtype = converter->convertType(sub.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype);
|
||||
Type resultElementType = sub.getType().cast<BaseTensorType>().getDtype();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||
return b.create<arith::SubFOp>(loc, lhs, scaled);
|
||||
|
@ -575,9 +588,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Type dtype = converter->convertType(addScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
|
||||
Type resultElementType =
|
||||
addScalar.getType().cast<BaseTensorType>().getDtype();
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
|
||||
return b.create<arith::AddFOp>(loc, self, mult);
|
||||
|
|
|
@ -2316,6 +2316,31 @@ def ElementwiseBitwiseNotInt32Module_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSubTensorInt8Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int8, True),
|
||||
([-1, -1], torch.int8, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.sub(x, y, alpha=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSubTensorInt8Module())
|
||||
def ElementwiseSubTensorInt8Module_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(3, 4, high=10).to(dtype=torch.int8),
|
||||
tu.randint(3, 4, high=10).to(dtype=torch.int8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSubScalarIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -2472,6 +2497,28 @@ def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAddScalarInt8Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int8, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.add(x, 3, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAddScalarInt8Module())
|
||||
def ElementwiseAddScalarInt8Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, high=10).to(torch.int8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCloneModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -3619,3 +3666,22 @@ class ElementwiseBitwiseAndScalarInt32Module(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt32Module())
|
||||
def ElementwiseBitwiseAndScalarInt32Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32))
|
||||
|
||||
|
||||
class ElementwiseBitwiseAndScalarInt8Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int8, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.bitwise_and(x, 100)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt8Module())
|
||||
def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int8))
|
||||
|
|
Loading…
Reference in New Issue