[Torch] Address unnecessary dynamic shapes in argmax decomposition (#3889)

Addresses <https://github.com/iree-org/iree/issues/19262#issue>
main
zjgarvey 2024-11-22 18:03:29 -06:00 committed by GitHub
parent 0913b967ac
commit 99115dcdc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 16 deletions

View File

@ -2593,16 +2593,22 @@ public:
// first the input tensor is flattened to 1d tensor and then the reduction // first the input tensor is flattened to 1d tensor and then the reduction
// happens on the 0th dimension. // happens on the 0th dimension.
if (isa<Torch::NoneType>(dim.getType())) { if (isa<Torch::NoneType>(dim.getType())) {
BaseTensorType flattenType = Value zero = rewriter.create<ConstantIntOp>(loc, 0);
cast<BaseTensorType>(inputType.getWithSizesAndDtype( Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
{kUnknownSize}, inputType.getOptionalDtype())); if (inputType.getSizes().size() > 1) {
Value zero = int64_t flattenSize = Torch::kUnknownSize;
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0)); if (inputType.areAllSizesKnown()) {
flattenSize = 1;
for (int64_t sze : inputType.getSizes())
flattenSize *= sze;
}
auto flattenType = cast<BaseTensorType>(inputType.getWithSizesAndDtype(
{flattenSize}, inputType.getOptionalDtype()));
Value end = rewriter.create<ConstantIntOp>( Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank - 1)); loc, rewriter.getI64IntegerAttr(inputRank - 1));
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input, input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
zero, end); zero, end);
}
Value resultIndices = Value resultIndices =
rewriter rewriter
.create<DecompOpTy>( .create<DecompOpTy>(

View File

@ -545,10 +545,6 @@ FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
FX_IMPORTER_STABLEHLO_XFAIL_SET = { FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"AddFloatIntModule_basic", "AddFloatIntModule_basic",
"ArgmaxIntModule_basic",
"ArgmaxIntModule_multiple_maxs",
"ArgmaxKeepdimModule_basic",
"ArgmaxModule_basic",
"AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueDynamicDimsModule_basic",
"AtenKthvalueFloat64DynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic",
"AtenKthvalueFloat64Module_basic", "AtenKthvalueFloat64Module_basic",
@ -618,9 +614,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"AnyBoolFalseModule_basic", "AnyBoolFalseModule_basic",
"AnyBoolTrueModule_basic", "AnyBoolTrueModule_basic",
"ArangeStartOutViewModule_basic", "ArangeStartOutViewModule_basic",
"ArgminIntModule_basic",
"ArgminIntModule_multiple_mins",
"ArgminModule_basic",
"AtenComplexImagModule_basic", "AtenComplexImagModule_basic",
"AtenComplexRealModule_basic", "AtenComplexRealModule_basic",
"AtenComplexViewModule_basic", "AtenComplexViewModule_basic",

View File

@ -25,6 +25,19 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
return %0 : !torch.tensor return %0 : !torch.tensor
} }
// -----
// CHECK-LABEL: func.func @argmax_rank_1
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[VALUES:.*]], %[[INDICES:.*]] = torch.aten.max.dim %arg0, %[[I0]], %[[FALSE]] : !torch.vtensor<[20],si32>, !torch.int, !torch.bool -> !torch.vtensor<[],si32>, !torch.vtensor<[],si64>
// CHECK: return %[[INDICES]] : !torch.vtensor<[],si64>
func.func @argmax_rank_1(%arg0: !torch.vtensor<[20],si32>) -> !torch.vtensor<[],si64> {
%none = torch.constant.none
%false = torch.constant.bool false
%7 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[20],si32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
return %7 : !torch.vtensor<[],si64>
}
// ----- // -----
// CHECK-LABEL: func.func @torch.aten.type_as$basic( // CHECK-LABEL: func.func @torch.aten.type_as$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {