From 1cb14f6879914f84c4e9fcae9a6af550f77be953 Mon Sep 17 00:00:00 2001 From: Dave Liddell <44620210+daveliddell@users.noreply.github.com> Date: Mon, 5 Feb 2024 18:10:42 -0700 Subject: [PATCH] Rob's atenTensor folder (#2867) If a tensor is initialized by a list with a single constant integer, this folder turns it into a torch.vtensor.literal --------- Co-authored-by: Dave Liddell --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 21 +++++++++++++++++++ .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 11 ++++++++++ 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index a0ec9663b..fad589576 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8582,6 +8582,7 @@ def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [ printDefaultTorchOp(printer, *this, 4, 1); } }]; + let hasFolder = 1; } def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 98de4f85b..c557d2595 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2758,6 +2758,27 @@ void AtenDeviceWithIndexOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { + // If a torch.aten.tensor op is initialized by a list with a constant, single + // element, fold it into a torch.vtensor.literal + auto resultTy = dyn_cast(getType()); + Type eTy = resultTy.getDtype(); + ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + + SmallVector data; + if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && + data.size() == 1) { + Attribute attribute = IntegerAttr::get(eTy, data[0]); + return DenseElementsAttr::get(shapedTy, attribute); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 41a297ba6..cb9c484b7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -570,7 +570,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)") + emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)", has_folder=True) emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index cb2ec2d14..83055f3be 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1461,6 +1461,17 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to return %0 : !torch.tensor<[],f32> } +// CHECK-LABEL: func.func @torch.aten.tensor$one_elem( +// CHECK-NEXT: torch.vtensor.literal(dense<42> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) { + %none = torch.constant.none + %false = torch.constant.bool false + %int42 = torch.constant.int 42 + %66 = torch.prim.ListConstruct %int42 : (!torch.int) -> !torch.list + %67 = torch.aten.tensor %66, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + return %67 : !torch.vtensor<[1],si64> +} + // CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> { // CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32>