Rework AtenEmptyStridedOp checks (#2537)

Now using Value instead of Ints. Trades compile failure for a runtime
assert
pull/2538/head snapshot-20231101.1009
Daniel Garvey 2023-10-31 22:56:54 -05:00 committed by GitHub
parent 4199feffed
commit 1d41f7b6fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 194 additions and 128 deletions

View File

@ -764,6 +764,7 @@ STABLEHLO_PASS_SET = {
"NewEmptyModuleNonDefaultIntDtype_basic",
"NewEmptyStridedModuleDefaultDtype_basic",
"EmptyStridedModule_basic",
"EmptyStridedSizeIntStrideModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
@ -1440,6 +1441,7 @@ LTC_XFAIL_SET = {
"UniformStaticShapeModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"EmptyStridedModule_basic",
"EmptyStridedSizeIntStrideModule_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",

View File

@ -104,6 +104,12 @@ inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) {
return isAssumingStrictSymbolicShapes(builder.getBlock());
}
// Helper function for AtenEmptyStrided and friends that checks if the stride
// values are default or not. Throws a runtime assert if not.
LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter,
Value opSize, Value opStride,
Location loc);
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -332,7 +332,8 @@ public:
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, length);
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(
op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, /*start=*/start,
op, op.getResult().getType(), op.getSelf(), /*dim=*/dim,
/*start=*/start,
/*end=*/startPlusLength, /*step=*/one);
return success();
@ -404,16 +405,15 @@ public:
} // namespace
namespace {
class DecomposeAtenZeroOp
: public OpRewritePattern<AtenZeroOp> {
class DecomposeAtenZeroOp : public OpRewritePattern<AtenZeroOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenZeroOp op,
PatternRewriter &rewriter) const override {
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenFillScalarOp>(op, op.getType(), op.getSelf(),
zero);
rewriter.replaceOpWithNewOp<AtenFillScalarOp>(op, op.getType(),
op.getSelf(), zero);
return success();
}
};
@ -1139,14 +1139,21 @@ public:
Value constantOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
Value maxZeroX = rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
Value positiveOutput = rewriter.create<AtenMulScalarOp>(loc, resType, maxZeroX, scale);
Value minZeroX = rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
Value scaledMinZeroX = rewriter.create<AtenMulScalarOp>(loc, resType, minZeroX, inputScale);
Value maxZeroX =
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
Value positiveOutput =
rewriter.create<AtenMulScalarOp>(loc, resType, maxZeroX, scale);
Value minZeroX =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
Value scaledMinZeroX =
rewriter.create<AtenMulScalarOp>(loc, resType, minZeroX, inputScale);
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledMinZeroX);
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX, constantOne, constantOne);
Value scaledExpXM1 = rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, scale);
Value negativeOutput = rewriter.create<AtenMulScalarOp>(loc, resType, scaledExpXM1, alpha);
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX,
constantOne, constantOne);
Value scaledExpXM1 =
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, scale);
Value negativeOutput =
rewriter.create<AtenMulScalarOp>(loc, resType, scaledExpXM1, alpha);
Value eluOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOne);
@ -1419,8 +1426,8 @@ public:
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
Value reshapedDims =
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType, op.getSelf(),
unsqueezedDims);
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType,
op.getSelf(), unsqueezedDims);
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
reshaped, expandedDims);
@ -1601,8 +1608,8 @@ public:
return rewriter.notifyMatchFailure(
op, "unimplemented: requires implicit to be false");
}
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.getSelf(),
op.getSize());
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(),
op.getSelf(), op.getSize());
return success();
}
};
@ -1621,7 +1628,8 @@ public:
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
Value otherTensor =
createRank0Tensor(rewriter, loc, resType, op.getOther());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
selfTensor, otherTensor);
return success();
@ -1642,7 +1650,8 @@ public:
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
Value otherTensor =
createRank0Tensor(rewriter, loc, resType, op.getOther());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
op.getSelf(), otherTensor);
return success();
@ -1686,8 +1695,8 @@ public:
}
Value mask = op.getMask();
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
value, op.getSelf());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask, value,
op.getSelf());
return success();
}
};
@ -1747,8 +1756,8 @@ public:
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
op.getStride(), op.getPadding(), op.getDilation(), /*transposed=*/cstTrue,
op.getOutputPadding(), op.getGroups());
op.getStride(), op.getPadding(), op.getDilation(),
/*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups());
return success();
}
};
@ -2500,9 +2509,9 @@ public:
op, "aten.std.dim expects input tensor of floating-point type");
}
Value varDim =
rewriter.create<AtenVarDimOp>(op->getLoc(), op.getType(), self,
op.getDim(), op.getUnbiased(), op.getKeepdim());
Value varDim = rewriter.create<AtenVarDimOp>(
op->getLoc(), op.getType(), self, op.getDim(), op.getUnbiased(),
op.getKeepdim());
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varDim);
return success();
}
@ -2626,8 +2635,8 @@ public:
Value one =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
loc, resultType, input, zero, op.getDtype(), op.getLayout(), op.getDevice(),
op.getPinMemory(), op.getMemoryFormat());
loc, resultType, input, zero, op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
rewriter.replaceOpWithNewOp<AtenUniformOp>(op, resultType, emptyTensor,
/*from=*/zero, /*to=*/one,
/*generator=*/none);
@ -2829,7 +2838,8 @@ class DecomposeAtenNativeLayerNormOp
SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
auto reduceDimInts = llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank));
auto reduceDimInts =
llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank));
auto reducedTy = op.getResult(1).getType();
auto sizeListType = ListType::get(IntType::get(context));
@ -2905,8 +2915,8 @@ public:
Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), sizeList, op.getDtype(), op.getLayout(), op.getDevice(),
op.getPinMemory(), op.getMemoryFormat());
op, op.getType(), sizeList, op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
return success();
}
};
@ -2927,8 +2937,8 @@ class DecomposeAtenArangeOp : public OpRewritePattern<AtenArangeOp> {
step = rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
op, op.getType(), start, op.getEnd(), step, op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory());
op, op.getType(), start, op.getEnd(), step, op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory());
return success();
}
};
@ -2947,8 +2957,8 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
step = rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory());
op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory());
return success();
}
};
@ -3035,7 +3045,8 @@ class DecomposeAtenNativeBatchNormOp
loc, ListType::get(IntType::get(context)), runningStatsShape);
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
runningStatsShapeInt[1] = runningMean.getType().cast<BaseTensorType>().getSizes()[0];
runningStatsShapeInt[1] =
runningMean.getType().cast<BaseTensorType>().getSizes()[0];
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
Type reshapeType = ValueTensorType::get(
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
@ -3320,11 +3331,10 @@ public:
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
rewriter.replaceOpWithNewOp<AtenFullOp>(
op, op.getType(), op.getSize(), op.getFillValue(), dtype, op.getLayout(), op.getDevice(),
op.getPinMemory());
op, op.getType(), op.getSize(), op.getFillValue(), dtype,
op.getLayout(), op.getDevice(), op.getPinMemory());
return success();
}
};
} // namespace
@ -3338,7 +3348,8 @@ public:
PatternRewriter &rewriter) const override {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(),
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
op.getAccumulate(),
/*unsafe=*/cstFalse);
return success();
}
@ -3355,8 +3366,8 @@ class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getOther());
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.getSelf(),
sizeList);
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(),
op.getSelf(), sizeList);
return success();
}
};
@ -3378,8 +3389,9 @@ public:
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
resultDtype);
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory(),
op.getMemoryFormat());
rewriter.replaceOpWithNewOp<AtenCopyOp>(op, op.getType(), emptyTensor,
op.getSelf(), op.getNonBlocking());
return success();
@ -3450,7 +3462,8 @@ public:
PatternRewriter &rewriter) const override {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(),
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
op.getAccumulate(),
/*unsafe=*/cstFalse);
return success();
}
@ -3539,9 +3552,9 @@ public:
op, "unimplemented: layout is expected to be strided");
}
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
op.getDtype(), op.getNonBlocking(),
op.getCopy(), op.getMemoryFormat());
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(),
op.getCopy(), op.getMemoryFormat());
return success();
}
};
@ -3557,9 +3570,9 @@ public:
// Device information isn't relevant to torch-mlir, so we can drop that info
// here.
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
op.getDtype(), op.getNonBlocking(),
op.getCopy(), op.getMemoryFormat());
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(),
op.getCopy(), op.getMemoryFormat());
return success();
}
@ -3798,8 +3811,8 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern<AtenBaddbmmOp> {
LogicalResult matchAndRewrite(AtenBaddbmmOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value bmm =
rewriter.create<AtenBmmOp>(loc, op.getType(), op.getBatch1(), op.getBatch2());
Value bmm = rewriter.create<AtenBmmOp>(loc, op.getType(), op.getBatch1(),
op.getBatch2());
Value alphaTimesBmm =
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha());
Value input = op.getSelf();
@ -4160,7 +4173,8 @@ public:
resultType.getOptionalDtype())
.cast<BaseTensorType>();
Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
Value sub =
createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
Value result = rewriter.create<AtenSquareOp>(loc, subType, sub);
if (reductionType == torch_upstream::Reduction::None) {
rewriter.replaceOp(op, result);
@ -4242,7 +4256,8 @@ public:
rewriter.getF32Type())
.cast<BaseTensorType>();
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
loc, floatResultType, op.getSize(), /*dtype=*/none,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(),
/*memoryFormat=*/none);
@ -4272,11 +4287,11 @@ public:
Value low = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenRandintLowOp>(
op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory());
op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory());
return success();
}
};
@ -4294,10 +4309,11 @@ public:
Location loc = op.getLoc();
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value var = rewriter.create<AtenVarCorrectionOp>(
loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), op.getKeepdim());
Value mean =
rewriter.create<AtenMeanDimOp>(loc, op.getType(0), op.getSelf(), op.getDim(),
op.getKeepdim(), /*dtype=*/noneVal);
loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(),
op.getKeepdim());
Value mean = rewriter.create<AtenMeanDimOp>(
loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(),
/*dtype=*/noneVal);
rewriter.replaceOp(op, {var, mean});
return success();
}
@ -4394,14 +4410,16 @@ public:
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/none);
Value uOne = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorA,
/*from=*/low,
/*to=*/high,
/*generator=*/op.getGenerator());
Value uTwo = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorB,
/*from=*/low,
/*to=*/high,
/*generator=*/op.getGenerator());
Value uOne =
rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorA,
/*from=*/low,
/*to=*/high,
/*generator=*/op.getGenerator());
Value uTwo =
rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorB,
/*from=*/low,
/*to=*/high,
/*generator=*/op.getGenerator());
Value logUOne = rewriter.create<AtenLogOp>(loc, resultType, uOne);
Value minusTwoLogUOne =
@ -4526,37 +4544,18 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> sizeListInts, strideListInts;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
return rewriter.notifyMatchFailure(
op, "all size list elements must be constant ints");
if (!matchPattern(op.getStride(),
m_TorchListOfConstantInts(strideListInts)))
return rewriter.notifyMatchFailure(
op, "all stride list elements must be constant ints");
Location loc = op.getLoc();
Value opSize = op.getSize();
Value opStride = op.getStride();
// We only support the cases with default stride values.
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
// stride[2] == 1.
bool isDefaultStride = true;
for (unsigned i = 0; i < strideListInts.size(); i++) {
int64_t defaultStride = 1;
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
defaultStride *= sizeListInts[j];
if (defaultStride != strideListInts[i]) {
isDefaultStride = false;
break;
}
}
if (!isDefaultStride)
if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc)))
return rewriter.notifyMatchFailure(
op, "only default strides supported for new_empty_strided op");
op, "Unable to determine if stride is default");
rewriter.replaceOpWithNewOp<AtenNewEmptyOp>(
op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory());
return success();
}
};
@ -4569,42 +4568,20 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> sizeListInts, strideListInts;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
return rewriter.notifyMatchFailure(
op, "all size list elements must be constant ints");
if (!matchPattern(op.getStride(),
m_TorchListOfConstantInts(strideListInts)))
return rewriter.notifyMatchFailure(
op, "all stride list elements must be constant ints");
Location loc = op.getLoc();
Value opSize = op.getSize();
Value opStride = op.getStride();
// We only support the cases with default stride values.
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
// stride[2] == 1.
bool isDefaultStride = true;
for (unsigned i = 0; i < strideListInts.size(); i++) {
int64_t defaultStride = 1;
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
defaultStride *= sizeListInts[j];
if (defaultStride != strideListInts[i]) {
isDefaultStride = false;
break;
}
}
if (!isDefaultStride)
if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc)))
return rewriter.notifyMatchFailure(
op, "only default strides supported for new_empty_strided op");
op, "Unable to determine if stride is default");
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(),
op.getPinMemory(), /*memoryFormat=*/noneVal);
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal);
return success();
}
};
} // namespace
@ -4993,8 +4970,8 @@ public:
auto selectGreater =
rewriter.create<AtenWhereScalarOp>(loc, outType, greater, one, zero);
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(op, outType, greaterEqual,
selectGreater, minusOne);
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
op, outType, greaterEqual, selectGreater, minusOne);
return success();
}
};
@ -5407,7 +5384,8 @@ public:
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);

