mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for conversion to int8 dtype
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/2491/head
parent
71ac62f3a8
commit
c434736ee9
|
@ -288,6 +288,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
|
|
||||||
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
||||||
"AtenEmbeddingBagStaticModule_basic",
|
"AtenEmbeddingBagStaticModule_basic",
|
||||||
|
|
||||||
|
# Lowering not present for this case
|
||||||
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
||||||
|
|
|
@ -87,7 +87,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
|
||||||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||||
// should be converted builtin types.
|
// should be converted builtin types.
|
||||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
std::optional<Type> srcOriginalDtype = std::nullopt);
|
std::optional<Type> srcOriginalDtype = std::nullopt,
|
||||||
|
std::optional<Type> dstOriginalDtype = std::nullopt);
|
||||||
|
|
||||||
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
Value torchOptionalInt, Value builtinInt,
|
Value torchOptionalInt, Value builtinInt,
|
||||||
|
|
|
@ -988,7 +988,23 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(atenToDtype.getType())
|
Type dtype = converter->convertType(atenToDtype.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value result = convertScalarToDtype(b, loc, input, dtype);
|
Type resultElementType;
|
||||||
|
int64_t dtypeInt;
|
||||||
|
if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) {
|
||||||
|
atenToDtype.emitError("unimplemented: dtype must be a constant integer");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
|
||||||
|
atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||||
|
IntegerType::Signless);
|
||||||
|
if (failed(maybeResultElementType)) {
|
||||||
|
atenToDtype.emitError("unable to convert `dtypeInt` to builtin type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
resultElementType = *maybeResultElementType;
|
||||||
|
Value result = convertScalarToDtype(b, loc, input, dtype,
|
||||||
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
|
/*dstOriginalDtype=*/resultElementType);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
|
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
|
||||||
|
|
|
@ -249,7 +249,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
|
||||||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||||
// should be converted builtin types.
|
// should be converted builtin types.
|
||||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
std::optional<Type> srcOriginalDtype) {
|
std::optional<Type> srcOriginalDtype,
|
||||||
|
std::optional<Type> dstOriginalDtype) {
|
||||||
Type scalarType = scalar.getType();
|
Type scalarType = scalar.getType();
|
||||||
if (scalarType == dtype)
|
if (scalarType == dtype)
|
||||||
return scalar;
|
return scalar;
|
||||||
|
@ -261,14 +262,20 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// We only support conversion from Byte or Char scalarType not to Byte or Char
|
// We don't support conversion to Byte dtype.
|
||||||
// dtype.
|
|
||||||
if (isByteOrChar(dtype)) {
|
if (isByteOrChar(dtype)) {
|
||||||
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
|
if (!dstOriginalDtype.has_value()) {
|
||||||
"convertScalarToDtype "
|
mlir::emitError(loc)
|
||||||
<< scalarType << "(scalar type) -> " << dtype
|
<< "unimplemented: for conversion to byte or char type "
|
||||||
<< "(dtype)";
|
"dstOriginalDtype has to be passed to convertScalarToDtype";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (dstOriginalDtype->isUnsignedInteger()) {
|
||||||
|
mlir::emitError(loc)
|
||||||
|
<< "unsupported: conversion to byte type for convertScalarToDtype "
|
||||||
|
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the dtype is i1, i.e., a boolean type.
|
// If the dtype is i1, i.e., a boolean type.
|
||||||
|
|
|
@ -14,6 +14,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||||
"NativeGroupNormBackwardModule_basic",
|
"NativeGroupNormBackwardModule_basic",
|
||||||
"QuantizedMLP_basic",
|
"QuantizedMLP_basic",
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO: Delete once torch 2.1.0 is released
|
# TODO: Delete once torch 2.1.0 is released
|
||||||
|
|
|
@ -1642,6 +1642,44 @@ def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseToDtypeI64ToI8Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([-1, -1], torch.int64, True)])
|
||||||
|
def forward(self, x):
|
||||||
|
return x.to(torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToI8Module())
|
||||||
|
def ElementwiseToDtypeI64ToI8Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-100, high=100))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseToDtypeI64ToUI8Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([-1, -1], torch.int64, True)])
|
||||||
|
def forward(self, x):
|
||||||
|
return x.to(torch.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToUI8Module())
|
||||||
|
def ElementwiseToDtypeI64ToUI8Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-100, high=100))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLog2Module(torch.nn.Module):
|
class ElementwiseLog2Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue