diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 88e909c14..0842cff33 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4506,7 +4506,8 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) { if (auto intAttr = dyn_cast(splat)) { return intAttr.getType().isUnsignedInteger() ? getI64IntegerAttr(getContext(), intAttr.getUInt()) - : getI64IntegerAttr(getContext(), intAttr.getSInt()); + : getI64IntegerAttr(getContext(), + intAttr.getValue().getSExtValue()); } if (auto floatAttr = dyn_cast(splat)) { return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble()); diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index dd2f835ed..0e88bd8d6 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -9,7 +9,9 @@ #include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Iterators.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -25,7 +27,7 @@ namespace { LogicalResult materializeFolds(ImplicitLocOpBuilder b, ArrayRef fold, - SmallVector &values) { + SmallVectorImpl &values) { for (auto f : fold) { if (auto val = dyn_cast(f)) { values.push_back(val); @@ -41,7 +43,7 @@ LogicalResult materializeFolds(ImplicitLocOpBuilder b, if (auto val = dyn_cast(attr)) { values.push_back( - b.create(b.getType(), val)); + b.create(val.getValue().getSExtValue())); continue; } } @@ -63,33 +65,14 @@ LogicalResult getListOperands(Value value, SmallVector &vals) { return success(); } -LogicalResult constructListFromLiteral(PatternRewriter &rewriter, - ValueTensorLiteralOp literalOp, - SmallVector &vals) { - // only supports splat ValueTensorLiterals for now. TODO: add support for - // small non-splat valuetensorliterals. - auto ty = dyn_cast(literalOp.getType()); - if (!ty || !ty.hasSizes()) - return failure(); - auto attr = dyn_cast_or_null(literalOp.getValue()); - if (!attr) - return failure(); - auto attrInt = dyn_cast(attr.getSplatValue()); - if (!attrInt) - return failure(); - IntegerType intty = cast(attrInt.getType()); - if (!intty.isSignedInteger()) - return failure(); - Value materializedVal = rewriter.create( - literalOp.getLoc(), attrInt.getSInt()); - vals.resize(vals.size() + ty.getSizes()[0], materializedVal); - return success(); -} - -LogicalResult getListFromTensor(Value value, SmallVector &vals) { +LogicalResult getListFromTensor(Value value, SmallVector &vals) { constexpr int64_t kMaxFold = 16; - if (auto tensor = value.getDefiningOp()) - return getListOperands(tensor.getData(), vals); + if (auto tensor = value.getDefiningOp()) { + SmallVector unfolded; + LogicalResult gotList = getListOperands(tensor.getData(), unfolded); + vals = getAsOpFoldResult(unfolded); + return gotList; + } if (auto full = value.getDefiningOp()) { auto ty = cast(full.getType()); @@ -99,14 +82,67 @@ LogicalResult getListFromTensor(Value value, SmallVector &vals) { if (ty.getSizes()[0] > kMaxFold) return failure(); - vals.resize(vals.size() + ty.getSizes()[0], full.getFillValue()); + vals.resize(vals.size() + ty.getSizes()[0], + getAsOpFoldResult(full.getFillValue())); return success(); } + // TODO: Add a case for unsqueeze of a primnumtotensorscalarop? - return failure(); + // Last supported case: ValueTensorLiteralOp + auto literalOp = value.getDefiningOp(); + if (!literalOp) + return failure(); + + // Check the type. We make sure the type is not unsigned here before trying to + // materialize + auto ty = cast(literalOp.getType()); + if (!ty.hasSizes() || ty.getSizes().size() > 1) + return failure(); + int64_t listSize = ty.getSizes().size() == 1 ? ty.getSizes().front() : 1; + auto intTy = dyn_cast_or_null(ty.getDtype()); + if (!intTy || intTy.isUnsigned()) + return failure(); + + auto splattr = dyn_cast_or_null(literalOp.getValue()); + auto denseAttr = dyn_cast_or_null(literalOp.getValue()); + + if (!splattr && !denseAttr) + return failure(); + + if (splattr) { + auto attr = splattr.getSplatValue(); + vals.resize((int64_t)vals.size() + listSize, attr); + } + + if (denseAttr && !splattr) { + for (auto e : denseAttr.getValues()) + vals.push_back(e); + } + + if ((int64_t)vals.size() != listSize) + return failure(); + + return success(); +} + +Value constructAtenTensorOpFromList(ImplicitLocOpBuilder b, mlir::Type resultTy, + SmallVector &listValues) { + auto dimList = b.create( + b.getType(listValues.front().getType()), listValues); + Value cstNone = b.create(); + Value cstFalse = b.create(b.getBoolAttr(false)); + return b.create(resultTy, dimList, cstNone, cstNone, + cstFalse); } } // namespace +/// ------ Propagation Patterns ------ /// +// The general goal of these patterns is to convert SomeTensorOp to [scalarOps +// -> PrimListOfInts -> AtenTensorOp] Since these tensorized shape calculation +// ops are chained together, sequences like OpA -> OpB will propagate OpA first: +// [scalarOpsA -> ListA -> TensorA] -> OpB. Then OpB will be able to +// getListFromTensor(A), and further propagate scalarization. + namespace { class PropagateAtenShapeToTensorPattern : public OpRewritePattern { @@ -115,30 +151,27 @@ public: LogicalResult matchAndRewrite(Aten_ShapeAsTensorOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); auto self = op.getSelf(); auto selfTy = cast(self.getType()); if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "self has unknown rank"); int64_t rank = selfTy.getSizes().size(); - SmallVector dims; + SmallVector dims; for (int64_t i = 0; i < rank; ++i) { - auto iv = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - dims.push_back(rewriter.create( - loc, rewriter.getType(), self, iv)); + auto iv = b.create(i); + dims.push_back(b.createOrFold( + rewriter.getType(), self, iv)); + } + SmallVector materializedDims; + if (failed(materializeFolds(b, dims, materializedDims))) { + return failure(); } - auto dimList = rewriter.create( - loc, - rewriter.getType(rewriter.getType()), - dims); - - Value cstNone = rewriter.create(loc); - Value cstFalse = rewriter.create( - loc, rewriter.getBoolAttr(false)); - rewriter.replaceOpWithNewOp( - op, op.getType(), dimList, cstNone, cstNone, cstFalse); + Value result = + constructAtenTensorOpFromList(b, op.getType(), materializedDims); + rewriter.replaceOp(op, result); return success(); } }; @@ -171,56 +204,20 @@ public: SmallVector scalars; for (auto element : tensors) { - llvm::SmallVector delisted; - if (succeeded(getListFromTensor(element, delisted))) { - for (auto scalar : delisted) - scalars.push_back(scalar); - continue; - } + llvm::SmallVector delisted; + if (failed(getListFromTensor(element, delisted))) + return rewriter.notifyMatchFailure(op, "unknown op fold type"); - DenseElementsAttr attr; - if (matchPattern(element, m_Constant(&attr))) { - if (attr.isSplat()) { - scalars.resize(scalars.size() + attr.getNumElements(), - attr.getSplatValue()); - continue; - } - - for (auto e : attr.getValues()) { - scalars.push_back(e); - } - continue; - } - - return rewriter.notifyMatchFailure(op, "unknown op fold type"); - } - - for (auto &scalar : scalars) { - if (auto attr = dyn_cast(scalar)) { - if (auto iattr = dyn_cast(attr)) { - auto i64 = iattr.getValue().getSExtValue(); - scalar = rewriter.getI64IntegerAttr(i64); - } - } + for (auto scalar : delisted) + scalars.push_back(scalar); } SmallVector values; - if (failed(materializeFolds(b, scalars, values))) + if (failed(materializeFolds(b, scalars, values)) || values.empty()) return rewriter.notifyMatchFailure(op, "unable to materialize constants"); - Type eTy = b.getType(); - if (isa(resultTy.getDtype())) - eTy = rewriter.getType(); - - auto elementsList = b.create( - rewriter.getType(eTy), values); - - Value cstNone = b.create(); - Value cstFalse = - b.create(rewriter.getBoolAttr(false)); - rewriter.replaceOpWithNewOp( - op, op.getType(), elementsList, cstNone, cstNone, cstFalse); - + Value result = constructAtenTensorOpFromList(b, resultTy, values); + rewriter.replaceOp(op, result); return success(); } }; @@ -236,7 +233,7 @@ public: auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); - SmallVector elements; + SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); @@ -244,8 +241,8 @@ public: if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "requires a constant dim"); - DenseElementsAttr idx; - if (!matchPattern(op.getIndex(), m_Constant(&idx))) + SmallVector idxFolds; + if (failed(getListFromTensor(op.getIndex(), idxFolds))) return rewriter.notifyMatchFailure(op, "requires a constant index"); auto selfTy = cast(op.getSelf().getType()); @@ -268,28 +265,25 @@ public: "expects unary non-dim dimension"); } - SmallVector selected; - if (idx.isSplat()) { - int64_t indexInt = idx.getSplatValue().getSExtValue(); + SmallVector selected; + for (auto idx : idxFolds) { + auto attr = dyn_cast_or_null(dyn_cast(idx)); + if (!attr) + return failure(); + int64_t indexInt = attr.getValue().getSExtValue(); indexInt = indexInt < 0 ? indexInt + dimLength : indexInt; - selected.resize(idx.getNumElements(), elements[indexInt]); - } else { - for (APInt val : idx.getValues()) { - int64_t indexInt = val.getSExtValue(); - selected.push_back(elements[indexInt]); - } + if (indexInt < 0 || indexInt >= dimLength) + return failure(); + selected.push_back(elements[indexInt]); } - auto eTy = elements.front().getType(); + SmallVector materializedSelected; + if (failed(materializeFolds(b, selected, materializedSelected))) + return failure(); - auto dimList = rewriter.create( - loc, rewriter.getType(eTy), selected); - - Value cstNone = rewriter.create(loc); - Value cstFalse = rewriter.create( - loc, rewriter.getBoolAttr(false)); - rewriter.replaceOpWithNewOp( - op, op.getType(), dimList, cstNone, cstNone, cstFalse); + Value result = + constructAtenTensorOpFromList(b, op.getType(), materializedSelected); + rewriter.replaceOp(op, result); return success(); } }; @@ -309,7 +303,7 @@ public: auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); - SmallVector elements; + SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); @@ -356,19 +350,16 @@ public: "expects unary non-dim dimension"); } - SmallVector selected; + SmallVector selected; for (int i = start; i < end; i += step) selected.push_back(elements[i]); - auto eTy = elements.front().getType(); - auto dimList = rewriter.create( - loc, rewriter.getType(eTy), selected); + SmallVector values; + if (failed(materializeFolds(b, selected, values))) + return failure(); - Value cstNone = rewriter.create(loc); - Value cstFalse = rewriter.create( - loc, rewriter.getBoolAttr(false)); - rewriter.replaceOpWithNewOp( - op, op.getType(), dimList, cstNone, cstNone, cstFalse); + Value result = constructAtenTensorOpFromList(b, op.getType(), values); + rewriter.replaceOp(op, result); return success(); } }; @@ -407,62 +398,39 @@ public: return rewriter.notifyMatchFailure(op, "arguments are dynamic or too big"); + SmallVector conditionFolds, selfFolds, otherFolds; + if (failed(getListFromTensor(condition, conditionFolds)) || + failed(getListFromTensor(self, selfFolds)) || + failed(getListFromTensor(other, otherFolds))) + return failure(); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector conditionList, selfList, otherList; - if (failed(getListFromTensor(condition, conditionList)) || - (int64_t)conditionList.size() != conditionSize) + if (failed(materializeFolds(b, conditionFolds, conditionList)) || + failed(materializeFolds(b, selfFolds, selfList)) || + failed(materializeFolds(b, otherFolds, otherList))) return failure(); - // If one of these tensors is a value tensor literal op, we will need to - // create constant ints in the IR to form a list. Before calling - // constructListFromLiteral, we must be certain that the conversion can no - // longer fail, otherwise we will cause an infinite loop of creating a - // constant and removing it. - LogicalResult selfFromList = getListFromTensor(self, selfList); - LogicalResult otherFromList = getListFromTensor(other, otherList); - - if (failed(selfFromList) && failed(otherFromList)) - return rewriter.notifyMatchFailure( - op, "At least one operand must succeed at constructing a list"); - - auto selfLiteral = self.getDefiningOp(); - auto otherLiteral = other.getDefiningOp(); - if (succeeded(selfFromList) && otherLiteral && - failed(constructListFromLiteral(rewriter, otherLiteral, otherList))) - return failure(); - if (succeeded(otherFromList) && selfLiteral && - failed(constructListFromLiteral(rewriter, selfLiteral, selfList))) - return failure(); - if ((int64_t)selfList.size() != selfSize || - (int64_t)otherList.size() != otherSize) - // this should only occur if we did not generate IR with - // constructListFromLiteral - return failure(); - - Location loc = op.getLoc(); SmallVector whereVals; auto rank0IntTy = rewriter.getType( ArrayRef({}), selfTy.getDtype()); auto rank0BoolTy = rewriter.getType( ArrayRef({}), conditionTy.getDtype()); for (uint64_t i = 0; i < selfList.size(); i++) { - Value rank0Cond = rewriter.create( - loc, rank0BoolTy, conditionList[i]); - Value rank0Self = rewriter.create( - loc, rank0IntTy, selfList[i]); - Value rank0Other = rewriter.create( - loc, rank0IntTy, otherList[i]); - Value rank0Where = rewriter.create( - loc, rank0IntTy, rank0Cond, rank0Self, rank0Other); - whereVals.push_back(rewriter.create( - loc, rewriter.getType(), rank0Where)); + Value rank0Cond = b.create( + rank0BoolTy, conditionList[i]); + Value rank0Self = + b.create(rank0IntTy, selfList[i]); + Value rank0Other = + b.create(rank0IntTy, otherList[i]); + Value rank0Where = b.create(rank0IntTy, rank0Cond, + rank0Self, rank0Other); + whereVals.push_back( + b.create(rewriter.getType(), rank0Where)); } - Value list = rewriter.create( - op.getLoc(), Torch::ListType::get(whereVals[0].getType()), whereVals); - Value cstNone = rewriter.create(op.getLoc()); - Value cstFalse = rewriter.create( - op.getLoc(), rewriter.getBoolAttr(false)); - rewriter.replaceOpWithNewOp( - op, op.getType(), list, cstNone, cstNone, cstFalse); + Value result = constructAtenTensorOpFromList(b, op.getType(), whereVals); + rewriter.replaceOp(op, result); return success(); } }; @@ -496,45 +464,34 @@ public: return rewriter.notifyMatchFailure(op, "self or other is dynamic or too big"); + SmallVector selfFolds, otherFolds; + if (failed(getListFromTensor(self, selfFolds)) || + failed(getListFromTensor(other, otherFolds))) + return rewriter.notifyMatchFailure(op, "failed to get list from tensor"); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); SmallVector selfList, otherList; - // If one of these tensors is a value tensor literal op, we will need to - // create constant ints in the IR to form a list. Before calling - // constructListFromLiteral, we must be certain that the conversion can no - // longer fail, otherwise we will cause an infinite loop of creating a - // constant and removing it. - LogicalResult selfFromList = getListFromTensor(self, selfList); - LogicalResult otherFromList = getListFromTensor(other, otherList); + if (failed(materializeFolds(b, selfFolds, selfList)) || + failed(materializeFolds(b, otherFolds, otherList))) + return rewriter.notifyMatchFailure(op, "failed to materialize folds"); - if (failed(selfFromList) && failed(otherFromList)) - return rewriter.notifyMatchFailure( - op, "At least one operand must succeed at constructing a list"); - - auto selfLiteral = self.getDefiningOp(); - auto otherLiteral = other.getDefiningOp(); - if (succeeded(selfFromList) && otherLiteral && - failed(constructListFromLiteral(rewriter, otherLiteral, otherList))) - return failure(); - if (succeeded(otherFromList) && selfLiteral && - failed(constructListFromLiteral(rewriter, selfLiteral, selfList))) - return failure(); - if ((int64_t)selfList.size() != selfSize || - (int64_t)otherList.size() != otherSize) - // this should only occur if we did not generate IR with - // constructListFromLiteral - return failure(); - - SmallVector eqVals; + SmallVector eqBoolFolds; for (uint64_t i = 0; i < selfList.size(); i++) { - eqVals.push_back( - rewriter.create(op.getLoc(), selfList[i], otherList[i])); + OpFoldResult eqInt = + b.createOrFold(selfList[i], otherList[i]); + if (auto eqIntVal = dyn_cast(eqInt)) + eqInt = b.createOrFold(eqIntVal); + // if eqInt was an Attribute, it will materialize to a constant int op, + // which is what we want. + eqBoolFolds.push_back(eqInt); } - Value list = rewriter.create( - op.getLoc(), Torch::ListType::get(eqVals[0].getType()), eqVals); - Value cstNone = rewriter.create(op.getLoc()); - Value cstFalse = rewriter.create( - op.getLoc(), rewriter.getBoolAttr(false)); - rewriter.replaceOpWithNewOp( - op, op.getType(), list, cstNone, cstNone, cstFalse); + SmallVector eqVals; + if (failed(materializeFolds(b, eqBoolFolds, eqVals))) { + return failure(); + } + + Value result = constructAtenTensorOpFromList(b, op.getType(), eqVals); + rewriter.replaceOp(op, result); return success(); } }; @@ -546,20 +503,47 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenItemOp op, PatternRewriter &rewriter) const override { + SmallVector elements; + Value self = op.getSelf(); + auto selfTy = cast(self.getType()); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - SmallVector elements; + + // Rank 0 item op prop + if (selfTy.getSizes().size() == 0) { + auto numToTensor = self.getDefiningOp(); + auto squeezeDim = self.getDefiningOp(); + if (!squeezeDim && !numToTensor) + return rewriter.notifyMatchFailure(op, + "unhandled item of rank 0 operand"); + if (numToTensor) { + rewriter.replaceOp(op, numToTensor.getA()); + return success(); + } + rewriter.replaceOpWithNewOp(op, op.getType(), + squeezeDim.getSelf()); + return success(); + } + + // Rank 1 item op prop if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); if (elements.size() != 1) - return rewriter.notifyMatchFailure(op, "expected no elements"); + return rewriter.notifyMatchFailure(op, "expected one element"); - rewriter.replaceOp(op, elements[0]); + SmallVector materialized; + if (failed(materializeFolds(b, elements, materialized))) + return failure(); + + rewriter.replaceOp(op, materialized.front()); return success(); } }; } // namespace +/// ------ Fold Patterns ------ /// +// These are shape-specific folding patterns + namespace { class FoldAtenTensorSplatPattern : public OpRewritePattern { public: @@ -643,26 +627,6 @@ public: }; } // namespace -namespace { -class FoldAtenSqueezeDimPattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenSqueezeDimOp op, - PatternRewriter &rewriter) const override { - auto resultTy = cast(op.getType()); - if (!resultTy.hasSizes() || resultTy.getSizes().size() != 0) - return rewriter.notifyMatchFailure(op, "Unknown result shape"); - - if (auto atenFull = op.getSelf().getDefiningOp()) { - rewriter.replaceOpWithNewOp( - op, resultTy, atenFull.getFillValue()); - return success(); - } - return failure(); - } -}; -} // namespace - namespace { class FoldAtenWhereSelf : public OpRewritePattern { public: @@ -697,16 +661,19 @@ public: if (selfSize && otherSize) { if (selfSize.getSelf() != otherSize.getSelf()) - return failure(); - - if (selfSize.getDim() != otherSize.getDim()) - return failure(); + return rewriter.notifyMatchFailure(op, "sizes not of same tensor"); + int64_t dimSelf, dimOther; + if ((selfSize.getDim() != otherSize.getDim()) && + (!matchPattern(selfSize.getDim(), m_TorchConstantInt(&dimSelf)) || + !matchPattern(otherSize.getDim(), m_TorchConstantInt(&dimOther)) || + (dimSelf != dimOther))) + return rewriter.notifyMatchFailure(op, "sizes not of same dim"); rewriter.replaceOp(op, op.getSelf()); return success(); } - return failure(); + return rewriter.notifyMatchFailure(op, "unable to fold"); } }; } // namespace @@ -750,6 +717,8 @@ public: }; } // namespace +/// ------ Canonicalization Patterns ------ /// + namespace { // This is a specific pattern for converting views like [?,...,?,lastDim] -> // [?,...,?,factor0,factor1] to unflatten, and views like @@ -888,6 +857,58 @@ public: }; } // namespace +namespace { + +bool isSourceOpForShapeScalarization(Operation *op) { + return llvm::isa(op); +} + +bool isPrimListOfInts(Operation *op) { + auto primListOp = dyn_cast(op); + if (!primListOp) + return false; + auto listType = dyn_cast(primListOp.getType()); + if (!listType) + return false; + return llvm::isa(listType.getContainedType()); +} + +void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { + patterns.insert( + patterns.getContext()); +} + +void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + +void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { + patterns.insert, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern>( + patterns.getContext()); +} + +} // namespace namespace { class ScalarizeShapesPass : public ScalarizeShapesBase { public: @@ -898,33 +919,74 @@ public: void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.insert, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern>(context); + // populate patterns + populateScalarizationPropagationPatterns(patterns); + populateScalarizationFoldPatterns(patterns); + populateScalarizationCanonicalizePatterns(patterns); + populateScalarizationRemovePatterns(patterns); context->getLoadedDialect() ->getCanonicalizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + // don't load torch canonicalization patterns, since these may lead to + // issues with propagation + + // walk func op bottom-up to collect a SetVector of shape-related operations + // When we pass this SetVector to the pattern rewrite driver, it will + // process the operations top-down, thereby propagating scalarization + // starting from sources. + auto funcOp = getOperation(); + llvm::SetVector shapeCalculationOps; + funcOp.walk( + [&](Operation *op) { + // Walking bottom-up, start adding ops when we reach an anchor point + // (a prim list of ints) + if (isPrimListOfInts(op)) { + shapeCalculationOps.insert(op); + return; + } + // add view ops for now until the decompositions for flatten and + // unflatten are removed. + if (isa(op)) { + shapeCalculationOps.insert(op); + return; + } + // Insert the op if any of it's consumers have already been identified + // as a shape calculation op. To avoid adding the producer of + // something like a size.int op, don't add ops when their consumer is + // a source op for shape scalarization. Here is some sample IR: + // ------ + // %0 = aten.matmul %arg0, %arg1 : ... -> !torch.vtensor<[?,?,?],f32> + // %1 = aten.size.int %0, %int0 : !torch.int + // %2 = prim.ListConstruct %1 : (!torch.int) -> !torch.list + // return %2 : !torch.list + // ------ + // In this example, don't add the matmul (%0), or it's producers, to + // shapeCalculationOps. It's consumer (%1) is indeed a shape + // calculation op, but the size.int op is an elementary unit of shape + // computation. No futher gathering of producers is necessary to + // reduce this. Similarly, don't add the `self` of a view op. + for (OpOperand &use : op->getUses()) { + Operation *userOp = use.getOwner(); + if (shapeCalculationOps.contains(userOp) && + !isSourceOpForShapeScalarization(userOp) && + !isa(userOp)) { + shapeCalculationOps.insert(op); + return; + } + } + }); + + GreedyRewriteConfig config; + // When propagating, we need to go back and clean up aten.Tensor ops that + // have been futher propagated. It is also necessary to add newly created + // ops for custom folding after scalarizing a where.self op. + config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; + if (failed(applyOpPatternsAndFold(shapeCalculationOps.getArrayRef(), + std::move(patterns), config))) { return signalPassFailure(); } + + // TODO: Warn when failing to process operations in the worklist. } }; } // namespace diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index c86844996..7f6aa8a26 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -12,7 +12,13 @@ func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtenso // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[I5]], %[[SZ1]], %[[SZ2]] // CHECK-DAG: %[[TENSOR:.+]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] // CHECK: return %[[TENSOR]] : !torch.vtensor<[3],si32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %literal1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> %0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> + %1 = torch.aten.index_select %0, %int0, %literal1: !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list return %0 : !torch.vtensor<[3],si32> } @@ -20,17 +26,20 @@ func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtenso // CHECK-LABEL: @shape_as_tensor_dim func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> { - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[INT1:.+]] = torch.constant.int 1 // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1_0]] // CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]] // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32> %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> %dim = torch.constant.int 0 %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> %select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32> + %item = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list return %select : !torch.vtensor<[],si32> } @@ -47,6 +56,22 @@ func.func @shape_as_tensor_dim_item(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !tor %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> %select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32> %out = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int + %list = torch.prim.ListConstruct %out : (!torch.int) -> !torch.list + return %out : !torch.int +} + +// ----- + +// CHECK-LABEL: @literal_item +func.func @literal_item() -> !torch.int { + // CHECK: %int2 = torch.constant.int 2 + // CHECK: return %int2 : !torch.int + %shape = torch.vtensor.literal(dense<[1,2,3]> : tensor<3xsi32>) : !torch.vtensor<[3],si32> + %dim = torch.constant.int 0 + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> + %select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32> + %out = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int + %list = torch.prim.ListConstruct %out : (!torch.int) -> !torch.list return %out : !torch.int } @@ -64,12 +89,16 @@ func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torc // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[SZ1]], %[[SZ3]] // CHECK-DAG: %[[TENSOR:.+]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] // CHECK: return %[[TENSOR]] + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?,?],f32> -> !torch.vtensor<[4],si32> %dim = torch.constant.int 0 %start = torch.constant.int 1 %end = torch.constant.int 5 %step = torch.constant.int 2 %slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32> + %select = torch.aten.index_select %slice, %dim, %idx : !torch.vtensor<[2],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32> + %item = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list return %slice : !torch.vtensor<[2],si32> } @@ -158,6 +187,7 @@ func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !t %12 = torch.aten.cat %11, %int0 : !torch.list, !torch.int -> !torch.vtensor<[3],si64> %13 = torch.aten.slice.Tensor %12, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int + %list = torch.prim.ListConstruct %14 : (!torch.int) -> !torch.list return %14 : !torch.int } @@ -166,18 +196,20 @@ func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !t // CHECK-LABEL: @eq_tensor_and_where_self func.func @eq_tensor_and_where_self(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],si64> { - // CHECK-DAG: %[[false:.*]] = torch.constant.bool false - // CHECK-DAG: %[[none:.*]] = torch.constant.none - // CHECK-DAG: %[[I1:.*]] = torch.constant.int 1 - // CHECK-DAG: %[[I0:.*]] = torch.constant.int 0 - // CHECK-DAG: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int - // CHECK-DAG: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int - // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[I1]], %[[DIM1]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + // CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + // CHECK: %[[I1_0:.*]] = torch.constant.int 1 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[I1_0]], %[[DIM1]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64> // CHECK: return %[[TENSOR]] : !torch.vtensor<[4],si64> %none = torch.constant.none %0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> %1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %false = torch.constant.bool false %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 @@ -187,6 +219,9 @@ func.func @eq_tensor_and_where_self(%arg0: !torch.vtensor<[?,?],si64>) -> !torch %5 = torch.aten.tensor %4, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64> %6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> %7 = torch.aten.where.self %6, %1, %5 : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + %select = torch.aten.index_select %7, %int0, %idx : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %item = torch.aten.item %select : !torch.vtensor<[],si64> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list return %7 : !torch.vtensor<[4],si64> } @@ -195,15 +230,20 @@ func.func @eq_tensor_and_where_self(%arg0: !torch.vtensor<[?,?],si64>) -> !torch // CHECK-LABEL: @eq_tensor_from_tensor_and_literal func.func @eq_tensor_from_tensor_and_literal(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],i1> { - // CHECK-DAG: %[[none:.*]] = torch.constant.none - // CHECK-DAG: %[[false:.*]] = torch.constant.bool false - // CHECK-DAG: %[[true:.*]] = torch.constant.bool true - // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[false]], %[[true]], %[[false]], %[[false]] : (!torch.bool, !torch.bool, !torch.bool, !torch.bool) -> !torch.list - // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_0:.*]] = torch.constant.int 1 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[int0_2:.*]] = torch.constant.int 0 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[int0]], %[[int1_0]], %[[int0_1]], %[[int0_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[TENSOR]] : !torch.vtensor<[4],i1> %none = torch.constant.none %0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> %1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %false = torch.constant.bool false %int1 = torch.constant.int 1 %int-1 = torch.constant.int -1 @@ -213,6 +253,9 @@ func.func @eq_tensor_from_tensor_and_literal(%arg0: !torch.vtensor<[?,?],si64>) %4 = torch.prim.ListConstruct %3, %int-1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %5 = torch.aten.tensor %4, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64> %6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + %select = torch.aten.index_select %6, %int0, %idx : !torch.vtensor<[4],i1>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],i1> + %item = torch.aten.item %select : !torch.vtensor<[],i1> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list return %6 : !torch.vtensor<[4],i1> } @@ -221,10 +264,11 @@ func.func @eq_tensor_from_tensor_and_literal(%arg0: !torch.vtensor<[?,?],si64>) // ----- // CHECK-LABEL: @squeeze_dim_full_fold -func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int { +func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.list { // CHECK: %[[I0:.*]] = torch.constant.int 0 // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int - // CHECK: return %[[SZE]] : !torch.int + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SZE]] : (!torch.int) -> !torch.list + // CHECK: return %[[LIST]] : !torch.list %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 %none = torch.constant.none @@ -234,5 +278,6 @@ func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.in %56 = torch.aten.full %55, %51, %none, %none, %none, %false : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> %57 = torch.aten.squeeze.dim %56, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> %58 = torch.aten.item %57 : !torch.vtensor<[],si64> -> !torch.int - return %58 : !torch.int + %59 = torch.prim.ListConstruct %58 : (!torch.int) -> !torch.list + return %59 : !torch.list }