From dca2b8a40a49cee07ab72ff42e8c1c13b4da91c0 Mon Sep 17 00:00:00 2001 From: Ziheng Jiang Date: Thu, 9 Mar 2023 16:17:35 -0800 Subject: [PATCH] [TORCH] Improve type refinement for aten.cat. (#1908) * [TORCH] Fix type refinement for aten.cat. * Add test. * Address comments. * Update. * Update. * Update. * Update. * Update. --------- Co-authored-by: Ziheng Jiang --- e2e_testing/xfail_sets.py | 2 + lib/Conversion/TorchToLinalg/DataMovement.cpp | 12 ++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 37 ++++--------------- .../torch_mlir_e2e_test/test_suite/basic.py | 26 +++++++++++++ test/Dialect/Torch/refine-types-ops.mlir | 16 ++++++++ 5 files changed, 63 insertions(+), 30 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 9ad4157a5..eb9bd3371 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -277,6 +277,7 @@ STABLEHLO_PASS_SET = { "FlattenStaticModule_basic", "FlattenRank0Module_basic", "TensorsConcatNegativeDimModule_basic", + "TensorsConcatPromoteDTypeModule_basic", "LiftFreshCopyModule_basic", "Mlp2LayerModuleNoBias_basic", "NumelModule_basic", @@ -803,6 +804,7 @@ LTC_XFAIL_SET = { "SubFloatModule_basic", "SubIntModule_basic", "TensorsConcatNegativeDimModule_basic", + "TensorsConcatPromoteDTypeModule_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", "TensorToFloatZeroRank_basic", diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index a316e1e9d..293649de5 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1117,6 +1117,18 @@ public: RankedTensorType newResultType = typeConverter->convertType(op.getType()).cast(); + + auto outElemType = newResultType.getElementType(); + auto dtypePromoteBody = [&](OpBuilder &builder, Location loc, + ValueRange payloadArgs) { + Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], outElemType); + builder.create(loc, elem); + }; + for (size_t i = 0; i < tensors.size(); ++i) { + tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody); + } + int rank = newResultType.getRank(); SmallVector offsets, sizes, strides; sizes.reserve(rank); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index af07d0fc9..a5a928b6e 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -112,24 +112,6 @@ static torch_upstream::TypeKind getTypeKind(Type type) { return torch_upstream::TypeKind::AnyType; } -/// Returns the dtype that assumes information from both `lhs` and `rhs`. -/// Returns `std::nullopt` if the types are contradictory. Note this can only -/// be used on the `dtype` from tensors and can't be used on other types like -/// scalar types. -static std::optional meetElementTypes(Type lhs, Type rhs) { - auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); }; - (void)isNullOrBuiltIn; - assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type"); - assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type"); - - if (!lhs) - return rhs; - if (!rhs) - return lhs; - if (lhs == rhs) - return lhs; - return std::nullopt; -} enum class OptionalKnowledge { unKnown, @@ -1446,19 +1428,14 @@ void TypeAnalysis::visitAtenCatOp(AtenCatOp op, return; } - auto tensors = llvm::to_vector<4>( - llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge { - return getLatticeElement(v)->getValue(); + SmallVector tensors = llvm::to_vector( + llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge* { + return &getLatticeElement(v)->getValue(); })); - for (auto tensor : tensors) { - auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype); - if (!newDtype.has_value()) { - incorporateKnowledge(op.getResult(), knowledge); - return; - } - knowledge.dtype = newDtype.value(); - } - incorporateKnowledge(op.getResult(), knowledge); + + knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( + op->getContext(), tensors); + incorporateKnowledge(op->getResult(0), knowledge); } void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) { diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 376e3bf18..b7889cb22 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -621,6 +621,32 @@ def TensorsConcatNegativeDimModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorsConcatPromoteDTypeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.bool, True), + ([-1, -1, -1], torch.int32, True), + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, x, y, z): + return torch.cat([x, y, z], dim=-2) + + +@register_test_case(module_factory=lambda: TensorsConcatPromoteDTypeModule()) +def TensorsConcatPromoteDTypeModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 2, 4, low=0, high=2).bool(), + tu.randint(2, 1, 4, low=0, high=100).int(), + tu.randint(2, 3, 4, low=0, high=100).long()) + + +# ============================================================================== + + class GatherModule(torch.nn.Module): def __init__(self): diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir index 261b0ecc8..3c90de228 100644 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ b/test/Dialect/Torch/refine-types-ops.mlir @@ -163,6 +163,22 @@ func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[ return %ret : !torch.tensor } +// ----- +// CHECK-LABEL: func.func @torch.aten.cat$promote_type( +// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>, +// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list +// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list, !torch.int -> !torch.tensor<*,si64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor +// CHECK: return %[[CAST]] : !torch.tensor +func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor { + %int1 = torch.constant.int 1 + %tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list + %ret = torch.aten.cat %tensorList, %int1 : !torch.list, !torch.int -> !torch.tensor + return %ret : !torch.tensor +} + // ----- // CHECK-LABEL: func.func @torch.aten._shape_as_tensor( // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {