From 1d41f7b6fe2cd3fb0bb0d6574bb48637f9ae9c84 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Tue, 31 Oct 2023 22:56:54 -0500 Subject: [PATCH] Rework AtenEmptyStridedOp checks (#2537) Now using Value instead of Ints. Trades compile failure for a runtime assert --- e2e_testing/xfail_sets.py | 2 + .../torch-mlir/Dialect/Torch/Utils/Utils.h | 6 + .../Torch/Transforms/DecomposeComplexOps.cpp | 232 ++++++++---------- lib/Dialect/Torch/Utils/Utils.cpp | 58 +++++ .../test_suite/constant_alloc.py | 24 +- 5 files changed, 194 insertions(+), 128 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 7ab664dd6..62b40a3b5 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -764,6 +764,7 @@ STABLEHLO_PASS_SET = { "NewEmptyModuleNonDefaultIntDtype_basic", "NewEmptyStridedModuleDefaultDtype_basic", "EmptyStridedModule_basic", + "EmptyStridedSizeIntStrideModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", @@ -1440,6 +1441,7 @@ LTC_XFAIL_SET = { "UniformStaticShapeModule_basic", "AtenEmbeddingBagStaticModule_basic", "EmptyStridedModule_basic", + "EmptyStridedSizeIntStrideModule_basic", "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 9b715ec0d..3fa871b33 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -104,6 +104,12 @@ inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) { return isAssumingStrictSymbolicShapes(builder.getBlock()); } +// Helper function for AtenEmptyStrided and friends that checks if the stride +// values are default or not. Throws a runtime assert if not. +LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter, + Value opSize, Value opStride, + Location loc); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 26a6bbc5f..985bfead5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -332,7 +332,8 @@ public: rewriter.create(loc, one.getType(), start, length); rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, /*start=*/start, + op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, + /*start=*/start, /*end=*/startPlusLength, /*step=*/one); return success(); @@ -404,16 +405,15 @@ public: } // namespace namespace { -class DecomposeAtenZeroOp - : public OpRewritePattern { +class DecomposeAtenZeroOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenZeroOp op, PatternRewriter &rewriter) const override { Value zero = rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0)); - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - zero); + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getSelf(), zero); return success(); } }; @@ -1139,14 +1139,21 @@ public: Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); - Value maxZeroX = rewriter.create(loc, resType, zeroTensor, input); - Value positiveOutput = rewriter.create(loc, resType, maxZeroX, scale); - Value minZeroX = rewriter.create(loc, resType, zeroTensor, input); - Value scaledMinZeroX = rewriter.create(loc, resType, minZeroX, inputScale); + Value maxZeroX = + rewriter.create(loc, resType, zeroTensor, input); + Value positiveOutput = + rewriter.create(loc, resType, maxZeroX, scale); + Value minZeroX = + rewriter.create(loc, resType, zeroTensor, input); + Value scaledMinZeroX = + rewriter.create(loc, resType, minZeroX, inputScale); Value expX = rewriter.create(loc, resType, scaledMinZeroX); - Value expXM1 = rewriter.create(loc, resType, expX, constantOne, constantOne); - Value scaledExpXM1 = rewriter.create(loc, resType, expXM1, scale); - Value negativeOutput = rewriter.create(loc, resType, scaledExpXM1, alpha); + Value expXM1 = rewriter.create(loc, resType, expX, + constantOne, constantOne); + Value scaledExpXM1 = + rewriter.create(loc, resType, expXM1, scale); + Value negativeOutput = + rewriter.create(loc, resType, scaledExpXM1, alpha); Value eluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOne); @@ -1419,8 +1426,8 @@ public: rewriter.create(loc, listType, expandedSizes); Value reshapedDims = rewriter.create(loc, listType, reshapedSizes); - auto reshaped = rewriter.create(loc, unsqueezedType, op.getSelf(), - unsqueezedDims); + auto reshaped = rewriter.create(loc, unsqueezedType, + op.getSelf(), unsqueezedDims); auto expanded = rewriter.create(loc, expandedType, reshaped, expandedDims); @@ -1601,8 +1608,8 @@ public: return rewriter.notifyMatchFailure( op, "unimplemented: requires implicit to be false"); } - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - op.getSize()); + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getSelf(), op.getSize()); return success(); } }; @@ -1621,7 +1628,8 @@ public: return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf()); - Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther()); + Value otherTensor = + createRank0Tensor(rewriter, loc, resType, op.getOther()); rewriter.replaceOpWithNewOp(op, resType, op.getCondition(), selfTensor, otherTensor); return success(); @@ -1642,7 +1650,8 @@ public: if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } - Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther()); + Value otherTensor = + createRank0Tensor(rewriter, loc, resType, op.getOther()); rewriter.replaceOpWithNewOp(op, resType, op.getCondition(), op.getSelf(), otherTensor); return success(); @@ -1686,8 +1695,8 @@ public: } Value mask = op.getMask(); Value value = createRank0Tensor(rewriter, loc, resType, op.getValue()); - rewriter.replaceOpWithNewOp(op, resType, mask, - value, op.getSelf()); + rewriter.replaceOpWithNewOp(op, resType, mask, value, + op.getSelf()); return success(); } }; @@ -1747,8 +1756,8 @@ public: Value cstTrue = rewriter.create(op.getLoc(), true); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), - op.getStride(), op.getPadding(), op.getDilation(), /*transposed=*/cstTrue, - op.getOutputPadding(), op.getGroups()); + op.getStride(), op.getPadding(), op.getDilation(), + /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups()); return success(); } }; @@ -2500,9 +2509,9 @@ public: op, "aten.std.dim expects input tensor of floating-point type"); } - Value varDim = - rewriter.create(op->getLoc(), op.getType(), self, - op.getDim(), op.getUnbiased(), op.getKeepdim()); + Value varDim = rewriter.create( + op->getLoc(), op.getType(), self, op.getDim(), op.getUnbiased(), + op.getKeepdim()); rewriter.replaceOpWithNewOp(op, op.getType(), varDim); return success(); } @@ -2626,8 +2635,8 @@ public: Value one = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value emptyTensor = rewriter.create( - loc, resultType, input, zero, op.getDtype(), op.getLayout(), op.getDevice(), - op.getPinMemory(), op.getMemoryFormat()); + loc, resultType, input, zero, op.getDtype(), op.getLayout(), + op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); rewriter.replaceOpWithNewOp(op, resultType, emptyTensor, /*from=*/zero, /*to=*/one, /*generator=*/none); @@ -2829,7 +2838,8 @@ class DecomposeAtenNativeLayerNormOp SmallVector normalizedShapeSizesTorchInt; getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); - auto reduceDimInts = llvm::to_vector<4>(llvm::seq(axis, inputRank)); + auto reduceDimInts = + llvm::to_vector<4>(llvm::seq(axis, inputRank)); auto reducedTy = op.getResult(1).getType(); auto sizeListType = ListType::get(IntType::get(context)); @@ -2905,8 +2915,8 @@ public: Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getSelf()); rewriter.replaceOpWithNewOp( - op, op.getType(), sizeList, op.getDtype(), op.getLayout(), op.getDevice(), - op.getPinMemory(), op.getMemoryFormat()); + op, op.getType(), sizeList, op.getDtype(), op.getLayout(), + op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); return success(); } }; @@ -2927,8 +2937,8 @@ class DecomposeAtenArangeOp : public OpRewritePattern { step = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( - op, op.getType(), start, op.getEnd(), step, op.getDtype(), op.getLayout(), - op.getDevice(), op.getPinMemory()); + op, op.getType(), start, op.getEnd(), step, op.getDtype(), + op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; @@ -2947,8 +2957,8 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern { step = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( - op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), op.getLayout(), - op.getDevice(), op.getPinMemory()); + op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), + op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; @@ -3035,7 +3045,8 @@ class DecomposeAtenNativeBatchNormOp loc, ListType::get(IntType::get(context)), runningStatsShape); SmallVector runningStatsShapeInt(inputRank, 1); - runningStatsShapeInt[1] = runningMean.getType().cast().getSizes()[0]; + runningStatsShapeInt[1] = + runningMean.getType().cast().getSizes()[0]; Type dtype = input.getType().cast().getOptionalDtype(); Type reshapeType = ValueTensorType::get( context, llvm::ArrayRef(runningStatsShapeInt), dtype); @@ -3320,11 +3331,10 @@ public: getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); } rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSize(), op.getFillValue(), dtype, op.getLayout(), op.getDevice(), - op.getPinMemory()); + op, op.getType(), op.getSize(), op.getFillValue(), dtype, + op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); - } }; } // namespace @@ -3338,7 +3348,8 @@ public: PatternRewriter &rewriter) const override { Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), + op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), + op.getAccumulate(), /*unsafe=*/cstFalse); return success(); } @@ -3355,8 +3366,8 @@ class DecomposeAtenExpandAsOp : public OpRewritePattern { Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getOther()); - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - sizeList); + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getSelf(), sizeList); return success(); } }; @@ -3378,8 +3389,9 @@ public: Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0, resultDtype); Value emptyTensor = rewriter.create( - op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), op.getLayout(), - op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); + op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), + op.getLayout(), op.getDevice(), op.getPinMemory(), + op.getMemoryFormat()); rewriter.replaceOpWithNewOp(op, op.getType(), emptyTensor, op.getSelf(), op.getNonBlocking()); return success(); @@ -3450,7 +3462,8 @@ public: PatternRewriter &rewriter) const override { Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), + op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), + op.getAccumulate(), /*unsafe=*/cstFalse); return success(); } @@ -3539,9 +3552,9 @@ public: op, "unimplemented: layout is expected to be strided"); } - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - op.getDtype(), op.getNonBlocking(), - op.getCopy(), op.getMemoryFormat()); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), + op.getCopy(), op.getMemoryFormat()); return success(); } }; @@ -3557,9 +3570,9 @@ public: // Device information isn't relevant to torch-mlir, so we can drop that info // here. - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - op.getDtype(), op.getNonBlocking(), - op.getCopy(), op.getMemoryFormat()); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), + op.getCopy(), op.getMemoryFormat()); return success(); } @@ -3798,8 +3811,8 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenBaddbmmOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value bmm = - rewriter.create(loc, op.getType(), op.getBatch1(), op.getBatch2()); + Value bmm = rewriter.create(loc, op.getType(), op.getBatch1(), + op.getBatch2()); Value alphaTimesBmm = rewriter.create(loc, op.getType(), bmm, op.getAlpha()); Value input = op.getSelf(); @@ -4160,7 +4173,8 @@ public: resultType.getOptionalDtype()) .cast(); - Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); + Value sub = + createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); Value result = rewriter.create(loc, subType, sub); if (reductionType == torch_upstream::Reduction::None) { rewriter.replaceOp(op, result); @@ -4242,7 +4256,8 @@ public: rewriter.getF32Type()) .cast(); Value emptyTensor = rewriter.create( - loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), + loc, floatResultType, op.getSize(), /*dtype=*/none, + /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(), /*memoryFormat=*/none); @@ -4272,11 +4287,11 @@ public: Value low = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); - + rewriter.replaceOpWithNewOp( - op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), op.getLayout(), - op.getDevice(), op.getPinMemory()); - + op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), + op.getLayout(), op.getDevice(), op.getPinMemory()); + return success(); } }; @@ -4294,10 +4309,11 @@ public: Location loc = op.getLoc(); Value noneVal = rewriter.create(loc); Value var = rewriter.create( - loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), op.getKeepdim()); - Value mean = - rewriter.create(loc, op.getType(0), op.getSelf(), op.getDim(), - op.getKeepdim(), /*dtype=*/noneVal); + loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), + op.getKeepdim()); + Value mean = rewriter.create( + loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(), + /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } @@ -4394,14 +4410,16 @@ public: /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); - Value uOne = rewriter.create(loc, resultType, emptyTensorA, - /*from=*/low, - /*to=*/high, - /*generator=*/op.getGenerator()); - Value uTwo = rewriter.create(loc, resultType, emptyTensorB, - /*from=*/low, - /*to=*/high, - /*generator=*/op.getGenerator()); + Value uOne = + rewriter.create(loc, resultType, emptyTensorA, + /*from=*/low, + /*to=*/high, + /*generator=*/op.getGenerator()); + Value uTwo = + rewriter.create(loc, resultType, emptyTensorB, + /*from=*/low, + /*to=*/high, + /*generator=*/op.getGenerator()); Value logUOne = rewriter.create(loc, resultType, uOne); Value minusTwoLogUOne = @@ -4526,37 +4544,18 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op, PatternRewriter &rewriter) const override { - SmallVector sizeListInts, strideListInts; - if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) - return rewriter.notifyMatchFailure( - op, "all size list elements must be constant ints"); - if (!matchPattern(op.getStride(), - m_TorchListOfConstantInts(strideListInts))) - return rewriter.notifyMatchFailure( - op, "all stride list elements must be constant ints"); + Location loc = op.getLoc(); + Value opSize = op.getSize(); + Value opStride = op.getStride(); - // We only support the cases with default stride values. - // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) - // Here the stride[0] == size[1] * size[2], stride[1] == size[2], and - // stride[2] == 1. - bool isDefaultStride = true; - for (unsigned i = 0; i < strideListInts.size(); i++) { - int64_t defaultStride = 1; - for (unsigned j = i + 1; j < sizeListInts.size(); j++) - defaultStride *= sizeListInts[j]; - if (defaultStride != strideListInts[i]) { - isDefaultStride = false; - break; - } - } - - if (!isDefaultStride) + if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc))) return rewriter.notifyMatchFailure( - op, "only default strides supported for new_empty_strided op"); + op, "Unable to determine if stride is default"); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); + return success(); } }; @@ -4569,42 +4568,20 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEmptyStridedOp op, PatternRewriter &rewriter) const override { - SmallVector sizeListInts, strideListInts; - if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) - return rewriter.notifyMatchFailure( - op, "all size list elements must be constant ints"); - if (!matchPattern(op.getStride(), - m_TorchListOfConstantInts(strideListInts))) - return rewriter.notifyMatchFailure( - op, "all stride list elements must be constant ints"); + Location loc = op.getLoc(); + Value opSize = op.getSize(); + Value opStride = op.getStride(); - // We only support the cases with default stride values. - // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) - // Here the stride[0] == size[1] * size[2], stride[1] == size[2], and - // stride[2] == 1. - bool isDefaultStride = true; - for (unsigned i = 0; i < strideListInts.size(); i++) { - int64_t defaultStride = 1; - for (unsigned j = i + 1; j < sizeListInts.size(); j++) - defaultStride *= sizeListInts[j]; - if (defaultStride != strideListInts[i]) { - isDefaultStride = false; - break; - } - } - if (!isDefaultStride) + if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc))) return rewriter.notifyMatchFailure( - op, "only default strides supported for new_empty_strided op"); + op, "Unable to determine if stride is default"); Value noneVal = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), - op.getPinMemory(), /*memoryFormat=*/noneVal); - + op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), + op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal); return success(); - - } }; } // namespace @@ -4993,8 +4970,8 @@ public: auto selectGreater = rewriter.create(loc, outType, greater, one, zero); - rewriter.replaceOpWithNewOp(op, outType, greaterEqual, - selectGreater, minusOne); + rewriter.replaceOpWithNewOp( + op, outType, greaterEqual, selectGreater, minusOne); return success(); } }; @@ -5407,7 +5384,8 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index ddc95bd4b..14a264ada 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -333,3 +333,61 @@ bool Torch::isAssumingStrictSymbolicShapes(Block *block) { } return false; } + +LogicalResult Torch::checkDefaultStrideHelper(Operation *op, + PatternRewriter &rewriter, + Value opSize, Value opStride, + Location loc) { + + SmallVector sizeListInts, strideListInts; + if (matchPattern(opSize, m_TorchListOfConstantInts(sizeListInts)) && + matchPattern(opStride, m_TorchListOfConstantInts(strideListInts))) { + + // We only support the cases with default stride values. + // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) + // Here the stride[0] == size[1] * size[2], stride[1] == size[2], and + // stride[2] == 1. + bool isDefaultStride = true; + for (unsigned i = 0; i < strideListInts.size(); i++) { + int64_t defaultStride = 1; + for (unsigned j = i + 1; j < sizeListInts.size(); j++) + defaultStride *= sizeListInts[j]; + if (defaultStride != strideListInts[i]) { + isDefaultStride = false; + break; + } + } + if (!isDefaultStride) + return rewriter.notifyMatchFailure( + op, "only default strides supported for empty_strided op"); + + return success(); + + } else { + SmallVector sizeListValues; + if (!getListConstructElements(opSize, sizeListValues)) + return rewriter.notifyMatchFailure(op, "couldn't get size list values"); + SmallVector strideListValues; + if (!getListConstructElements(opStride, strideListValues)) + return rewriter.notifyMatchFailure(op, + "couldn't get stride list values."); + SmallVector boolVector; + for (unsigned i = 0; i < strideListValues.size(); i++) { + Value defaultStride = rewriter.createOrFold( + loc, rewriter.getI64IntegerAttr(1)); + for (unsigned j = i + 1; j < sizeListValues.size(); j++) { + defaultStride = rewriter.createOrFold( + loc, defaultStride, sizeListValues[j]); + } + boolVector.push_back(rewriter.createOrFold( + loc, defaultStride, strideListValues[i])); + } + Value allBoolOpList = rewriter.createOrFold( + loc, Torch::ListType::get(rewriter.getType()), + boolVector); + Value cmp = rewriter.createOrFold(loc, allBoolOpList); + rewriter.createOrFold( + loc, cmp, "not all strides are default"); + return success(); + } +} diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 27cf2eb4a..675e32757 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1629,7 +1629,6 @@ class NewEmptyStridedModuleDefaultDtype(torch.nn.Module): def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) - # ============================================================================== @@ -1651,4 +1650,27 @@ class EmptyStridedModule(torch.nn.Module): @register_test_case(module_factory=lambda: EmptyStridedModule()) def EmptyStridedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 3, 4)) + +# ============================================================================== + + +class EmptyStridedSizeIntStrideModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, -1], torch.float32, True), + ]) + def forward(self, a): + x = torch.ops.aten.empty_strided(a.size(), stride=[12, a.size(2), 1]) + y = x.copy_(a) + return y + + +@register_test_case(module_factory=lambda: EmptyStridedSizeIntStrideModule()) +def EmptyStridedSizeIntStrideModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4))