add squeeze dims op

shark
Vivek Khandelwal 2023-02-03 11:39:13 +05:30
parent e510a9b5ec
commit 102058dc70
10 changed files with 213 additions and 2 deletions

View File

@ -5706,6 +5706,30 @@ def Torch_AtenSqueezeOp : Torch_Op<"aten.squeeze", [
let hasFolder = 1;
}
def Torch_AtenSqueezeDimsOp : Torch_Op<"aten.squeeze.dims", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::squeeze.dims : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSqueezeDimsOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenSqueezeDimsOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
AllowsTypeRefinement,
ReadOnly

View File

@ -679,6 +679,18 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenSqueezeDimsOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeDimsOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0);
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenRoundOp
//===----------------------------------------------------------------------===//
@ -1443,6 +1455,41 @@ void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}
//===----------------------------------------------------------------------===//
// AtenRemoveIntOp
//===----------------------------------------------------------------------===//
// void AtenRemoveIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// MLIRContext *context) {
// patterns.add(+[](AtenRemoveIntOp op, PatternRewriter &rewriter) {
// SmallVector<int64_t> listElements;
// if (!matchPattern(op.getSelf(), m_TorchListOfConstantInts(listElements)))
// return rewriter.notifyMatchFailure(
// op, "all input list elements must be constant ints");
// int64_t elementToBeRemoved;
// if (!matchPattern(op.getEl(), m_TorchConstantInt(&elementToBeRemoved)))
// return rewriter.notifyMatchFailure(
// op, "Expected element to be removed a constant int.");
// listElements.erase(listElements.begin(),
// std::find(listElements.begin(), listElements.end(),
// elementToBeRemoved));
// SmallVector<Value> updatedListElements;
// for (int64_t elem : listElements)
// updatedListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
// op->getLoc(), rewriter.getI64IntegerAttr(elem)));
// Value result = rewriter.create<Torch::PrimListConstructOp>(
// op->getLoc(), Torch::ListType::get(rewriter.getType<Torch::IntType>()),
// updatedListElements);
// llvm::outs()<<"*****************************\n";
// op.getSelf().replaceAllUsesWith(result);
// rewriter.eraseOp(op);
// return success();
// });
// }
//===----------------------------------------------------------------------===//
// NonValueTensorLiteralOp
//===----------------------------------------------------------------------===//
@ -2407,6 +2454,29 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
}
//===----------------------------------------------------------------------===//
// PrimMaxSelfIntOp
//===----------------------------------------------------------------------===//
// OpFoldResult PrimMaxSelfIntOp::fold(FoldAdaptor adaptor) {
// auto list = getOperand().getDefiningOp<PrimListConstructOp>();
// if (!list)
// return nullptr;
// // TODO: What does it return for an empty list?
// if (list->getNumOperands() == 0)
// return nullptr;
// SmallVector<int64_t> values;
// for (auto operand : list->getOperands()) {
// int64_t value;
// if (!matchPattern(operand, m_TorchConstantInt(&value)))
// return nullptr;
// values.push_back(value);
// }
// return getI64IntegerAttr(getContext(),
// *std::max_element(values.begin(), values.end()));
// }
//===----------------------------------------------------------------------===//
// PrimMinSelfIntOp
//===----------------------------------------------------------------------===//

View File

@ -7075,6 +7075,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.squeeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.squeeze.dims\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %1 = torch.prim.Loop %0, %true, init(%arg0) {\n"
" ^bb0(%arg2: !torch.int, %arg3: !torch.list<int>):\n"
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.sub.int %2, %arg2 : !torch.int, !torch.int -> !torch.int\n"
" %4 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %5 = torch.aten.__getitem__.t %arg1, %4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %6 = func.call @__torch__.torch.jit._shape_functions.squeeze(%arg3, %5) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter(%6 : !torch.list<int>)\n"
" } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.squeeze.dims\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.int {\n"
" return %arg1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"

View File

@ -4159,6 +4159,32 @@ public:
};
} // namespace
namespace {
class DecomposeAtenSqueezeDimsOp : public OpRewritePattern<AtenSqueezeDimsOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSqueezeDimsOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
SmallVector<int64_t> dimListInts;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimListInts)))
return rewriter.notifyMatchFailure(
op, "all dim list elements must be constant ints");
for (unsigned i = dimListInts.size(); i > 0; i--) {
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimListInts[i - 1]));
BaseTensorType selfType = self.getType().cast<BaseTensorType>();
Type squeezedType = computeReductionType(rewriter, op, selfType, cstDim,
/*keepDim=*/false);
self = rewriter.create<AtenSqueezeDimOp>(loc, squeezedType, self, cstDim);
}
rewriter.replaceOp(op, self);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -4321,6 +4347,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSqueezeDimsOp>(patterns);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;