View File

@ -333,3 +333,61 @@ bool Torch::isAssumingStrictSymbolicShapes(Block *block) {
}
return false;
}
LogicalResult Torch::checkDefaultStrideHelper(Operation *op,
PatternRewriter &rewriter,
Value opSize, Value opStride,
Location loc) {
SmallVector<int64_t> sizeListInts, strideListInts;
if (matchPattern(opSize, m_TorchListOfConstantInts(sizeListInts)) &&
matchPattern(opStride, m_TorchListOfConstantInts(strideListInts))) {
// We only support the cases with default stride values.
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
// stride[2] == 1.
bool isDefaultStride = true;
for (unsigned i = 0; i < strideListInts.size(); i++) {
int64_t defaultStride = 1;
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
defaultStride *= sizeListInts[j];
if (defaultStride != strideListInts[i]) {
isDefaultStride = false;
break;
}
}
if (!isDefaultStride)
return rewriter.notifyMatchFailure(
op, "only default strides supported for empty_strided op");
return success();
} else {
SmallVector<Value> sizeListValues;
if (!getListConstructElements(opSize, sizeListValues))
return rewriter.notifyMatchFailure(op, "couldn't get size list values");
SmallVector<Value> strideListValues;
if (!getListConstructElements(opStride, strideListValues))
return rewriter.notifyMatchFailure(op,
"couldn't get stride list values.");
SmallVector<Value> boolVector;
for (unsigned i = 0; i < strideListValues.size(); i++) {
Value defaultStride = rewriter.createOrFold<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
for (unsigned j = i + 1; j < sizeListValues.size(); j++) {
defaultStride = rewriter.createOrFold<Torch::AtenMulIntOp>(
loc, defaultStride, sizeListValues[j]);
}
boolVector.push_back(rewriter.createOrFold<Torch::AtenEqIntOp>(
loc, defaultStride, strideListValues[i]));
}
Value allBoolOpList = rewriter.createOrFold<PrimListConstructOp>(
loc, Torch::ListType::get(rewriter.getType<Torch::BoolType>()),
boolVector);
Value cmp = rewriter.createOrFold<Torch::AtenAllBoolOp>(loc, allBoolOpList);
rewriter.createOrFold<Torch::RuntimeAssertOp>(
loc, cmp, "not all strides are default");
return success();
}
}

View File

@ -1629,7 +1629,6 @@ class NewEmptyStridedModuleDefaultDtype(torch.nn.Module):
def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
@ -1651,4 +1650,27 @@ class EmptyStridedModule(torch.nn.Module):
@register_test_case(module_factory=lambda: EmptyStridedModule())
def EmptyStridedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, 4))
# ==============================================================================
class EmptyStridedSizeIntStrideModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, -1], torch.float32, True),
])
def forward(self, a):
x = torch.ops.aten.empty_strided(a.size(), stride=[12, a.size(2), 1])
y = x.copy_(a)
return y
@register_test_case(module_factory=lambda: EmptyStridedSizeIntStrideModule())
def EmptyStridedSizeIntStrideModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))