mirror of https://github.com/llvm/torch-mlir
[Torch] Address unnecessary dynamic shapes in argmax decomposition (#3889)
Addresses <https://github.com/iree-org/iree/issues/19262#issue>main
parent
0913b967ac
commit
99115dcdc8
|
@ -2593,16 +2593,22 @@ public:
|
||||||
// first the input tensor is flattened to 1d tensor and then the reduction
|
// first the input tensor is flattened to 1d tensor and then the reduction
|
||||||
// happens on the 0th dimension.
|
// happens on the 0th dimension.
|
||||||
if (isa<Torch::NoneType>(dim.getType())) {
|
if (isa<Torch::NoneType>(dim.getType())) {
|
||||||
BaseTensorType flattenType =
|
Value zero = rewriter.create<ConstantIntOp>(loc, 0);
|
||||||
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
|
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
|
||||||
{kUnknownSize}, inputType.getOptionalDtype()));
|
if (inputType.getSizes().size() > 1) {
|
||||||
Value zero =
|
int64_t flattenSize = Torch::kUnknownSize;
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
if (inputType.areAllSizesKnown()) {
|
||||||
|
flattenSize = 1;
|
||||||
|
for (int64_t sze : inputType.getSizes())
|
||||||
|
flattenSize *= sze;
|
||||||
|
}
|
||||||
|
auto flattenType = cast<BaseTensorType>(inputType.getWithSizesAndDtype(
|
||||||
|
{flattenSize}, inputType.getOptionalDtype()));
|
||||||
Value end = rewriter.create<ConstantIntOp>(
|
Value end = rewriter.create<ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
||||||
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
|
|
||||||
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
||||||
zero, end);
|
zero, end);
|
||||||
|
}
|
||||||
Value resultIndices =
|
Value resultIndices =
|
||||||
rewriter
|
rewriter
|
||||||
.create<DecompOpTy>(
|
.create<DecompOpTy>(
|
||||||
|
|
|
@ -545,10 +545,6 @@ FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
|
|
||||||
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"AddFloatIntModule_basic",
|
"AddFloatIntModule_basic",
|
||||||
"ArgmaxIntModule_basic",
|
|
||||||
"ArgmaxIntModule_multiple_maxs",
|
|
||||||
"ArgmaxKeepdimModule_basic",
|
|
||||||
"ArgmaxModule_basic",
|
|
||||||
"AtenKthvalueDynamicDimsModule_basic",
|
"AtenKthvalueDynamicDimsModule_basic",
|
||||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||||
"AtenKthvalueFloat64Module_basic",
|
"AtenKthvalueFloat64Module_basic",
|
||||||
|
@ -618,9 +614,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"AnyBoolFalseModule_basic",
|
"AnyBoolFalseModule_basic",
|
||||||
"AnyBoolTrueModule_basic",
|
"AnyBoolTrueModule_basic",
|
||||||
"ArangeStartOutViewModule_basic",
|
"ArangeStartOutViewModule_basic",
|
||||||
"ArgminIntModule_basic",
|
|
||||||
"ArgminIntModule_multiple_mins",
|
|
||||||
"ArgminModule_basic",
|
|
||||||
"AtenComplexImagModule_basic",
|
"AtenComplexImagModule_basic",
|
||||||
"AtenComplexRealModule_basic",
|
"AtenComplexRealModule_basic",
|
||||||
"AtenComplexViewModule_basic",
|
"AtenComplexViewModule_basic",
|
||||||
|
|
|
@ -25,6 +25,19 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
|
||||||
return %0 : !torch.tensor
|
return %0 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func.func @argmax_rank_1
|
||||||
|
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[VALUES:.*]], %[[INDICES:.*]] = torch.aten.max.dim %arg0, %[[I0]], %[[FALSE]] : !torch.vtensor<[20],si32>, !torch.int, !torch.bool -> !torch.vtensor<[],si32>, !torch.vtensor<[],si64>
|
||||||
|
// CHECK: return %[[INDICES]] : !torch.vtensor<[],si64>
|
||||||
|
func.func @argmax_rank_1(%arg0: !torch.vtensor<[20],si32>) -> !torch.vtensor<[],si64> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%7 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[20],si32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
|
||||||
|
return %7 : !torch.vtensor<[],si64>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
|
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
|
||||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
|
||||||
|
|
Loading…
Reference in New Issue