mirror of https://github.com/llvm/torch-mlir
[TORCH] Improve type refinement for aten.cat. (#1908)
* [TORCH] Fix type refinement for aten.cat. * Add test. * Address comments. * Update. * Update. * Update. * Update. * Update. --------- Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>pull/1747/head snapshot-20230310.773
parent
1e6608f90c
commit
dca2b8a40a
|
@ -277,6 +277,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"FlattenStaticModule_basic",
|
"FlattenStaticModule_basic",
|
||||||
"FlattenRank0Module_basic",
|
"FlattenRank0Module_basic",
|
||||||
"TensorsConcatNegativeDimModule_basic",
|
"TensorsConcatNegativeDimModule_basic",
|
||||||
|
"TensorsConcatPromoteDTypeModule_basic",
|
||||||
"LiftFreshCopyModule_basic",
|
"LiftFreshCopyModule_basic",
|
||||||
"Mlp2LayerModuleNoBias_basic",
|
"Mlp2LayerModuleNoBias_basic",
|
||||||
"NumelModule_basic",
|
"NumelModule_basic",
|
||||||
|
@ -803,6 +804,7 @@ LTC_XFAIL_SET = {
|
||||||
"SubFloatModule_basic",
|
"SubFloatModule_basic",
|
||||||
"SubIntModule_basic",
|
"SubIntModule_basic",
|
||||||
"TensorsConcatNegativeDimModule_basic",
|
"TensorsConcatNegativeDimModule_basic",
|
||||||
|
"TensorsConcatPromoteDTypeModule_basic",
|
||||||
"TensorToBoolZeroRank_basic",
|
"TensorToBoolZeroRank_basic",
|
||||||
"TensorToBool_basic",
|
"TensorToBool_basic",
|
||||||
"TensorToFloatZeroRank_basic",
|
"TensorToFloatZeroRank_basic",
|
||||||
|
|
|
@ -1117,6 +1117,18 @@ public:
|
||||||
|
|
||||||
RankedTensorType newResultType =
|
RankedTensorType newResultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
|
|
||||||
|
auto outElemType = newResultType.getElementType();
|
||||||
|
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
||||||
|
ValueRange payloadArgs) {
|
||||||
|
Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], outElemType);
|
||||||
|
builder.create<linalg::YieldOp>(loc, elem);
|
||||||
|
};
|
||||||
|
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||||
|
tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||||
|
rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody);
|
||||||
|
}
|
||||||
|
|
||||||
int rank = newResultType.getRank();
|
int rank = newResultType.getRank();
|
||||||
SmallVector<Value> offsets, sizes, strides;
|
SmallVector<Value> offsets, sizes, strides;
|
||||||
sizes.reserve(rank);
|
sizes.reserve(rank);
|
||||||
|
|
|
@ -112,24 +112,6 @@ static torch_upstream::TypeKind getTypeKind(Type type) {
|
||||||
return torch_upstream::TypeKind::AnyType;
|
return torch_upstream::TypeKind::AnyType;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the dtype that assumes information from both `lhs` and `rhs`.
|
|
||||||
/// Returns `std::nullopt` if the types are contradictory. Note this can only
|
|
||||||
/// be used on the `dtype` from tensors and can't be used on other types like
|
|
||||||
/// scalar types.
|
|
||||||
static std::optional<Type> meetElementTypes(Type lhs, Type rhs) {
|
|
||||||
auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); };
|
|
||||||
(void)isNullOrBuiltIn;
|
|
||||||
assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type");
|
|
||||||
assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type");
|
|
||||||
|
|
||||||
if (!lhs)
|
|
||||||
return rhs;
|
|
||||||
if (!rhs)
|
|
||||||
return lhs;
|
|
||||||
if (lhs == rhs)
|
|
||||||
return lhs;
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
enum class OptionalKnowledge {
|
enum class OptionalKnowledge {
|
||||||
unKnown,
|
unKnown,
|
||||||
|
@ -1446,19 +1428,14 @@ void TypeAnalysis::visitAtenCatOp(AtenCatOp op,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tensors = llvm::to_vector<4>(
|
SmallVector<ValueKnowledge*> tensors = llvm::to_vector(
|
||||||
llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge {
|
llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge* {
|
||||||
return getLatticeElement(v)->getValue();
|
return &getLatticeElement(v)->getValue();
|
||||||
}));
|
}));
|
||||||
for (auto tensor : tensors) {
|
|
||||||
auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype);
|
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||||
if (!newDtype.has_value()) {
|
op->getContext(), tensors);
|
||||||
incorporateKnowledge(op.getResult(), knowledge);
|
incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
return;
|
|
||||||
}
|
|
||||||
knowledge.dtype = newDtype.value();
|
|
||||||
}
|
|
||||||
incorporateKnowledge(op.getResult(), knowledge);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
||||||
|
|
|
@ -621,6 +621,32 @@ def TensorsConcatNegativeDimModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TensorsConcatPromoteDTypeModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.bool, True),
|
||||||
|
([-1, -1, -1], torch.int32, True),
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y, z):
|
||||||
|
return torch.cat([x, y, z], dim=-2)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorsConcatPromoteDTypeModule())
|
||||||
|
def TensorsConcatPromoteDTypeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(2, 2, 4, low=0, high=2).bool(),
|
||||||
|
tu.randint(2, 1, 4, low=0, high=100).int(),
|
||||||
|
tu.randint(2, 3, 4, low=0, high=100).long())
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class GatherModule(torch.nn.Module):
|
class GatherModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -163,6 +163,22 @@ func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.cat$promote_type(
|
||||||
|
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>,
|
||||||
|
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor {
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list<tensor>
|
||||||
|
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,si64>
|
||||||
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
|
||||||
|
// CHECK: return %[[CAST]] : !torch.tensor
|
||||||
|
func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list<tensor>
|
||||||
|
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
|
||||||
|
return %ret : !torch.tensor
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor(
|
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
|
||||||
|
|
Loading…
Reference in New Issue