mirror of https://github.com/llvm/torch-mlir
[Torch] Address unnecessary dynamic shapes in argmax decomposition (#3889)
Addresses <https://github.com/iree-org/iree/issues/19262#issue>main
parent
0913b967ac
commit
99115dcdc8
|
@ -2593,16 +2593,22 @@ public:
|
|||
// first the input tensor is flattened to 1d tensor and then the reduction
|
||||
// happens on the 0th dimension.
|
||||
if (isa<Torch::NoneType>(dim.getType())) {
|
||||
BaseTensorType flattenType =
|
||||
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
|
||||
{kUnknownSize}, inputType.getOptionalDtype()));
|
||||
Value zero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value zero = rewriter.create<ConstantIntOp>(loc, 0);
|
||||
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
if (inputType.getSizes().size() > 1) {
|
||||
int64_t flattenSize = Torch::kUnknownSize;
|
||||
if (inputType.areAllSizesKnown()) {
|
||||
flattenSize = 1;
|
||||
for (int64_t sze : inputType.getSizes())
|
||||
flattenSize *= sze;
|
||||
}
|
||||
auto flattenType = cast<BaseTensorType>(inputType.getWithSizesAndDtype(
|
||||
{flattenSize}, inputType.getOptionalDtype()));
|
||||
Value end = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
||||
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
||||
zero, end);
|
||||
}
|
||||
Value resultIndices =
|
||||
rewriter
|
||||
.create<DecompOpTy>(
|
||||
|
|
|
@ -545,10 +545,6 @@ FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
|
||||
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||
"AddFloatIntModule_basic",
|
||||
"ArgmaxIntModule_basic",
|
||||
"ArgmaxIntModule_multiple_maxs",
|
||||
"ArgmaxKeepdimModule_basic",
|
||||
"ArgmaxModule_basic",
|
||||
"AtenKthvalueDynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64Module_basic",
|
||||
|
@ -618,9 +614,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"AnyBoolFalseModule_basic",
|
||||
"AnyBoolTrueModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"ArgminIntModule_basic",
|
||||
"ArgminIntModule_multiple_mins",
|
||||
"ArgminModule_basic",
|
||||
"AtenComplexImagModule_basic",
|
||||
"AtenComplexRealModule_basic",
|
||||
"AtenComplexViewModule_basic",
|
||||
|
|
|
@ -25,6 +25,19 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
|
|||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @argmax_rank_1
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VALUES:.*]], %[[INDICES:.*]] = torch.aten.max.dim %arg0, %[[I0]], %[[FALSE]] : !torch.vtensor<[20],si32>, !torch.int, !torch.bool -> !torch.vtensor<[],si32>, !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[INDICES]] : !torch.vtensor<[],si64>
|
||||
func.func @argmax_rank_1(%arg0: !torch.vtensor<[20],si32>) -> !torch.vtensor<[],si64> {
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%7 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[20],si32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
|
||||
return %7 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
|
|
Loading…
Reference in New Issue