mirror of https://github.com/llvm/torch-mlir
Cleanup unused code.
parent
0c5f209fb5
commit
52374dbc74
|
@ -1705,40 +1705,21 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
|||
return success();
|
||||
});
|
||||
|
||||
// For AtenUnbindIntOp that are immediately followed by a call to
|
||||
// Aten__Getitem__TOp, we can simplify with a direct call to
|
||||
// AtenTensorIndexOp.
|
||||
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
|
||||
auto potentialUnbindOp = op.getOperand(0);
|
||||
auto unbindOp = potentialUnbindOp.getDefiningOp<AtenUnbindIntOp>();
|
||||
if (!unbindOp)
|
||||
return failure();
|
||||
std::cout << "XXX: Found unbind op!" << std::endl;
|
||||
|
||||
// Create an AtenTensorIntOp, representing the index into the tensor
|
||||
auto tensor = unbindOp.getOperand(0);
|
||||
////rewriter.replaceOpWithNewOp<Aten__Getitem__TOp>(op, tensor, op.getOperand(1));
|
||||
//SmallVector<Value> listElements;
|
||||
//// TODO: Conver this integer (op.getOperand(1)) to a tensor
|
||||
//int64_t items[] = {};
|
||||
//llvm::MutableArrayRef<long> arrayRefItems;
|
||||
//auto adjusted = rewriter.create<Torch::TensorStaticInfoCastOp>(
|
||||
// op->getLoc(), rewriter.getType<Torch::NonValueTensorType>(
|
||||
// arrayRefItems, op.getOperand(1).getType()), op.getOperand(1));
|
||||
//listElements.push_back(adjusted);
|
||||
|
||||
// Try AtenTensorIntOp
|
||||
//auto boolConst = rewriter.create<Torch::ConstantBoolOp>(op->getLoc(), false);
|
||||
//SmallVector<Value> listElements;
|
||||
//auto indexAsTensor = rewriter.create<Torch::AtenTensorIntOp>(
|
||||
// op->getLoc(), op.getOperand(1), op.getOperand(1).getType(), boolConst);
|
||||
//auto indexAsTensor = rewriter.create<Torch::AtenTensorIntOp>(
|
||||
// op->getLoc(), mlir::ValueRange{op.getOperand(1), boolConst});
|
||||
//listElements.push_back(indexAsTensor);
|
||||
auto boolConst = rewriter.create<Torch::ConstantBoolOp>(op->getLoc(), false);
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
SmallVector<Value> listElements;
|
||||
auto indexAsTensorType = rewriter.getType<Torch::NonValueTensorType>(
|
||||
llvm::makeArrayRef({1L}), op.getOperand(1).getType());
|
||||
auto indexAsTensorMLIRType = rewriter.getType<Torch::NonValueTensorType>(
|
||||
ArrayRef<int64_t>(), IntegerType::get(op->getContext(), 64, IntegerType::Signed));
|
||||
|
||||
auto indexAsTensor = rewriter.create<Torch::AtenTensorIntOp>(
|
||||
op->getLoc(),
|
||||
indexAsTensorMLIRType,
|
||||
|
@ -1746,34 +1727,17 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
|||
/*dtype=*/noneVal,
|
||||
/*device=*/noneVal,
|
||||
/*requiresGrad=*/boolConst);
|
||||
listElements.push_back(indexAsTensor);
|
||||
|
||||
//listElements.push_back(op.getOperand(1).cast<Torch::NonValueTensorType>());
|
||||
//listElements.push_back(op.getOperand(1));
|
||||
//for (int64_t size : type->getSizes()) {
|
||||
// listElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
// op->getLoc(), rewriter.getI64IntegerAttr(size)));
|
||||
//}
|
||||
//
|
||||
//rewriter.
|
||||
//rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(
|
||||
// op, Torch::ListType::get(rewriter.getType<Torch::ValueTensorType>()),
|
||||
// listElements);
|
||||
|
||||
// Create arguments for AtenIndexTensorOp
|
||||
SmallVector<Value> listElements;
|
||||
listElements.push_back(indexAsTensor);
|
||||
auto list = rewriter.create<Torch::PrimListConstructOp>(
|
||||
op->getLoc(), Torch::ListType::get(indexAsTensor.getType()),
|
||||
listElements);
|
||||
//auto list = rewriter.create<Torch::PrimListConstructOp>(
|
||||
// op->getLoc(), Torch::ListType::get(Torch::OptionalType::get(
|
||||
// rewriter.getType<Torch::NonValueTensorType>(llvm::makeArrayRef({1L}), tensor.getType()))),
|
||||
// listElements);
|
||||
//Value latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||
// op->getLoc(), op.getType(), op->getOperands());
|
||||
rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(
|
||||
op, op.getType(), tensor, list);
|
||||
//rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(
|
||||
// op, op.getType(), tensor, op.getOperand(1));
|
||||
//rewriter.replaceOp(op, {tensor, op.getOperand(1)});
|
||||
|
||||
// Create AtenIndexTensorOp
|
||||
rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(op, op.getType(),
|
||||
tensor, list);
|
||||
return success();
|
||||
});
|
||||
|
||||
|
|
Loading…
Reference in New Issue