mirror of https://github.com/llvm/torch-mlir
Add more patterns to scalarize-shapes pass (#3781)
-Adds patterns for propagating shapes through AtenWhereSelf and AtenEqTensor -Adds fold pattern for a rank0 squeezeDim of a full op -Adds support for getting a list from a splat ValueTensorLiteralOp for materializing scalar comparisons in where.self and eq.tensor With a bit of hammering, these changes should unblock several IREE inference failures.memory_effect
parent
7b11dfc0ee
commit
ab62f35373
|
@ -63,6 +63,29 @@ LogicalResult getListOperands(Value value, SmallVector<Value> &vals) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult constructListFromLiteral(PatternRewriter &rewriter,
|
||||
ValueTensorLiteralOp literalOp,
|
||||
SmallVector<Value> &vals) {
|
||||
// only supports splat ValueTensorLiterals for now. TODO: add support for
|
||||
// small non-splat valuetensorliterals.
|
||||
auto ty = dyn_cast<ValueTensorType>(literalOp.getType());
|
||||
if (!ty || !ty.hasSizes())
|
||||
return failure();
|
||||
auto attr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue());
|
||||
if (!attr)
|
||||
return failure();
|
||||
auto attrInt = dyn_cast<IntegerAttr>(attr.getSplatValue<Attribute>());
|
||||
if (!attrInt)
|
||||
return failure();
|
||||
IntegerType intty = cast<IntegerType>(attrInt.getType());
|
||||
if (!intty.isSignedInteger())
|
||||
return failure();
|
||||
Value materializedVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
literalOp.getLoc(), attrInt.getSInt());
|
||||
vals.resize(vals.size() + ty.getSizes()[0], materializedVal);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult getListFromTensor(Value value, SmallVector<Value> &vals) {
|
||||
constexpr int64_t kMaxFold = 16;
|
||||
if (auto tensor = value.getDefiningOp<Torch::AtenTensorOp>())
|
||||
|
@ -351,6 +374,172 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class PropagateAtenWhereSelfPattern : public OpRewritePattern<AtenWhereSelfOp> {
|
||||
public:
|
||||
using OpRewritePattern<AtenWhereSelfOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenWhereSelfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value condition = op.getCondition();
|
||||
Value self = op.getSelf();
|
||||
Value other = op.getOther();
|
||||
auto conditionTy = dyn_cast<Torch::ValueTensorType>(condition.getType());
|
||||
if (!conditionTy || !conditionTy.hasSizes() ||
|
||||
conditionTy.getSizes().size() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "bad condition type");
|
||||
auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType());
|
||||
if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "bad self type");
|
||||
auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType());
|
||||
if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "bad other type");
|
||||
int64_t conditionSize = selfTy.getSizes()[0];
|
||||
int64_t selfSize = selfTy.getSizes()[0];
|
||||
int64_t otherSize = otherTy.getSizes()[0];
|
||||
|
||||
if (selfSize != otherSize || selfSize != conditionSize)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"unimplemented: support for propogating with implicit broadcasting.");
|
||||
|
||||
constexpr int64_t kMaxFold = 16;
|
||||
if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"arguments are dynamic or too big");
|
||||
|
||||
SmallVector<Value> conditionList, selfList, otherList;
|
||||
if (failed(getListFromTensor(condition, conditionList)) ||
|
||||
(int64_t)conditionList.size() != conditionSize)
|
||||
return failure();
|
||||
|
||||
// If one of these tensors is a value tensor literal op, we will need to
|
||||
// create constant ints in the IR to form a list. Before calling
|
||||
// constructListFromLiteral, we must be certain that the conversion can no
|
||||
// longer fail, otherwise we will cause an infinite loop of creating a
|
||||
// constant and removing it.
|
||||
LogicalResult selfFromList = getListFromTensor(self, selfList);
|
||||
LogicalResult otherFromList = getListFromTensor(other, otherList);
|
||||
|
||||
if (failed(selfFromList) && failed(otherFromList))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "At least one operand must succeed at constructing a list");
|
||||
|
||||
auto selfLiteral = self.getDefiningOp<Torch::ValueTensorLiteralOp>();
|
||||
auto otherLiteral = other.getDefiningOp<Torch::ValueTensorLiteralOp>();
|
||||
if (succeeded(selfFromList) && otherLiteral &&
|
||||
failed(constructListFromLiteral(rewriter, otherLiteral, otherList)))
|
||||
return failure();
|
||||
if (succeeded(otherFromList) && selfLiteral &&
|
||||
failed(constructListFromLiteral(rewriter, selfLiteral, selfList)))
|
||||
return failure();
|
||||
if ((int64_t)selfList.size() != selfSize ||
|
||||
(int64_t)otherList.size() != otherSize)
|
||||
// this should only occur if we did not generate IR with
|
||||
// constructListFromLiteral
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value> whereVals;
|
||||
auto rank0IntTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>({}), selfTy.getDtype());
|
||||
auto rank0BoolTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>({}), conditionTy.getDtype());
|
||||
for (uint64_t i = 0; i < selfList.size(); i++) {
|
||||
Value rank0Cond = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
||||
loc, rank0BoolTy, conditionList[i]);
|
||||
Value rank0Self = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
||||
loc, rank0IntTy, selfList[i]);
|
||||
Value rank0Other = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
||||
loc, rank0IntTy, otherList[i]);
|
||||
Value rank0Where = rewriter.create<AtenWhereSelfOp>(
|
||||
loc, rank0IntTy, rank0Cond, rank0Self, rank0Other);
|
||||
whereVals.push_back(rewriter.create<AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), rank0Where));
|
||||
}
|
||||
Value list = rewriter.create<Torch::PrimListConstructOp>(
|
||||
op.getLoc(), Torch::ListType::get(whereVals[0].getType()), whereVals);
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
op.getLoc(), rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
|
||||
op, op.getType(), list, cstNone, cstNone, cstFalse);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class PropagateAtenEqTensorPattern : public OpRewritePattern<AtenEqTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern<AtenEqTensorOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenEqTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.getSelf();
|
||||
Value other = op.getOther();
|
||||
auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType());
|
||||
if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "bad self type");
|
||||
auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType());
|
||||
if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "bad other type");
|
||||
int64_t selfSize = selfTy.getSizes()[0];
|
||||
int64_t otherSize = otherTy.getSizes()[0];
|
||||
|
||||
if (selfSize != otherSize)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"unimplemented: support for propogating with implicit broadcasting.");
|
||||
|
||||
constexpr int64_t kMaxFold = 16;
|
||||
if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold ||
|
||||
otherSize == Torch::kUnknownSize || otherSize > kMaxFold)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"self or other is dynamic or too big");
|
||||
|
||||
SmallVector<Value> selfList, otherList;
|
||||
// If one of these tensors is a value tensor literal op, we will need to
|
||||
// create constant ints in the IR to form a list. Before calling
|
||||
// constructListFromLiteral, we must be certain that the conversion can no
|
||||
// longer fail, otherwise we will cause an infinite loop of creating a
|
||||
// constant and removing it.
|
||||
LogicalResult selfFromList = getListFromTensor(self, selfList);
|
||||
LogicalResult otherFromList = getListFromTensor(other, otherList);
|
||||
|
||||
if (failed(selfFromList) && failed(otherFromList))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "At least one operand must succeed at constructing a list");
|
||||
|
||||
auto selfLiteral = self.getDefiningOp<Torch::ValueTensorLiteralOp>();
|
||||
auto otherLiteral = other.getDefiningOp<Torch::ValueTensorLiteralOp>();
|
||||
if (succeeded(selfFromList) && otherLiteral &&
|
||||
failed(constructListFromLiteral(rewriter, otherLiteral, otherList)))
|
||||
return failure();
|
||||
if (succeeded(otherFromList) && selfLiteral &&
|
||||
failed(constructListFromLiteral(rewriter, selfLiteral, selfList)))
|
||||
return failure();
|
||||
if ((int64_t)selfList.size() != selfSize ||
|
||||
(int64_t)otherList.size() != otherSize)
|
||||
// this should only occur if we did not generate IR with
|
||||
// constructListFromLiteral
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> eqVals;
|
||||
for (uint64_t i = 0; i < selfList.size(); i++) {
|
||||
eqVals.push_back(
|
||||
rewriter.create<AtenEqIntOp>(op.getLoc(), selfList[i], otherList[i]));
|
||||
}
|
||||
Value list = rewriter.create<Torch::PrimListConstructOp>(
|
||||
op.getLoc(), Torch::ListType::get(eqVals[0].getType()), eqVals);
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
op.getLoc(), rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
|
||||
op, op.getType(), list, cstNone, cstNone, cstFalse);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
|
||||
public:
|
||||
|
@ -454,6 +643,26 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class FoldAtenSqueezeDimPattern : public OpRewritePattern<AtenSqueezeDimOp> {
|
||||
public:
|
||||
using OpRewritePattern<AtenSqueezeDimOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSqueezeDimOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||
if (!resultTy.hasSizes() || resultTy.getSizes().size() != 0)
|
||||
return rewriter.notifyMatchFailure(op, "Unknown result shape");
|
||||
|
||||
if (auto atenFull = op.getSelf().getDefiningOp<AtenFullOp>()) {
|
||||
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(
|
||||
op, resultTy, atenFull.getFillValue());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
|
||||
public:
|
||||
|
@ -694,6 +903,8 @@ public:
|
|||
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
|
||||
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
|
||||
FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
|
||||
PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern,
|
||||
FoldAtenSqueezeDimPattern,
|
||||
RemoveUnusedPattern<Torch::AtenIntBoolOp>,
|
||||
RemoveUnusedPattern<Torch::AtenEqIntOp>,
|
||||
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
|
||||
|
|
|
@ -160,3 +160,79 @@ func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !t
|
|||
%14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
return %14 : !torch.int
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @eq_tensor_and_where_self
|
||||
func.func @eq_tensor_and_where_self(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],si64> {
|
||||
// CHECK-DAG: %[[false:.*]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[none:.*]] = torch.constant.none
|
||||
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
|
||||
// CHECK-DAG: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[I1]], %[[DIM1]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64>
|
||||
// CHECK: return %[[TENSOR]] : !torch.vtensor<[4],si64>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%false = torch.constant.bool false
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
|
||||
%3 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
|
||||
%4 = torch.prim.ListConstruct %3, %int1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%5 = torch.aten.tensor %4, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64>
|
||||
%6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1>
|
||||
%7 = torch.aten.where.self %6, %1, %5 : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64>
|
||||
return %7 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @eq_tensor_from_tensor_and_literal
|
||||
func.func @eq_tensor_from_tensor_and_literal(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],i1> {
|
||||
// CHECK-DAG: %[[none:.*]] = torch.constant.none
|
||||
// CHECK-DAG: %[[false:.*]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[true:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[false]], %[[true]], %[[false]], %[[false]] : (!torch.bool, !torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
|
||||
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list<bool>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],i1>
|
||||
// CHECK: return %[[TENSOR]] : !torch.vtensor<[4],i1>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%false = torch.constant.bool false
|
||||
%int1 = torch.constant.int 1
|
||||
%int-1 = torch.constant.int -1
|
||||
%int0 = torch.constant.int 0
|
||||
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
|
||||
%3 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
|
||||
%4 = torch.prim.ListConstruct %3, %int-1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%5 = torch.aten.tensor %4, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64>
|
||||
%6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1>
|
||||
return %6 : !torch.vtensor<[4],i1>
|
||||
}
|
||||
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @squeeze_dim_full_fold
|
||||
func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int {
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
|
||||
// CHECK: return %[[SZE]] : !torch.int
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%51 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
|
||||
%55 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%56 = torch.aten.full %55, %51, %none, %none, %none, %false : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
|
||||
%57 = torch.aten.squeeze.dim %56, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
|
||||
%58 = torch.aten.item %57 : !torch.vtensor<[],si64> -> !torch.int
|
||||
return %58 : !torch.int
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue