mirror of https://github.com/llvm/torch-mlir
parent
46a25d7241
commit
28c7051ceb
|
@ -1 +1 @@
|
|||
Subproject commit eae82ac259ee5a58bc4070a414bc53239e18bad0
|
||||
Subproject commit 5fcf907b34355980f77d7665a175b05fea7a6b7b
|
|
@ -81,7 +81,7 @@ public:
|
|||
}
|
||||
newResultTypes.push_back(type);
|
||||
}
|
||||
rewriter.updateRootInPlace(func, [&] {
|
||||
rewriter.modifyOpInPlace(func, [&] {
|
||||
func.setType(FunctionType::get(
|
||||
getContext(), conversion.getConvertedTypes(), newResultTypes));
|
||||
// Clear out the type bounds, now that the type incorporates them.
|
||||
|
@ -194,14 +194,12 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
|
|||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
typeConverter.addConversion(
|
||||
[](Torch::TupleType type,
|
||||
SmallVectorImpl<Type> &types) -> LogicalResult {
|
||||
[](Torch::TupleType type, SmallVectorImpl<Type> &types) -> LogicalResult {
|
||||
llvm::append_range(types, type.getContainedTypes());
|
||||
return success();
|
||||
});
|
||||
typeConverter.addConversion(
|
||||
[](Torch::NoneType type,
|
||||
SmallVectorImpl<Type> &types) -> LogicalResult {
|
||||
[](Torch::NoneType type, SmallVectorImpl<Type> &types) -> LogicalResult {
|
||||
return success();
|
||||
});
|
||||
|
||||
|
|
|
@ -175,7 +175,7 @@ public:
|
|||
|
||||
// Replace return type of view-like ops with value-semantics type variant.
|
||||
for (Operation *viewLikeOp : ops.viewLikeOps) {
|
||||
rewriter.updateRootInPlace(viewLikeOp, [&] {
|
||||
rewriter.modifyOpInPlace(viewLikeOp, [&] {
|
||||
Value result = viewLikeOp->getResult(0);
|
||||
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
|
||||
if (resultType)
|
||||
|
@ -337,7 +337,7 @@ public:
|
|||
// correctly copy them back to their mlir::func::ReturnOp's expected types.
|
||||
DenseMap<Value, Type> originalTypes;
|
||||
for (Operation *op : viewLikeOps) {
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
rewriter.modifyOpInPlace(op, [&]() {
|
||||
if (auto nonValueTensorType =
|
||||
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) {
|
||||
originalTypes[op->getResult(0)] = nonValueTensorType;
|
||||
|
|
|
@ -9,10 +9,10 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "ReifyAbstractInterpCalculationsUtils.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "ReifyAbstractInterpCalculationsUtils.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -72,8 +72,8 @@ namespace {
|
|||
// immutable tensors.
|
||||
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
|
||||
public:
|
||||
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context,
|
||||
const std::optional<SymbolTable>& extraLibrary)
|
||||
ConvertHasValueSemanticsOpsToValueTensors(
|
||||
MLIRContext *context, const std::optional<SymbolTable> &extraLibrary)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
|
||||
this->extraLibrary = extraLibrary;
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "does not have value semantics");
|
||||
}
|
||||
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
// Convert all operands.
|
||||
SmallVector<Value> newOperands;
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
|
@ -105,7 +105,7 @@ public:
|
|||
auto listConstruct =
|
||||
opOperand.get().getDefiningOp<PrimListConstructOp>();
|
||||
if (!listConstruct) {
|
||||
rewriter.cancelRootUpdate(op);
|
||||
rewriter.cancelOpModification(op);
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: list of non vtensor type not constructed "
|
||||
"from list construct");
|
||||
|
@ -120,7 +120,7 @@ public:
|
|||
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
|
||||
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
|
||||
})) {
|
||||
rewriter.cancelRootUpdate(op);
|
||||
rewriter.cancelOpModification(op);
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: list containing optional type is not "
|
||||
"handled.");
|
||||
|
@ -138,7 +138,7 @@ public:
|
|||
|
||||
Type newListType = getContainerOrTensorTypeWithValueSemantics(listType);
|
||||
if (!newListType) {
|
||||
rewriter.cancelRootUpdate(op);
|
||||
rewriter.cancelOpModification(op);
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unable to convert list type to value semantics.");
|
||||
}
|
||||
|
@ -154,7 +154,7 @@ public:
|
|||
// from the non value tensor of the original optional value.
|
||||
auto derefine = opOperand.get().getDefiningOp<DerefineOp>();
|
||||
if (!derefine) {
|
||||
rewriter.cancelRootUpdate(op);
|
||||
rewriter.cancelOpModification(op);
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: optional of non vtensor type not from "
|
||||
"derefine");
|
||||
|
@ -180,9 +180,10 @@ public:
|
|||
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
|
||||
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
|
||||
}
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<SymbolTable> extraLibrary;
|
||||
};
|
||||
|
@ -290,9 +291,9 @@ public:
|
|||
Operation *newOp = rewriter.create(state);
|
||||
// Note: need to convert result to first input's dtype because mix precision
|
||||
// compute would result in different behaviors.
|
||||
// For example:
|
||||
// a = torch.randn(3, 3).half() # float16
|
||||
// b = torch.randn(3, 3) # float32
|
||||
// For example:
|
||||
// a = torch.randn(3, 3).half() # float16
|
||||
// b = torch.randn(3, 3) # float32
|
||||
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16
|
||||
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
|
||||
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
|
||||
|
@ -300,7 +301,8 @@ public:
|
|||
auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0));
|
||||
auto toDtype = rewriter.create<AtenToDtypeOp>(
|
||||
op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0),
|
||||
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
|
||||
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/none);
|
||||
auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype);
|
||||
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
||||
op->getOperand(0));
|
||||
|
|
Loading…
Reference in New Issue