Rob Suderman 2024-04-29 12:01:40 -07:00 committed by GitHub
parent 087fea0608
commit db6721084a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 4 additions and 4 deletions

@ -1 +1 @@
Subproject commit a952c123880eb1168f1021b116485e27170d48ca
Subproject commit 593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf

View File

@ -30,8 +30,6 @@ namespace detail {
LogicalResult verifyTMTensorOpInterface(Operation *op);
}
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
/// Include the generated interface declarations.
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export
@ -39,4 +37,6 @@ LogicalResult verifyTMTensorOpInterface(Operation *op);
} // namespace torch
} // namespace mlir
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_

View File

@ -936,7 +936,7 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
if (opOperand->get().isa<BlockArgument>())
if (isa<BlockArgument>(opOperand->get()))
return false;
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);