mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for conversion to int8 dtype
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/2491/head
parent
71ac62f3a8
commit
c434736ee9
|
@ -288,6 +288,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
|
||||
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
||||
"AtenEmbeddingBagStaticModule_basic",
|
||||
|
||||
# Lowering not present for this case
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
||||
|
|
|
@ -87,7 +87,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
|
|||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||
// should be converted builtin types.
|
||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||
std::optional<Type> srcOriginalDtype = std::nullopt);
|
||||
std::optional<Type> srcOriginalDtype = std::nullopt,
|
||||
std::optional<Type> dstOriginalDtype = std::nullopt);
|
||||
|
||||
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value torchOptionalInt, Value builtinInt,
|
||||
|
|
|
@ -988,7 +988,23 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Type dtype = converter->convertType(atenToDtype.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value result = convertScalarToDtype(b, loc, input, dtype);
|
||||
Type resultElementType;
|
||||
int64_t dtypeInt;
|
||||
if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) {
|
||||
atenToDtype.emitError("unimplemented: dtype must be a constant integer");
|
||||
return nullptr;
|
||||
}
|
||||
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
|
||||
atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||
IntegerType::Signless);
|
||||
if (failed(maybeResultElementType)) {
|
||||
atenToDtype.emitError("unable to convert `dtypeInt` to builtin type");
|
||||
return nullptr;
|
||||
}
|
||||
resultElementType = *maybeResultElementType;
|
||||
Value result = convertScalarToDtype(b, loc, input, dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
return result;
|
||||
}
|
||||
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
|
||||
|
|
|
@ -249,7 +249,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
|
|||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||
// should be converted builtin types.
|
||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||
std::optional<Type> srcOriginalDtype) {
|
||||
std::optional<Type> srcOriginalDtype,
|
||||
std::optional<Type> dstOriginalDtype) {
|
||||
Type scalarType = scalar.getType();
|
||||
if (scalarType == dtype)
|
||||
return scalar;
|
||||
|
@ -261,15 +262,21 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
return false;
|
||||
};
|
||||
|
||||
// We only support conversion from Byte or Char scalarType not to Byte or Char
|
||||
// dtype.
|
||||
// We don't support conversion to Byte dtype.
|
||||
if (isByteOrChar(dtype)) {
|
||||
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
|
||||
"convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype
|
||||
<< "(dtype)";
|
||||
if (!dstOriginalDtype.has_value()) {
|
||||
mlir::emitError(loc)
|
||||
<< "unimplemented: for conversion to byte or char type "
|
||||
"dstOriginalDtype has to be passed to convertScalarToDtype";
|
||||
return nullptr;
|
||||
}
|
||||
if (dstOriginalDtype->isUnsignedInteger()) {
|
||||
mlir::emitError(loc)
|
||||
<< "unsupported: conversion to byte type for convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// If the dtype is i1, i.e., a boolean type.
|
||||
if (dtype.isSignlessInteger(1)) {
|
||||
|
|
|
@ -14,6 +14,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"NativeGroupNormBackwardModule_basic",
|
||||
"QuantizedMLP_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
}
|
||||
|
||||
# TODO: Delete once torch 2.1.0 is released
|
||||
|
|
|
@ -1642,6 +1642,44 @@ def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseToDtypeI64ToI8Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int8)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToI8Module())
|
||||
def ElementwiseToDtypeI64ToI8Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseToDtypeI64ToUI8Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.uint8)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToUI8Module())
|
||||
def ElementwiseToDtypeI64ToUI8Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLog2Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue