diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 1de407a8a..9c2cfcfb3 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8356,6 +8356,7 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a7279d347..c2bedd5f6 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2317,6 +2317,25 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { return list.getElements()[0]; } +//===----------------------------------------------------------------------===// +// AtenBroadcastToOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { + auto inType = getOperand(0).getType().dyn_cast(); + auto outType = getResult().getType().dyn_cast(); + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) + return nullptr; + if (inType.getSizes().size() != outType.getSizes().size() || + !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) + return nullptr; + for (size_t i = 0; i < inType.getSizes().size(); ++i) { + if (inType.getSizes()[i] != outType.getSizes()[i]) + return nullptr; + } + return getOperand(0); +} + //===----------------------------------------------------------------------===// // AtenSliceTensorOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 0d2ae9af8..c192dacc4 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -542,7 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") - emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)") + emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 88b73ed5c..21e0500f4 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1971,6 +1971,18 @@ func.func @torch.aten.cat$fold_single_operand(%arg0: !torch.tensor) -> !torch.te return %1: !torch.tensor } +// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32> +func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[3,4,2],f32>, !torch.list -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> +} + // CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice // CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32> // CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>