pull/2824/head snapshot-20240127.1096
MaheshRavishankar 2024-01-26 18:38:44 -08:00 committed by GitHub
parent 46a25d7241
commit 28c7051ceb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 21 deletions

@ -1 +1 @@
Subproject commit eae82ac259ee5a58bc4070a414bc53239e18bad0 Subproject commit 5fcf907b34355980f77d7665a175b05fea7a6b7b

View File

@ -81,7 +81,7 @@ public:
} }
newResultTypes.push_back(type); newResultTypes.push_back(type);
} }
rewriter.updateRootInPlace(func, [&] { rewriter.modifyOpInPlace(func, [&] {
func.setType(FunctionType::get( func.setType(FunctionType::get(
getContext(), conversion.getConvertedTypes(), newResultTypes)); getContext(), conversion.getConvertedTypes(), newResultTypes));
// Clear out the type bounds, now that the type incorporates them. // Clear out the type bounds, now that the type incorporates them.
@ -194,14 +194,12 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
TypeConverter typeConverter; TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion( typeConverter.addConversion(
[](Torch::TupleType type, [](Torch::TupleType type, SmallVectorImpl<Type> &types) -> LogicalResult {
SmallVectorImpl<Type> &types) -> LogicalResult {
llvm::append_range(types, type.getContainedTypes()); llvm::append_range(types, type.getContainedTypes());
return success(); return success();
}); });
typeConverter.addConversion( typeConverter.addConversion(
[](Torch::NoneType type, [](Torch::NoneType type, SmallVectorImpl<Type> &types) -> LogicalResult {
SmallVectorImpl<Type> &types) -> LogicalResult {
return success(); return success();
}); });

View File

@ -175,7 +175,7 @@ public:
// Replace return type of view-like ops with value-semantics type variant. // Replace return type of view-like ops with value-semantics type variant.
for (Operation *viewLikeOp : ops.viewLikeOps) { for (Operation *viewLikeOp : ops.viewLikeOps) {
rewriter.updateRootInPlace(viewLikeOp, [&] { rewriter.modifyOpInPlace(viewLikeOp, [&] {
Value result = viewLikeOp->getResult(0); Value result = viewLikeOp->getResult(0);
auto resultType = result.getType().dyn_cast<NonValueTensorType>(); auto resultType = result.getType().dyn_cast<NonValueTensorType>();
if (resultType) if (resultType)
@ -337,7 +337,7 @@ public:
// correctly copy them back to their mlir::func::ReturnOp's expected types. // correctly copy them back to their mlir::func::ReturnOp's expected types.
DenseMap<Value, Type> originalTypes; DenseMap<Value, Type> originalTypes;
for (Operation *op : viewLikeOps) { for (Operation *op : viewLikeOps) {
rewriter.updateRootInPlace(op, [&]() { rewriter.modifyOpInPlace(op, [&]() {
if (auto nonValueTensorType = if (auto nonValueTensorType =
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) { op->getResult(0).getType().dyn_cast<NonValueTensorType>()) {
originalTypes[op->getResult(0)] = nonValueTensorType; originalTypes[op->getResult(0)] = nonValueTensorType;

View File

@ -9,10 +9,10 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
using namespace mlir; using namespace mlir;
@ -72,8 +72,8 @@ namespace {
// immutable tensors. // immutable tensors.
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
public: public:
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context, ConvertHasValueSemanticsOpsToValueTensors(
const std::optional<SymbolTable>& extraLibrary) MLIRContext *context, const std::optional<SymbolTable> &extraLibrary)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) { : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
this->extraLibrary = extraLibrary; this->extraLibrary = extraLibrary;
} }
@ -87,7 +87,7 @@ public:
return rewriter.notifyMatchFailure(op, "does not have value semantics"); return rewriter.notifyMatchFailure(op, "does not have value semantics");
} }
rewriter.startRootUpdate(op); rewriter.startOpModification(op);
// Convert all operands. // Convert all operands.
SmallVector<Value> newOperands; SmallVector<Value> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) { for (OpOperand &opOperand : op->getOpOperands()) {
@ -105,7 +105,7 @@ public:
auto listConstruct = auto listConstruct =
opOperand.get().getDefiningOp<PrimListConstructOp>(); opOperand.get().getDefiningOp<PrimListConstructOp>();
if (!listConstruct) { if (!listConstruct) {
rewriter.cancelRootUpdate(op); rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: list of non vtensor type not constructed " op, "unimplemented: list of non vtensor type not constructed "
"from list construct"); "from list construct");
@ -120,7 +120,7 @@ public:
if (!llvm::all_of(listConstruct.getElements(), [](Value val) { if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
return val.getType().isa<NonValueTensorType, Torch::NoneType>(); return val.getType().isa<NonValueTensorType, Torch::NoneType>();
})) { })) {
rewriter.cancelRootUpdate(op); rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: list containing optional type is not " op, "unimplemented: list containing optional type is not "
"handled."); "handled.");
@ -138,7 +138,7 @@ public:
Type newListType = getContainerOrTensorTypeWithValueSemantics(listType); Type newListType = getContainerOrTensorTypeWithValueSemantics(listType);
if (!newListType) { if (!newListType) {
rewriter.cancelRootUpdate(op); rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unable to convert list type to value semantics."); op, "Unable to convert list type to value semantics.");
} }
@ -154,7 +154,7 @@ public:
// from the non value tensor of the original optional value. // from the non value tensor of the original optional value.
auto derefine = opOperand.get().getDefiningOp<DerefineOp>(); auto derefine = opOperand.get().getDefiningOp<DerefineOp>();
if (!derefine) { if (!derefine) {
rewriter.cancelRootUpdate(op); rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: optional of non vtensor type not from " op, "unimplemented: optional of non vtensor type not from "
"derefine"); "derefine");
@ -180,9 +180,10 @@ public:
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result); rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor); result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
} }
rewriter.finalizeRootUpdate(op); rewriter.finalizeOpModification(op);
return success(); return success();
} }
private: private:
std::optional<SymbolTable> extraLibrary; std::optional<SymbolTable> extraLibrary;
}; };
@ -290,9 +291,9 @@ public:
Operation *newOp = rewriter.create(state); Operation *newOp = rewriter.create(state);
// Note: need to convert result to first input's dtype because mix precision // Note: need to convert result to first input's dtype because mix precision
// compute would result in different behaviors. // compute would result in different behaviors.
// For example: // For example:
// a = torch.randn(3, 3).half() # float16 // a = torch.randn(3, 3).half() # float16
// b = torch.randn(3, 3) # float32 // b = torch.randn(3, 3) # float32
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16 // a += b # i.e. torch.ops.aten.add_(a, b), result is float16
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32 // c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
Value none = rewriter.create<ConstantNoneOp>(op->getLoc()); Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
@ -300,7 +301,8 @@ public:
auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0)); auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0));
auto toDtype = rewriter.create<AtenToDtypeOp>( auto toDtype = rewriter.create<AtenToDtypeOp>(
op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0), op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0),
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype); auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype);
createOverwriteTensorContents(rewriter, op->getLoc(), tensor, createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
op->getOperand(0)); op->getOperand(0));