mirror of https://github.com/llvm/torch-mlir
Expand definition of tensor subtype to include shape/dtype info (#1929)
Currently, the op `torch.tensor_static_info_cast` will not get canonicalized away if the result type has any shape or dtype information. This is because `isValidSubtype` only returns true when the tensor types being compared are exactly the same or the supertype has no shape and dtype information. Being unable to canonicalize away the `torch.tensor_static_info_cast` gets in the way of further optimizations, such as shape propagation. This commit improves `isValidSubtype` by adding logic that compares the shapes and dtypes of the two tensor types to determine of one type is indeed a valid subtype of the other. Fixes https://github.com/llvm/torch-mlir/issues/1926pull/1931/head
parent
66b1045a80
commit
d310bb12bd
|
@ -68,16 +68,32 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
|
|||
return true;
|
||||
}
|
||||
|
||||
// TODO: This is not subtyping according to PEP 483. See description
|
||||
// of NonValueTensorType.
|
||||
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
|
||||
type ==
|
||||
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
auto subtypeTensorType = subtype.dyn_cast<BaseTensorType>();
|
||||
auto typeTensorType = type.dyn_cast<BaseTensorType>();
|
||||
if (subtypeTensorType && typeTensorType) {
|
||||
// Check that both tensors have the same `BaseTensorType` subtype.
|
||||
// TODO: This is not subtyping according to PEP 483. See description
|
||||
// of NonValueTensorType.
|
||||
if (subtypeTensorType.isa<ValueTensorType>() !=
|
||||
typeTensorType.isa<ValueTensorType>())
|
||||
return false;
|
||||
|
||||
// `type` must not have more static information than `subtype`, and `type`
|
||||
// must not disagree with `subtype`.
|
||||
if (typeTensorType.hasDtype() &&
|
||||
(!subtypeTensorType.hasDtype() ||
|
||||
typeTensorType.getDtype() != subtypeTensorType.getDtype())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (typeTensorType.hasSizes() &&
|
||||
(!subtypeTensorType.hasSizes() ||
|
||||
typeTensorType.getSizes() != subtypeTensorType.getSizes())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
|
||||
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -1201,6 +1201,26 @@ func.func @torch.tensor_static_info_cast$refine(%arg0: !torch.vtensor<[], f32>)
|
|||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$refine$dtype(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor {
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.relu %[[ARG]] : !torch.vtensor<[],f32> -> !torch.vtensor
|
||||
// CHECK-NEXT: return %[[RESULT]] : !torch.vtensor
|
||||
func.func @torch.tensor_static_info_cast$refine$dtype(%arg0: !torch.vtensor<[], f32>) -> !torch.vtensor {
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor<[],unk>
|
||||
%1 = torch.aten.relu %0 : !torch.vtensor<[],unk> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$refine$shape(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor {
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.relu %[[ARG]] : !torch.vtensor<[],f32> -> !torch.vtensor
|
||||
// CHECK-NEXT: return %[[RESULT]] : !torch.vtensor
|
||||
func.func @torch.tensor_static_info_cast$refine$shape(%arg0: !torch.vtensor<[], f32>) -> !torch.vtensor {
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor<*,f32>
|
||||
%1 = torch.aten.relu %0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$no_refine(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor to !torch.vtensor<[],f32>
|
||||
|
@ -1212,6 +1232,28 @@ func.func @torch.tensor_static_info_cast$no_refine(%arg0: !torch.vtensor) -> !to
|
|||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$no_refine$dtype(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],unk>) -> !torch.vtensor {
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],unk> to !torch.vtensor<[],f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.relu %[[CAST]] : !torch.vtensor<[],f32> -> !torch.vtensor
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor
|
||||
func.func @torch.tensor_static_info_cast$no_refine$dtype(%arg0: !torch.vtensor<[],unk>) -> !torch.vtensor {
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],unk> to !torch.vtensor<[],f32>
|
||||
%1 = torch.aten.relu %0 : !torch.vtensor<[],f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$no_refine$shape(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<*,f32> to !torch.vtensor<[],f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.relu %[[CAST]] : !torch.vtensor<[],f32> -> !torch.vtensor
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor
|
||||
func.func @torch.tensor_static_info_cast$no_refine$shape(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<[],f32>
|
||||
%1 = torch.aten.relu %0 : !torch.vtensor<[],f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$refine_allowed_ops(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.tuple<vtensor, vtensor> {
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],f32> to !torch.vtensor
|
||||
|
|
|
@ -265,3 +265,19 @@ torch.global_slot.module_initializer {
|
|||
@tensor(%1 : !torch.tensor)
|
||||
]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.tensor_static_info_cast$shape_mismatch(%arg0: !torch.vtensor<[],unk>) -> !torch.vtensor<[?],unk> {
|
||||
// expected-error@+1 {{'torch.tensor_static_info_cast' op operand type '!torch.vtensor<[],unk>' and result type '!torch.vtensor<[?],unk>' are cast incompatible}}
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],unk> to !torch.vtensor<[?],unk>
|
||||
return %0 : !torch.vtensor<[?],unk>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.tensor_static_info_cast$dtype_mismatch(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<*,f64> {
|
||||
// expected-error@+1 {{'torch.tensor_static_info_cast' op operand type '!torch.vtensor<*,f32>' and result type '!torch.vtensor<*,f64>' are cast incompatible}}
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<*,f64>
|
||||
return %0 : !torch.vtensor<*,f64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue