mirror of https://github.com/llvm/torch-mlir
Fold brodcast ops for llama2
parent
14e6da8588
commit
5e5d51e4b5
|
@ -2361,6 +2361,26 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
|
||||||
return list.getElements()[0];
|
return list.getElements()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenBroadcastToOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
||||||
|
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||||
|
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
||||||
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
||||||
|
return nullptr;
|
||||||
|
if (inType.getSizes().size() != outType.getSizes().size())
|
||||||
|
return nullptr;
|
||||||
|
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
||||||
|
if (inType.getSizes()[i] != outType.getSizes()[i])
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return getOperand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenBroadcastToOp
|
// AtenBroadcastToOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue