mirror of https://github.com/llvm/torch-mlir
Add More Scalarize Shapes Patterns (#3810)
### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.pull/3811/head
parent
a83e106f92
commit
140cad5659
|
@ -3700,6 +3700,12 @@ OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
|
||||||
|
auto intLhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
|
||||||
|
auto intRhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
|
||||||
|
if (intRhs && intRhs.getValue().getSExtValue() == 0)
|
||||||
|
return getA();
|
||||||
|
if (intLhs && intLhs.getValue().getSExtValue() == 0)
|
||||||
|
return getB();
|
||||||
return atenBinaryIntOperatorFoldHelper(
|
return atenBinaryIntOperatorFoldHelper(
|
||||||
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
|
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
|
||||||
}
|
}
|
||||||
|
@ -3709,6 +3715,9 @@ OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
|
||||||
|
if (getA() == getB())
|
||||||
|
return IntegerAttr::get(
|
||||||
|
IntegerType::get(getContext(), 64, IntegerType::Signless), 0);
|
||||||
return atenBinaryIntOperatorFoldHelper(
|
return atenBinaryIntOperatorFoldHelper(
|
||||||
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
|
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,42 +86,62 @@ LogicalResult getListFromTensor(Value value, SmallVector<OpFoldResult> &vals) {
|
||||||
getAsOpFoldResult(full.getFillValue()));
|
getAsOpFoldResult(full.getFillValue()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
// TODO: Add a case for unsqueeze of a primnumtotensorscalarop?
|
|
||||||
|
if (auto unsqueeze = value.getDefiningOp<Torch::AtenUnsqueezeOp>()) {
|
||||||
|
Value usqSelf = unsqueeze.getSelf();
|
||||||
|
if (auto numToTensor =
|
||||||
|
usqSelf.getDefiningOp<Torch::PrimNumToTensorScalarOp>()) {
|
||||||
|
vals.push_back(getAsOpFoldResult(numToTensor.getA()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A common rank 0 tensor producer
|
||||||
|
if (auto numToTensor =
|
||||||
|
value.getDefiningOp<Torch::PrimNumToTensorScalarOp>()) {
|
||||||
|
vals.push_back(getAsOpFoldResult(numToTensor.getA()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Last supported case: ValueTensorLiteralOp
|
// Last supported case: ValueTensorLiteralOp
|
||||||
auto literalOp = value.getDefiningOp<Torch::ValueTensorLiteralOp>();
|
auto literalOp = value.getDefiningOp<Torch::ValueTensorLiteralOp>();
|
||||||
if (!literalOp)
|
if (!literalOp)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Check the type. We make sure the type is not unsigned here before trying to
|
// Check the type.
|
||||||
// materialize
|
|
||||||
auto ty = cast<ValueTensorType>(literalOp.getType());
|
auto ty = cast<ValueTensorType>(literalOp.getType());
|
||||||
if (!ty.hasSizes() || ty.getSizes().size() > 1)
|
if (!ty.hasSizes() || ty.getSizes().size() > 1)
|
||||||
return failure();
|
return failure();
|
||||||
int64_t listSize = ty.getSizes().size() == 1 ? ty.getSizes().front() : 1;
|
// make sure the type is not unsigned here before trying to materialize
|
||||||
auto intTy = dyn_cast_or_null<IntegerType>(ty.getDtype());
|
auto intTy = dyn_cast_or_null<IntegerType>(ty.getDtype());
|
||||||
if (!intTy || intTy.isUnsigned())
|
if (!intTy || intTy.isUnsigned())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
// if we have a rank 0 literal, we will be adding one element to the list
|
||||||
|
int64_t listSize = ty.getSizes().size() == 1 ? ty.getSizes().front() : 1;
|
||||||
|
|
||||||
|
if (listSize > kMaxFold)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// check for a splat or dense attr
|
||||||
auto splattr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue());
|
auto splattr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue());
|
||||||
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(literalOp.getValue());
|
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(literalOp.getValue());
|
||||||
|
|
||||||
if (!splattr && !denseAttr)
|
if (!splattr && !denseAttr)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
// These are not mutually exclusive, so try splat first.
|
||||||
if (splattr) {
|
if (splattr) {
|
||||||
auto attr = splattr.getSplatValue<Attribute>();
|
auto attr = splattr.getSplatValue<Attribute>();
|
||||||
vals.resize((int64_t)vals.size() + listSize, attr);
|
vals.resize((int64_t)vals.size() + listSize, attr);
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (denseAttr && !splattr) {
|
// remaining case: denseAttr
|
||||||
for (auto e : denseAttr.getValues<Attribute>())
|
if ((int64_t)denseAttr.getValues<Attribute>().size() != listSize)
|
||||||
vals.push_back(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((int64_t)vals.size() != listSize)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
for (auto e : denseAttr.getValues<Attribute>())
|
||||||
|
vals.push_back(e);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,6 +163,45 @@ Value constructAtenTensorOpFromList(ImplicitLocOpBuilder b, mlir::Type resultTy,
|
||||||
// [scalarOpsA -> ListA -> TensorA] -> OpB. Then OpB will be able to
|
// [scalarOpsA -> ListA -> TensorA] -> OpB. Then OpB will be able to
|
||||||
// getListFromTensor(A), and further propagate scalarization.
|
// getListFromTensor(A), and further propagate scalarization.
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class PropagateAtenBroadcastToPattern
|
||||||
|
: public OpRewritePattern<AtenBroadcastToOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenBroadcastToOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenBroadcastToOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
constexpr int64_t kMaxFold = 16;
|
||||||
|
// for tensor<si64>, or tensor<1xsi64>, broadcasted to tensor<nxsi64>, grab
|
||||||
|
// the element and convert to a full op.
|
||||||
|
auto ty = cast<ValueTensorType>(op.getType());
|
||||||
|
if (!ty.areAllSizesKnown() || ty.getSizes().size() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (ty.getSizes()[0] > kMaxFold)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> fillFold;
|
||||||
|
if (failed(getListFromTensor(op.getSelf(), fillFold)) ||
|
||||||
|
fillFold.size() != 1)
|
||||||
|
return failure();
|
||||||
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||||
|
SmallVector<Value, 1> fillVals;
|
||||||
|
if (failed(materializeFolds(b, fillFold, fillVals)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value size = b.create<Torch::ConstantIntOp>(ty.getSizes().front());
|
||||||
|
Value sizeList = b.create<Torch::PrimListConstructOp>(
|
||||||
|
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||||
|
size);
|
||||||
|
Value none = b.create<Torch::ConstantNoneOp>();
|
||||||
|
Value cstFalse = b.create<Torch::ConstantBoolOp>(false);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenFullOp>(op, ty, sizeList, fillVals.front(),
|
||||||
|
none, none, none, cstFalse);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class PropagateAtenShapeToTensorPattern
|
class PropagateAtenShapeToTensorPattern
|
||||||
: public OpRewritePattern<Aten_ShapeAsTensorOp> {
|
: public OpRewritePattern<Aten_ShapeAsTensorOp> {
|
||||||
|
@ -541,9 +600,128 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename OpTy> struct ArithmeticHelper {
|
||||||
|
static LogicalResult getAlphaAndVerify(OpTy &op, int64_t &alpha) {
|
||||||
|
alpha = 1;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct ArithmeticHelper<AtenAddTensorOp> {
|
||||||
|
static LogicalResult getAlphaAndVerify(AtenAddTensorOp &op, int64_t &alpha) {
|
||||||
|
if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1)
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct ArithmeticHelper<AtenSubTensorOp> {
|
||||||
|
static LogicalResult getAlphaAndVerify(AtenSubTensorOp &op, int64_t &alpha) {
|
||||||
|
if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1)
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpTy, typename ScalarOpTy>
|
||||||
|
class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Check type
|
||||||
|
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||||
|
if (resultTy.getSizes().size() > 1)
|
||||||
|
return rewriter.notifyMatchFailure(op, "unsupported: rank > 1");
|
||||||
|
if (!resultTy.hasDtype() || !isa<mlir::IntegerType>(resultTy.getDtype()))
|
||||||
|
return rewriter.notifyMatchFailure(op, "not an int type");
|
||||||
|
|
||||||
|
int64_t alpha;
|
||||||
|
if (failed(ArithmeticHelper<OpTy>::getAlphaAndVerify(op, alpha)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "alpha must be 1");
|
||||||
|
|
||||||
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||||
|
SmallVector<OpFoldResult> selfFold, otherFold;
|
||||||
|
if (failed(getListFromTensor(op.getSelf(), selfFold)) ||
|
||||||
|
failed(getListFromTensor(op.getOther(), otherFold)) ||
|
||||||
|
selfFold.size() != otherFold.size())
|
||||||
|
return failure();
|
||||||
|
SmallVector<Value> selfVals, otherVals;
|
||||||
|
if (failed(materializeFolds(b, selfFold, selfVals)) ||
|
||||||
|
failed(materializeFolds(b, otherFold, otherVals)))
|
||||||
|
return failure();
|
||||||
|
SmallVector<OpFoldResult> resultFolds;
|
||||||
|
for (uint64_t i = 0; i < selfVals.size(); i++) {
|
||||||
|
resultFolds.push_back(b.createOrFold<ScalarOpTy>(
|
||||||
|
selfVals[i].getType(), selfVals[i], otherVals[i]));
|
||||||
|
}
|
||||||
|
SmallVector<Value> resultVals;
|
||||||
|
if (failed(materializeFolds(b, resultFolds, resultVals)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (resultTy.getSizes().size() == 0) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
|
||||||
|
op, resultTy, resultVals.front());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value result = constructAtenTensorOpFromList(b, resultTy, resultVals);
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
/// ------ Fold Patterns ------ ///
|
/// ------ Fold Patterns ------ ///
|
||||||
// These are shape-specific folding patterns
|
// These are shape-specific folding patterns
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class FoldAtenEqIntPattern : public OpRewritePattern<AtenEqIntOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenEqIntOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenEqIntOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// replaces (size.int == 0) with false and adds an assert
|
||||||
|
// these comparisons are getting generated because onnx.Reshape considers 0
|
||||||
|
// to mean "don't change this dim". However, if the size we are passing to
|
||||||
|
// onnx.Reshape is a tensor dim, this is definitely never supposed to be
|
||||||
|
// interpreted as "don't change this dim".
|
||||||
|
int64_t otherInt;
|
||||||
|
if (!matchPattern(op.getB(), m_TorchConstantInt(&otherInt)) ||
|
||||||
|
otherInt != 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// in case the shape is a product of two ints, check each
|
||||||
|
if (auto mulOp = op.getA().getDefiningOp<AtenMulIntOp>()) {
|
||||||
|
Value self = mulOp.getA();
|
||||||
|
Value other = mulOp.getB();
|
||||||
|
Value selfEq = rewriter.create<AtenEqIntOp>(op.getLoc(), self, op.getB());
|
||||||
|
Value otherEq =
|
||||||
|
rewriter.create<AtenEqIntOp>(op.getLoc(), other, op.getB());
|
||||||
|
rewriter.replaceOpWithNewOp<Aten__Or__BoolOp>(op, selfEq, otherEq);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// if lhs is size.int op, assert size > 0 and replace with false.
|
||||||
|
if (auto sizeOp = op.getA().getDefiningOp<AtenSizeIntOp>()) {
|
||||||
|
Value selfGtOther = rewriter.create<AtenGtIntOp>(
|
||||||
|
op.getLoc(), op.getType(), op.getA(), op.getB());
|
||||||
|
rewriter.create<Torch::RuntimeAssertOp>(
|
||||||
|
op.getLoc(), selfGtOther,
|
||||||
|
rewriter.getStringAttr("Expected dim size > 0."));
|
||||||
|
Value cstFalse =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||||
|
rewriter.replaceOp(op, cstFalse);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
|
class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -594,16 +772,24 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class FoldAtenSqueezePattern : public OpRewritePattern<AtenSqueezeOp> {
|
template <typename SqueezeOp>
|
||||||
|
class FoldAtenSqueezePattern : public OpRewritePattern<SqueezeOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<AtenSqueezeOp>::OpRewritePattern;
|
using OpRewritePattern<SqueezeOp>::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenSqueezeOp op,
|
LogicalResult matchAndRewrite(SqueezeOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto resultTy = cast<ValueTensorType>(op.getType());
|
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||||
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
|
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
|
||||||
return rewriter.notifyMatchFailure(op, "Unknown result shape");
|
return rewriter.notifyMatchFailure(op, "Unknown result shape");
|
||||||
|
|
||||||
if (auto atenFull = op.getSelf().getDefiningOp<AtenFullOp>()) {
|
Value self = op.getSelf();
|
||||||
|
if (auto atenFull = self.getDefiningOp<AtenFullOp>()) {
|
||||||
|
// in the rank 0 case, just return the rank 0 scalar
|
||||||
|
if (resultTy.getSizes().size() == 0) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
|
||||||
|
op, resultTy, atenFull.getFillValue());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
SmallVector<Value> sizes;
|
SmallVector<Value> sizes;
|
||||||
for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i)
|
for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i)
|
||||||
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
|
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
@ -874,9 +1060,16 @@ bool isPrimListOfInts(Operation *op) {
|
||||||
return llvm::isa<Torch::IntType>(listType.getContainedType());
|
return llvm::isa<Torch::IntType>(listType.getContainedType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isAnchorOp(Operation *op) {
|
||||||
|
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
|
||||||
|
isPrimListOfInts(op);
|
||||||
|
}
|
||||||
|
|
||||||
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
|
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
|
||||||
patterns.insert<FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
|
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
|
||||||
FoldAtenWhereSelf, FoldAtenTensorSplatPattern>(
|
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
|
||||||
|
FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
|
||||||
|
FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
|
||||||
patterns.getContext());
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -885,10 +1078,21 @@ void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
|
void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
|
||||||
patterns.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
|
// A note on division: onnx.Div from int, int -> int types rounds towards
|
||||||
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
|
// zero. The torch DivTensorOp actually doesn't allow returning an int dtype,
|
||||||
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
|
// but this was artificially plummbed through. Unfortunately, there is no
|
||||||
PropagateAtenWhereSelfPattern>(patterns.getContext());
|
// scalar trunc div op in torch; however, we can safely assume all operands
|
||||||
|
// are positive so floor divide should be a sufficient scalar replacement.
|
||||||
|
patterns.insert<
|
||||||
|
PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
|
||||||
|
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
|
||||||
|
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
|
||||||
|
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
|
||||||
|
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
|
||||||
|
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
|
||||||
|
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
|
||||||
|
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
|
||||||
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
|
void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
|
||||||
|
@ -940,7 +1144,7 @@ public:
|
||||||
[&](Operation *op) {
|
[&](Operation *op) {
|
||||||
// Walking bottom-up, start adding ops when we reach an anchor point
|
// Walking bottom-up, start adding ops when we reach an anchor point
|
||||||
// (a prim list of ints)
|
// (a prim list of ints)
|
||||||
if (isPrimListOfInts(op)) {
|
if (isAnchorOp(op)) {
|
||||||
shapeCalculationOps.insert(op);
|
shapeCalculationOps.insert(op);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,6 +75,99 @@ func.func @literal_item() -> !torch.int {
|
||||||
return %out : !torch.int
|
return %out : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @arith_prop
|
||||||
|
func.func @arith_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[int0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[int1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[int1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[int12:.*]] = torch.constant.int 12
|
||||||
|
// CHECK: %[[int1_0:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[x2:.*]] = torch.aten.floordiv.int %[[x0]], %[[int12]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[x3:.*]] = torch.aten.floordiv.int %[[x1]], %[[int1_0]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[int12_1:.*]] = torch.constant.int 12
|
||||||
|
// CHECK: %[[int1_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[x4:.*]] = torch.aten.mul.int %[[x2]], %[[int12_1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[x5:.*]] = torch.aten.mul.int %[[x3]], %[[int1_2]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[x7:.*]] = torch.aten.sub.int %[[x1]], %[[x5]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[x8:.*]] = torch.prim.ListConstruct %[[x7]], %[[x6]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[x9:.*]] = torch.aten.constant_pad_nd %arg0, %[[x8]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[x9]] : !torch.vtensor<[?,?],f32>
|
||||||
|
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
%1 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
%float0.000000e00 = torch.constant.float 0.000000e+00
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%2 = torch.vtensor.literal(dense<[12, 1]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[2],si64>
|
||||||
|
%4 = torch.aten.div.Tensor %3, %2 : !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64> -> !torch.vtensor<[2],si64>
|
||||||
|
%5 = torch.aten.mul.Tensor %4, %2 : !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64> -> !torch.vtensor<[2],si64>
|
||||||
|
%6 = torch.aten.sub.Tensor %3, %5, %int1 : !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>, !torch.int -> !torch.vtensor<[2],si64>
|
||||||
|
%7 = torch.aten.index_select %6, %int0, %1 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
|
||||||
|
%8 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
|
||||||
|
%9 = torch.aten.item %7 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
%10 = torch.aten.item %8 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
%11 = torch.prim.ListConstruct %10, %9 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%12 = torch.aten.constant_pad_nd %arg0, %11, %float0.000000e00 : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %12 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @broadcast_prop
|
||||||
|
func.func @broadcast_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.int {
|
||||||
|
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: return %[[SZE]] : !torch.int
|
||||||
|
%dim = torch.constant.int 0
|
||||||
|
%size = torch.aten.size.int %arg0, %dim : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
%shape = torch.prim.NumToTensor.Scalar %size : !torch.int -> !torch.vtensor<[],si32>
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%idx = torch.vtensor.literal(dense<-1> : tensor<si32>) : !torch.vtensor<[],si32>
|
||||||
|
%bcastlist = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
|
||||||
|
%bcast = torch.aten.broadcast_to %shape, %bcastlist : !torch.vtensor<[],si32>, !torch.list<int> -> !torch.vtensor<[3],si32>
|
||||||
|
%select = torch.aten.index_select %bcast, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32>
|
||||||
|
%out = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
%list = torch.prim.ListConstruct %out : (!torch.int) -> !torch.list<int>
|
||||||
|
return %out : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @eq_int_fold
|
||||||
|
func.func @eq_int_fold(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],f32> {
|
||||||
|
// CHECK: %[[int1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[sze0:.*]] = torch.aten.size.int %arg0, %[[int0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[sze1:.*]] = torch.aten.size.int %arg0, %[[int1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[mul:.*]] = torch.aten.mul.int %[[sze0]], %[[sze1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[gt0:.*]] = torch.aten.gt.int %[[sze0]], %[[int0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.runtime.assert %[[gt0]], "Expected dim size > 0."
|
||||||
|
// CHECK: %[[gt1:.*]] = torch.aten.gt.int %[[sze1]], %[[int0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.runtime.assert %[[gt1]], "Expected dim size > 0."
|
||||||
|
// CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[mul]], %[[int1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[view:.*]] = torch.aten.view %arg0, %[[list]] : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,1],f32>
|
||||||
|
// CHECK: return %[[view:.*]] : !torch.vtensor<[?,1],f32>
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
%2 = torch.aten.mul.int %0, %1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%3 = torch.aten.eq.int %2, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
%4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int
|
||||||
|
%5 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],i1>
|
||||||
|
%6 = torch.prim.NumToTensor.Scalar %0 : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%7 = torch.prim.NumToTensor.Scalar %2 : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%8 = torch.aten.where.self %5, %6, %7 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
|
||||||
|
%9 = torch.aten.item %8 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
%10 = torch.prim.ListConstruct %9, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%11 = torch.aten.view %arg0, %10 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,1],f32>
|
||||||
|
return %11 : !torch.vtensor<[?,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
|
@ -36,8 +36,8 @@ func.func @test_triu_decompose(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vten
|
||||||
module {
|
module {
|
||||||
// CHECK-LABEL: func.func @test_scalarize
|
// CHECK-LABEL: func.func @test_scalarize
|
||||||
func.func @test_scalarize(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} {
|
func.func @test_scalarize(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} {
|
||||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
|
||||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
// CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3
|
||||||
// CHECK: %[[ADD:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT2]], %[[INT3]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
|
// CHECK: %[[ADD:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT2]], %[[INT3]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
|
||||||
%0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64>
|
%0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64>
|
||||||
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__21> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__21> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||||
|
|
Loading…
Reference in New Issue