[Torch Dialect] add fold pattern for aten.clone (#2804)

rm_obsolete_build_automation
Yuanqiang Liu 2024-01-31 09:43:21 +08:00 committed by GitHub
parent 25a5a22cbd
commit d778950f45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 36 additions and 17 deletions

View File

@ -9101,6 +9101,7 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenLiftFreshCopyOp : Torch_Op<"aten.lift_fresh_copy", [

View File

@ -1763,7 +1763,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context)
INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp);
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);

View File

@ -1662,6 +1662,19 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
});
}
//===----------------------------------------------------------------------===//
// AtenCloneOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) {
// note: memory_format would be ignored
if (llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
// self should have value semantics
return getSelf();
}
return {};
}
//===----------------------------------------------------------------------===//
// AtenSortIntOp
//===----------------------------------------------------------------------===//

View File

@ -1021,6 +1021,7 @@ TOSA_PASS_SET = {
"BroadcastZeroRankInputStaticModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CloneModule_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"ConstantBoolParameterModule_basic",

View File

@ -592,7 +592,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)", has_folder=True)
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")

View File

@ -4994,3 +4994,23 @@ class IscloseStaticModuleTrue(torch.nn.Module):
@register_test_case(module_factory=lambda: IscloseStaticModuleTrue())
def IscloseStaticModuleTrue_basic(module, tu: TestUtils):
module.forward(torch.ones(5, 5))
# ==============================================================================
class CloneModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([5, 5], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.clone(x)
@register_test_case(module_factory=lambda: CloneModule())
def CloneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5))

View File

@ -1,21 +1,6 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// -----
// CHECK-LABEL: func.func @torch.aten.clone$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%none = torch.constant.none
%0 = torch.aten.clone %arg0, %none : !torch.vtensor<[?,?],f32>, !torch.none -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {