[TORCH] Improve type refinement for aten.cat. (#1908)

* [TORCH] Fix type refinement for aten.cat.

* Add test.

* Address comments.

* Update.

* Update.

* Update.

* Update.

* Update.

---------

Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
pull/1747/head snapshot-20230310.773
Ziheng Jiang 2023-03-09 16:17:35 -08:00 committed by GitHub
parent 1e6608f90c
commit dca2b8a40a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 30 deletions

View File

@ -277,6 +277,7 @@ STABLEHLO_PASS_SET = {
"FlattenStaticModule_basic", "FlattenStaticModule_basic",
"FlattenRank0Module_basic", "FlattenRank0Module_basic",
"TensorsConcatNegativeDimModule_basic", "TensorsConcatNegativeDimModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"LiftFreshCopyModule_basic", "LiftFreshCopyModule_basic",
"Mlp2LayerModuleNoBias_basic", "Mlp2LayerModuleNoBias_basic",
"NumelModule_basic", "NumelModule_basic",
@ -803,6 +804,7 @@ LTC_XFAIL_SET = {
"SubFloatModule_basic", "SubFloatModule_basic",
"SubIntModule_basic", "SubIntModule_basic",
"TensorsConcatNegativeDimModule_basic", "TensorsConcatNegativeDimModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorToBoolZeroRank_basic", "TensorToBoolZeroRank_basic",
"TensorToBool_basic", "TensorToBool_basic",
"TensorToFloatZeroRank_basic", "TensorToFloatZeroRank_basic",

View File

@ -1117,6 +1117,18 @@ public:
RankedTensorType newResultType = RankedTensorType newResultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); typeConverter->convertType(op.getType()).cast<RankedTensorType>();
auto outElemType = newResultType.getElementType();
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
ValueRange payloadArgs) {
Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], outElemType);
builder.create<linalg::YieldOp>(loc, elem);
};
for (size_t i = 0; i < tensors.size(); ++i) {
tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody);
}
int rank = newResultType.getRank(); int rank = newResultType.getRank();
SmallVector<Value> offsets, sizes, strides; SmallVector<Value> offsets, sizes, strides;
sizes.reserve(rank); sizes.reserve(rank);

View File

@ -112,24 +112,6 @@ static torch_upstream::TypeKind getTypeKind(Type type) {
return torch_upstream::TypeKind::AnyType; return torch_upstream::TypeKind::AnyType;
} }
/// Returns the dtype that assumes information from both `lhs` and `rhs`.
/// Returns `std::nullopt` if the types are contradictory. Note this can only
/// be used on the `dtype` from tensors and can't be used on other types like
/// scalar types.
static std::optional<Type> meetElementTypes(Type lhs, Type rhs) {
auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); };
(void)isNullOrBuiltIn;
assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type");
assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type");
if (!lhs)
return rhs;
if (!rhs)
return lhs;
if (lhs == rhs)
return lhs;
return std::nullopt;
}
enum class OptionalKnowledge { enum class OptionalKnowledge {
unKnown, unKnown,
@ -1446,19 +1428,14 @@ void TypeAnalysis::visitAtenCatOp(AtenCatOp op,
return; return;
} }
auto tensors = llvm::to_vector<4>( SmallVector<ValueKnowledge*> tensors = llvm::to_vector(
llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge { llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge* {
return getLatticeElement(v)->getValue(); return &getLatticeElement(v)->getValue();
})); }));
for (auto tensor : tensors) {
auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
if (!newDtype.has_value()) { op->getContext(), tensors);
incorporateKnowledge(op.getResult(), knowledge); incorporateKnowledge(op->getResult(0), knowledge);
return;
}
knowledge.dtype = newDtype.value();
}
incorporateKnowledge(op.getResult(), knowledge);
} }
void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) { void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) {

View File

@ -621,6 +621,32 @@ def TensorsConcatNegativeDimModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class TensorsConcatPromoteDTypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.bool, True),
([-1, -1, -1], torch.int32, True),
([-1, -1, -1], torch.int64, True),
])
def forward(self, x, y, z):
return torch.cat([x, y, z], dim=-2)
@register_test_case(module_factory=lambda: TensorsConcatPromoteDTypeModule())
def TensorsConcatPromoteDTypeModule_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 2, 4, low=0, high=2).bool(),
tu.randint(2, 1, 4, low=0, high=100).int(),
tu.randint(2, 3, 4, low=0, high=100).long())
# ==============================================================================
class GatherModule(torch.nn.Module): class GatherModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -163,6 +163,22 @@ func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// -----
// CHECK-LABEL: func.func @torch.aten.cat$promote_type(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list<tensor>
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor {
%int1 = torch.constant.int 1
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list<tensor>
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// ----- // -----
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor( // CHECK-LABEL: func.func @torch.aten._shape_as_tensor(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor { // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {