mirror of https://github.com/llvm/torch-mlir
parent
0765449684
commit
940959589b
|
@ -623,4 +623,6 @@ LTC_XFAIL_SET = {
|
|||
"ElementwiseRemainderScalarModule_Float_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||
"AtenIntTensorByteDtypeModule_basic",
|
||||
"AtenIntTensorCharDtypeModule_basic",
|
||||
}
|
||||
|
|
|
@ -78,8 +78,9 @@ SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
|
|||
// Convert a scalar value to the target type. The scalar value can be an element
|
||||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||
// should be converted builtin types.
|
||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||
Type dtype);
|
||||
Value convertScalarToDtype(
|
||||
OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||
llvm::Optional<Type> srcOriginalDtype = llvm::NoneType());
|
||||
|
||||
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||
// return -1.
|
||||
|
|
|
@ -84,6 +84,8 @@ public:
|
|||
Value input = adaptor.a();
|
||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
int64_t inputRank = inputSizes.size();
|
||||
Type inputDtype =
|
||||
op.a().getType().template cast<BaseTensorType>().getDtype();
|
||||
|
||||
// The `input` tensor must contain exactly one element, i.e., either the
|
||||
// `input` is a zero rank tensor or all the dimensions of the `input` tensor
|
||||
|
@ -97,7 +99,11 @@ public:
|
|||
Value constantZero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||
SmallVector<Value> indices(inputRank, constantZero);
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, input, indices);
|
||||
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
|
||||
Type resultType =
|
||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result,
|
||||
resultType, inputDtype));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -229,14 +229,12 @@ SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
|
|||
// Convert a scalar value to the target type. The scalar value can be an element
|
||||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||
// should be converted builtin types.
|
||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||
Type dtype) {
|
||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||
llvm::Optional<Type> srcOriginalDtype) {
|
||||
Type scalarType = scalar.getType();
|
||||
if (scalarType == dtype)
|
||||
return scalar;
|
||||
|
||||
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
|
||||
// be able to know if we need signed or unsigned conversion.
|
||||
auto isByteOrChar = [](Type type) {
|
||||
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
|
||||
return integerTy.getWidth() == 8;
|
||||
|
@ -244,11 +242,13 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
|||
return false;
|
||||
};
|
||||
|
||||
if (isByteOrChar(scalarType) || isByteOrChar(dtype)) {
|
||||
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
|
||||
mlir::emitError(loc)
|
||||
<< "unsupported byte, char or bool type for convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
||||
// We only support conversion from Byte or Char scalarType not to Byte or Char
|
||||
// dtype.
|
||||
if (isByteOrChar(dtype)) {
|
||||
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
|
||||
"convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype
|
||||
<< "(dtype)";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -278,7 +278,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
|||
return b.create<arith::ExtFOp>(loc, dtype, scalar);
|
||||
}
|
||||
assert(scalarType.isa<mlir::IntegerType>());
|
||||
if (scalarType.isSignlessInteger(1))
|
||||
if (scalarType.isSignlessInteger(1) ||
|
||||
(srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger()))
|
||||
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
|
||||
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
|
||||
// unsigned handling is needed, and we checked for that case above.
|
||||
|
@ -292,7 +293,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
|||
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
|
||||
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
||||
return b.create<arith::TruncIOp>(loc, dtype, scalar);
|
||||
if (scalarType.isSignlessInteger(1))
|
||||
if (scalarType.isSignlessInteger(1) ||
|
||||
(srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger()))
|
||||
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
|
||||
// Only scalarInteger width < dtypeInteger width can reach here.
|
||||
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
|
||||
|
|
|
@ -57,6 +57,10 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
|||
return torch_upstream::ScalarType::BFloat16;
|
||||
if (type.isF16())
|
||||
return torch_upstream::ScalarType::Half;
|
||||
if (type.isUnsignedInteger(8))
|
||||
return torch_upstream::ScalarType::Byte;
|
||||
if (type.isSignedInteger(8))
|
||||
return torch_upstream::ScalarType::Char;
|
||||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
||||
}
|
||||
|
||||
|
@ -88,6 +92,9 @@ Type Torch::getTypeForScalarType(
|
|||
return mlir::FloatType::getBF16(context);
|
||||
case torch_upstream::ScalarType::Half:
|
||||
return mlir::FloatType::getF16(context);
|
||||
case torch_upstream::ScalarType::Byte:
|
||||
case torch_upstream::ScalarType::Char:
|
||||
return mlir::IntegerType::get(context, 8, signedness);
|
||||
default:
|
||||
return Type();
|
||||
}
|
||||
|
|
|
@ -61,8 +61,13 @@ static bool isArgMemRefTypeValid(Type type) {
|
|||
return true;
|
||||
if (integerTy.isSignlessInteger(32))
|
||||
return true;
|
||||
if (integerTy.isSignlessInteger(8))
|
||||
return true;
|
||||
if (integerTy.isSignedInteger(8))
|
||||
return true;
|
||||
if (integerTy.isSignlessInteger(1))
|
||||
return true;
|
||||
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
@ -139,7 +144,7 @@ static LogicalResult mungeFunction(
|
|||
auto type = arg.getType();
|
||||
if (!isArgMemRefTypeValid(type))
|
||||
return emitError(arg.getLoc(),
|
||||
"argument must be a memref of f32, f64, i32, i64, i1");
|
||||
"argument must be a memref of f32, f64, i32, i64, i8, i1");
|
||||
auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg);
|
||||
arg.replaceAllUsesExcept(cast, cast);
|
||||
arg.setType(getAbiTypeForMemRef(type));
|
||||
|
|
|
@ -356,6 +356,12 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
|||
case ScalarType::Half:
|
||||
return mlirDenseElementsAttrFloat16Get(
|
||||
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
|
||||
case ScalarType::Byte:
|
||||
return mlirDenseElementsAttrUInt8Get(
|
||||
shapedType, numElements, static_cast<const uint8_t *>(tensorData));
|
||||
case ScalarType::Char:
|
||||
return mlirDenseElementsAttrInt8Get(
|
||||
shapedType, numElements, static_cast<const int8_t *>(tensorData));
|
||||
|
||||
default:
|
||||
throwUnsupportedTensorError();
|
||||
|
|
|
@ -21,7 +21,7 @@ __all__ = [
|
|||
|
||||
|
||||
def assert_arg_type_is_supported(ty):
|
||||
SUPPORTED = [np.float32, np.float64, np.int32, np.int64, np.bool_]
|
||||
SUPPORTED = [np.float32, np.float64, np.uint8, np.int8, np.int32, np.int64, np.bool_]
|
||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
|
||||
|
||||
|
||||
|
@ -29,11 +29,13 @@ memref_type_to_np_dtype = {
|
|||
"mrf32": np.float32,
|
||||
"mrf64": np.float64,
|
||||
"mri1": np.bool_,
|
||||
"mri8": np.int8,
|
||||
"mri32": np.int32,
|
||||
"mri64": np.int64
|
||||
}
|
||||
elemental_type_to_ctype = {
|
||||
"i1": ctypes.c_bool,
|
||||
"i8": ctypes.c_byte,
|
||||
"i64": ctypes.c_int,
|
||||
"f32": ctypes.c_float,
|
||||
"f64": ctypes.c_double
|
||||
|
|
|
@ -311,3 +311,42 @@ class BoolIntConstantModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: BoolIntConstantModule())
|
||||
def BoolIntConstantModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenIntTensorByteDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.uint8, True),
|
||||
])
|
||||
|
||||
def forward(self, val):
|
||||
return int(val)
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenIntTensorByteDtypeModule())
|
||||
def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.uint8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenIntTensorCharDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.int8, True),
|
||||
])
|
||||
|
||||
def forward(self, val):
|
||||
return int(val)
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule())
|
||||
def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))
|
||||
|
|
|
@ -95,6 +95,38 @@ func.func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.Int.Tensor$zero_rank$byte_dtype
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],ui8>) -> !torch.int {
|
||||
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],ui8> -> tensor<i8>
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[I]][] : tensor<i8>
|
||||
// CHECK: %[[RES:.*]] = arith.extui %[[EXTRACT]] : i8 to i64
|
||||
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[RES]]
|
||||
// CHECK: return %[[RET]] : !torch.int
|
||||
func.func @torch.aten.Int.Tensor$zero_rank$byte_dtype(%arg0: !torch.vtensor<[],ui8>) -> !torch.int {
|
||||
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],ui8> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.Int.Tensor$zero_rank$char_dtype
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],si8>) -> !torch.int {
|
||||
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],si8> -> tensor<i8>
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[I]][] : tensor<i8>
|
||||
// CHECK: %[[RES:.*]] = arith.extsi %[[EXTRACT]] : i8 to i64
|
||||
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[RES]]
|
||||
// CHECK: return %[[RET]] : !torch.int
|
||||
func.func @torch.aten.Int.Tensor$zero_rank$char_dtype(%arg0: !torch.vtensor<[],si8>) -> !torch.int {
|
||||
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si8> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.Float.Tensor$zero_rank
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],f64>) -> !torch.float {
|
||||
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f64> -> tensor<f64>
|
||||
|
|
|
@ -23,6 +23,8 @@ class TestModule(torch.nn.Module):
|
|||
self.ones_bool = torch.ones(1, dtype=torch.bool)
|
||||
self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16)
|
||||
self.ones_f16 = torch.ones(1, dtype=torch.half)
|
||||
self.ones_ui8 = torch.ones(1, dtype=torch.uint8)
|
||||
self.ones_i8 = torch.ones(1, dtype=torch.int8)
|
||||
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
|
||||
self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8)
|
||||
self.arange = torch.nn.Parameter(torch.arange(3.0))
|
||||
|
@ -36,6 +38,8 @@ class TestModule(torch.nn.Module):
|
|||
# CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense<true> : tensor<1xi1>) : !torch.tensor<[1],i1>
|
||||
# CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16>
|
||||
# CHECK: %[[ONES_F16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf16>) : !torch.tensor<[1],f16>
|
||||
# CHECK: %[[ONES_UI8:.*]] = torch.tensor.literal(dense<1> : tensor<1xui8>) : !torch.tensor<[1],ui8>
|
||||
# CHECK: %[[ONES_I8:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8>
|
||||
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8>
|
||||
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
|
||||
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
|
||||
|
@ -52,6 +56,8 @@ class TestModule(torch.nn.Module):
|
|||
# CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1>
|
||||
# CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16>
|
||||
# CHECK: torch.slot "ones_f16", %[[ONES_F16]] : !torch.tensor<[1],f16>
|
||||
# CHECK: torch.slot "ones_ui8", %[[ONES_UI8]] : !torch.tensor<[1],ui8>
|
||||
# CHECK: torch.slot "ones_i8", %[[ONES_I8]] : !torch.tensor<[1],si8>
|
||||
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8>
|
||||
# CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8>
|
||||
# CHECK: }
|
||||
|
|
Loading…
Reference in New Issue