mirror of https://github.com/llvm/torch-mlir
[torch] Add an `aten.cat` length-0 canonicalization (#2966)
If an input is length-0 along the dimension of canonicalization we can remove the tensor from the listpull/2975/head
parent
d030bffc62
commit
61f0a5facf
|
@ -12594,6 +12594,7 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [
|
|||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenStackOp : Torch_Op<"aten.stack", [
|
||||
|
|
|
@ -2900,13 +2900,46 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
||||
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
||||
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
|
||||
if (!list || list.getElements().size() != 1)
|
||||
return nullptr;
|
||||
if (list.getElements()[0].getType() != getResult().getType())
|
||||
return nullptr;
|
||||
return list.getElements()[0];
|
||||
}
|
||||
|
||||
void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenCatOp op, PatternRewriter &rewriter) {
|
||||
auto list = op.getTensors().getDefiningOp<PrimListConstructOp>();
|
||||
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
|
||||
if (!list || !resultTy)
|
||||
return failure();
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return failure();
|
||||
|
||||
llvm::SmallVector<Value> filtered;
|
||||
for (auto operand : list.getOperands()) {
|
||||
auto operandTy = dyn_cast<BaseTensorType>(operand.getType());
|
||||
if (!operandTy || !operandTy.hasSizes())
|
||||
return failure();
|
||||
int64_t adim = dim < 0 ? dim + operandTy.getSizes().size() : dim;
|
||||
if (operandTy.getSizes()[adim] != 0)
|
||||
filtered.push_back(operand);
|
||||
}
|
||||
|
||||
if (filtered.size() == list.getNumOperands())
|
||||
return failure();
|
||||
|
||||
auto newlist = rewriter.create<PrimListConstructOp>(
|
||||
op.getLoc(), list.getType(), filtered);
|
||||
rewriter.replaceOpWithNewOp<AtenCatOp>(op, op.getType(), newlist,
|
||||
op.getDim());
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenBroadcastToOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -2144,7 +2144,6 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceMaxUnsignedIntModule_basic",
|
||||
|
||||
# Failure - torch.aten.view lower
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
|
|
|
@ -727,7 +727,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::Delete.Dict_str : (Dict(str, t), str) -> ()")
|
||||
|
||||
# List ops.
|
||||
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||
emit("aten::append.t : (t[], t) -> (t[])")
|
||||
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
||||
|
|
|
@ -2663,3 +2663,13 @@ func.func @aten_shape_to_tensor(%arg0 : !torch.vtensor<[4,5,6],f32>) -> !torch.v
|
|||
return %0 : !torch.vtensor<[3],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @aten_cat_zero
|
||||
func.func @aten_cat_zero(%arg0 : !torch.vtensor<[4,5,6],f32>, %arg1 : !torch.vtensor<[4,0,6],f32>) -> !torch.vtensor<[4,5,6],f32> {
|
||||
// CHECK: return %arg0 : !torch.vtensor<[4,5,6],f32>
|
||||
%list = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[4,5,6],f32>, !torch.vtensor<[4,0,6],f32>) -> !torch.list<vtensor>
|
||||
%dim = torch.constant.int -2
|
||||
%0 = torch.aten.cat %list, %dim : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,5,6],f32>
|
||||
return %0 : !torch.vtensor<[4,5,6],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue