mirror of https://github.com/llvm/torch-mlir
parent
46a25d7241
commit
28c7051ceb
|
@ -1 +1 @@
|
||||||
Subproject commit eae82ac259ee5a58bc4070a414bc53239e18bad0
|
Subproject commit 5fcf907b34355980f77d7665a175b05fea7a6b7b
|
|
@ -81,7 +81,7 @@ public:
|
||||||
}
|
}
|
||||||
newResultTypes.push_back(type);
|
newResultTypes.push_back(type);
|
||||||
}
|
}
|
||||||
rewriter.updateRootInPlace(func, [&] {
|
rewriter.modifyOpInPlace(func, [&] {
|
||||||
func.setType(FunctionType::get(
|
func.setType(FunctionType::get(
|
||||||
getContext(), conversion.getConvertedTypes(), newResultTypes));
|
getContext(), conversion.getConvertedTypes(), newResultTypes));
|
||||||
// Clear out the type bounds, now that the type incorporates them.
|
// Clear out the type bounds, now that the type incorporates them.
|
||||||
|
@ -194,14 +194,12 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
typeConverter.addConversion(
|
typeConverter.addConversion(
|
||||||
[](Torch::TupleType type,
|
[](Torch::TupleType type, SmallVectorImpl<Type> &types) -> LogicalResult {
|
||||||
SmallVectorImpl<Type> &types) -> LogicalResult {
|
|
||||||
llvm::append_range(types, type.getContainedTypes());
|
llvm::append_range(types, type.getContainedTypes());
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
typeConverter.addConversion(
|
typeConverter.addConversion(
|
||||||
[](Torch::NoneType type,
|
[](Torch::NoneType type, SmallVectorImpl<Type> &types) -> LogicalResult {
|
||||||
SmallVectorImpl<Type> &types) -> LogicalResult {
|
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -175,7 +175,7 @@ public:
|
||||||
|
|
||||||
// Replace return type of view-like ops with value-semantics type variant.
|
// Replace return type of view-like ops with value-semantics type variant.
|
||||||
for (Operation *viewLikeOp : ops.viewLikeOps) {
|
for (Operation *viewLikeOp : ops.viewLikeOps) {
|
||||||
rewriter.updateRootInPlace(viewLikeOp, [&] {
|
rewriter.modifyOpInPlace(viewLikeOp, [&] {
|
||||||
Value result = viewLikeOp->getResult(0);
|
Value result = viewLikeOp->getResult(0);
|
||||||
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
|
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
|
||||||
if (resultType)
|
if (resultType)
|
||||||
|
@ -337,7 +337,7 @@ public:
|
||||||
// correctly copy them back to their mlir::func::ReturnOp's expected types.
|
// correctly copy them back to their mlir::func::ReturnOp's expected types.
|
||||||
DenseMap<Value, Type> originalTypes;
|
DenseMap<Value, Type> originalTypes;
|
||||||
for (Operation *op : viewLikeOps) {
|
for (Operation *op : viewLikeOps) {
|
||||||
rewriter.updateRootInPlace(op, [&]() {
|
rewriter.modifyOpInPlace(op, [&]() {
|
||||||
if (auto nonValueTensorType =
|
if (auto nonValueTensorType =
|
||||||
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) {
|
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) {
|
||||||
originalTypes[op->getResult(0)] = nonValueTensorType;
|
originalTypes[op->getResult(0)] = nonValueTensorType;
|
||||||
|
|
|
@ -9,10 +9,10 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "ReifyAbstractInterpCalculationsUtils.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "ReifyAbstractInterpCalculationsUtils.h"
|
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -72,8 +72,8 @@ namespace {
|
||||||
// immutable tensors.
|
// immutable tensors.
|
||||||
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
|
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
|
||||||
public:
|
public:
|
||||||
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context,
|
ConvertHasValueSemanticsOpsToValueTensors(
|
||||||
const std::optional<SymbolTable>& extraLibrary)
|
MLIRContext *context, const std::optional<SymbolTable> &extraLibrary)
|
||||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
|
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
|
||||||
this->extraLibrary = extraLibrary;
|
this->extraLibrary = extraLibrary;
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "does not have value semantics");
|
return rewriter.notifyMatchFailure(op, "does not have value semantics");
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.startRootUpdate(op);
|
rewriter.startOpModification(op);
|
||||||
// Convert all operands.
|
// Convert all operands.
|
||||||
SmallVector<Value> newOperands;
|
SmallVector<Value> newOperands;
|
||||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||||
|
@ -105,7 +105,7 @@ public:
|
||||||
auto listConstruct =
|
auto listConstruct =
|
||||||
opOperand.get().getDefiningOp<PrimListConstructOp>();
|
opOperand.get().getDefiningOp<PrimListConstructOp>();
|
||||||
if (!listConstruct) {
|
if (!listConstruct) {
|
||||||
rewriter.cancelRootUpdate(op);
|
rewriter.cancelOpModification(op);
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: list of non vtensor type not constructed "
|
op, "unimplemented: list of non vtensor type not constructed "
|
||||||
"from list construct");
|
"from list construct");
|
||||||
|
@ -120,7 +120,7 @@ public:
|
||||||
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
|
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
|
||||||
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
|
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
|
||||||
})) {
|
})) {
|
||||||
rewriter.cancelRootUpdate(op);
|
rewriter.cancelOpModification(op);
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: list containing optional type is not "
|
op, "unimplemented: list containing optional type is not "
|
||||||
"handled.");
|
"handled.");
|
||||||
|
@ -138,7 +138,7 @@ public:
|
||||||
|
|
||||||
Type newListType = getContainerOrTensorTypeWithValueSemantics(listType);
|
Type newListType = getContainerOrTensorTypeWithValueSemantics(listType);
|
||||||
if (!newListType) {
|
if (!newListType) {
|
||||||
rewriter.cancelRootUpdate(op);
|
rewriter.cancelOpModification(op);
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unable to convert list type to value semantics.");
|
op, "Unable to convert list type to value semantics.");
|
||||||
}
|
}
|
||||||
|
@ -154,7 +154,7 @@ public:
|
||||||
// from the non value tensor of the original optional value.
|
// from the non value tensor of the original optional value.
|
||||||
auto derefine = opOperand.get().getDefiningOp<DerefineOp>();
|
auto derefine = opOperand.get().getDefiningOp<DerefineOp>();
|
||||||
if (!derefine) {
|
if (!derefine) {
|
||||||
rewriter.cancelRootUpdate(op);
|
rewriter.cancelOpModification(op);
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: optional of non vtensor type not from "
|
op, "unimplemented: optional of non vtensor type not from "
|
||||||
"derefine");
|
"derefine");
|
||||||
|
@ -180,9 +180,10 @@ public:
|
||||||
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
|
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
|
||||||
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
|
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
|
||||||
}
|
}
|
||||||
rewriter.finalizeRootUpdate(op);
|
rewriter.finalizeOpModification(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::optional<SymbolTable> extraLibrary;
|
std::optional<SymbolTable> extraLibrary;
|
||||||
};
|
};
|
||||||
|
@ -290,9 +291,9 @@ public:
|
||||||
Operation *newOp = rewriter.create(state);
|
Operation *newOp = rewriter.create(state);
|
||||||
// Note: need to convert result to first input's dtype because mix precision
|
// Note: need to convert result to first input's dtype because mix precision
|
||||||
// compute would result in different behaviors.
|
// compute would result in different behaviors.
|
||||||
// For example:
|
// For example:
|
||||||
// a = torch.randn(3, 3).half() # float16
|
// a = torch.randn(3, 3).half() # float16
|
||||||
// b = torch.randn(3, 3) # float32
|
// b = torch.randn(3, 3) # float32
|
||||||
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16
|
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16
|
||||||
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
|
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
|
||||||
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
|
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
|
||||||
|
@ -300,7 +301,8 @@ public:
|
||||||
auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0));
|
auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0));
|
||||||
auto toDtype = rewriter.create<AtenToDtypeOp>(
|
auto toDtype = rewriter.create<AtenToDtypeOp>(
|
||||||
op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0),
|
op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0),
|
||||||
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
|
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||||
|
/*memory_format=*/none);
|
||||||
auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype);
|
auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype);
|
||||||
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
||||||
op->getOperand(0));
|
op->getOperand(0));
|
||||||
|
|
Loading…
Reference in New Issue