mirror of https://github.com/llvm/torch-mlir
Fold brodcast ops for llama2
parent
14e6da8588
commit
5e5d51e4b5
|
@ -2361,6 +2361,26 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
|
|||
return list.getElements()[0];
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenBroadcastToOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
||||
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
||||
return nullptr;
|
||||
if (inType.getSizes().size() != outType.getSizes().size())
|
||||
return nullptr;
|
||||
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
||||
if (inType.getSizes()[i] != outType.getSizes()[i])
|
||||
return nullptr;
|
||||
}
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenBroadcastToOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue