mirror of https://github.com/llvm/torch-mlir
Rework AtenEmptyStridedOp checks (#2537)
Now using Value instead of Ints. Trades compile failure for a runtime assertpull/2538/head snapshot-20231101.1009
parent
4199feffed
commit
1d41f7b6fe
|
@ -764,6 +764,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"NewEmptyModuleNonDefaultIntDtype_basic",
|
"NewEmptyModuleNonDefaultIntDtype_basic",
|
||||||
"NewEmptyStridedModuleDefaultDtype_basic",
|
"NewEmptyStridedModuleDefaultDtype_basic",
|
||||||
"EmptyStridedModule_basic",
|
"EmptyStridedModule_basic",
|
||||||
|
"EmptyStridedSizeIntStrideModule_basic",
|
||||||
"PermuteModule_basic",
|
"PermuteModule_basic",
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||||
|
@ -1440,6 +1441,7 @@ LTC_XFAIL_SET = {
|
||||||
"UniformStaticShapeModule_basic",
|
"UniformStaticShapeModule_basic",
|
||||||
"AtenEmbeddingBagStaticModule_basic",
|
"AtenEmbeddingBagStaticModule_basic",
|
||||||
"EmptyStridedModule_basic",
|
"EmptyStridedModule_basic",
|
||||||
|
"EmptyStridedSizeIntStrideModule_basic",
|
||||||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||||
|
|
|
@ -104,6 +104,12 @@ inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) {
|
||||||
return isAssumingStrictSymbolicShapes(builder.getBlock());
|
return isAssumingStrictSymbolicShapes(builder.getBlock());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function for AtenEmptyStrided and friends that checks if the stride
|
||||||
|
// values are default or not. Throws a runtime assert if not.
|
||||||
|
LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter,
|
||||||
|
Value opSize, Value opStride,
|
||||||
|
Location loc);
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -332,7 +332,8 @@ public:
|
||||||
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, length);
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, length);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(
|
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(
|
||||||
op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, /*start=*/start,
|
op, op.getResult().getType(), op.getSelf(), /*dim=*/dim,
|
||||||
|
/*start=*/start,
|
||||||
/*end=*/startPlusLength, /*step=*/one);
|
/*end=*/startPlusLength, /*step=*/one);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -404,16 +405,15 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenZeroOp
|
class DecomposeAtenZeroOp : public OpRewritePattern<AtenZeroOp> {
|
||||||
: public OpRewritePattern<AtenZeroOp> {
|
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenZeroOp op,
|
LogicalResult matchAndRewrite(AtenZeroOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
|
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
|
||||||
rewriter.getI64IntegerAttr(0));
|
rewriter.getI64IntegerAttr(0));
|
||||||
rewriter.replaceOpWithNewOp<AtenFillScalarOp>(op, op.getType(), op.getSelf(),
|
rewriter.replaceOpWithNewOp<AtenFillScalarOp>(op, op.getType(),
|
||||||
zero);
|
op.getSelf(), zero);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1139,14 +1139,21 @@ public:
|
||||||
Value constantOne =
|
Value constantOne =
|
||||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||||
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
||||||
Value maxZeroX = rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
Value maxZeroX =
|
||||||
Value positiveOutput = rewriter.create<AtenMulScalarOp>(loc, resType, maxZeroX, scale);
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
||||||
Value minZeroX = rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
Value positiveOutput =
|
||||||
Value scaledMinZeroX = rewriter.create<AtenMulScalarOp>(loc, resType, minZeroX, inputScale);
|
rewriter.create<AtenMulScalarOp>(loc, resType, maxZeroX, scale);
|
||||||
|
Value minZeroX =
|
||||||
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
||||||
|
Value scaledMinZeroX =
|
||||||
|
rewriter.create<AtenMulScalarOp>(loc, resType, minZeroX, inputScale);
|
||||||
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledMinZeroX);
|
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledMinZeroX);
|
||||||
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX, constantOne, constantOne);
|
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX,
|
||||||
Value scaledExpXM1 = rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, scale);
|
constantOne, constantOne);
|
||||||
Value negativeOutput = rewriter.create<AtenMulScalarOp>(loc, resType, scaledExpXM1, alpha);
|
Value scaledExpXM1 =
|
||||||
|
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, scale);
|
||||||
|
Value negativeOutput =
|
||||||
|
rewriter.create<AtenMulScalarOp>(loc, resType, scaledExpXM1, alpha);
|
||||||
|
|
||||||
Value eluOutput = rewriter.create<AtenAddTensorOp>(
|
Value eluOutput = rewriter.create<AtenAddTensorOp>(
|
||||||
loc, resType, positiveOutput, negativeOutput, constantOne);
|
loc, resType, positiveOutput, negativeOutput, constantOne);
|
||||||
|
@ -1419,8 +1426,8 @@ public:
|
||||||
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
|
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
|
||||||
Value reshapedDims =
|
Value reshapedDims =
|
||||||
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
|
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
|
||||||
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType, op.getSelf(),
|
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType,
|
||||||
unsqueezedDims);
|
op.getSelf(), unsqueezedDims);
|
||||||
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
|
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
|
||||||
reshaped, expandedDims);
|
reshaped, expandedDims);
|
||||||
|
|
||||||
|
@ -1601,8 +1608,8 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: requires implicit to be false");
|
op, "unimplemented: requires implicit to be false");
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.getSelf(),
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(),
|
||||||
op.getSize());
|
op.getSelf(), op.getSize());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1621,7 +1628,8 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
||||||
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
|
Value otherTensor =
|
||||||
|
createRank0Tensor(rewriter, loc, resType, op.getOther());
|
||||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
||||||
selfTensor, otherTensor);
|
selfTensor, otherTensor);
|
||||||
return success();
|
return success();
|
||||||
|
@ -1642,7 +1650,8 @@ public:
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
|
Value otherTensor =
|
||||||
|
createRank0Tensor(rewriter, loc, resType, op.getOther());
|
||||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
||||||
op.getSelf(), otherTensor);
|
op.getSelf(), otherTensor);
|
||||||
return success();
|
return success();
|
||||||
|
@ -1686,8 +1695,8 @@ public:
|
||||||
}
|
}
|
||||||
Value mask = op.getMask();
|
Value mask = op.getMask();
|
||||||
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
|
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
|
||||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask, value,
|
||||||
value, op.getSelf());
|
op.getSelf());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1747,8 +1756,8 @@ public:
|
||||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
|
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
|
||||||
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
||||||
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
||||||
op.getStride(), op.getPadding(), op.getDilation(), /*transposed=*/cstTrue,
|
op.getStride(), op.getPadding(), op.getDilation(),
|
||||||
op.getOutputPadding(), op.getGroups());
|
/*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -2500,9 +2509,9 @@ public:
|
||||||
op, "aten.std.dim expects input tensor of floating-point type");
|
op, "aten.std.dim expects input tensor of floating-point type");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value varDim =
|
Value varDim = rewriter.create<AtenVarDimOp>(
|
||||||
rewriter.create<AtenVarDimOp>(op->getLoc(), op.getType(), self,
|
op->getLoc(), op.getType(), self, op.getDim(), op.getUnbiased(),
|
||||||
op.getDim(), op.getUnbiased(), op.getKeepdim());
|
op.getKeepdim());
|
||||||
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varDim);
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varDim);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2626,8 +2635,8 @@ public:
|
||||||
Value one =
|
Value one =
|
||||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||||
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
||||||
loc, resultType, input, zero, op.getDtype(), op.getLayout(), op.getDevice(),
|
loc, resultType, input, zero, op.getDtype(), op.getLayout(),
|
||||||
op.getPinMemory(), op.getMemoryFormat());
|
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
|
||||||
rewriter.replaceOpWithNewOp<AtenUniformOp>(op, resultType, emptyTensor,
|
rewriter.replaceOpWithNewOp<AtenUniformOp>(op, resultType, emptyTensor,
|
||||||
/*from=*/zero, /*to=*/one,
|
/*from=*/zero, /*to=*/one,
|
||||||
/*generator=*/none);
|
/*generator=*/none);
|
||||||
|
@ -2829,7 +2838,8 @@ class DecomposeAtenNativeLayerNormOp
|
||||||
SmallVector<Value> normalizedShapeSizesTorchInt;
|
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||||||
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
||||||
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
||||||
auto reduceDimInts = llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank));
|
auto reduceDimInts =
|
||||||
|
llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank));
|
||||||
auto reducedTy = op.getResult(1).getType();
|
auto reducedTy = op.getResult(1).getType();
|
||||||
auto sizeListType = ListType::get(IntType::get(context));
|
auto sizeListType = ListType::get(IntType::get(context));
|
||||||
|
|
||||||
|
@ -2905,8 +2915,8 @@ public:
|
||||||
Value sizeList =
|
Value sizeList =
|
||||||
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
|
||||||
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
||||||
op, op.getType(), sizeList, op.getDtype(), op.getLayout(), op.getDevice(),
|
op, op.getType(), sizeList, op.getDtype(), op.getLayout(),
|
||||||
op.getPinMemory(), op.getMemoryFormat());
|
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -2927,8 +2937,8 @@ class DecomposeAtenArangeOp : public OpRewritePattern<AtenArangeOp> {
|
||||||
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
||||||
rewriter.getI64IntegerAttr(1));
|
rewriter.getI64IntegerAttr(1));
|
||||||
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
||||||
op, op.getType(), start, op.getEnd(), step, op.getDtype(), op.getLayout(),
|
op, op.getType(), start, op.getEnd(), step, op.getDtype(),
|
||||||
op.getDevice(), op.getPinMemory());
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -2947,8 +2957,8 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
|
||||||
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
||||||
rewriter.getI64IntegerAttr(1));
|
rewriter.getI64IntegerAttr(1));
|
||||||
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
||||||
op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), op.getLayout(),
|
op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(),
|
||||||
op.getDevice(), op.getPinMemory());
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -3035,7 +3045,8 @@ class DecomposeAtenNativeBatchNormOp
|
||||||
loc, ListType::get(IntType::get(context)), runningStatsShape);
|
loc, ListType::get(IntType::get(context)), runningStatsShape);
|
||||||
|
|
||||||
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
||||||
runningStatsShapeInt[1] = runningMean.getType().cast<BaseTensorType>().getSizes()[0];
|
runningStatsShapeInt[1] =
|
||||||
|
runningMean.getType().cast<BaseTensorType>().getSizes()[0];
|
||||||
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||||
Type reshapeType = ValueTensorType::get(
|
Type reshapeType = ValueTensorType::get(
|
||||||
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
|
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
|
||||||
|
@ -3320,11 +3331,10 @@ public:
|
||||||
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<AtenFullOp>(
|
rewriter.replaceOpWithNewOp<AtenFullOp>(
|
||||||
op, op.getType(), op.getSize(), op.getFillValue(), dtype, op.getLayout(), op.getDevice(),
|
op, op.getType(), op.getSize(), op.getFillValue(), dtype,
|
||||||
op.getPinMemory());
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -3338,7 +3348,8 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||||
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
||||||
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(),
|
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
|
||||||
|
op.getAccumulate(),
|
||||||
/*unsafe=*/cstFalse);
|
/*unsafe=*/cstFalse);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -3355,8 +3366,8 @@ class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
|
||||||
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
||||||
Value sizeList =
|
Value sizeList =
|
||||||
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getOther());
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getOther());
|
||||||
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.getSelf(),
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(),
|
||||||
sizeList);
|
op.getSelf(), sizeList);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -3378,8 +3389,9 @@ public:
|
||||||
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
|
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
|
||||||
resultDtype);
|
resultDtype);
|
||||||
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
||||||
op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), op.getLayout(),
|
op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(),
|
||||||
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
|
op.getLayout(), op.getDevice(), op.getPinMemory(),
|
||||||
|
op.getMemoryFormat());
|
||||||
rewriter.replaceOpWithNewOp<AtenCopyOp>(op, op.getType(), emptyTensor,
|
rewriter.replaceOpWithNewOp<AtenCopyOp>(op, op.getType(), emptyTensor,
|
||||||
op.getSelf(), op.getNonBlocking());
|
op.getSelf(), op.getNonBlocking());
|
||||||
return success();
|
return success();
|
||||||
|
@ -3450,7 +3462,8 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||||
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
||||||
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(),
|
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
|
||||||
|
op.getAccumulate(),
|
||||||
/*unsafe=*/cstFalse);
|
/*unsafe=*/cstFalse);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -3539,8 +3552,8 @@ public:
|
||||||
op, "unimplemented: layout is expected to be strided");
|
op, "unimplemented: layout is expected to be strided");
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
|
||||||
op.getDtype(), op.getNonBlocking(),
|
op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(),
|
||||||
op.getCopy(), op.getMemoryFormat());
|
op.getCopy(), op.getMemoryFormat());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -3557,8 +3570,8 @@ public:
|
||||||
|
|
||||||
// Device information isn't relevant to torch-mlir, so we can drop that info
|
// Device information isn't relevant to torch-mlir, so we can drop that info
|
||||||
// here.
|
// here.
|
||||||
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
|
||||||
op.getDtype(), op.getNonBlocking(),
|
op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(),
|
||||||
op.getCopy(), op.getMemoryFormat());
|
op.getCopy(), op.getMemoryFormat());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -3798,8 +3811,8 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern<AtenBaddbmmOp> {
|
||||||
LogicalResult matchAndRewrite(AtenBaddbmmOp op,
|
LogicalResult matchAndRewrite(AtenBaddbmmOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value bmm =
|
Value bmm = rewriter.create<AtenBmmOp>(loc, op.getType(), op.getBatch1(),
|
||||||
rewriter.create<AtenBmmOp>(loc, op.getType(), op.getBatch1(), op.getBatch2());
|
op.getBatch2());
|
||||||
Value alphaTimesBmm =
|
Value alphaTimesBmm =
|
||||||
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha());
|
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha());
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
|
@ -4160,7 +4173,8 @@ public:
|
||||||
resultType.getOptionalDtype())
|
resultType.getOptionalDtype())
|
||||||
.cast<BaseTensorType>();
|
.cast<BaseTensorType>();
|
||||||
|
|
||||||
Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
|
Value sub =
|
||||||
|
createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
|
||||||
Value result = rewriter.create<AtenSquareOp>(loc, subType, sub);
|
Value result = rewriter.create<AtenSquareOp>(loc, subType, sub);
|
||||||
if (reductionType == torch_upstream::Reduction::None) {
|
if (reductionType == torch_upstream::Reduction::None) {
|
||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOp(op, result);
|
||||||
|
@ -4242,7 +4256,8 @@ public:
|
||||||
rewriter.getF32Type())
|
rewriter.getF32Type())
|
||||||
.cast<BaseTensorType>();
|
.cast<BaseTensorType>();
|
||||||
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||||
loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
|
loc, floatResultType, op.getSize(), /*dtype=*/none,
|
||||||
|
/*layout=*/op.getLayout(),
|
||||||
/*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(),
|
/*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(),
|
||||||
/*memoryFormat=*/none);
|
/*memoryFormat=*/none);
|
||||||
|
|
||||||
|
@ -4274,8 +4289,8 @@ public:
|
||||||
loc, rewriter.getI64IntegerAttr(0));
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenRandintLowOp>(
|
rewriter.replaceOpWithNewOp<AtenRandintLowOp>(
|
||||||
op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), op.getLayout(),
|
op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(),
|
||||||
op.getDevice(), op.getPinMemory());
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -4294,10 +4309,11 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||||
Value var = rewriter.create<AtenVarCorrectionOp>(
|
Value var = rewriter.create<AtenVarCorrectionOp>(
|
||||||
loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), op.getKeepdim());
|
loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(),
|
||||||
Value mean =
|
op.getKeepdim());
|
||||||
rewriter.create<AtenMeanDimOp>(loc, op.getType(0), op.getSelf(), op.getDim(),
|
Value mean = rewriter.create<AtenMeanDimOp>(
|
||||||
op.getKeepdim(), /*dtype=*/noneVal);
|
loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(),
|
||||||
|
/*dtype=*/noneVal);
|
||||||
rewriter.replaceOp(op, {var, mean});
|
rewriter.replaceOp(op, {var, mean});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -4394,11 +4410,13 @@ public:
|
||||||
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
|
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
|
||||||
/*memory_format=*/none);
|
/*memory_format=*/none);
|
||||||
|
|
||||||
Value uOne = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorA,
|
Value uOne =
|
||||||
|
rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorA,
|
||||||
/*from=*/low,
|
/*from=*/low,
|
||||||
/*to=*/high,
|
/*to=*/high,
|
||||||
/*generator=*/op.getGenerator());
|
/*generator=*/op.getGenerator());
|
||||||
Value uTwo = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorB,
|
Value uTwo =
|
||||||
|
rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorB,
|
||||||
/*from=*/low,
|
/*from=*/low,
|
||||||
/*to=*/high,
|
/*to=*/high,
|
||||||
/*generator=*/op.getGenerator());
|
/*generator=*/op.getGenerator());
|
||||||
|
@ -4526,37 +4544,18 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op,
|
LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
SmallVector<int64_t> sizeListInts, strideListInts;
|
Location loc = op.getLoc();
|
||||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
|
Value opSize = op.getSize();
|
||||||
return rewriter.notifyMatchFailure(
|
Value opStride = op.getStride();
|
||||||
op, "all size list elements must be constant ints");
|
|
||||||
if (!matchPattern(op.getStride(),
|
|
||||||
m_TorchListOfConstantInts(strideListInts)))
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "all stride list elements must be constant ints");
|
|
||||||
|
|
||||||
// We only support the cases with default stride values.
|
if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc)))
|
||||||
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
|
|
||||||
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
|
|
||||||
// stride[2] == 1.
|
|
||||||
bool isDefaultStride = true;
|
|
||||||
for (unsigned i = 0; i < strideListInts.size(); i++) {
|
|
||||||
int64_t defaultStride = 1;
|
|
||||||
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
|
|
||||||
defaultStride *= sizeListInts[j];
|
|
||||||
if (defaultStride != strideListInts[i]) {
|
|
||||||
isDefaultStride = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isDefaultStride)
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only default strides supported for new_empty_strided op");
|
op, "Unable to determine if stride is default");
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenNewEmptyOp>(
|
rewriter.replaceOpWithNewOp<AtenNewEmptyOp>(
|
||||||
op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(),
|
op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(),
|
||||||
op.getLayout(), op.getDevice(), op.getPinMemory());
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -4569,42 +4568,20 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
|
LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
SmallVector<int64_t> sizeListInts, strideListInts;
|
Location loc = op.getLoc();
|
||||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
|
Value opSize = op.getSize();
|
||||||
return rewriter.notifyMatchFailure(
|
Value opStride = op.getStride();
|
||||||
op, "all size list elements must be constant ints");
|
|
||||||
if (!matchPattern(op.getStride(),
|
|
||||||
m_TorchListOfConstantInts(strideListInts)))
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "all stride list elements must be constant ints");
|
|
||||||
|
|
||||||
// We only support the cases with default stride values.
|
if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc)))
|
||||||
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
|
|
||||||
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
|
|
||||||
// stride[2] == 1.
|
|
||||||
bool isDefaultStride = true;
|
|
||||||
for (unsigned i = 0; i < strideListInts.size(); i++) {
|
|
||||||
int64_t defaultStride = 1;
|
|
||||||
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
|
|
||||||
defaultStride *= sizeListInts[j];
|
|
||||||
if (defaultStride != strideListInts[i]) {
|
|
||||||
isDefaultStride = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!isDefaultStride)
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only default strides supported for new_empty_strided op");
|
op, "Unable to determine if stride is default");
|
||||||
|
|
||||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
||||||
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(),
|
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(),
|
||||||
op.getPinMemory(), /*memoryFormat=*/noneVal);
|
op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -4993,8 +4970,8 @@ public:
|
||||||
auto selectGreater =
|
auto selectGreater =
|
||||||
rewriter.create<AtenWhereScalarOp>(loc, outType, greater, one, zero);
|
rewriter.create<AtenWhereScalarOp>(loc, outType, greater, one, zero);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(op, outType, greaterEqual,
|
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
|
||||||
selectGreater, minusOne);
|
op, outType, greaterEqual, selectGreater, minusOne);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -5407,7 +5384,8 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(
|
||||||
|
patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
||||||
|
|
|
@ -333,3 +333,61 @@ bool Torch::isAssumingStrictSymbolicShapes(Block *block) {
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult Torch::checkDefaultStrideHelper(Operation *op,
|
||||||
|
PatternRewriter &rewriter,
|
||||||
|
Value opSize, Value opStride,
|
||||||
|
Location loc) {
|
||||||
|
|
||||||
|
SmallVector<int64_t> sizeListInts, strideListInts;
|
||||||
|
if (matchPattern(opSize, m_TorchListOfConstantInts(sizeListInts)) &&
|
||||||
|
matchPattern(opStride, m_TorchListOfConstantInts(strideListInts))) {
|
||||||
|
|
||||||
|
// We only support the cases with default stride values.
|
||||||
|
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
|
||||||
|
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
|
||||||
|
// stride[2] == 1.
|
||||||
|
bool isDefaultStride = true;
|
||||||
|
for (unsigned i = 0; i < strideListInts.size(); i++) {
|
||||||
|
int64_t defaultStride = 1;
|
||||||
|
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
|
||||||
|
defaultStride *= sizeListInts[j];
|
||||||
|
if (defaultStride != strideListInts[i]) {
|
||||||
|
isDefaultStride = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!isDefaultStride)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only default strides supported for empty_strided op");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
|
||||||
|
} else {
|
||||||
|
SmallVector<Value> sizeListValues;
|
||||||
|
if (!getListConstructElements(opSize, sizeListValues))
|
||||||
|
return rewriter.notifyMatchFailure(op, "couldn't get size list values");
|
||||||
|
SmallVector<Value> strideListValues;
|
||||||
|
if (!getListConstructElements(opStride, strideListValues))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"couldn't get stride list values.");
|
||||||
|
SmallVector<Value> boolVector;
|
||||||
|
for (unsigned i = 0; i < strideListValues.size(); i++) {
|
||||||
|
Value defaultStride = rewriter.createOrFold<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
for (unsigned j = i + 1; j < sizeListValues.size(); j++) {
|
||||||
|
defaultStride = rewriter.createOrFold<Torch::AtenMulIntOp>(
|
||||||
|
loc, defaultStride, sizeListValues[j]);
|
||||||
|
}
|
||||||
|
boolVector.push_back(rewriter.createOrFold<Torch::AtenEqIntOp>(
|
||||||
|
loc, defaultStride, strideListValues[i]));
|
||||||
|
}
|
||||||
|
Value allBoolOpList = rewriter.createOrFold<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(rewriter.getType<Torch::BoolType>()),
|
||||||
|
boolVector);
|
||||||
|
Value cmp = rewriter.createOrFold<Torch::AtenAllBoolOp>(loc, allBoolOpList);
|
||||||
|
rewriter.createOrFold<Torch::RuntimeAssertOp>(
|
||||||
|
loc, cmp, "not all strides are default");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1629,7 +1629,6 @@ class NewEmptyStridedModuleDefaultDtype(torch.nn.Module):
|
||||||
def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils):
|
def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3, 4))
|
module.forward(tu.rand(2, 3, 4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@ -1651,4 +1650,27 @@ class EmptyStridedModule(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: EmptyStridedModule())
|
@register_test_case(module_factory=lambda: EmptyStridedModule())
|
||||||
def EmptyStridedModule_basic(module, tu: TestUtils):
|
def EmptyStridedModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyStridedSizeIntStrideModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 3, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
x = torch.ops.aten.empty_strided(a.size(), stride=[12, a.size(2), 1])
|
||||||
|
y = x.copy_(a)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmptyStridedSizeIntStrideModule())
|
||||||
|
def EmptyStridedSizeIntStrideModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3, 4))
|
module.forward(tu.rand(2, 3, 4))
|
||||||
|
|
Loading…
Reference in New Issue