mirror of https://github.com/llvm/torch-mlir
[TORCH] Improve type refinement for aten.cat. (#1908)
* [TORCH] Fix type refinement for aten.cat. * Add test. * Address comments. * Update. * Update. * Update. * Update. * Update. --------- Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>pull/1747/head snapshot-20230310.773
parent
1e6608f90c
commit
dca2b8a40a
|
@ -277,6 +277,7 @@ STABLEHLO_PASS_SET = {
|
|||
"FlattenStaticModule_basic",
|
||||
"FlattenRank0Module_basic",
|
||||
"TensorsConcatNegativeDimModule_basic",
|
||||
"TensorsConcatPromoteDTypeModule_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"Mlp2LayerModuleNoBias_basic",
|
||||
"NumelModule_basic",
|
||||
|
@ -803,6 +804,7 @@ LTC_XFAIL_SET = {
|
|||
"SubFloatModule_basic",
|
||||
"SubIntModule_basic",
|
||||
"TensorsConcatNegativeDimModule_basic",
|
||||
"TensorsConcatPromoteDTypeModule_basic",
|
||||
"TensorToBoolZeroRank_basic",
|
||||
"TensorToBool_basic",
|
||||
"TensorToFloatZeroRank_basic",
|
||||
|
|
|
@ -1117,6 +1117,18 @@ public:
|
|||
|
||||
RankedTensorType newResultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
auto outElemType = newResultType.getElementType();
|
||||
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
||||
ValueRange payloadArgs) {
|
||||
Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], outElemType);
|
||||
builder.create<linalg::YieldOp>(loc, elem);
|
||||
};
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody);
|
||||
}
|
||||
|
||||
int rank = newResultType.getRank();
|
||||
SmallVector<Value> offsets, sizes, strides;
|
||||
sizes.reserve(rank);
|
||||
|
|
|
@ -112,24 +112,6 @@ static torch_upstream::TypeKind getTypeKind(Type type) {
|
|||
return torch_upstream::TypeKind::AnyType;
|
||||
}
|
||||
|
||||
/// Returns the dtype that assumes information from both `lhs` and `rhs`.
|
||||
/// Returns `std::nullopt` if the types are contradictory. Note this can only
|
||||
/// be used on the `dtype` from tensors and can't be used on other types like
|
||||
/// scalar types.
|
||||
static std::optional<Type> meetElementTypes(Type lhs, Type rhs) {
|
||||
auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); };
|
||||
(void)isNullOrBuiltIn;
|
||||
assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type");
|
||||
assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type");
|
||||
|
||||
if (!lhs)
|
||||
return rhs;
|
||||
if (!rhs)
|
||||
return lhs;
|
||||
if (lhs == rhs)
|
||||
return lhs;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
enum class OptionalKnowledge {
|
||||
unKnown,
|
||||
|
@ -1446,19 +1428,14 @@ void TypeAnalysis::visitAtenCatOp(AtenCatOp op,
|
|||
return;
|
||||
}
|
||||
|
||||
auto tensors = llvm::to_vector<4>(
|
||||
llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge {
|
||||
return getLatticeElement(v)->getValue();
|
||||
SmallVector<ValueKnowledge*> tensors = llvm::to_vector(
|
||||
llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge* {
|
||||
return &getLatticeElement(v)->getValue();
|
||||
}));
|
||||
for (auto tensor : tensors) {
|
||||
auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype);
|
||||
if (!newDtype.has_value()) {
|
||||
incorporateKnowledge(op.getResult(), knowledge);
|
||||
return;
|
||||
}
|
||||
knowledge.dtype = newDtype.value();
|
||||
}
|
||||
incorporateKnowledge(op.getResult(), knowledge);
|
||||
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
op->getContext(), tensors);
|
||||
incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
||||
|
|
|
@ -621,6 +621,32 @@ def TensorsConcatNegativeDimModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorsConcatPromoteDTypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.bool, True),
|
||||
([-1, -1, -1], torch.int32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
return torch.cat([x, y, z], dim=-2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorsConcatPromoteDTypeModule())
|
||||
def TensorsConcatPromoteDTypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(2, 2, 4, low=0, high=2).bool(),
|
||||
tu.randint(2, 1, 4, low=0, high=100).int(),
|
||||
tu.randint(2, 3, 4, low=0, high=100).long())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class GatherModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -163,6 +163,22 @@ func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.cat$promote_type(
|
||||
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>,
|
||||
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list<tensor>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor {
|
||||
%int1 = torch.constant.int 1
|
||||
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list<tensor>
|
||||
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
|
||||
|
|
Loading…
Reference in New Issue