From 5e5d51e4b5c0d44921ef93d26add6a6bc67ad02d Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 28 Aug 2023 12:42:11 +0000 Subject: [PATCH] Fold brodcast ops for llama2 --- lib/Dialect/Torch/IR/TorchOps.cpp | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index aed453a62..c5b96269b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2361,6 +2361,26 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { return list.getElements()[0]; } + +//===----------------------------------------------------------------------===// +// AtenBroadcastToOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { + auto inType = getOperand(0).getType().dyn_cast(); + auto outType = getResult().getType().dyn_cast(); + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) + return nullptr; + if (inType.getSizes().size() != outType.getSizes().size()) + return nullptr; + for (size_t i = 0; i < inType.getSizes().size(); ++i) { + if (inType.getSizes()[i] != outType.getSizes()[i]) + return nullptr; + } + return getOperand(0); +} + + //===----------------------------------------------------------------------===// // AtenBroadcastToOp //===----------------------------------------------------------------------===//