Expand definition of tensor subtype to include shape/dtype info (#1929)

Currently, the op `torch.tensor_static_info_cast` will not get
canonicalized away if the result type has any shape or dtype
information. This is because `isValidSubtype` only returns true when
the tensor types being compared are exactly the same or the supertype
has no shape and dtype information. Being unable to canonicalize away
the `torch.tensor_static_info_cast` gets in the way of further
optimizations, such as shape propagation.

This commit improves `isValidSubtype` by adding logic that compares
the shapes and dtypes of the two tensor types to determine of one type
is indeed a valid subtype of the other.

Fixes https://github.com/llvm/torch-mlir/issues/1926
pull/1931/head
Ramiro Leal-Cavazos 2023-03-10 16:43:57 -08:00 committed by GitHub
parent 66b1045a80
commit d310bb12bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 8 deletions

View File

@ -68,16 +68,32 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
return true;
}
// TODO: This is not subtyping according to PEP 483. See description
// of NonValueTensorType.
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
type ==
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
auto subtypeTensorType = subtype.dyn_cast<BaseTensorType>();
auto typeTensorType = type.dyn_cast<BaseTensorType>();
if (subtypeTensorType && typeTensorType) {
// Check that both tensors have the same `BaseTensorType` subtype.
// TODO: This is not subtyping according to PEP 483. See description
// of NonValueTensorType.
if (subtypeTensorType.isa<ValueTensorType>() !=
typeTensorType.isa<ValueTensorType>())
return false;
// `type` must not have more static information than `subtype`, and `type`
// must not disagree with `subtype`.
if (typeTensorType.hasDtype() &&
(!subtypeTensorType.hasDtype() ||
typeTensorType.getDtype() != subtypeTensorType.getDtype())) {
return false;
}
if (typeTensorType.hasSizes() &&
(!subtypeTensorType.hasSizes() ||
typeTensorType.getSizes() != subtypeTensorType.getSizes())) {
return false;
}
if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
}
return false;
}

View File

@ -1201,6 +1201,26 @@ func.func @torch.tensor_static_info_cast$refine(%arg0: !torch.vtensor<[], f32>)
return %1 : !torch.vtensor
}
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$refine$dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor {
// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.relu %[[ARG]] : !torch.vtensor<[],f32> -> !torch.vtensor
// CHECK-NEXT: return %[[RESULT]] : !torch.vtensor
func.func @torch.tensor_static_info_cast$refine$dtype(%arg0: !torch.vtensor<[], f32>) -> !torch.vtensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor<[],unk>
%1 = torch.aten.relu %0 : !torch.vtensor<[],unk> -> !torch.vtensor
return %1 : !torch.vtensor
}
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$refine$shape(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor {
// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.relu %[[ARG]] : !torch.vtensor<[],f32> -> !torch.vtensor
// CHECK-NEXT: return %[[RESULT]] : !torch.vtensor
func.func @torch.tensor_static_info_cast$refine$shape(%arg0: !torch.vtensor<[], f32>) -> !torch.vtensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor<*,f32>
%1 = torch.aten.relu %0 : !torch.vtensor<*,f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$no_refine(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor to !torch.vtensor<[],f32>
@ -1212,6 +1232,28 @@ func.func @torch.tensor_static_info_cast$no_refine(%arg0: !torch.vtensor) -> !to
return %1 : !torch.vtensor
}
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$no_refine$dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],unk>) -> !torch.vtensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],unk> to !torch.vtensor<[],f32>
// CHECK: %[[RESULT:.*]] = torch.aten.relu %[[CAST]] : !torch.vtensor<[],f32> -> !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func.func @torch.tensor_static_info_cast$no_refine$dtype(%arg0: !torch.vtensor<[],unk>) -> !torch.vtensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],unk> to !torch.vtensor<[],f32>
%1 = torch.aten.relu %0 : !torch.vtensor<[],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$no_refine$shape(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<*,f32> to !torch.vtensor<[],f32>
// CHECK: %[[RESULT:.*]] = torch.aten.relu %[[CAST]] : !torch.vtensor<[],f32> -> !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func.func @torch.tensor_static_info_cast$no_refine$shape(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<[],f32>
%1 = torch.aten.relu %0 : !torch.vtensor<[],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$refine_allowed_ops(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.tuple<vtensor, vtensor> {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],f32> to !torch.vtensor

View File

@ -265,3 +265,19 @@ torch.global_slot.module_initializer {
@tensor(%1 : !torch.tensor)
]
}
// -----
func.func @torch.tensor_static_info_cast$shape_mismatch(%arg0: !torch.vtensor<[],unk>) -> !torch.vtensor<[?],unk> {
// expected-error@+1 {{'torch.tensor_static_info_cast' op operand type '!torch.vtensor<[],unk>' and result type '!torch.vtensor<[?],unk>' are cast incompatible}}
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],unk> to !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
func.func @torch.tensor_static_info_cast$dtype_mismatch(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<*,f64> {
// expected-error@+1 {{'torch.tensor_static_info_cast' op operand type '!torch.vtensor<*,f32>' and result type '!torch.vtensor<*,f64>' are cast incompatible}}
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<*,f64>
return %0 : !torch.vtensor<*,f64>
}