[torch] Add an `aten.cat` length-0 canonicalization (#2966)

If an input is length-0 along the dimension of canonicalization we can
remove the tensor from the list
pull/2975/head
Rob Suderman 2024-03-01 21:41:12 -08:00 committed by GitHub
parent d030bffc62
commit 61f0a5facf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 46 additions and 3 deletions

View File

@ -12594,6 +12594,7 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Torch_AtenStackOp : Torch_Op<"aten.stack", [

View File

@ -2900,13 +2900,46 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
if (!list || list.getElements().size() != 1)
return nullptr;
if (list.getElements()[0].getType() != getResult().getType())
return nullptr;
return list.getElements()[0];
}
void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenCatOp op, PatternRewriter &rewriter) {
auto list = op.getTensors().getDefiningOp<PrimListConstructOp>();
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
if (!list || !resultTy)
return failure();
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return failure();
llvm::SmallVector<Value> filtered;
for (auto operand : list.getOperands()) {
auto operandTy = dyn_cast<BaseTensorType>(operand.getType());
if (!operandTy || !operandTy.hasSizes())
return failure();
int64_t adim = dim < 0 ? dim + operandTy.getSizes().size() : dim;
if (operandTy.getSizes()[adim] != 0)
filtered.push_back(operand);
}
if (filtered.size() == list.getNumOperands())
return failure();
auto newlist = rewriter.create<PrimListConstructOp>(
op.getLoc(), list.getType(), filtered);
rewriter.replaceOpWithNewOp<AtenCatOp>(op, op.getType(), newlist,
op.getDim());
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenBroadcastToOp
//===----------------------------------------------------------------------===//

View File

@ -2144,7 +2144,6 @@ ONNX_XFAIL_SET = {
"ReduceMaxUnsignedIntModule_basic",
# Failure - torch.aten.view lower
"ElementwiseFlattenBroadcastModule_basic",
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",

View File

@ -727,7 +727,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::Delete.Dict_str : (Dict(str, t), str) -> ()")
# List ops.
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True)
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::append.t : (t[], t) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)

View File

@ -2663,3 +2663,13 @@ func.func @aten_shape_to_tensor(%arg0 : !torch.vtensor<[4,5,6],f32>) -> !torch.v
return %0 : !torch.vtensor<[3],si32>
}
// -----
// CHECK-LABEL: @aten_cat_zero
func.func @aten_cat_zero(%arg0 : !torch.vtensor<[4,5,6],f32>, %arg1 : !torch.vtensor<[4,0,6],f32>) -> !torch.vtensor<[4,5,6],f32> {
// CHECK: return %arg0 : !torch.vtensor<[4,5,6],f32>
%list = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[4,5,6],f32>, !torch.vtensor<[4,0,6],f32>) -> !torch.list<vtensor>
%dim = torch.constant.int -2
%0 = torch.aten.cat %list, %dim : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,5,6],f32>
return %0 : !torch.vtensor<[4,5,6],f32>
}