[MLIR][TORCH] Add support for conversion to int8 dtype

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/2491/head
Vivek Khandelwal 2023-09-29 12:19:18 +00:00
parent 71ac62f3a8
commit c434736ee9
6 changed files with 76 additions and 10 deletions

View File

@ -288,6 +288,9 @@ TORCHDYNAMO_XFAIL_SET = {
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
"AtenEmbeddingBagStaticModule_basic",
# Lowering not present for this case
"ElementwiseToDtypeI64ToUI8Module_basic",
}
if torch_version_for_comparison() < version.parse("2.1.0.dev"):

View File

@ -87,7 +87,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype = std::nullopt);
std::optional<Type> srcOriginalDtype = std::nullopt,
std::optional<Type> dstOriginalDtype = std::nullopt);
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,

View File

@ -988,7 +988,23 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(atenToDtype.getType())
.cast<RankedTensorType>()
.getElementType();
Value result = convertScalarToDtype(b, loc, input, dtype);
Type resultElementType;
int64_t dtypeInt;
if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) {
atenToDtype.emitError("unimplemented: dtype must be a constant integer");
return nullptr;
}
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless);
if (failed(maybeResultElementType)) {
atenToDtype.emitError("unable to convert `dtypeInt` to builtin type");
return nullptr;
}
resultElementType = *maybeResultElementType;
Value result = convertScalarToDtype(b, loc, input, dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
return result;
}
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {

View File

@ -249,7 +249,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype) {
std::optional<Type> srcOriginalDtype,
std::optional<Type> dstOriginalDtype) {
Type scalarType = scalar.getType();
if (scalarType == dtype)
return scalar;
@ -261,15 +262,21 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return false;
};
// We only support conversion from Byte or Char scalarType not to Byte or Char
// dtype.
// We don't support conversion to Byte dtype.
if (isByteOrChar(dtype)) {
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
"convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype
<< "(dtype)";
if (!dstOriginalDtype.has_value()) {
mlir::emitError(loc)
<< "unimplemented: for conversion to byte or char type "
"dstOriginalDtype has to be passed to convertScalarToDtype";
return nullptr;
}
if (dstOriginalDtype->isUnsignedInteger()) {
mlir::emitError(loc)
<< "unsupported: conversion to byte type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
return nullptr;
}
}
// If the dtype is i1, i.e., a boolean type.
if (dtype.isSignlessInteger(1)) {

View File

@ -14,6 +14,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"NativeGroupNormBackwardModule_basic",
"QuantizedMLP_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
}
# TODO: Delete once torch 2.1.0 is released

View File

@ -1642,6 +1642,44 @@ def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseToDtypeI64ToI8Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1, -1], torch.int64, True)])
def forward(self, x):
return x.to(torch.int8)
@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToI8Module())
def ElementwiseToDtypeI64ToI8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100))
# ==============================================================================
class ElementwiseToDtypeI64ToUI8Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1, -1], torch.int64, True)])
def forward(self, x):
return x.to(torch.uint8)
@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToUI8Module())
def ElementwiseToDtypeI64ToUI8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100))
# ==============================================================================
class ElementwiseLog2Module(torch.nn.Module):
def __init__(self):