Rework AtenEmptyStridedOp checks (#2537)

Now using Value instead of Ints. Trades compile failure for a runtime
assert
pull/2538/head snapshot-20231101.1009
Daniel Garvey 2023-10-31 22:56:54 -05:00 committed by GitHub
parent 4199feffed
commit 1d41f7b6fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 194 additions and 128 deletions

View File

@ -764,6 +764,7 @@ STABLEHLO_PASS_SET = {
"NewEmptyModuleNonDefaultIntDtype_basic", "NewEmptyModuleNonDefaultIntDtype_basic",
"NewEmptyStridedModuleDefaultDtype_basic", "NewEmptyStridedModuleDefaultDtype_basic",
"EmptyStridedModule_basic", "EmptyStridedModule_basic",
"EmptyStridedSizeIntStrideModule_basic",
"PermuteModule_basic", "PermuteModule_basic",
"PermuteNegativeIndexModule_basic", "PermuteNegativeIndexModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
@ -1440,6 +1441,7 @@ LTC_XFAIL_SET = {
"UniformStaticShapeModule_basic", "UniformStaticShapeModule_basic",
"AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagStaticModule_basic",
"EmptyStridedModule_basic", "EmptyStridedModule_basic",
"EmptyStridedSizeIntStrideModule_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic",

View File

@ -104,6 +104,12 @@ inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) {
return isAssumingStrictSymbolicShapes(builder.getBlock()); return isAssumingStrictSymbolicShapes(builder.getBlock());
} }
// Helper function for AtenEmptyStrided and friends that checks if the stride
// values are default or not. Throws a runtime assert if not.
LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter,
Value opSize, Value opStride,
Location loc);
} // namespace Torch } // namespace Torch
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -332,7 +332,8 @@ public:
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, length); rewriter.create<AtenAddIntOp>(loc, one.getType(), start, length);
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>( rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(
op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, /*start=*/start, op, op.getResult().getType(), op.getSelf(), /*dim=*/dim,
/*start=*/start,
/*end=*/startPlusLength, /*step=*/one); /*end=*/startPlusLength, /*step=*/one);
return success(); return success();
@ -404,16 +405,15 @@ public:
} // namespace } // namespace
namespace { namespace {
class DecomposeAtenZeroOp class DecomposeAtenZeroOp : public OpRewritePattern<AtenZeroOp> {
: public OpRewritePattern<AtenZeroOp> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenZeroOp op, LogicalResult matchAndRewrite(AtenZeroOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(), Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
rewriter.getI64IntegerAttr(0)); rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenFillScalarOp>(op, op.getType(), op.getSelf(), rewriter.replaceOpWithNewOp<AtenFillScalarOp>(op, op.getType(),
zero); op.getSelf(), zero);
return success(); return success();
} }
}; };
@ -1139,14 +1139,21 @@ public:
Value constantOne = Value constantOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0)); rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
Value maxZeroX = rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input); Value maxZeroX =
Value positiveOutput = rewriter.create<AtenMulScalarOp>(loc, resType, maxZeroX, scale); rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
Value minZeroX = rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input); Value positiveOutput =
Value scaledMinZeroX = rewriter.create<AtenMulScalarOp>(loc, resType, minZeroX, inputScale); rewriter.create<AtenMulScalarOp>(loc, resType, maxZeroX, scale);
Value minZeroX =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
Value scaledMinZeroX =
rewriter.create<AtenMulScalarOp>(loc, resType, minZeroX, inputScale);
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledMinZeroX); Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledMinZeroX);
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX, constantOne, constantOne); Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX,
Value scaledExpXM1 = rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, scale); constantOne, constantOne);
Value negativeOutput = rewriter.create<AtenMulScalarOp>(loc, resType, scaledExpXM1, alpha); Value scaledExpXM1 =
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, scale);
Value negativeOutput =
rewriter.create<AtenMulScalarOp>(loc, resType, scaledExpXM1, alpha);
Value eluOutput = rewriter.create<AtenAddTensorOp>( Value eluOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOne); loc, resType, positiveOutput, negativeOutput, constantOne);
@ -1419,8 +1426,8 @@ public:
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes); rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
Value reshapedDims = Value reshapedDims =
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes); rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType, op.getSelf(), auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType,
unsqueezedDims); op.getSelf(), unsqueezedDims);
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType, auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
reshaped, expandedDims); reshaped, expandedDims);
@ -1601,8 +1608,8 @@ public:
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: requires implicit to be false"); op, "unimplemented: requires implicit to be false");
} }
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.getSelf(), rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(),
op.getSize()); op.getSelf(), op.getSize());
return success(); return success();
} }
}; };
@ -1621,7 +1628,8 @@ public:
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf()); Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther()); Value otherTensor =
createRank0Tensor(rewriter, loc, resType, op.getOther());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(), rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
selfTensor, otherTensor); selfTensor, otherTensor);
return success(); return success();
@ -1642,7 +1650,8 @@ public:
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther()); Value otherTensor =
createRank0Tensor(rewriter, loc, resType, op.getOther());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(), rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
op.getSelf(), otherTensor); op.getSelf(), otherTensor);
return success(); return success();
@ -1686,8 +1695,8 @@ public:
} }
Value mask = op.getMask(); Value mask = op.getMask();
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue()); Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask, rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask, value,
value, op.getSelf()); op.getSelf());
return success(); return success();
} }
}; };
@ -1747,8 +1756,8 @@ public:
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true); Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>( rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
op.getStride(), op.getPadding(), op.getDilation(), /*transposed=*/cstTrue, op.getStride(), op.getPadding(), op.getDilation(),
op.getOutputPadding(), op.getGroups()); /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups());
return success(); return success();
} }
}; };
@ -2500,9 +2509,9 @@ public:
op, "aten.std.dim expects input tensor of floating-point type"); op, "aten.std.dim expects input tensor of floating-point type");
} }
Value varDim = Value varDim = rewriter.create<AtenVarDimOp>(
rewriter.create<AtenVarDimOp>(op->getLoc(), op.getType(), self, op->getLoc(), op.getType(), self, op.getDim(), op.getUnbiased(),
op.getDim(), op.getUnbiased(), op.getKeepdim()); op.getKeepdim());
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varDim); rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varDim);
return success(); return success();
} }
@ -2626,8 +2635,8 @@ public:
Value one = Value one =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0)); rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value emptyTensor = rewriter.create<AtenFullLikeOp>( Value emptyTensor = rewriter.create<AtenFullLikeOp>(
loc, resultType, input, zero, op.getDtype(), op.getLayout(), op.getDevice(), loc, resultType, input, zero, op.getDtype(), op.getLayout(),
op.getPinMemory(), op.getMemoryFormat()); op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
rewriter.replaceOpWithNewOp<AtenUniformOp>(op, resultType, emptyTensor, rewriter.replaceOpWithNewOp<AtenUniformOp>(op, resultType, emptyTensor,
/*from=*/zero, /*to=*/one, /*from=*/zero, /*to=*/one,
/*generator=*/none); /*generator=*/none);
@ -2829,7 +2838,8 @@ class DecomposeAtenNativeLayerNormOp
SmallVector<Value> normalizedShapeSizesTorchInt; SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
auto reduceDimInts = llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank)); auto reduceDimInts =
llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank));
auto reducedTy = op.getResult(1).getType(); auto reducedTy = op.getResult(1).getType();
auto sizeListType = ListType::get(IntType::get(context)); auto sizeListType = ListType::get(IntType::get(context));
@ -2905,8 +2915,8 @@ public:
Value sizeList = Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf()); rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>( rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), sizeList, op.getDtype(), op.getLayout(), op.getDevice(), op, op.getType(), sizeList, op.getDtype(), op.getLayout(),
op.getPinMemory(), op.getMemoryFormat()); op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
return success(); return success();
} }
}; };
@ -2927,8 +2937,8 @@ class DecomposeAtenArangeOp : public OpRewritePattern<AtenArangeOp> {
step = rewriter.create<Torch::ConstantIntOp>(loc, step = rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(1)); rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>( rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
op, op.getType(), start, op.getEnd(), step, op.getDtype(), op.getLayout(), op, op.getType(), start, op.getEnd(), step, op.getDtype(),
op.getDevice(), op.getPinMemory()); op.getLayout(), op.getDevice(), op.getPinMemory());
return success(); return success();
} }
}; };
@ -2947,8 +2957,8 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
step = rewriter.create<Torch::ConstantIntOp>(loc, step = rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(1)); rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>( rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), op.getLayout(), op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(),
op.getDevice(), op.getPinMemory()); op.getLayout(), op.getDevice(), op.getPinMemory());
return success(); return success();
} }
}; };
@ -3035,7 +3045,8 @@ class DecomposeAtenNativeBatchNormOp
loc, ListType::get(IntType::get(context)), runningStatsShape); loc, ListType::get(IntType::get(context)), runningStatsShape);
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1); SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
runningStatsShapeInt[1] = runningMean.getType().cast<BaseTensorType>().getSizes()[0]; runningStatsShapeInt[1] =
runningMean.getType().cast<BaseTensorType>().getSizes()[0];
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype(); Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
Type reshapeType = ValueTensorType::get( Type reshapeType = ValueTensorType::get(
context, llvm::ArrayRef(runningStatsShapeInt), dtype); context, llvm::ArrayRef(runningStatsShapeInt), dtype);
@ -3320,11 +3331,10 @@ public:
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
} }
rewriter.replaceOpWithNewOp<AtenFullOp>( rewriter.replaceOpWithNewOp<AtenFullOp>(
op, op.getType(), op.getSize(), op.getFillValue(), dtype, op.getLayout(), op.getDevice(), op, op.getType(), op.getSize(), op.getFillValue(), dtype,
op.getPinMemory()); op.getLayout(), op.getDevice(), op.getPinMemory());
return success(); return success();
} }
}; };
} // namespace } // namespace
@ -3338,7 +3348,8 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false); Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>( rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
op.getAccumulate(),
/*unsafe=*/cstFalse); /*unsafe=*/cstFalse);
return success(); return success();
} }
@ -3355,8 +3366,8 @@ class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
Torch::ListType::get(Torch::IntType::get(op.getContext())); Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList = Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getOther()); rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getOther());
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.getSelf(), rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(),
sizeList); op.getSelf(), sizeList);
return success(); return success();
} }
}; };
@ -3378,8 +3389,9 @@ public:
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0, Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
resultDtype); resultDtype);
Value emptyTensor = rewriter.create<AtenFullLikeOp>( Value emptyTensor = rewriter.create<AtenFullLikeOp>(
op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), op.getLayout(), op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(),
op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); op.getLayout(), op.getDevice(), op.getPinMemory(),
op.getMemoryFormat());
rewriter.replaceOpWithNewOp<AtenCopyOp>(op, op.getType(), emptyTensor, rewriter.replaceOpWithNewOp<AtenCopyOp>(op, op.getType(), emptyTensor,
op.getSelf(), op.getNonBlocking()); op.getSelf(), op.getNonBlocking());
return success(); return success();
@ -3450,7 +3462,8 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false); Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>( rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
op.getAccumulate(),
/*unsafe=*/cstFalse); /*unsafe=*/cstFalse);
return success(); return success();
} }
@ -3539,8 +3552,8 @@ public:
op, "unimplemented: layout is expected to be strided"); op, "unimplemented: layout is expected to be strided");
} }
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(), rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
op.getDtype(), op.getNonBlocking(), op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(),
op.getCopy(), op.getMemoryFormat()); op.getCopy(), op.getMemoryFormat());
return success(); return success();
} }
@ -3557,8 +3570,8 @@ public:
// Device information isn't relevant to torch-mlir, so we can drop that info // Device information isn't relevant to torch-mlir, so we can drop that info
// here. // here.
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(), rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
op.getDtype(), op.getNonBlocking(), op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(),
op.getCopy(), op.getMemoryFormat()); op.getCopy(), op.getMemoryFormat());
return success(); return success();
@ -3798,8 +3811,8 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern<AtenBaddbmmOp> {
LogicalResult matchAndRewrite(AtenBaddbmmOp op, LogicalResult matchAndRewrite(AtenBaddbmmOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value bmm = Value bmm = rewriter.create<AtenBmmOp>(loc, op.getType(), op.getBatch1(),
rewriter.create<AtenBmmOp>(loc, op.getType(), op.getBatch1(), op.getBatch2()); op.getBatch2());
Value alphaTimesBmm = Value alphaTimesBmm =
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha()); rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha());
Value input = op.getSelf(); Value input = op.getSelf();
@ -4160,7 +4173,8 @@ public:
resultType.getOptionalDtype()) resultType.getOptionalDtype())
.cast<BaseTensorType>(); .cast<BaseTensorType>();
Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); Value sub =
createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
Value result = rewriter.create<AtenSquareOp>(loc, subType, sub); Value result = rewriter.create<AtenSquareOp>(loc, subType, sub);
if (reductionType == torch_upstream::Reduction::None) { if (reductionType == torch_upstream::Reduction::None) {
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
@ -4242,7 +4256,8 @@ public:
rewriter.getF32Type()) rewriter.getF32Type())
.cast<BaseTensorType>(); .cast<BaseTensorType>();
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>( Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), loc, floatResultType, op.getSize(), /*dtype=*/none,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(), /*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(),
/*memoryFormat=*/none); /*memoryFormat=*/none);
@ -4274,8 +4289,8 @@ public:
loc, rewriter.getI64IntegerAttr(0)); loc, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenRandintLowOp>( rewriter.replaceOpWithNewOp<AtenRandintLowOp>(
op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), op.getLayout(), op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(),
op.getDevice(), op.getPinMemory()); op.getLayout(), op.getDevice(), op.getPinMemory());
return success(); return success();
} }
@ -4294,10 +4309,11 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value noneVal = rewriter.create<ConstantNoneOp>(loc); Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value var = rewriter.create<AtenVarCorrectionOp>( Value var = rewriter.create<AtenVarCorrectionOp>(
loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), op.getKeepdim()); loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(),
Value mean = op.getKeepdim());
rewriter.create<AtenMeanDimOp>(loc, op.getType(0), op.getSelf(), op.getDim(), Value mean = rewriter.create<AtenMeanDimOp>(
op.getKeepdim(), /*dtype=*/noneVal); loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(),
/*dtype=*/noneVal);
rewriter.replaceOp(op, {var, mean}); rewriter.replaceOp(op, {var, mean});
return success(); return success();
} }
@ -4394,11 +4410,13 @@ public:
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/none); /*memory_format=*/none);
Value uOne = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorA, Value uOne =
rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorA,
/*from=*/low, /*from=*/low,
/*to=*/high, /*to=*/high,
/*generator=*/op.getGenerator()); /*generator=*/op.getGenerator());
Value uTwo = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorB, Value uTwo =
rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorB,
/*from=*/low, /*from=*/low,
/*to=*/high, /*to=*/high,
/*generator=*/op.getGenerator()); /*generator=*/op.getGenerator());
@ -4526,37 +4544,18 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op, LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
SmallVector<int64_t> sizeListInts, strideListInts; Location loc = op.getLoc();
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) Value opSize = op.getSize();
return rewriter.notifyMatchFailure( Value opStride = op.getStride();
op, "all size list elements must be constant ints");
if (!matchPattern(op.getStride(),
m_TorchListOfConstantInts(strideListInts)))
return rewriter.notifyMatchFailure(
op, "all stride list elements must be constant ints");
// We only support the cases with default stride values. if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc)))
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
// stride[2] == 1.
bool isDefaultStride = true;
for (unsigned i = 0; i < strideListInts.size(); i++) {
int64_t defaultStride = 1;
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
defaultStride *= sizeListInts[j];
if (defaultStride != strideListInts[i]) {
isDefaultStride = false;
break;
}
}
if (!isDefaultStride)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only default strides supported for new_empty_strided op"); op, "Unable to determine if stride is default");
rewriter.replaceOpWithNewOp<AtenNewEmptyOp>( rewriter.replaceOpWithNewOp<AtenNewEmptyOp>(
op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(), op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory()); op.getLayout(), op.getDevice(), op.getPinMemory());
return success(); return success();
} }
}; };
@ -4569,42 +4568,20 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenEmptyStridedOp op, LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
SmallVector<int64_t> sizeListInts, strideListInts; Location loc = op.getLoc();
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) Value opSize = op.getSize();
return rewriter.notifyMatchFailure( Value opStride = op.getStride();
op, "all size list elements must be constant ints");
if (!matchPattern(op.getStride(),
m_TorchListOfConstantInts(strideListInts)))
return rewriter.notifyMatchFailure(
op, "all stride list elements must be constant ints");
// We only support the cases with default stride values. if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc)))
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
// stride[2] == 1.
bool isDefaultStride = true;
for (unsigned i = 0; i < strideListInts.size(); i++) {
int64_t defaultStride = 1;
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
defaultStride *= sizeListInts[j];
if (defaultStride != strideListInts[i]) {
isDefaultStride = false;
break;
}
}
if (!isDefaultStride)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only default strides supported for new_empty_strided op"); op, "Unable to determine if stride is default");
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc()); Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>( rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(),
op.getPinMemory(), /*memoryFormat=*/noneVal); op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal);
return success(); return success();
} }
}; };
} // namespace } // namespace
@ -4993,8 +4970,8 @@ public:
auto selectGreater = auto selectGreater =
rewriter.create<AtenWhereScalarOp>(loc, outType, greater, one, zero); rewriter.create<AtenWhereScalarOp>(loc, outType, greater, one, zero);
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(op, outType, greaterEqual, rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
selectGreater, minusOne); op, outType, greaterEqual, selectGreater, minusOne);
return success(); return success();
} }
}; };
@ -5407,7 +5384,8 @@ public:
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns); addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);

View File

@ -333,3 +333,61 @@ bool Torch::isAssumingStrictSymbolicShapes(Block *block) {
} }
return false; return false;
} }
LogicalResult Torch::checkDefaultStrideHelper(Operation *op,
PatternRewriter &rewriter,
Value opSize, Value opStride,
Location loc) {
SmallVector<int64_t> sizeListInts, strideListInts;
if (matchPattern(opSize, m_TorchListOfConstantInts(sizeListInts)) &&
matchPattern(opStride, m_TorchListOfConstantInts(strideListInts))) {
// We only support the cases with default stride values.
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
// stride[2] == 1.
bool isDefaultStride = true;
for (unsigned i = 0; i < strideListInts.size(); i++) {
int64_t defaultStride = 1;
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
defaultStride *= sizeListInts[j];
if (defaultStride != strideListInts[i]) {
isDefaultStride = false;
break;
}
}
if (!isDefaultStride)
return rewriter.notifyMatchFailure(
op, "only default strides supported for empty_strided op");
return success();
} else {
SmallVector<Value> sizeListValues;
if (!getListConstructElements(opSize, sizeListValues))
return rewriter.notifyMatchFailure(op, "couldn't get size list values");
SmallVector<Value> strideListValues;
if (!getListConstructElements(opStride, strideListValues))
return rewriter.notifyMatchFailure(op,
"couldn't get stride list values.");
SmallVector<Value> boolVector;
for (unsigned i = 0; i < strideListValues.size(); i++) {
Value defaultStride = rewriter.createOrFold<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
for (unsigned j = i + 1; j < sizeListValues.size(); j++) {
defaultStride = rewriter.createOrFold<Torch::AtenMulIntOp>(
loc, defaultStride, sizeListValues[j]);
}
boolVector.push_back(rewriter.createOrFold<Torch::AtenEqIntOp>(
loc, defaultStride, strideListValues[i]));
}
Value allBoolOpList = rewriter.createOrFold<PrimListConstructOp>(
loc, Torch::ListType::get(rewriter.getType<Torch::BoolType>()),
boolVector);
Value cmp = rewriter.createOrFold<Torch::AtenAllBoolOp>(loc, allBoolOpList);
rewriter.createOrFold<Torch::RuntimeAssertOp>(
loc, cmp, "not all strides are default");
return success();
}
}

View File

@ -1629,7 +1629,6 @@ class NewEmptyStridedModuleDefaultDtype(torch.nn.Module):
def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils): def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4)) module.forward(tu.rand(2, 3, 4))
# ============================================================================== # ==============================================================================
@ -1651,4 +1650,27 @@ class EmptyStridedModule(torch.nn.Module):
@register_test_case(module_factory=lambda: EmptyStridedModule()) @register_test_case(module_factory=lambda: EmptyStridedModule())
def EmptyStridedModule_basic(module, tu: TestUtils): def EmptyStridedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, 4))
# ==============================================================================
class EmptyStridedSizeIntStrideModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, -1], torch.float32, True),
])
def forward(self, a):
x = torch.ops.aten.empty_strided(a.size(), stride=[12, a.size(2), 1])
y = x.copy_(a)
return y
@register_test_case(module_factory=lambda: EmptyStridedSizeIntStrideModule())
def EmptyStridedSizeIntStrideModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4)) module.forward(tu.rand(2, 3, 4))