pull/2824/head snapshot-20240127.1096
MaheshRavishankar 2024-01-26 18:38:44 -08:00 committed by GitHub
parent 46a25d7241
commit 28c7051ceb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 21 deletions

@ -1 +1 @@
Subproject commit eae82ac259ee5a58bc4070a414bc53239e18bad0
Subproject commit 5fcf907b34355980f77d7665a175b05fea7a6b7b

View File

@ -81,7 +81,7 @@ public:
}
newResultTypes.push_back(type);
}
rewriter.updateRootInPlace(func, [&] {
rewriter.modifyOpInPlace(func, [&] {
func.setType(FunctionType::get(
getContext(), conversion.getConvertedTypes(), newResultTypes));
// Clear out the type bounds, now that the type incorporates them.
@ -194,14 +194,12 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
[](Torch::TupleType type,
SmallVectorImpl<Type> &types) -> LogicalResult {
[](Torch::TupleType type, SmallVectorImpl<Type> &types) -> LogicalResult {
llvm::append_range(types, type.getContainedTypes());
return success();
});
typeConverter.addConversion(
[](Torch::NoneType type,
SmallVectorImpl<Type> &types) -> LogicalResult {
[](Torch::NoneType type, SmallVectorImpl<Type> &types) -> LogicalResult {
return success();
});

View File

@ -175,7 +175,7 @@ public:
// Replace return type of view-like ops with value-semantics type variant.
for (Operation *viewLikeOp : ops.viewLikeOps) {
rewriter.updateRootInPlace(viewLikeOp, [&] {
rewriter.modifyOpInPlace(viewLikeOp, [&] {
Value result = viewLikeOp->getResult(0);
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
if (resultType)
@ -337,7 +337,7 @@ public:
// correctly copy them back to their mlir::func::ReturnOp's expected types.
DenseMap<Value, Type> originalTypes;
for (Operation *op : viewLikeOps) {
rewriter.updateRootInPlace(op, [&]() {
rewriter.modifyOpInPlace(op, [&]() {
if (auto nonValueTensorType =
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) {
originalTypes[op->getResult(0)] = nonValueTensorType;

View File

@ -9,10 +9,10 @@
#include "PassDetail.h"
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
@ -72,8 +72,8 @@ namespace {
// immutable tensors.
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
public:
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context,
const std::optional<SymbolTable>& extraLibrary)
ConvertHasValueSemanticsOpsToValueTensors(
MLIRContext *context, const std::optional<SymbolTable> &extraLibrary)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
this->extraLibrary = extraLibrary;
}
@ -87,7 +87,7 @@ public:
return rewriter.notifyMatchFailure(op, "does not have value semantics");
}
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
// Convert all operands.
SmallVector<Value> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) {
@ -105,7 +105,7 @@ public:
auto listConstruct =
opOperand.get().getDefiningOp<PrimListConstructOp>();
if (!listConstruct) {
rewriter.cancelRootUpdate(op);
rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure(
op, "unimplemented: list of non vtensor type not constructed "
"from list construct");
@ -120,7 +120,7 @@ public:
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
})) {
rewriter.cancelRootUpdate(op);
rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure(
op, "unimplemented: list containing optional type is not "
"handled.");
@ -138,7 +138,7 @@ public:
Type newListType = getContainerOrTensorTypeWithValueSemantics(listType);
if (!newListType) {
rewriter.cancelRootUpdate(op);
rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure(
op, "Unable to convert list type to value semantics.");
}
@ -154,7 +154,7 @@ public:
// from the non value tensor of the original optional value.
auto derefine = opOperand.get().getDefiningOp<DerefineOp>();
if (!derefine) {
rewriter.cancelRootUpdate(op);
rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure(
op, "unimplemented: optional of non vtensor type not from "
"derefine");
@ -180,9 +180,10 @@ public:
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
}
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
return success();
}
private:
std::optional<SymbolTable> extraLibrary;
};
@ -290,9 +291,9 @@ public:
Operation *newOp = rewriter.create(state);
// Note: need to convert result to first input's dtype because mix precision
// compute would result in different behaviors.
// For example:
// a = torch.randn(3, 3).half() # float16
// b = torch.randn(3, 3) # float32
// For example:
// a = torch.randn(3, 3).half() # float16
// b = torch.randn(3, 3) # float32
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
@ -300,7 +301,8 @@ public:
auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0));
auto toDtype = rewriter.create<AtenToDtypeOp>(
op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0),
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype);
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
op->getOperand(0));