[MLIR][TORCH] Add Byte and Char Dtype support

pull/1443/head snapshot-20220930.612
AmosLewis 2022-09-19 11:50:51 -07:00 committed by Vivek Khandelwal
parent 0765449684
commit 940959589b
11 changed files with 124 additions and 16 deletions

View File

@ -623,4 +623,6 @@ LTC_XFAIL_SET = {
"ElementwiseRemainderScalarModule_Float_basic",
"ElementwiseRemainderScalarModule_Int_basic",
"ElementwiseRemainderScalarModule_Bool_basic",
"AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic",
}

View File

@ -78,8 +78,9 @@ SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
// Convert a scalar value to the target type. The scalar value can be an element
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
Type dtype);
Value convertScalarToDtype(
OpBuilder &b, Location loc, Value scalar, Type dtype,
llvm::Optional<Type> srcOriginalDtype = llvm::NoneType());
// Return the number of elements of a tensor if the shape is static; otherwise,
// return -1.

View File

@ -84,6 +84,8 @@ public:
Value input = adaptor.a();
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
int64_t inputRank = inputSizes.size();
Type inputDtype =
op.a().getType().template cast<BaseTensorType>().getDtype();
// The `input` tensor must contain exactly one element, i.e., either the
// `input` is a zero rank tensor or all the dimensions of the `input` tensor
@ -97,7 +99,11 @@ public:
Value constantZero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
SmallVector<Value> indices(inputRank, constantZero);
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, input, indices);
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result,
resultType, inputDtype));
return success();
}
};

View File

@ -229,14 +229,12 @@ SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
// Convert a scalar value to the target type. The scalar value can be an element
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
Type dtype) {
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
llvm::Optional<Type> srcOriginalDtype) {
Type scalarType = scalar.getType();
if (scalarType == dtype)
return scalar;
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
// be able to know if we need signed or unsigned conversion.
auto isByteOrChar = [](Type type) {
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
return integerTy.getWidth() == 8;
@ -244,11 +242,13 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
return false;
};
if (isByteOrChar(scalarType) || isByteOrChar(dtype)) {
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
mlir::emitError(loc)
<< "unsupported byte, char or bool type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
// We only support conversion from Byte or Char scalarType not to Byte or Char
// dtype.
if (isByteOrChar(dtype)) {
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
"convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype
<< "(dtype)";
return nullptr;
}
@ -278,7 +278,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
return b.create<arith::ExtFOp>(loc, dtype, scalar);
}
assert(scalarType.isa<mlir::IntegerType>());
if (scalarType.isSignlessInteger(1))
if (scalarType.isSignlessInteger(1) ||
(srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger()))
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
@ -292,7 +293,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
return b.create<arith::TruncIOp>(loc, dtype, scalar);
if (scalarType.isSignlessInteger(1))
if (scalarType.isSignlessInteger(1) ||
(srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger()))
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
// Only scalarInteger width < dtypeInteger width can reach here.
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where

View File

@ -57,6 +57,10 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::BFloat16;
if (type.isF16())
return torch_upstream::ScalarType::Half;
if (type.isUnsignedInteger(8))
return torch_upstream::ScalarType::Byte;
if (type.isSignedInteger(8))
return torch_upstream::ScalarType::Char;
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}
@ -88,6 +92,9 @@ Type Torch::getTypeForScalarType(
return mlir::FloatType::getBF16(context);
case torch_upstream::ScalarType::Half:
return mlir::FloatType::getF16(context);
case torch_upstream::ScalarType::Byte:
case torch_upstream::ScalarType::Char:
return mlir::IntegerType::get(context, 8, signedness);
default:
return Type();
}

View File

@ -61,8 +61,13 @@ static bool isArgMemRefTypeValid(Type type) {
return true;
if (integerTy.isSignlessInteger(32))
return true;
if (integerTy.isSignlessInteger(8))
return true;
if (integerTy.isSignedInteger(8))
return true;
if (integerTy.isSignlessInteger(1))
return true;
}
}
return false;
@ -139,7 +144,7 @@ static LogicalResult mungeFunction(
auto type = arg.getType();
if (!isArgMemRefTypeValid(type))
return emitError(arg.getLoc(),
"argument must be a memref of f32, f64, i32, i64, i1");
"argument must be a memref of f32, f64, i32, i64, i8, i1");
auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg);
arg.replaceAllUsesExcept(cast, cast);
arg.setType(getAbiTypeForMemRef(type));

View File

@ -356,6 +356,12 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
case ScalarType::Half:
return mlirDenseElementsAttrFloat16Get(
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
case ScalarType::Byte:
return mlirDenseElementsAttrUInt8Get(
shapedType, numElements, static_cast<const uint8_t *>(tensorData));
case ScalarType::Char:
return mlirDenseElementsAttrInt8Get(
shapedType, numElements, static_cast<const int8_t *>(tensorData));
default:
throwUnsupportedTensorError();

View File

@ -21,7 +21,7 @@ __all__ = [
def assert_arg_type_is_supported(ty):
SUPPORTED = [np.float32, np.float64, np.int32, np.int64, np.bool_]
SUPPORTED = [np.float32, np.float64, np.uint8, np.int8, np.int32, np.int64, np.bool_]
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
@ -29,11 +29,13 @@ memref_type_to_np_dtype = {
"mrf32": np.float32,
"mrf64": np.float64,
"mri1": np.bool_,
"mri8": np.int8,
"mri32": np.int32,
"mri64": np.int64
}
elemental_type_to_ctype = {
"i1": ctypes.c_bool,
"i8": ctypes.c_byte,
"i64": ctypes.c_int,
"f32": ctypes.c_float,
"f64": ctypes.c_double

View File

@ -311,3 +311,42 @@ class BoolIntConstantModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BoolIntConstantModule())
def BoolIntConstantModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class AtenIntTensorByteDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.uint8, True),
])
def forward(self, val):
return int(val)
@register_test_case(module_factory=lambda: AtenIntTensorByteDtypeModule())
def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.uint8))
# ==============================================================================
class AtenIntTensorCharDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.int8, True),
])
def forward(self, val):
return int(val)
@register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule())
def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))

View File

@ -95,6 +95,38 @@ func.func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>
// -----
// CHECK-LABEL: func.func @torch.aten.Int.Tensor$zero_rank$byte_dtype
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],ui8>) -> !torch.int {
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],ui8> -> tensor<i8>
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[I]][] : tensor<i8>
// CHECK: %[[RES:.*]] = arith.extui %[[EXTRACT]] : i8 to i64
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[RES]]
// CHECK: return %[[RET]] : !torch.int
func.func @torch.aten.Int.Tensor$zero_rank$byte_dtype(%arg0: !torch.vtensor<[],ui8>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],ui8> -> !torch.int
return %0 : !torch.int
}
// -----
// CHECK-LABEL: func.func @torch.aten.Int.Tensor$zero_rank$char_dtype
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],si8>) -> !torch.int {
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],si8> -> tensor<i8>
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[I]][] : tensor<i8>
// CHECK: %[[RES:.*]] = arith.extsi %[[EXTRACT]] : i8 to i64
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[RES]]
// CHECK: return %[[RET]] : !torch.int
func.func @torch.aten.Int.Tensor$zero_rank$char_dtype(%arg0: !torch.vtensor<[],si8>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si8> -> !torch.int
return %0 : !torch.int
}
// -----
// CHECK-LABEL: func.func @torch.aten.Float.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],f64>) -> !torch.float {
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f64> -> tensor<f64>

View File

@ -23,6 +23,8 @@ class TestModule(torch.nn.Module):
self.ones_bool = torch.ones(1, dtype=torch.bool)
self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16)
self.ones_f16 = torch.ones(1, dtype=torch.half)
self.ones_ui8 = torch.ones(1, dtype=torch.uint8)
self.ones_i8 = torch.ones(1, dtype=torch.int8)
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8)
self.arange = torch.nn.Parameter(torch.arange(3.0))
@ -36,6 +38,8 @@ class TestModule(torch.nn.Module):
# CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense<true> : tensor<1xi1>) : !torch.tensor<[1],i1>
# CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16>
# CHECK: %[[ONES_F16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf16>) : !torch.tensor<[1],f16>
# CHECK: %[[ONES_UI8:.*]] = torch.tensor.literal(dense<1> : tensor<1xui8>) : !torch.tensor<[1],ui8>
# CHECK: %[[ONES_I8:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8>
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8>
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
@ -52,6 +56,8 @@ class TestModule(torch.nn.Module):
# CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1>
# CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16>
# CHECK: torch.slot "ones_f16", %[[ONES_F16]] : !torch.tensor<[1],f16>
# CHECK: torch.slot "ones_ui8", %[[ONES_UI8]] : !torch.tensor<[1],ui8>
# CHECK: torch.slot "ones_i8", %[[ONES_I8]] : !torch.tensor<[1],si8>
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8>
# CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8>
# CHECK: }