Fold brodcast ops for llama2

updated_bcast
Vivek Khandelwal 2023-08-28 12:42:11 +00:00 committed by dan
parent 14e6da8588
commit 5e5d51e4b5
1 changed files with 20 additions and 0 deletions

View File

@ -2361,6 +2361,26 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
return list.getElements()[0];
}
//===----------------------------------------------------------------------===//
// AtenBroadcastToOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size())
return nullptr;
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
if (inType.getSizes()[i] != outType.getSizes()[i])
return nullptr;
}
return getOperand(0);
}
//===----------------------------------------------------------------------===//
// AtenBroadcastToOp
//===----------------------------------------------------------------------===//