diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index aed453a62..c5b96269b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2361,6 +2361,26 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { return list.getElements()[0]; } + +//===----------------------------------------------------------------------===// +// AtenBroadcastToOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { + auto inType = getOperand(0).getType().dyn_cast(); + auto outType = getResult().getType().dyn_cast(); + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) + return nullptr; + if (inType.getSizes().size() != outType.getSizes().size()) + return nullptr; + for (size_t i = 0; i < inType.getSizes().size(); ++i) { + if (inType.getSizes()[i] != outType.getSizes()[i]) + return nullptr; + } + return getOperand(0); +} + + //===----------------------------------------------------------------------===// // AtenBroadcastToOp //===----------------------------------------------------------------------===//