mirror of https://github.com/llvm/torch-mlir
add squeeze dims op
parent
e510a9b5ec
commit
102058dc70
|
@ -5706,6 +5706,30 @@ def Torch_AtenSqueezeOp : Torch_Op<"aten.squeeze", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenSqueezeDimsOp : Torch_Op<"aten.squeeze.dims", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::squeeze.dims : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSqueezeDimsOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenSqueezeDimsOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -679,6 +679,18 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSqueezeDimsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSqueezeDimsOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||
return getOperand(0);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenRoundOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1443,6 +1455,41 @@ void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenRemoveIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// void AtenRemoveIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
// MLIRContext *context) {
|
||||
// patterns.add(+[](AtenRemoveIntOp op, PatternRewriter &rewriter) {
|
||||
// SmallVector<int64_t> listElements;
|
||||
// if (!matchPattern(op.getSelf(), m_TorchListOfConstantInts(listElements)))
|
||||
// return rewriter.notifyMatchFailure(
|
||||
// op, "all input list elements must be constant ints");
|
||||
// int64_t elementToBeRemoved;
|
||||
// if (!matchPattern(op.getEl(), m_TorchConstantInt(&elementToBeRemoved)))
|
||||
// return rewriter.notifyMatchFailure(
|
||||
// op, "Expected element to be removed a constant int.");
|
||||
|
||||
// listElements.erase(listElements.begin(),
|
||||
// std::find(listElements.begin(), listElements.end(),
|
||||
// elementToBeRemoved));
|
||||
|
||||
// SmallVector<Value> updatedListElements;
|
||||
// for (int64_t elem : listElements)
|
||||
// updatedListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
// op->getLoc(), rewriter.getI64IntegerAttr(elem)));
|
||||
// Value result = rewriter.create<Torch::PrimListConstructOp>(
|
||||
// op->getLoc(), Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
||||
// updatedListElements);
|
||||
|
||||
// llvm::outs()<<"*****************************\n";
|
||||
// op.getSelf().replaceAllUsesWith(result);
|
||||
// rewriter.eraseOp(op);
|
||||
// return success();
|
||||
// });
|
||||
// }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NonValueTensorLiteralOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2407,6 +2454,29 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
|
|||
std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimMaxSelfIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// OpFoldResult PrimMaxSelfIntOp::fold(FoldAdaptor adaptor) {
|
||||
// auto list = getOperand().getDefiningOp<PrimListConstructOp>();
|
||||
// if (!list)
|
||||
// return nullptr;
|
||||
// // TODO: What does it return for an empty list?
|
||||
// if (list->getNumOperands() == 0)
|
||||
// return nullptr;
|
||||
|
||||
// SmallVector<int64_t> values;
|
||||
// for (auto operand : list->getOperands()) {
|
||||
// int64_t value;
|
||||
// if (!matchPattern(operand, m_TorchConstantInt(&value)))
|
||||
// return nullptr;
|
||||
// values.push_back(value);
|
||||
// }
|
||||
// return getI64IntegerAttr(getContext(),
|
||||
// *std::max_element(values.begin(), values.end()));
|
||||
// }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimMinSelfIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -7075,6 +7075,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.squeeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.squeeze.dims\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.prim.Loop %0, %true, init(%arg0) {\n"
|
||||
" ^bb0(%arg2: !torch.int, %arg3: !torch.list<int>):\n"
|
||||
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %3 = torch.aten.sub.int %2, %arg2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.aten.__getitem__.t %arg1, %4 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %6 = func.call @__torch__.torch.jit._shape_functions.squeeze(%arg3, %5) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter(%6 : !torch.list<int>)\n"
|
||||
" } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.squeeze.dims\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.int {\n"
|
||||
" return %arg1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
|
|
@ -4159,6 +4159,32 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSqueezeDimsOp : public OpRewritePattern<AtenSqueezeDimsOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSqueezeDimsOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
SmallVector<int64_t> dimListInts;
|
||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimListInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all dim list elements must be constant ints");
|
||||
for (unsigned i = dimListInts.size(); i > 0; i--) {
|
||||
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dimListInts[i - 1]));
|
||||
BaseTensorType selfType = self.getType().cast<BaseTensorType>();
|
||||
Type squeezedType = computeReductionType(rewriter, op, selfType, cstDim,
|
||||
/*keepDim=*/false);
|
||||
self = rewriter.create<AtenSqueezeDimOp>(loc, squeezedType, self, cstDim);
|
||||
}
|
||||
rewriter.replaceOp(op, self);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -4321,6 +4347,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSqueezeDimsOp>(patterns);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
|
|
@ -446,6 +446,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
||||
target.addIllegalOp<AtenBucketizeTensorOp>();
|
||||
target.addIllegalOp<AtenMovedimIntOp>();
|
||||
target.addIllegalOp<AtenSqueezeDimsOp>();
|
||||
for (std::string opName : backendLegalOps) {
|
||||
target.addLegalOp(OperationName(opName, context));
|
||||
}
|
||||
|
|
|
@ -186,6 +186,22 @@ public:
|
|||
listLiterals.push_back(runningList);
|
||||
continue;
|
||||
}
|
||||
// if (auto remove = dyn_cast<AtenRemoveIntOp>(user)) {
|
||||
// if (!remove.use_empty())
|
||||
// return rewriter.notifyMatchFailure(
|
||||
// op, "Expected `AtenInsertTOp` to not have users");
|
||||
// int64_t element;
|
||||
// if (!matchPattern(remove.getEl(), m_TorchConstantInt(&element)))
|
||||
// return rewriter.notifyMatchFailure(
|
||||
// op, "Expected `element` to be removed a constant int");
|
||||
// if (remove.getSelf() == op) {
|
||||
// llvm::erase_value(runningList, remove.getEl());
|
||||
// generatedNewLiteral = true;
|
||||
// llvm::outs()<<"*****************************\n";
|
||||
// }
|
||||
// listLiterals.push_back(runningList);
|
||||
// continue;
|
||||
// }
|
||||
// If this user potentially mutates the list and isn't handled above, then
|
||||
// we can't abstractly interpret any further.
|
||||
if (potentiallyMutatesListOperands(user))
|
||||
|
@ -224,6 +240,14 @@ public:
|
|||
rewriter.eraseOp(setItem);
|
||||
continue;
|
||||
}
|
||||
// if (auto removeInt = dyn_cast<AtenRemoveIntOp>(user)) {
|
||||
// rewriter.setInsertionPoint(removeInt);
|
||||
// latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||
// removeInt->getLoc(), op.getType(), listLiterals[nextLiteral++]);
|
||||
// if (removeInt.getSelf() == op)
|
||||
// rewriter.eraseOp(removeInt);
|
||||
// continue;
|
||||
// }
|
||||
for (OpOperand &opOperand : user->getOpOperands()) {
|
||||
if (opOperand.get() == op.getResult()) {
|
||||
opOperand.set(latestLiteral);
|
||||
|
|
|
@ -200,8 +200,8 @@ bool Torch::isViewLikeOp(Operation *op) {
|
|||
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
|
||||
AtenExpandOp, AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
|
||||
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
|
||||
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
|
||||
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
|
||||
AtenSqueezeDimOp, AtenSqueezeOp, AtenSqueezeDimsOp, AtenTOp,
|
||||
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
|
||||
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
|
||||
AtenNarrowOp, AtenToDeviceOp, AtenMovedimIntOp>(op);
|
||||
}
|
||||
|
|
|
@ -756,6 +756,29 @@ def aten〇squeeze〡shape(self: List[int]) -> List[int]:
|
|||
def aten〇squeeze〇dim〡shape(self: List[int], dim: int) -> List[int]:
|
||||
return upstream_shape_functions.squeeze(self, dim)
|
||||
|
||||
def aten〇squeeze〇dims〡shape(self: List[int], dim: List[int]) -> List[int]:
|
||||
# wrapped_dims : List[int] = []
|
||||
# wrapped_dims = upstream_shape_functions._copy(dim)
|
||||
# out: List[int] = self
|
||||
for i in range(len(dim)):
|
||||
self = upstream_shape_functions.squeeze(self, dim[len(dim) - i - 1])
|
||||
# sorted_dims = upstream_shape_functions._copy(dim)
|
||||
# sorted_dims.sort(reverse=True)
|
||||
|
||||
# result_shape = upstream_shape_functions._copy(self)
|
||||
# for i in range(len(wrapped_dims)):
|
||||
# curr_dim = max(wrapped_dims)
|
||||
# wrapped_dims.remove(curr_dim)
|
||||
# # for ele in wrapped_dims:
|
||||
|
||||
# # wrapped_dims = [ele for ele in wrapped_dims if ele != curr_dim]
|
||||
# # wrapped_dims.remove(curr_dim)
|
||||
# result_shape = upstream_shape_functions.squeeze(result_shape, curr_dim)
|
||||
return self
|
||||
|
||||
def aten〇squeeze〇dims〡dtype(self_rank: int, self_dtype: int, dim: List[int]) -> int:
|
||||
return self_dtype
|
||||
|
||||
def prim〇NumToTensor〇Scalar〡shape(a: float) -> List[int]:
|
||||
return []
|
||||
|
||||
|
|
|
@ -436,6 +436,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)")
|
||||
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("aten::squeeze.dims : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
|
||||
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
||||
|
@ -565,6 +566,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::stack : (Tensor[], int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::append.t : (t[], t) -> (t[])")
|
||||
# emit("aten::remove.int : (int[], int) -> ()", has_canonicalizer=True)
|
||||
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
||||
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
|
||||
emit("aten::list.t : (t[]) -> (t[])")
|
||||
|
|
|
@ -184,3 +184,25 @@ class SqueezeDimUnitDimModule(torch.nn.Module):
|
|||
module_factory=lambda: SqueezeDimUnitDimModule())
|
||||
def SqueezeDimModule_unitDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SqueezeDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 2, 1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.squeeze(a, dim=[2, 3])
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimsModule())
|
||||
def SqueezeDimsModule_unitDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 1, 1))
|
||||
|
|
Loading…
Reference in New Issue