[Torch] Address unnecessary dynamic shapes in argmax decomposition (#3889)

Addresses <https://github.com/iree-org/iree/issues/19262#issue>
main
zjgarvey 2024-11-22 18:03:29 -06:00 committed by GitHub
parent 0913b967ac
commit 99115dcdc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 16 deletions

View File

@ -2593,16 +2593,22 @@ public:
// first the input tensor is flattened to 1d tensor and then the reduction
// happens on the 0th dimension.
if (isa<Torch::NoneType>(dim.getType())) {
BaseTensorType flattenType =
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
{kUnknownSize}, inputType.getOptionalDtype()));
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value zero = rewriter.create<ConstantIntOp>(loc, 0);
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
if (inputType.getSizes().size() > 1) {
int64_t flattenSize = Torch::kUnknownSize;
if (inputType.areAllSizesKnown()) {
flattenSize = 1;
for (int64_t sze : inputType.getSizes())
flattenSize *= sze;
}
auto flattenType = cast<BaseTensorType>(inputType.getWithSizesAndDtype(
{flattenSize}, inputType.getOptionalDtype()));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank - 1));
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
zero, end);
}
Value resultIndices =
rewriter
.create<DecompOpTy>(

View File

@ -545,10 +545,6 @@ FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"AddFloatIntModule_basic",
"ArgmaxIntModule_basic",
"ArgmaxIntModule_multiple_maxs",
"ArgmaxKeepdimModule_basic",
"ArgmaxModule_basic",
"AtenKthvalueDynamicDimsModule_basic",
"AtenKthvalueFloat64DynamicDimsModule_basic",
"AtenKthvalueFloat64Module_basic",
@ -618,9 +614,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"AnyBoolFalseModule_basic",
"AnyBoolTrueModule_basic",
"ArangeStartOutViewModule_basic",
"ArgminIntModule_basic",
"ArgminIntModule_multiple_mins",
"ArgminModule_basic",
"AtenComplexImagModule_basic",
"AtenComplexRealModule_basic",
"AtenComplexViewModule_basic",

View File

@ -25,6 +25,19 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @argmax_rank_1
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[VALUES:.*]], %[[INDICES:.*]] = torch.aten.max.dim %arg0, %[[I0]], %[[FALSE]] : !torch.vtensor<[20],si32>, !torch.int, !torch.bool -> !torch.vtensor<[],si32>, !torch.vtensor<[],si64>
// CHECK: return %[[INDICES]] : !torch.vtensor<[],si64>
func.func @argmax_rank_1(%arg0: !torch.vtensor<[20],si32>) -> !torch.vtensor<[],si64> {
%none = torch.constant.none
%false = torch.constant.bool false
%7 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[20],si32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
return %7 : !torch.vtensor<[],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {