Fix illegal use of TypeRange (#2815)

TypeRange is an ArrayRef<Type> and therefore cannot be safely
instantiated from a list initializer.
pull/2812/head
Rob Suderman 2024-01-29 09:23:05 -08:00 committed by GitHub
parent 032f225fa5
commit 67cb2e7341
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 6 deletions

View File

@ -2615,11 +2615,9 @@ namespace {
LogicalResult matchAndRewrite(AtenConvTbcOp op, LogicalResult matchAndRewrite(AtenConvTbcOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value emptyList = rewriter.create<PrimListConstructOp>( Value emptyList = rewriter.create<PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), op.getLoc(),
Torch::ListType::get(Torch::IntType::get(op.getContext())),
SmallVector<Value>()); SmallVector<Value>());
Value zeroList = rewriter.create<PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
SmallVector<Value>{rewriter.create<Torch::ConstantIntOp>(op.getLoc(), rewriter.getI64IntegerAttr(0))});
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false); Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
Value oneList = rewriter.create<PrimListConstructOp>( Value oneList = rewriter.create<PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
@ -5406,8 +5404,8 @@ public:
auto resultType2 = op->getResult(2).getType(); auto resultType2 = op->getResult(2).getType();
auto resultType3 = op->getResult(3).getType(); auto resultType3 = op->getResult(3).getType();
mlir::TypeRange returnTypes{resultType0, resultType1, resultType2, llvm::SmallVector<Type> returnTypes{resultType0, resultType1, resultType2,
resultType3}; resultType3};
rewriter.replaceOpWithNewOp<AtenEmbeddingBagPaddingIdxOp>( rewriter.replaceOpWithNewOp<AtenEmbeddingBagPaddingIdxOp>(
op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode, op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode,