mirror of https://github.com/llvm/torch-mlir
Cleanup unused code.
parent
0c5f209fb5
commit
52374dbc74
|
@ -1705,40 +1705,21 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// For AtenUnbindIntOp that are immediately followed by a call to
|
||||||
|
// Aten__Getitem__TOp, we can simplify with a direct call to
|
||||||
|
// AtenTensorIndexOp.
|
||||||
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
|
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
|
||||||
auto potentialUnbindOp = op.getOperand(0);
|
auto potentialUnbindOp = op.getOperand(0);
|
||||||
auto unbindOp = potentialUnbindOp.getDefiningOp<AtenUnbindIntOp>();
|
auto unbindOp = potentialUnbindOp.getDefiningOp<AtenUnbindIntOp>();
|
||||||
if (!unbindOp)
|
if (!unbindOp)
|
||||||
return failure();
|
return failure();
|
||||||
std::cout << "XXX: Found unbind op!" << std::endl;
|
|
||||||
|
|
||||||
|
// Create an AtenTensorIntOp, representing the index into the tensor
|
||||||
auto tensor = unbindOp.getOperand(0);
|
auto tensor = unbindOp.getOperand(0);
|
||||||
////rewriter.replaceOpWithNewOp<Aten__Getitem__TOp>(op, tensor, op.getOperand(1));
|
|
||||||
//SmallVector<Value> listElements;
|
|
||||||
//// TODO: Conver this integer (op.getOperand(1)) to a tensor
|
|
||||||
//int64_t items[] = {};
|
|
||||||
//llvm::MutableArrayRef<long> arrayRefItems;
|
|
||||||
//auto adjusted = rewriter.create<Torch::TensorStaticInfoCastOp>(
|
|
||||||
// op->getLoc(), rewriter.getType<Torch::NonValueTensorType>(
|
|
||||||
// arrayRefItems, op.getOperand(1).getType()), op.getOperand(1));
|
|
||||||
//listElements.push_back(adjusted);
|
|
||||||
|
|
||||||
// Try AtenTensorIntOp
|
|
||||||
//auto boolConst = rewriter.create<Torch::ConstantBoolOp>(op->getLoc(), false);
|
|
||||||
//SmallVector<Value> listElements;
|
|
||||||
//auto indexAsTensor = rewriter.create<Torch::AtenTensorIntOp>(
|
|
||||||
// op->getLoc(), op.getOperand(1), op.getOperand(1).getType(), boolConst);
|
|
||||||
//auto indexAsTensor = rewriter.create<Torch::AtenTensorIntOp>(
|
|
||||||
// op->getLoc(), mlir::ValueRange{op.getOperand(1), boolConst});
|
|
||||||
//listElements.push_back(indexAsTensor);
|
|
||||||
auto boolConst = rewriter.create<Torch::ConstantBoolOp>(op->getLoc(), false);
|
auto boolConst = rewriter.create<Torch::ConstantBoolOp>(op->getLoc(), false);
|
||||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||||
SmallVector<Value> listElements;
|
|
||||||
auto indexAsTensorType = rewriter.getType<Torch::NonValueTensorType>(
|
|
||||||
llvm::makeArrayRef({1L}), op.getOperand(1).getType());
|
|
||||||
auto indexAsTensorMLIRType = rewriter.getType<Torch::NonValueTensorType>(
|
auto indexAsTensorMLIRType = rewriter.getType<Torch::NonValueTensorType>(
|
||||||
ArrayRef<int64_t>(), IntegerType::get(op->getContext(), 64, IntegerType::Signed));
|
ArrayRef<int64_t>(), IntegerType::get(op->getContext(), 64, IntegerType::Signed));
|
||||||
|
|
||||||
auto indexAsTensor = rewriter.create<Torch::AtenTensorIntOp>(
|
auto indexAsTensor = rewriter.create<Torch::AtenTensorIntOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
indexAsTensorMLIRType,
|
indexAsTensorMLIRType,
|
||||||
|
@ -1746,34 +1727,17 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
||||||
/*dtype=*/noneVal,
|
/*dtype=*/noneVal,
|
||||||
/*device=*/noneVal,
|
/*device=*/noneVal,
|
||||||
/*requiresGrad=*/boolConst);
|
/*requiresGrad=*/boolConst);
|
||||||
|
|
||||||
|
// Create arguments for AtenIndexTensorOp
|
||||||
|
SmallVector<Value> listElements;
|
||||||
listElements.push_back(indexAsTensor);
|
listElements.push_back(indexAsTensor);
|
||||||
|
|
||||||
//listElements.push_back(op.getOperand(1).cast<Torch::NonValueTensorType>());
|
|
||||||
//listElements.push_back(op.getOperand(1));
|
|
||||||
//for (int64_t size : type->getSizes()) {
|
|
||||||
// listElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
// op->getLoc(), rewriter.getI64IntegerAttr(size)));
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//rewriter.
|
|
||||||
//rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(
|
|
||||||
// op, Torch::ListType::get(rewriter.getType<Torch::ValueTensorType>()),
|
|
||||||
// listElements);
|
|
||||||
|
|
||||||
auto list = rewriter.create<Torch::PrimListConstructOp>(
|
auto list = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
op->getLoc(), Torch::ListType::get(indexAsTensor.getType()),
|
op->getLoc(), Torch::ListType::get(indexAsTensor.getType()),
|
||||||
listElements);
|
listElements);
|
||||||
//auto list = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
// op->getLoc(), Torch::ListType::get(Torch::OptionalType::get(
|
// Create AtenIndexTensorOp
|
||||||
// rewriter.getType<Torch::NonValueTensorType>(llvm::makeArrayRef({1L}), tensor.getType()))),
|
rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(op, op.getType(),
|
||||||
// listElements);
|
tensor, list);
|
||||||
//Value latestLiteral = rewriter.create<PrimListConstructOp>(
|
|
||||||
// op->getLoc(), op.getType(), op->getOperands());
|
|
||||||
rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(
|
|
||||||
op, op.getType(), tensor, list);
|
|
||||||
//rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(
|
|
||||||
// op, op.getType(), tensor, op.getOperand(1));
|
|
||||||
//rewriter.replaceOp(op, {tensor, op.getOperand(1)});
|
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue