mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] add fold pattern for aten.clone (#2804)
parent
25a5a22cbd
commit
d778950f45
|
@ -9101,6 +9101,7 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenLiftFreshCopyOp : Torch_Op<"aten.lift_fresh_copy", [
|
||||
|
|
|
@ -1763,7 +1763,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context)
|
||||
INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp);
|
||||
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
|
||||
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
|
||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
|
||||
|
|
|
@ -1662,6 +1662,19 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenCloneOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) {
|
||||
// note: memory_format would be ignored
|
||||
if (llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
|
||||
// self should have value semantics
|
||||
return getSelf();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSortIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1021,6 +1021,7 @@ TOSA_PASS_SET = {
|
|||
"BroadcastZeroRankInputStaticModule_basic",
|
||||
"BucketizeTensorStaticFloatModule_basic",
|
||||
"BucketizeTensorStaticModule_basic",
|
||||
"CloneModule_basic",
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"ConstantBoolParameterModule_basic",
|
||||
|
|
|
@ -592,7 +592,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)")
|
||||
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
|
||||
emit("aten::clone : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::clone : (Tensor, int?) -> (Tensor)", has_folder=True)
|
||||
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
|
||||
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")
|
||||
|
|
|
@ -4994,3 +4994,23 @@ class IscloseStaticModuleTrue(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: IscloseStaticModuleTrue())
|
||||
def IscloseStaticModuleTrue_basic(module, tu: TestUtils):
|
||||
module.forward(torch.ones(5, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class CloneModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.clone(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: CloneModule())
|
||||
def CloneModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 5))
|
||||
|
|
|
@ -1,21 +1,6 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.clone$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.clone %arg0, %none : !torch.vtensor<[?,?],f32>, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
||||
|
|
Loading…
Reference in New Issue