View File

@ -446,6 +446,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>();
target.addIllegalOp<AtenMovedimIntOp>();
target.addIllegalOp<AtenSqueezeDimsOp>();
for (std::string opName : backendLegalOps) {
target.addLegalOp(OperationName(opName, context));
}

View File

@ -186,6 +186,22 @@ public:
listLiterals.push_back(runningList);
continue;
}
// if (auto remove = dyn_cast<AtenRemoveIntOp>(user)) {
// if (!remove.use_empty())
// return rewriter.notifyMatchFailure(
// op, "Expected `AtenInsertTOp` to not have users");
// int64_t element;
// if (!matchPattern(remove.getEl(), m_TorchConstantInt(&element)))
// return rewriter.notifyMatchFailure(
// op, "Expected `element` to be removed a constant int");
// if (remove.getSelf() == op) {
// llvm::erase_value(runningList, remove.getEl());
// generatedNewLiteral = true;
// llvm::outs()<<"*****************************\n";
// }
// listLiterals.push_back(runningList);
// continue;
// }
// If this user potentially mutates the list and isn't handled above, then
// we can't abstractly interpret any further.
if (potentiallyMutatesListOperands(user))
@ -224,6 +240,14 @@ public:
rewriter.eraseOp(setItem);
continue;
}
// if (auto removeInt = dyn_cast<AtenRemoveIntOp>(user)) {
// rewriter.setInsertionPoint(removeInt);
// latestLiteral = rewriter.create<PrimListConstructOp>(
// removeInt->getLoc(), op.getType(), listLiterals[nextLiteral++]);
// if (removeInt.getSelf() == op)
// rewriter.eraseOp(removeInt);
// continue;
// }
for (OpOperand &opOperand : user->getOpOperands()) {
if (opOperand.get() == op.getResult()) {
opOperand.set(latestLiteral);

View File

@ -200,8 +200,8 @@ bool Torch::isViewLikeOp(Operation *op) {
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
AtenExpandOp, AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
AtenSqueezeDimOp, AtenSqueezeOp, AtenSqueezeDimsOp, AtenTOp,
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenToDeviceOp, AtenMovedimIntOp>(op);
}

View File

@ -756,6 +756,29 @@ def atensqueeze〡shape(self: List[int]) -> List[int]:
def atensqueezedim〡shape(self: List[int], dim: int) -> List[int]:
return upstream_shape_functions.squeeze(self, dim)
def atensqueezedims〡shape(self: List[int], dim: List[int]) -> List[int]:
# wrapped_dims : List[int] = []
# wrapped_dims = upstream_shape_functions._copy(dim)
# out: List[int] = self
for i in range(len(dim)):
self = upstream_shape_functions.squeeze(self, dim[len(dim) - i - 1])
# sorted_dims = upstream_shape_functions._copy(dim)
# sorted_dims.sort(reverse=True)
# result_shape = upstream_shape_functions._copy(self)
# for i in range(len(wrapped_dims)):
# curr_dim = max(wrapped_dims)
# wrapped_dims.remove(curr_dim)
# # for ele in wrapped_dims:
# # wrapped_dims = [ele for ele in wrapped_dims if ele != curr_dim]
# # wrapped_dims.remove(curr_dim)
# result_shape = upstream_shape_functions.squeeze(result_shape, curr_dim)
return self
def atensqueezedims〡dtype(self_rank: int, self_dtype: int, dim: List[int]) -> int:
return self_dtype
def primNumToTensorScalar〡shape(a: float) -> List[int]:
return []

View File

@ -436,6 +436,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)")
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::squeeze.dims : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
@ -565,6 +566,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)", has_folder=True)
emit("aten::append.t : (t[], t) -> (t[])")
# emit("aten::remove.int : (int[], int) -> ()", has_canonicalizer=True)
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
emit("aten::list.t : (t[]) -> (t[])")

View File

@ -184,3 +184,25 @@ class SqueezeDimUnitDimModule(torch.nn.Module):
module_factory=lambda: SqueezeDimUnitDimModule())
def SqueezeDimModule_unitDim(module, tu: TestUtils):
module.forward(tu.rand(1))
# ==============================================================================
class SqueezeDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 2, 1, 1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.squeeze(a, dim=[2, 3])
@register_test_case(
module_factory=lambda: SqueezeDimsModule())
def SqueezeDimsModule_unitDim(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 1, 1))