mirror of https://github.com/llvm/torch-mlir
Make MaximizeValueSemantics a bit smarter.
This adds a pattern to MaximizeValueSemantics which does a simple abstract interpretation within a block, which handles simple cases of `torch.overwrite_tensor`, enough to remove all the unnecessary uses of non-value tensors in ResNet right now. Before/after IR: [gist](https://gist.github.com/silvasean/a3e1ef625b19dfc63579f73cd3b543b6) Also, - Split `torch.copy.tensor` into `torch.copy.to_tensor` and `torch.copy.to_vtensor` which convert between value and non-value semantic tensors. This is a much cleaner factorization as they have very separate use cases and properties (e.g. different side effects) - Remove the various canonicalization patterns they had, which were confusing because they resulted in limited forms of maximizing value semantics throughout the pipeline. We should structure our compilation pipeline such that only MaximizeValueSemantics should be maximizing value semantics. - Adjust pass pipeline to only run MaximizeValueSemantics once. - Make OverwriteTensorOp `$value` always be a value tensor and `$overwritten` be a non-value tensor.pull/235/head
parent
6dddb4d4fe
commit
79aade33da
|
@ -905,11 +905,11 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
|
||||||
|
|
||||||
This op *cannot* be used to add/remove value semantics from a tensor.
|
This op *cannot* be used to add/remove value semantics from a tensor.
|
||||||
For converting between the value-semantic and non-value-semantic domains,
|
For converting between the value-semantic and non-value-semantic domains,
|
||||||
use `torch.copy.tensor`. The two ops are kept separate to prevent
|
use `torch.copy.to_tensor` and `torch.copy.from_tensor`. This op is kept
|
||||||
canonicalizations from accidentally dropping static information. In
|
separate to prevent canonicalizations from accidentally dropping static
|
||||||
most cases, after running the `torch-refine-types` pass, this op becomes
|
information. In most cases, after running the `torch-refine-types` pass,
|
||||||
a no-op (the pass will incorporate the static information into other ops
|
this op becomes a no-op (the pass will incorporate the static information
|
||||||
that allow type refinement).
|
into other ops that allow type refinement).
|
||||||
}];
|
}];
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
AnyTorchTensorType:$operand
|
AnyTorchTensorType:$operand
|
||||||
|
@ -922,34 +922,66 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", [
|
def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
|
||||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||||
|
TypesMatchWith<"operand is corresponding !torch.vtensor",
|
||||||
|
"result", "operand",
|
||||||
|
"$_self.cast<NonValueTensorType>().getWithValueSemantics()">,
|
||||||
]> {
|
]> {
|
||||||
let summary = "Makes a copy of a tensor.";
|
let summary = "Create a !torch.tensor with the same contents as the operand";
|
||||||
let description = [{
|
let description = [{
|
||||||
Changes to the original tensor will not be reflected in the copy.
|
This op is used to convert from !torch.vtensor to !torch.tensor.
|
||||||
|
It does so by allocating a new !torch.tensor and filling it with
|
||||||
|
the contents of the operand.
|
||||||
|
|
||||||
This op can be used to interconvert between value-semantic and
|
However, this op *does not* allow adding/removing static information about
|
||||||
non-value-semantic tensors. However, this op *does not* allow
|
sizes/dtype. For that, use `torch.tensor_static_info_cast`.
|
||||||
adding/removing static information about sizes/dtype. For that, use
|
|
||||||
`torch.tensor_static_info_cast`.
|
|
||||||
|
|
||||||
This op does not have the AllowsTypeRefinement trait because the operand
|
This op does not have the AllowsTypeRefinement trait because the operand
|
||||||
and result types are coupled. Only places that know how to simultaneously
|
and result types are coupled. Only places that know how to simultaneously
|
||||||
update both types should be changing the type of this op.
|
update both types should be changing the type of this op.
|
||||||
}];
|
}];
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
AnyTorchTensorType:$operand
|
Torch_ValueTensorType:$operand
|
||||||
);
|
);
|
||||||
let results = (outs
|
let results = (outs
|
||||||
AnyTorchTensorType:$result
|
Torch_NonValueTensorType:$result
|
||||||
);
|
);
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$operand attr-dict `:` type($operand) `->` type($result)
|
$operand attr-dict `:` type($result)
|
||||||
|
}];
|
||||||
|
let verifier = "return ::verify(*this);";
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
|
||||||
|
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||||
|
TypesMatchWith<"operand is corresponding !torch.tensor",
|
||||||
|
"result", "operand",
|
||||||
|
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">,
|
||||||
|
]> {
|
||||||
|
let summary = "Create a !torch.vtensor with the same contents as the operand";
|
||||||
|
let description = [{
|
||||||
|
This op is used to convert from !torch.tensor to !torch.vtensor.
|
||||||
|
|
||||||
|
However, this op *does not* allow adding/removing static information about
|
||||||
|
sizes/dtype. For that, use `torch.tensor_static_info_cast`.
|
||||||
|
|
||||||
|
This op does not have the AllowsTypeRefinement trait because the operand
|
||||||
|
and result types are coupled. Only places that know how to simultaneously
|
||||||
|
update both types should be changing the type of this op.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_NonValueTensorType:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_ValueTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict `:` type($result)
|
||||||
}];
|
}];
|
||||||
let verifier = "return ::verify(*this);";
|
let verifier = "return ::verify(*this);";
|
||||||
let hasFolder = 1;
|
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
||||||
|
@ -961,7 +993,7 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
||||||
`value`.
|
`value`.
|
||||||
|
|
||||||
Immediately after this op has completed, indexing `overwritten` will result
|
Immediately after this op has completed, indexing `overwritten` will result
|
||||||
in identical values as indexing into `tensor`. Of course, later ops
|
in identical values as indexing into `value`. Of course, later ops
|
||||||
might mutate `overwritten`, so this relationship need not hold for the
|
might mutate `overwritten`, so this relationship need not hold for the
|
||||||
entire program.
|
entire program.
|
||||||
|
|
||||||
|
@ -969,8 +1001,8 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
||||||
shapes or dtypes.
|
shapes or dtypes.
|
||||||
}];
|
}];
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
AnyTorchTensorType:$value,
|
Torch_ValueTensorType:$value,
|
||||||
AnyTorchTensorType:$overwritten
|
Torch_NonValueTensorType:$overwritten
|
||||||
);
|
);
|
||||||
let results = (outs
|
let results = (outs
|
||||||
);
|
);
|
||||||
|
|
|
@ -33,11 +33,17 @@ Value mlir::NPCOMP::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||||
tensor = builder.create<TensorStaticInfoCastOp>(
|
tensor = builder.create<TensorStaticInfoCastOp>(
|
||||||
loc, originalType.getWithSizesAndDtypeFrom(newType), tensor);
|
loc, originalType.getWithSizesAndDtypeFrom(newType), tensor);
|
||||||
}
|
}
|
||||||
// If both the original and new types already have value semantics, a copy is
|
|
||||||
// pointless.
|
// Unless both the original and new types are both value tensors, we end
|
||||||
if (originalType.isa<ValueTensorType>() && newType.isa<ValueTensorType>())
|
// up creating one op that converts between the value and non-value tensor
|
||||||
return tensor;
|
// domains. If both the original and new types are both non-value tensors,
|
||||||
return builder.create<CopyTensorOp>(loc, newType, tensor);
|
// then we do the copy by going to a value tensor and back.
|
||||||
|
if (tensor.getType().isa<NonValueTensorType>())
|
||||||
|
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
|
||||||
|
if (newType.isa<NonValueTensorType>())
|
||||||
|
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
|
||||||
|
|
||||||
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -504,10 +510,10 @@ bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// CopyTensorOp
|
// CopyToNonValueTensorOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult verify(CopyTensorOp op) {
|
static LogicalResult verify(CopyToNonValueTensorOp op) {
|
||||||
auto resultType = op.getResult().getType().cast<BaseTensorType>();
|
auto resultType = op.getResult().getType().cast<BaseTensorType>();
|
||||||
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
|
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
|
||||||
if (!resultType.hasSameSizesAndDtype(operandType)) {
|
if (!resultType.hasSameSizesAndDtype(operandType)) {
|
||||||
|
@ -517,50 +523,48 @@ static LogicalResult verify(CopyTensorOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult CopyTensorOp::fold(ArrayRef<Attribute> operands) {
|
LogicalResult CopyToNonValueTensorOp::inferReturnTypes(
|
||||||
// A copy between value semantic tensors is a no-op.
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
if (getType().isa<ValueTensorType>() &&
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
getOperand().getType().isa<ValueTensorType>()) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
return getOperand();
|
auto resultType = operands[0].getType().cast<ValueTensorType>();
|
||||||
}
|
inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
|
||||||
return nullptr;
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
void CopyToNonValueTensorOp::getEffects(
|
||||||
MLIRContext *context) {
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||||
// y = torch.copy.tensor(torch.copy.tensor(x)) -> x
|
|
||||||
// Only safe when `y` and `x` have value semantics, and
|
|
||||||
// all users of the intermediate tensor op treat the tensor as if it
|
|
||||||
// had value semantics (even if it is a NonValueTensorType).
|
|
||||||
patterns.add(+[](CopyTensorOp op, PatternRewriter &rewriter) {
|
|
||||||
auto otherCopy = op.getOperand().getDefiningOp<CopyTensorOp>();
|
|
||||||
if (!otherCopy)
|
|
||||||
return failure();
|
|
||||||
if (!otherCopy.getOperand().getType().isa<ValueTensorType>() ||
|
|
||||||
!op.getResult().getType().isa<ValueTensorType>())
|
|
||||||
return failure();
|
|
||||||
// TODO: Use a proper interface here.
|
|
||||||
// MemoryEffectOpInterface is not powerful enough because it cannot model
|
|
||||||
// aliasing. We don't just care that the user is readonly -- we care also
|
|
||||||
// whether it creates an alias. Basically, we care if the user "treats the
|
|
||||||
// tensor as if it has value semantics".
|
|
||||||
// For now, just hardcode the important case of multiple CopyTensorOp users.
|
|
||||||
if (llvm::all_of(op.getOperand().getUsers(),
|
|
||||||
[](Operation *op) { return isa<CopyTensorOp>(op); })) {
|
|
||||||
rewriter.replaceOp(op, {otherCopy.getOperand()});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void CopyTensorOp::getEffects(
|
|
||||||
SmallVectorImpl<SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>>
|
|
||||||
&effects) {
|
&effects) {
|
||||||
if (getResult().getType().isa<NonValueTensorType>())
|
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
|
||||||
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
|
}
|
||||||
if (getOperand().getType().isa<NonValueTensorType>())
|
|
||||||
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
|
//===----------------------------------------------------------------------===//
|
||||||
|
// CopyToValueTensorOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult verify(CopyToValueTensorOp op) {
|
||||||
|
auto resultType = op.getResult().getType().cast<BaseTensorType>();
|
||||||
|
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
|
||||||
|
if (!resultType.hasSameSizesAndDtype(operandType)) {
|
||||||
|
return op.emitError()
|
||||||
|
<< "operand and result must have same sizes and dtype";
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult CopyToValueTensorOp::inferReturnTypes(
|
||||||
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
|
auto resultType = operands[0].getType().cast<NonValueTensorType>();
|
||||||
|
inferredReturnTypes.push_back(resultType.getWithValueSemantics());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CopyToValueTensorOp::getEffects(
|
||||||
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||||
|
&effects) {
|
||||||
|
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -216,7 +216,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
||||||
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
|
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
|
||||||
return !opsInOriginalProgram.contains(op.getOperation());
|
return !opsInOriginalProgram.contains(op.getOperation());
|
||||||
});
|
});
|
||||||
target.addLegalOp<CopyTensorOp>();
|
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
|
||||||
target.addLegalOp<TensorStaticInfoCastOp>();
|
target.addLegalOp<TensorStaticInfoCastOp>();
|
||||||
target.addLegalOp<ConstantNoneOp>();
|
target.addLegalOp<ConstantNoneOp>();
|
||||||
// We don't know how to rewrite it, so mark it as illegal.
|
// We don't know how to rewrite it, so mark it as illegal.
|
||||||
|
|
|
@ -19,6 +19,65 @@ using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::NPCOMP::Torch;
|
||||||
|
|
||||||
|
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
||||||
|
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<Operation *> users;
|
||||||
|
// See if our limited form of analysis is even applicatble.
|
||||||
|
for (Operation *user : copy.getResult().getUsers()) {
|
||||||
|
// We can only analyze within a single basic block.
|
||||||
|
if (user->getBlock() != copy->getBlock())
|
||||||
|
return failure();
|
||||||
|
// We can only analyze these ops.
|
||||||
|
if (!isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
|
||||||
|
return failure();
|
||||||
|
users.push_back(user);
|
||||||
|
}
|
||||||
|
// Sort by order in the block, so we can abstractly interpret the ops.
|
||||||
|
llvm::sort(users, [](Operation *lhs, Operation *rhs) {
|
||||||
|
return lhs->isBeforeInBlock(rhs);
|
||||||
|
});
|
||||||
|
// Do an abstract interpretation within the block.
|
||||||
|
// We track the current value tensor that holds the same contents as the
|
||||||
|
// non-value tensor at each program point as we walk forward.
|
||||||
|
Value currentlyHeldValueTensor = copy.getOperand();
|
||||||
|
for (Operation *user : users) {
|
||||||
|
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
||||||
|
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
|
||||||
|
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) {
|
||||||
|
currentlyHeldValueTensor = overwriteTensor.value();
|
||||||
|
rewriter.eraseOp(overwriteTensor);
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("only those ops supported!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.eraseOp(copy);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class RewriteNonValueTensorNeverMutatedOrAliased
|
||||||
|
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<Operation *> users;
|
||||||
|
// See if our limited form of analysis is even applicatble.
|
||||||
|
for (Operation *user : copy.getResult().getUsers()) {
|
||||||
|
if (!isa<CopyToValueTensorOp>(user))
|
||||||
|
return failure();
|
||||||
|
users.push_back(user);
|
||||||
|
}
|
||||||
|
for (Operation *user : users)
|
||||||
|
rewriter.replaceOp(user, copy.getOperand());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class MaximizeValueSemanticsPass
|
class MaximizeValueSemanticsPass
|
||||||
|
@ -28,8 +87,8 @@ class MaximizeValueSemanticsPass
|
||||||
auto func = getOperation();
|
auto func = getOperation();
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
CopyTensorOp::getCanonicalizationPatterns(patterns, context);
|
patterns.insert<AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock,
|
||||||
TensorStaticInfoCastOp::getCanonicalizationPatterns(patterns, context);
|
RewriteNonValueTensorNeverMutatedOrAliased>(context);
|
||||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -120,14 +120,12 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
||||||
// Lowering to ranked !torch.vtensors of known dtype.
|
// Lowering to ranked !torch.vtensors of known dtype.
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Convert the bulk of non-ABI-visible arrays to tensors.
|
|
||||||
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
|
||||||
// Do shape and dtype refinement.
|
// Do shape and dtype refinement.
|
||||||
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
|
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
|
||||||
// Propagate to ABI return types the shape/dtype information discovered by
|
// Propagate to ABI return types the shape/dtype information discovered by
|
||||||
// the previous pass. Doing this is ABI-compatible for our backends.
|
// the previous pass. Doing this is ABI-compatible for our backends.
|
||||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||||
// Clean up a few stray conversion remnants.
|
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
|
||||||
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||||
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
|
|
|
@ -37,8 +37,8 @@ public:
|
||||||
opOperand.get().getType().dyn_cast<NonValueTensorType>();
|
opOperand.get().getType().dyn_cast<NonValueTensorType>();
|
||||||
if (!tensorType)
|
if (!tensorType)
|
||||||
continue;
|
continue;
|
||||||
opOperand.set(rewriter.create<CopyTensorOp>(
|
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
||||||
op->getLoc(), tensorType.getWithValueSemantics(), opOperand.get()));
|
opOperand.get()));
|
||||||
}
|
}
|
||||||
// Convert all results.
|
// Convert all results.
|
||||||
rewriter.setInsertionPointAfter(op);
|
rewriter.setInsertionPointAfter(op);
|
||||||
|
@ -46,10 +46,10 @@ public:
|
||||||
auto tensorType = result.getType().dyn_cast<NonValueTensorType>();
|
auto tensorType = result.getType().dyn_cast<NonValueTensorType>();
|
||||||
if (!tensorType)
|
if (!tensorType)
|
||||||
continue;
|
continue;
|
||||||
auto createArray = rewriter.create<CopyTensorOp>(
|
|
||||||
op->getLoc(), result.getType(), result);
|
|
||||||
result.replaceAllUsesExcept(createArray, createArray);
|
|
||||||
result.setType(tensorType.getWithValueSemantics());
|
result.setType(tensorType.getWithValueSemantics());
|
||||||
|
auto nonValueTensor =
|
||||||
|
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
|
||||||
|
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return success();
|
return success();
|
||||||
|
@ -85,12 +85,8 @@ public:
|
||||||
"Torch JIT operators shouldn't have regions or successors");
|
"Torch JIT operators shouldn't have regions or successors");
|
||||||
|
|
||||||
Operation *newOp = rewriter.createOperation(state);
|
Operation *newOp = rewriter.createOperation(state);
|
||||||
auto tensor = rewriter.create<CopyTensorOp>(op->getLoc(),
|
auto tensor =
|
||||||
newOp->getResult(0)
|
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
||||||
.getType()
|
|
||||||
.cast<NonValueTensorType>()
|
|
||||||
.getWithValueSemantics(),
|
|
||||||
newOp->getResult(0));
|
|
||||||
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
|
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
|
||||||
rewriter.replaceOp(op, op->getOperand(0));
|
rewriter.replaceOp(op, op->getOperand(0));
|
||||||
|
|
||||||
|
|
|
@ -141,8 +141,8 @@ public:
|
||||||
ChangeResult
|
ChangeResult
|
||||||
visitOperation(Operation *op,
|
visitOperation(Operation *op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
|
||||||
if (isa<TensorStaticInfoCastOp, CopyTensorOp, AtenTanhOp, AtenBatchNormOp,
|
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
|
||||||
AtenReluOp>(op)) {
|
AtenTanhOp, AtenBatchNormOp, AtenReluOp>(op)) {
|
||||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||||
}
|
}
|
||||||
if (isa<AtenMmOp>(op)) {
|
if (isa<AtenMmOp>(op)) {
|
||||||
|
@ -303,12 +303,13 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
|
||||||
// which allows arbitrary refinements. But some other cases are safe too,
|
// which allows arbitrary refinements. But some other cases are safe too,
|
||||||
// such as when an op has two types that are coupled, but we know that our
|
// such as when an op has two types that are coupled, but we know that our
|
||||||
// analysis and updating logic will correctly maintain the invariants of the op.
|
// analysis and updating logic will correctly maintain the invariants of the op.
|
||||||
// The `torch.copy.tensor` is an example of the latter case, since its
|
// The `torch.copy.to_tensor` / `torch.copy.to_vtensor` are examples of the
|
||||||
// operand and result types must have the same shape and dtype -- we know
|
// latter case, since their operand and result types must have the same shape
|
||||||
// that our transfer functions and updating logic will do the right thing
|
// and dtype -- we know that our transfer functions and updating logic will do
|
||||||
// for that op.
|
// the right thing forthose ops.
|
||||||
static bool allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(Operation *op) {
|
static bool allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(Operation *op) {
|
||||||
return allowsTypeRefinement(op) || isa<CopyTensorOp>(op);
|
return allowsTypeRefinement(op) ||
|
||||||
|
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
// CHECK-LABEL: func @basic(
|
// CHECK-LABEL: func @basic(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||||
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.tensor %[[ERASED]] : !torch.vtensor -> !torch.tensor
|
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[ERASED]] : !torch.tensor
|
||||||
// CHECK: return %[[NONVAL_TENSOR]] : !torch.tensor
|
// CHECK: return %[[NONVAL_TENSOR]] : !torch.tensor
|
||||||
func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
|
func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
|
||||||
return %arg0 : !torch.tensor
|
return %arg0 : !torch.tensor
|
||||||
|
@ -19,9 +19,9 @@ func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
|
||||||
// CHECK-LABEL: func @call(
|
// CHECK-LABEL: func @call(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||||
// CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
// CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||||
// CHECK: %[[ARG_NONVAL:.*]] = torch.copy.tensor %[[ARG_ERASED]] : !torch.vtensor -> !torch.tensor
|
// CHECK: %[[ARG_NONVAL:.*]] = torch.copy.to_tensor %[[ARG_ERASED]] : !torch.tensor
|
||||||
// CHECK: %[[INFO_ADDED:.*]] = torch.tensor_static_info_cast %[[ARG_NONVAL]] : !torch.tensor to !torch.tensor<[2,3,?],f32>
|
// CHECK: %[[INFO_ADDED:.*]] = torch.tensor_static_info_cast %[[ARG_NONVAL]] : !torch.tensor to !torch.tensor<[2,3,?],f32>
|
||||||
// CHECK: %[[CALL_ARG:.*]] = torch.copy.tensor %[[INFO_ADDED]] : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
// CHECK: %[[CALL_ARG:.*]] = torch.copy.to_vtensor %[[INFO_ADDED]] : !torch.vtensor<[2,3,?],f32>
|
||||||
// CHECK: %[[CALL_RES:.*]] = call @call(%[[CALL_ARG]]) : (!torch.vtensor<[2,3,?],f32>) -> !torch.tensor
|
// CHECK: %[[CALL_RES:.*]] = call @call(%[[CALL_ARG]]) : (!torch.vtensor<[2,3,?],f32>) -> !torch.tensor
|
||||||
// CHECK: return %[[ARG_NONVAL]] : !torch.tensor
|
// CHECK: return %[[ARG_NONVAL]] : !torch.tensor
|
||||||
func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
|
func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
|
||||||
|
|
|
@ -78,23 +78,6 @@ func @torch.aten.len.t$of_build_list(%arg0: !torch.int) -> !torch.int {
|
||||||
return %1 : !torch.int
|
return %1 : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.copy.tensor$value_copy_is_noop(
|
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
|
|
||||||
// CHECK: return %[[ARG]] : !torch.vtensor
|
|
||||||
func @torch.copy.tensor$value_copy_is_noop(%arg0: !torch.vtensor) -> !torch.vtensor {
|
|
||||||
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.vtensor
|
|
||||||
return %0 : !torch.vtensor
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(
|
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
|
|
||||||
// CHECK: return %[[ARG]] : !torch.vtensor
|
|
||||||
func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(%arg0: !torch.vtensor) -> !torch.vtensor {
|
|
||||||
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
|
|
||||||
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
|
||||||
return %1 : !torch.vtensor
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.__getitem__.t(
|
// CHECK-LABEL: func @torch.aten.__getitem__.t(
|
||||||
// CHECK: %[[C5:.*]] = torch.constant.int 5
|
// CHECK: %[[C5:.*]] = torch.constant.int 5
|
||||||
// CHECK: return %[[C5]] : !torch.int
|
// CHECK: return %[[C5]] : !torch.int
|
||||||
|
@ -179,13 +162,3 @@ func @torch.prim.If$erase_dead_branch(%arg0: !torch.int) -> !torch.int {
|
||||||
}
|
}
|
||||||
return %0 : !torch.int
|
return %0 : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.copy.tensor$untouched_nonval(
|
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
|
||||||
// CHECK-NEXT: return %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor
|
|
||||||
func @torch.copy.tensor$untouched_nonval(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
|
||||||
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
|
|
||||||
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
|
||||||
%2 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
|
||||||
return %1, %2 : !torch.vtensor, !torch.vtensor
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,22 +1,72 @@
|
||||||
// RUN: npcomp-opt -split-input-file %s -torch-maximize-value-semantics | FileCheck %s
|
// RUN: npcomp-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
|
||||||
|
|
||||||
// Basic case that can be resolved with local reasoning.
|
// CHECK-LABEL: func @torch.copy.tensor$basic(
|
||||||
// This pass will eventually need to learn about aliasing relationships.
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||||
//
|
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.vtensor, !torch.vtensor
|
||||||
// This is taken from a test case from an e2e spike, and isn't intended to be
|
func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||||
// particularly minimal or specifically test one thing, since the pass is
|
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||||
// currently just a handful of canonicalization patterns that are already
|
%1 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||||
// tested elsewhere.
|
%2 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||||
|
return %1, %2 : !torch.vtensor, !torch.vtensor
|
||||||
// CHECK-LABEL: func @local(
|
}
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
|
||||||
// CHECK: %[[RET:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
// CHECK-LABEL: func @one_mutation_in_a_block(
|
||||||
// CHECK: return %[[RET]] : !torch.vtensor<[2,3,?],f32>
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
|
||||||
func @local(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
|
// CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor
|
||||||
%1 = torch.aten.tanh %0 : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||||
%2 = torch.copy.tensor %1 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||||
%3 = torch.tensor_static_info_cast %2 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
%equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||||
%4 = torch.copy.tensor %2 : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
|
||||||
return %4 : !torch.vtensor<[2,3,?],f32>
|
%equal_to_arg1 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||||
|
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @multiple_mutations_in_a_block(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, %[[ARG1:.*]]: !torch.vtensor,
|
||||||
|
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
|
||||||
|
// CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG2]] : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
|
||||||
|
func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
|
||||||
|
// The mutable tensor we are overwriting.
|
||||||
|
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||||
|
|
||||||
|
// The original value.
|
||||||
|
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||||
|
|
||||||
|
// Overwrite with %arg1
|
||||||
|
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||||
|
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||||
|
%equal_to_arg1_again = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||||
|
|
||||||
|
// Overwrite with %arg2
|
||||||
|
torch.overwrite.tensor %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||||
|
%equal_to_arg2 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||||
|
|
||||||
|
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @unmodeled_mutation(
|
||||||
|
// CHECK: torch.overwrite.tensor
|
||||||
|
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||||
|
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||||
|
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
|
||||||
|
"some.op"(%0) : (!torch.tensor) -> ()
|
||||||
|
%result = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||||
|
return %result : !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// We don't yet handle nontrivial cases involving control flow.
|
||||||
|
// CHECK-LABEL: func @unimplemented_control_flow(
|
||||||
|
// CHECK: torch.copy.to_vtensor
|
||||||
|
func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %cond: !torch.bool) -> (!torch.vtensor, !torch.vtensor) {
|
||||||
|
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||||
|
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||||
|
torch.prim.If %cond -> () {
|
||||||
|
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||||
|
torch.prim.If.yield
|
||||||
|
} else {
|
||||||
|
torch.prim.If.yield
|
||||||
|
}
|
||||||
|
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||||
|
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
// CHECK-LABEL: func @convert_to_value_semantic_tensors(
|
// CHECK-LABEL: func @convert_to_value_semantic_tensors(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||||
// CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.tensor %[[ARG]] : !torch.tensor<[],f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.to_vtensor %[[ARG]] : !torch.vtensor<[],f32>
|
||||||
// CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: %[[RET:.*]] = torch.copy.tensor %[[RESULT_TENSOR]] : !torch.vtensor<[],f32> -> !torch.tensor<[],f32>
|
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[RESULT_TENSOR]] : !torch.tensor<[],f32>
|
||||||
// CHECK: return %[[RET]] : !torch.tensor<[],f32>
|
// CHECK: return %[[RET]] : !torch.tensor<[],f32>
|
||||||
func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||||
%0 = torch.aten.tanh %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
|
%0 = torch.aten.tanh %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
|
||||||
|
@ -16,14 +16,14 @@ func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
||||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[TENSOR0:.*]] = torch.copy.tensor %[[ARG0]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
// CHECK: %[[TENSOR0:.*]] = torch.copy.to_vtensor %[[ARG0]] : !torch.vtensor<[2,2],f32>
|
||||||
// CHECK: %[[TENSOR1:.*]] = torch.copy.tensor %[[ARG1]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
// CHECK: %[[TENSOR1:.*]] = torch.copy.to_vtensor %[[ARG1]] : !torch.vtensor<[2,2],f32>
|
||||||
// CHECK: %[[TENSOR_RESULT:.*]] = torch.aten.add.Tensor %[[TENSOR0]], %[[TENSOR1]], %[[C1]] : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, !torch.int -> !torch.vtensor<[2,2],f32>
|
// CHECK: %[[TENSOR_RESULT:.*]] = torch.aten.add.Tensor %[[TENSOR0]], %[[TENSOR1]], %[[C1]] : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, !torch.int -> !torch.vtensor<[2,2],f32>
|
||||||
// Note: This somewhat redundant conversion back and forth
|
// Note: This somewhat redundant conversion back and forth
|
||||||
// (which is cleaned up by canonicalization) is an artifact of two patterns
|
// (which is cleaned up by canonicalization) is an artifact of two patterns
|
||||||
// being applied in sequence.
|
// being applied in sequence.
|
||||||
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.tensor %[[TENSOR_RESULT]] : !torch.vtensor<[2,2],f32> -> !torch.tensor<[2,2],f32>
|
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
|
||||||
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.tensor %[[ARRAY_RESULT]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
|
||||||
// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||||
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||||
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
||||||
|
@ -35,7 +35,7 @@ func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>
|
||||||
// CHECK-LABEL: func @torch.tensor.literal() -> !torch.tensor {
|
// CHECK-LABEL: func @torch.tensor.literal() -> !torch.tensor {
|
||||||
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
|
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
|
||||||
// CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor
|
// CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor
|
||||||
// CHECK: %[[TENSOR:.*]] = torch.copy.tensor %[[SIZES_ERASED]] : !torch.vtensor -> !torch.tensor
|
// CHECK: %[[TENSOR:.*]] = torch.copy.to_tensor %[[SIZES_ERASED]] : !torch.tensor
|
||||||
// CHECK: return %[[TENSOR]] : !torch.tensor
|
// CHECK: return %[[TENSOR]] : !torch.tensor
|
||||||
func @torch.tensor.literal() -> !torch.tensor {
|
func @torch.tensor.literal() -> !torch.tensor {
|
||||||
%0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor
|
%0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor
|
||||||
|
|
|
@ -2,12 +2,11 @@
|
||||||
|
|
||||||
// CHECK-LABEL: func @basic(
|
// CHECK-LABEL: func @basic(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
||||||
// CHECK: %[[COPIED_NONVAL:.*]] = torch.copy.tensor %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
// CHECK: %[[COPIED_NONVAL:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32>
|
||||||
// CHECK: %[[COPIED_VALUE:.*]] = torch.copy.tensor %[[COPIED_NONVAL]] : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
// CHECK: %[[COPIED_VALUE:.*]] = torch.copy.to_vtensor %[[COPIED_NONVAL]] : !torch.vtensor<[2,3,?],f32>
|
||||||
// CHECK: return %[[COPIED_VALUE]] : !torch.vtensor<[2,3,?],f32>
|
// CHECK: return %[[COPIED_VALUE]] : !torch.vtensor<[2,3,?],f32>
|
||||||
// CHECK: }
|
|
||||||
func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||||
%1 = torch.copy.tensor %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
%1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
|
||||||
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||||
return %2 : !torch.tensor
|
return %2 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
@ -15,11 +14,11 @@ func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||||
// No conversion on private function.
|
// No conversion on private function.
|
||||||
// CHECK-LABEL: func private @basic_private(
|
// CHECK-LABEL: func private @basic_private(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||||
// CHECK: %[[COPIED:.*]] = torch.copy.tensor %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
// CHECK: %[[COPIED:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32>
|
||||||
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[COPIED]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[COPIED]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||||
// CHECK: return %[[CASTED]] : !torch.tensor
|
// CHECK: return %[[CASTED]] : !torch.tensor
|
||||||
func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||||
%1 = torch.copy.tensor %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
%1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
|
||||||
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||||
return %2 : !torch.tensor
|
return %2 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,12 +15,12 @@ func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
||||||
// CHECK-LABEL: func @f(
|
// CHECK-LABEL: func @f(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||||
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
|
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
|
||||||
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.tensor %[[CASTED]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[CASTED]] : !torch.tensor<[2,3,?],f32>
|
||||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[NONVAL_TENSOR]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[NONVAL_TENSOR]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||||
// CHECK: return %[[ERASED]] : !torch.tensor
|
// CHECK: return %[[ERASED]] : !torch.tensor
|
||||||
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||||
%1 = torch.copy.tensor %0 : !torch.vtensor -> !torch.tensor
|
%1 = torch.copy.to_tensor %0 : !torch.tensor
|
||||||
return %1 : !torch.tensor
|
return %1 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue