From 940959589b186369925a9d53a8d0f6d39fdafac6 Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Mon, 19 Sep 2022 11:50:51 -0700 Subject: [PATCH] [MLIR][TORCH] Add Byte and Char Dtype support --- e2e_testing/xfail_sets.py | 2 + include/torch-mlir/Conversion/Utils/Utils.h | 5 ++- .../TorchToLinalg/TensorScalarInterop.cpp | 8 +++- lib/Conversion/Utils/Utils.cpp | 24 ++++++------ lib/Dialect/Torch/Utils/Utils.cpp | 7 ++++ lib/RefBackend/RefBackend.cpp | 7 +++- .../jit_ir/csrc/torch_to_mlir_utils.cpp | 6 +++ .../linalg_on_tensors_backends/refbackend.py | 4 +- .../torch_mlir_e2e_test/test_suite/scalar.py | 39 +++++++++++++++++++ test/Conversion/TorchToLinalg/basic.mlir | 32 +++++++++++++++ .../importer/jit_ir/ivalue_import/tensors.py | 6 +++ 11 files changed, 124 insertions(+), 16 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 5bc5bd086..eee0c0118 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -623,4 +623,6 @@ LTC_XFAIL_SET = { "ElementwiseRemainderScalarModule_Float_basic", "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRemainderScalarModule_Bool_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", } diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 81b3b5484..bf944ee9c 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -78,8 +78,9 @@ SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, // Convert a scalar value to the target type. The scalar value can be an element // from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype // should be converted builtin types. -Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, - Type dtype); +Value convertScalarToDtype( + OpBuilder &b, Location loc, Value scalar, Type dtype, + llvm::Optional srcOriginalDtype = llvm::NoneType()); // Return the number of elements of a tensor if the shape is static; otherwise, // return -1. diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index f4d6b1c26..24b487c5b 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -84,6 +84,8 @@ public: Value input = adaptor.a(); SmallVector inputSizes = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputSizes.size(); + Type inputDtype = + op.a().getType().template cast().getDtype(); // The `input` tensor must contain exactly one element, i.e., either the // `input` is a zero rank tensor or all the dimensions of the `input` tensor @@ -97,7 +99,11 @@ public: Value constantZero = rewriter.create(loc, rewriter.getIndexAttr(0)); SmallVector indices(inputRank, constantZero); - rewriter.replaceOpWithNewOp(op, input, indices); + Value result = rewriter.create(loc, input, indices); + Type resultType = + this->getTypeConverter()->convertType(op->getResult(0).getType()); + rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, + resultType, inputDtype)); return success(); } }; diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index d8382423e..a0bf7bb67 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -229,14 +229,12 @@ SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, // Convert a scalar value to the target type. The scalar value can be an element // from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype // should be converted builtin types. -Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, - Type dtype) { +Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, + llvm::Optional srcOriginalDtype) { Type scalarType = scalar.getType(); if (scalarType == dtype) return scalar; - // TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to - // be able to know if we need signed or unsigned conversion. auto isByteOrChar = [](Type type) { if (auto integerTy = type.dyn_cast()) { return integerTy.getWidth() == 8; @@ -244,11 +242,13 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, return false; }; - if (isByteOrChar(scalarType) || isByteOrChar(dtype)) { - // TODO: Handle to-boolean conversion(from-boolean conversion is handled). - mlir::emitError(loc) - << "unsupported byte, char or bool type for convertScalarToDtype " - << scalarType << "(scalar type) -> " << dtype << "(dtype)"; + // We only support conversion from Byte or Char scalarType not to Byte or Char + // dtype. + if (isByteOrChar(dtype)) { + mlir::emitError(loc) << "unsupported: conversion to byte or char type for " + "convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype + << "(dtype)"; return nullptr; } @@ -278,7 +278,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, return b.create(loc, dtype, scalar); } assert(scalarType.isa()); - if (scalarType.isSignlessInteger(1)) + if (scalarType.isSignlessInteger(1) || + (srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger())) return b.create(loc, dtype, scalar); // It's safe to use SIToFPOp because ui8/si8 are the only ones where // unsigned handling is needed, and we checked for that case above. @@ -292,7 +293,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, auto scalarInteger = scalarType.cast(); if (scalarInteger.getWidth() > dtypeInteger.getWidth()) return b.create(loc, dtype, scalar); - if (scalarType.isSignlessInteger(1)) + if (scalarType.isSignlessInteger(1) || + (srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger())) return b.create(loc, dtype, scalar); // Only scalarInteger width < dtypeInteger width can reach here. // It's safe to use ExtSIOp here because ui8/si8 are the only ones where diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index d9d375618..1ff3b1608 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -57,6 +57,10 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::BFloat16; if (type.isF16()) return torch_upstream::ScalarType::Half; + if (type.isUnsignedInteger(8)) + return torch_upstream::ScalarType::Byte; + if (type.isSignedInteger(8)) + return torch_upstream::ScalarType::Char; llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } @@ -88,6 +92,9 @@ Type Torch::getTypeForScalarType( return mlir::FloatType::getBF16(context); case torch_upstream::ScalarType::Half: return mlir::FloatType::getF16(context); + case torch_upstream::ScalarType::Byte: + case torch_upstream::ScalarType::Char: + return mlir::IntegerType::get(context, 8, signedness); default: return Type(); } diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 313aeab72..cde11a481 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -61,8 +61,13 @@ static bool isArgMemRefTypeValid(Type type) { return true; if (integerTy.isSignlessInteger(32)) return true; + if (integerTy.isSignlessInteger(8)) + return true; + if (integerTy.isSignedInteger(8)) + return true; if (integerTy.isSignlessInteger(1)) return true; + } } return false; @@ -139,7 +144,7 @@ static LogicalResult mungeFunction( auto type = arg.getType(); if (!isArgMemRefTypeValid(type)) return emitError(arg.getLoc(), - "argument must be a memref of f32, f64, i32, i64, i1"); + "argument must be a memref of f32, f64, i32, i64, i8, i1"); auto cast = b.create(arg.getLoc(), type, arg); arg.replaceAllUsesExcept(cast, cast); arg.setType(getAbiTypeForMemRef(type)); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index 471e2d570..ab3b638b1 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -356,6 +356,12 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, case ScalarType::Half: return mlirDenseElementsAttrFloat16Get( shapedType, numElements, static_cast(tensorData)); + case ScalarType::Byte: + return mlirDenseElementsAttrUInt8Get( + shapedType, numElements, static_cast(tensorData)); + case ScalarType::Char: + return mlirDenseElementsAttrInt8Get( + shapedType, numElements, static_cast(tensorData)); default: throwUnsupportedTensorError(); diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 09ab07c4d..ab72673fc 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -21,7 +21,7 @@ __all__ = [ def assert_arg_type_is_supported(ty): - SUPPORTED = [np.float32, np.float64, np.int32, np.int64, np.bool_] + SUPPORTED = [np.float32, np.float64, np.uint8, np.int8, np.int32, np.int64, np.bool_] assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported" @@ -29,11 +29,13 @@ memref_type_to_np_dtype = { "mrf32": np.float32, "mrf64": np.float64, "mri1": np.bool_, + "mri8": np.int8, "mri32": np.int32, "mri64": np.int64 } elemental_type_to_ctype = { "i1": ctypes.c_bool, + "i8": ctypes.c_byte, "i64": ctypes.c_int, "f32": ctypes.c_float, "f64": ctypes.c_double diff --git a/python/torch_mlir_e2e_test/test_suite/scalar.py b/python/torch_mlir_e2e_test/test_suite/scalar.py index cbf17a09a..16ce64bb0 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -311,3 +311,42 @@ class BoolIntConstantModule(torch.nn.Module): @register_test_case(module_factory=lambda: BoolIntConstantModule()) def BoolIntConstantModule_basic(module, tu: TestUtils): module.forward() + +# ============================================================================== + +class AtenIntTensorByteDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.uint8, True), + ]) + + def forward(self, val): + return int(val) + +@register_test_case(module_factory=lambda: AtenIntTensorByteDtypeModule()) +def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils): + module.forward(tu.randint(low=-100, high=100).to(dtype=torch.uint8)) + + +# ============================================================================== + +class AtenIntTensorCharDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int8, True), + ]) + + def forward(self, val): + return int(val) + +@register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule()) +def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils): + module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8)) diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index fdb6742b1..2137d7d96 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -95,6 +95,38 @@ func.func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64> // ----- +// CHECK-LABEL: func.func @torch.aten.Int.Tensor$zero_rank$byte_dtype +// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],ui8>) -> !torch.int { +// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],ui8> -> tensor +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[I]][] : tensor +// CHECK: %[[RES:.*]] = arith.extui %[[EXTRACT]] : i8 to i64 +// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[RES]] +// CHECK: return %[[RET]] : !torch.int +func.func @torch.aten.Int.Tensor$zero_rank$byte_dtype(%arg0: !torch.vtensor<[],ui8>) -> !torch.int { + %0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],ui8> -> !torch.int + return %0 : !torch.int +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.Int.Tensor$zero_rank$char_dtype +// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],si8>) -> !torch.int { +// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],si8> -> tensor +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[I]][] : tensor +// CHECK: %[[RES:.*]] = arith.extsi %[[EXTRACT]] : i8 to i64 +// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[RES]] +// CHECK: return %[[RET]] : !torch.int +func.func @torch.aten.Int.Tensor$zero_rank$char_dtype(%arg0: !torch.vtensor<[],si8>) -> !torch.int { + %0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si8> -> !torch.int + return %0 : !torch.int +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.Float.Tensor$zero_rank // CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],f64>) -> !torch.float { // CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f64> -> tensor diff --git a/test/python/importer/jit_ir/ivalue_import/tensors.py b/test/python/importer/jit_ir/ivalue_import/tensors.py index 314c245f5..1c0612bf7 100644 --- a/test/python/importer/jit_ir/ivalue_import/tensors.py +++ b/test/python/importer/jit_ir/ivalue_import/tensors.py @@ -23,6 +23,8 @@ class TestModule(torch.nn.Module): self.ones_bool = torch.ones(1, dtype=torch.bool) self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16) self.ones_f16 = torch.ones(1, dtype=torch.half) + self.ones_ui8 = torch.ones(1, dtype=torch.uint8) + self.ones_i8 = torch.ones(1, dtype=torch.int8) self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8) self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8) self.arange = torch.nn.Parameter(torch.arange(3.0)) @@ -36,6 +38,8 @@ class TestModule(torch.nn.Module): # CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense : tensor<1xi1>) : !torch.tensor<[1],i1> # CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16> # CHECK: %[[ONES_F16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf16>) : !torch.tensor<[1],f16> +# CHECK: %[[ONES_UI8:.*]] = torch.tensor.literal(dense<1> : tensor<1xui8>) : !torch.tensor<[1],ui8> +# CHECK: %[[ONES_I8:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8> # CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8> # CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 # CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0 @@ -52,6 +56,8 @@ class TestModule(torch.nn.Module): # CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1> # CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16> # CHECK: torch.slot "ones_f16", %[[ONES_F16]] : !torch.tensor<[1],f16> +# CHECK: torch.slot "ones_ui8", %[[ONES_UI8]] : !torch.tensor<[1],ui8> +# CHECK: torch.slot "ones_i8", %[[ONES_I8]] : !torch.tensor<[1],si8> # CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8> # CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8> # CHECK: }