[MLIR][TORCH] Add support for int8 dtype for sub, add, and bitwise_and op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/2502/head
Vivek Khandelwal 2023-10-03 11:59:56 +00:00
parent 32d9b20bde
commit ca6ce8974f
3 changed files with 101 additions and 8 deletions

View File

@ -291,6 +291,9 @@ TORCHDYNAMO_XFAIL_SET = {
# Lowering not present for this case
"ElementwiseToDtypeI64ToUI8Module_basic",
# torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method add of type object at 0x7f4f8b05a720>(*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8!
"ElementwiseAddScalarInt8Module_basic",
}
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
@ -1261,6 +1264,8 @@ TOSA_PASS_SET = {
"SoftmaxIntNegDimModule_basic",
"_LogSoftmaxModule_basic",
"_SoftmaxModule_basic",
"ElementwiseAddScalarInt8Module_basic",
"ElementwiseSubTensorInt8Module_basic",
}
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
@ -1421,4 +1426,5 @@ LTC_XFAIL_SET = {
"EmptyStridedModule_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
}

View File

@ -309,8 +309,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
"bitwise_and.Scalar does not support non-integer input dtype.");
return nullptr;
}
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Type resultElementType =
bitwiseAndScalar.getType().cast<BaseTensorType>().getDtype();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value other = convertScalarToDtype(b, loc, operands[1], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
return b.create<arith::AndIOp>(loc, self, other);
}
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
@ -542,9 +548,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(sub.getType())
.cast<RankedTensorType>()
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype);
Type resultElementType = sub.getType().cast<BaseTensorType>().getDtype();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
if (dtype.isa<mlir::FloatType>()) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::SubFOp>(loc, lhs, scaled);
@ -575,9 +588,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(addScalar.getType())
.cast<RankedTensorType>()
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
Type resultElementType =
addScalar.getType().cast<BaseTensorType>().getDtype();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value other = convertScalarToDtype(b, loc, operands[1], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
if (dtype.isa<mlir::FloatType>()) {
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
return b.create<arith::AddFOp>(loc, self, mult);

View File

@ -2316,6 +2316,31 @@ def ElementwiseBitwiseNotInt32Module_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseSubTensorInt8Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int8, True),
([-1, -1], torch.int8, True),
])
def forward(self, x, y):
return torch.sub(x, y, alpha=2)
@register_test_case(module_factory=lambda: ElementwiseSubTensorInt8Module())
def ElementwiseSubTensorInt8Module_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, high=10).to(dtype=torch.int8),
tu.randint(3, 4, high=10).to(dtype=torch.int8))
# ==============================================================================
class ElementwiseSubScalarIntModule(torch.nn.Module):
def __init__(self):
@ -2472,6 +2497,28 @@ def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseAddScalarInt8Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int8, True),
])
def forward(self, x):
return torch.add(x, 3, 2)
@register_test_case(module_factory=lambda: ElementwiseAddScalarInt8Module())
def ElementwiseAddScalarInt8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=10).to(torch.int8))
# ==============================================================================
class ElementwiseCloneModule(torch.nn.Module):
def __init__(self):
@ -3619,3 +3666,22 @@ class ElementwiseBitwiseAndScalarInt32Module(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt32Module())
def ElementwiseBitwiseAndScalarInt32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32))
class ElementwiseBitwiseAndScalarInt8Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int8, True),
])
def forward(self, x):
return torch.bitwise_and(x, 100)
@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt8Module())
def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int8))