Cleanup unused code.

rewrite-getitem
Daniel Ellis 2022-10-25 18:58:49 +00:00
parent 0c5f209fb5
commit 52374dbc74
1 changed files with 11 additions and 47 deletions

View File

@ -1705,40 +1705,21 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
return success();
});
// For AtenUnbindIntOp that are immediately followed by a call to
// Aten__Getitem__TOp, we can simplify with a direct call to
// AtenTensorIndexOp.
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
auto potentialUnbindOp = op.getOperand(0);
auto unbindOp = potentialUnbindOp.getDefiningOp<AtenUnbindIntOp>();
if (!unbindOp)
return failure();
std::cout << "XXX: Found unbind op!" << std::endl;
// Create an AtenTensorIntOp, representing the index into the tensor
auto tensor = unbindOp.getOperand(0);
////rewriter.replaceOpWithNewOp<Aten__Getitem__TOp>(op, tensor, op.getOperand(1));
//SmallVector<Value> listElements;
//// TODO: Conver this integer (op.getOperand(1)) to a tensor
//int64_t items[] = {};
//llvm::MutableArrayRef<long> arrayRefItems;
//auto adjusted = rewriter.create<Torch::TensorStaticInfoCastOp>(
// op->getLoc(), rewriter.getType<Torch::NonValueTensorType>(
// arrayRefItems, op.getOperand(1).getType()), op.getOperand(1));
//listElements.push_back(adjusted);
// Try AtenTensorIntOp
//auto boolConst = rewriter.create<Torch::ConstantBoolOp>(op->getLoc(), false);
//SmallVector<Value> listElements;
//auto indexAsTensor = rewriter.create<Torch::AtenTensorIntOp>(
// op->getLoc(), op.getOperand(1), op.getOperand(1).getType(), boolConst);
//auto indexAsTensor = rewriter.create<Torch::AtenTensorIntOp>(
// op->getLoc(), mlir::ValueRange{op.getOperand(1), boolConst});
//listElements.push_back(indexAsTensor);
auto boolConst = rewriter.create<Torch::ConstantBoolOp>(op->getLoc(), false);
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
SmallVector<Value> listElements;
auto indexAsTensorType = rewriter.getType<Torch::NonValueTensorType>(
llvm::makeArrayRef({1L}), op.getOperand(1).getType());
auto indexAsTensorMLIRType = rewriter.getType<Torch::NonValueTensorType>(
ArrayRef<int64_t>(), IntegerType::get(op->getContext(), 64, IntegerType::Signed));
auto indexAsTensor = rewriter.create<Torch::AtenTensorIntOp>(
op->getLoc(),
indexAsTensorMLIRType,
@ -1746,34 +1727,17 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
/*dtype=*/noneVal,
/*device=*/noneVal,
/*requiresGrad=*/boolConst);
listElements.push_back(indexAsTensor);
//listElements.push_back(op.getOperand(1).cast<Torch::NonValueTensorType>());
//listElements.push_back(op.getOperand(1));
//for (int64_t size : type->getSizes()) {
// listElements.push_back(rewriter.create<Torch::ConstantIntOp>(
// op->getLoc(), rewriter.getI64IntegerAttr(size)));
//}
//
//rewriter.
//rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(
// op, Torch::ListType::get(rewriter.getType<Torch::ValueTensorType>()),
// listElements);
// Create arguments for AtenIndexTensorOp
SmallVector<Value> listElements;
listElements.push_back(indexAsTensor);
auto list = rewriter.create<Torch::PrimListConstructOp>(
op->getLoc(), Torch::ListType::get(indexAsTensor.getType()),
listElements);
//auto list = rewriter.create<Torch::PrimListConstructOp>(
// op->getLoc(), Torch::ListType::get(Torch::OptionalType::get(
// rewriter.getType<Torch::NonValueTensorType>(llvm::makeArrayRef({1L}), tensor.getType()))),
// listElements);
//Value latestLiteral = rewriter.create<PrimListConstructOp>(
// op->getLoc(), op.getType(), op->getOperands());
rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(
op, op.getType(), tensor, list);
//rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(
// op, op.getType(), tensor, op.getOperand(1));
//rewriter.replaceOp(op, {tensor, op.getOperand(1)});
// Create AtenIndexTensorOp
rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(op, op.getType(),
tensor, list);
return success();
});