Make MaximizeValueSemantics a bit smarter.

This adds a pattern to MaximizeValueSemantics which does a simple
abstract interpretation within a block, which handles simple cases of
`torch.overwrite_tensor`, enough to remove all the unnecessary uses of
non-value tensors in ResNet right now.

Before/after IR:
[gist](https://gist.github.com/silvasean/a3e1ef625b19dfc63579f73cd3b543b6)

Also,
- Split `torch.copy.tensor` into `torch.copy.to_tensor` and
  `torch.copy.to_vtensor` which convert between value and non-value
  semantic tensors. This is a much cleaner factorization as they have
  very separate use cases and properties (e.g. different side effects)
- Remove the various canonicalization patterns they had, which were
  confusing because they resulted in limited forms of maximizing value
  semantics throughout the pipeline. We should structure our compilation
  pipeline such that only MaximizeValueSemantics should be maximizing
  value semantics.
- Adjust pass pipeline to only run MaximizeValueSemantics once.
- Make OverwriteTensorOp `$value` always be a value tensor and
  `$overwritten` be a non-value tensor.
pull/235/head
Sean Silva 2021-06-18 13:47:47 -07:00
parent 6dddb4d4fe
commit 79aade33da
13 changed files with 269 additions and 157 deletions

View File

@ -905,11 +905,11 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
This op *cannot* be used to add/remove value semantics from a tensor.
For converting between the value-semantic and non-value-semantic domains,
use `torch.copy.tensor`. The two ops are kept separate to prevent
canonicalizations from accidentally dropping static information. In
most cases, after running the `torch-refine-types` pass, this op becomes
a no-op (the pass will incorporate the static information into other ops
that allow type refinement).
use `torch.copy.to_tensor` and `torch.copy.from_tensor`. This op is kept
separate to prevent canonicalizations from accidentally dropping static
information. In most cases, after running the `torch-refine-types` pass,
this op becomes a no-op (the pass will incorporate the static information
into other ops that allow type refinement).
}];
let arguments = (ins
AnyTorchTensorType:$operand
@ -922,34 +922,66 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
}];
}
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", [
def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.vtensor",
"result", "operand",
"$_self.cast<NonValueTensorType>().getWithValueSemantics()">,
]> {
let summary = "Makes a copy of a tensor.";
let summary = "Create a !torch.tensor with the same contents as the operand";
let description = [{
Changes to the original tensor will not be reflected in the copy.
This op is used to convert from !torch.vtensor to !torch.tensor.
It does so by allocating a new !torch.tensor and filling it with
the contents of the operand.
This op can be used to interconvert between value-semantic and
non-value-semantic tensors. However, this op *does not* allow
adding/removing static information about sizes/dtype. For that, use
`torch.tensor_static_info_cast`.
However, this op *does not* allow adding/removing static information about
sizes/dtype. For that, use `torch.tensor_static_info_cast`.
This op does not have the AllowsTypeRefinement trait because the operand
and result types are coupled. Only places that know how to simultaneously
update both types should be changing the type of this op.
}];
let arguments = (ins
AnyTorchTensorType:$operand
Torch_ValueTensorType:$operand
);
let results = (outs
AnyTorchTensorType:$result
Torch_NonValueTensorType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
$operand attr-dict `:` type($result)
}];
let verifier = "return ::verify(*this);";
}
def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.tensor",
"result", "operand",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">,
]> {
let summary = "Create a !torch.vtensor with the same contents as the operand";
let description = [{
This op is used to convert from !torch.tensor to !torch.vtensor.
However, this op *does not* allow adding/removing static information about
sizes/dtype. For that, use `torch.tensor_static_info_cast`.
This op does not have the AllowsTypeRefinement trait because the operand
and result types are coupled. Only places that know how to simultaneously
update both types should be changing the type of this op.
}];
let arguments = (ins
Torch_NonValueTensorType:$operand
);
let results = (outs
Torch_ValueTensorType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($result)
}];
let verifier = "return ::verify(*this);";
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
@ -961,7 +993,7 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
`value`.
Immediately after this op has completed, indexing `overwritten` will result
in identical values as indexing into `tensor`. Of course, later ops
in identical values as indexing into `value`. Of course, later ops
might mutate `overwritten`, so this relationship need not hold for the
entire program.
@ -969,8 +1001,8 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
shapes or dtypes.
}];
let arguments = (ins
AnyTorchTensorType:$value,
AnyTorchTensorType:$overwritten
Torch_ValueTensorType:$value,
Torch_NonValueTensorType:$overwritten
);
let results = (outs
);

View File

@ -33,11 +33,17 @@ Value mlir::NPCOMP::Torch::copyTensorToType(OpBuilder &builder, Location loc,
tensor = builder.create<TensorStaticInfoCastOp>(
loc, originalType.getWithSizesAndDtypeFrom(newType), tensor);
}
// If both the original and new types already have value semantics, a copy is
// pointless.
if (originalType.isa<ValueTensorType>() && newType.isa<ValueTensorType>())
return tensor;
return builder.create<CopyTensorOp>(loc, newType, tensor);
// Unless both the original and new types are both value tensors, we end
// up creating one op that converts between the value and non-value tensor
// domains. If both the original and new types are both non-value tensors,
// then we do the copy by going to a value tensor and back.
if (tensor.getType().isa<NonValueTensorType>())
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
if (newType.isa<NonValueTensorType>())
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
return tensor;
}
//===----------------------------------------------------------------------===//
@ -504,10 +510,10 @@ bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
}
//===----------------------------------------------------------------------===//
// CopyTensorOp
// CopyToNonValueTensorOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(CopyTensorOp op) {
static LogicalResult verify(CopyToNonValueTensorOp op) {
auto resultType = op.getResult().getType().cast<BaseTensorType>();
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
if (!resultType.hasSameSizesAndDtype(operandType)) {
@ -517,50 +523,48 @@ static LogicalResult verify(CopyTensorOp op) {
return success();
}
OpFoldResult CopyTensorOp::fold(ArrayRef<Attribute> operands) {
// A copy between value semantic tensors is a no-op.
if (getType().isa<ValueTensorType>() &&
getOperand().getType().isa<ValueTensorType>()) {
return getOperand();
}
return nullptr;
LogicalResult CopyToNonValueTensorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto resultType = operands[0].getType().cast<ValueTensorType>();
inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
return success();
}
void CopyTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// y = torch.copy.tensor(torch.copy.tensor(x)) -> x
// Only safe when `y` and `x` have value semantics, and
// all users of the intermediate tensor op treat the tensor as if it
// had value semantics (even if it is a NonValueTensorType).
patterns.add(+[](CopyTensorOp op, PatternRewriter &rewriter) {
auto otherCopy = op.getOperand().getDefiningOp<CopyTensorOp>();
if (!otherCopy)
return failure();
if (!otherCopy.getOperand().getType().isa<ValueTensorType>() ||
!op.getResult().getType().isa<ValueTensorType>())
return failure();
// TODO: Use a proper interface here.
// MemoryEffectOpInterface is not powerful enough because it cannot model
// aliasing. We don't just care that the user is readonly -- we care also
// whether it creates an alias. Basically, we care if the user "treats the
// tensor as if it has value semantics".
// For now, just hardcode the important case of multiple CopyTensorOp users.
if (llvm::all_of(op.getOperand().getUsers(),
[](Operation *op) { return isa<CopyTensorOp>(op); })) {
rewriter.replaceOp(op, {otherCopy.getOperand()});
return success();
}
return failure();
});
}
void CopyTensorOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>>
void CopyToNonValueTensorOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (getResult().getType().isa<NonValueTensorType>())
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
if (getOperand().getType().isa<NonValueTensorType>())
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
}
//===----------------------------------------------------------------------===//
// CopyToValueTensorOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(CopyToValueTensorOp op) {
auto resultType = op.getResult().getType().cast<BaseTensorType>();
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
if (!resultType.hasSameSizesAndDtype(operandType)) {
return op.emitError()
<< "operand and result must have same sizes and dtype";
}
return success();
}
LogicalResult CopyToValueTensorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto resultType = operands[0].getType().cast<NonValueTensorType>();
inferredReturnTypes.push_back(resultType.getWithValueSemantics());
return success();
}
void CopyToValueTensorOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
}
//===----------------------------------------------------------------------===//

View File

@ -216,7 +216,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
return !opsInOriginalProgram.contains(op.getOperation());
});
target.addLegalOp<CopyTensorOp>();
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
target.addLegalOp<TensorStaticInfoCastOp>();
target.addLegalOp<ConstantNoneOp>();
// We don't know how to rewrite it, so mark it as illegal.

View File

@ -19,6 +19,65 @@ using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
: public OpRewritePattern<CopyToNonValueTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
PatternRewriter &rewriter) const override {
SmallVector<Operation *> users;
// See if our limited form of analysis is even applicatble.
for (Operation *user : copy.getResult().getUsers()) {
// We can only analyze within a single basic block.
if (user->getBlock() != copy->getBlock())
return failure();
// We can only analyze these ops.
if (!isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
return failure();
users.push_back(user);
}
// Sort by order in the block, so we can abstractly interpret the ops.
llvm::sort(users, [](Operation *lhs, Operation *rhs) {
return lhs->isBeforeInBlock(rhs);
});
// Do an abstract interpretation within the block.
// We track the current value tensor that holds the same contents as the
// non-value tensor at each program point as we walk forward.
Value currentlyHeldValueTensor = copy.getOperand();
for (Operation *user : users) {
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) {
currentlyHeldValueTensor = overwriteTensor.value();
rewriter.eraseOp(overwriteTensor);
} else {
llvm_unreachable("only those ops supported!");
}
}
rewriter.eraseOp(copy);
return success();
}
};
class RewriteNonValueTensorNeverMutatedOrAliased
: public OpRewritePattern<CopyToNonValueTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
PatternRewriter &rewriter) const override {
SmallVector<Operation *> users;
// See if our limited form of analysis is even applicatble.
for (Operation *user : copy.getResult().getUsers()) {
if (!isa<CopyToValueTensorOp>(user))
return failure();
users.push_back(user);
}
for (Operation *user : users)
rewriter.replaceOp(user, copy.getOperand());
return success();
}
};
namespace {
class MaximizeValueSemanticsPass
@ -28,8 +87,8 @@ class MaximizeValueSemanticsPass
auto func = getOperation();
RewritePatternSet patterns(context);
CopyTensorOp::getCanonicalizationPatterns(patterns, context);
TensorStaticInfoCastOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock,
RewriteNonValueTensorNeverMutatedOrAliased>(context);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
};

View File

@ -120,14 +120,12 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
// Lowering to ranked !torch.vtensors of known dtype.
//===--------------------------------------------------------------------===//
// Convert the bulk of non-ABI-visible arrays to tensors.
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
// Do shape and dtype refinement.
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends.
pm.addPass(Torch::createRefinePublicReturnPass());
// Clean up a few stray conversion remnants.
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
if (options.optimize) {

View File

@ -37,8 +37,8 @@ public:
opOperand.get().getType().dyn_cast<NonValueTensorType>();
if (!tensorType)
continue;
opOperand.set(rewriter.create<CopyTensorOp>(
op->getLoc(), tensorType.getWithValueSemantics(), opOperand.get()));
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
opOperand.get()));
}
// Convert all results.
rewriter.setInsertionPointAfter(op);
@ -46,10 +46,10 @@ public:
auto tensorType = result.getType().dyn_cast<NonValueTensorType>();
if (!tensorType)
continue;
auto createArray = rewriter.create<CopyTensorOp>(
op->getLoc(), result.getType(), result);
result.replaceAllUsesExcept(createArray, createArray);
result.setType(tensorType.getWithValueSemantics());
auto nonValueTensor =
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
}
});
return success();
@ -85,12 +85,8 @@ public:
"Torch JIT operators shouldn't have regions or successors");
Operation *newOp = rewriter.createOperation(state);
auto tensor = rewriter.create<CopyTensorOp>(op->getLoc(),
newOp->getResult(0)
.getType()
.cast<NonValueTensorType>()
.getWithValueSemantics(),
newOp->getResult(0));
auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));

View File

@ -141,8 +141,8 @@ public:
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
if (isa<TensorStaticInfoCastOp, CopyTensorOp, AtenTanhOp, AtenBatchNormOp,
AtenReluOp>(op)) {
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
AtenTanhOp, AtenBatchNormOp, AtenReluOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}
if (isa<AtenMmOp>(op)) {
@ -303,12 +303,13 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
// which allows arbitrary refinements. But some other cases are safe too,
// such as when an op has two types that are coupled, but we know that our
// analysis and updating logic will correctly maintain the invariants of the op.
// The `torch.copy.tensor` is an example of the latter case, since its
// operand and result types must have the same shape and dtype -- we know
// that our transfer functions and updating logic will do the right thing
// for that op.
// The `torch.copy.to_tensor` / `torch.copy.to_vtensor` are examples of the
// latter case, since their operand and result types must have the same shape
// and dtype -- we know that our transfer functions and updating logic will do
// the right thing forthose ops.
static bool allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(Operation *op) {
return allowsTypeRefinement(op) || isa<CopyTensorOp>(op);
return allowsTypeRefinement(op) ||
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
}
void optimize(FuncOp func, TypeAnalyzer &analyzer) {

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.tensor %[[ERASED]] : !torch.vtensor -> !torch.tensor
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[ERASED]] : !torch.tensor
// CHECK: return %[[NONVAL_TENSOR]] : !torch.tensor
func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
return %arg0 : !torch.tensor
@ -19,9 +19,9 @@ func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
// CHECK-LABEL: func @call(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: %[[ARG_NONVAL:.*]] = torch.copy.tensor %[[ARG_ERASED]] : !torch.vtensor -> !torch.tensor
// CHECK: %[[ARG_NONVAL:.*]] = torch.copy.to_tensor %[[ARG_ERASED]] : !torch.tensor
// CHECK: %[[INFO_ADDED:.*]] = torch.tensor_static_info_cast %[[ARG_NONVAL]] : !torch.tensor to !torch.tensor<[2,3,?],f32>
// CHECK: %[[CALL_ARG:.*]] = torch.copy.tensor %[[INFO_ADDED]] : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
// CHECK: %[[CALL_ARG:.*]] = torch.copy.to_vtensor %[[INFO_ADDED]] : !torch.vtensor<[2,3,?],f32>
// CHECK: %[[CALL_RES:.*]] = call @call(%[[CALL_ARG]]) : (!torch.vtensor<[2,3,?],f32>) -> !torch.tensor
// CHECK: return %[[ARG_NONVAL]] : !torch.tensor
func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {

View File

@ -78,23 +78,6 @@ func @torch.aten.len.t$of_build_list(%arg0: !torch.int) -> !torch.int {
return %1 : !torch.int
}
// CHECK-LABEL: func @torch.copy.tensor$value_copy_is_noop(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: return %[[ARG]] : !torch.vtensor
func @torch.copy.tensor$value_copy_is_noop(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: return %[[ARG]] : !torch.vtensor
func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
return %1 : !torch.vtensor
}
// CHECK-LABEL: func @torch.aten.__getitem__.t(
// CHECK: %[[C5:.*]] = torch.constant.int 5
// CHECK: return %[[C5]] : !torch.int
@ -179,13 +162,3 @@ func @torch.prim.If$erase_dead_branch(%arg0: !torch.int) -> !torch.int {
}
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.copy.tensor$untouched_nonval(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK-NEXT: return %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor
func @torch.copy.tensor$untouched_nonval(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
%2 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
return %1, %2 : !torch.vtensor, !torch.vtensor
}

View File

@ -1,22 +1,72 @@
// RUN: npcomp-opt -split-input-file %s -torch-maximize-value-semantics | FileCheck %s
// RUN: npcomp-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
// Basic case that can be resolved with local reasoning.
// This pass will eventually need to learn about aliasing relationships.
//
// This is taken from a test case from an e2e spike, and isn't intended to be
// particularly minimal or specifically test one thing, since the pass is
// currently just a handful of canonicalization patterns that are already
// tested elsewhere.
// CHECK-LABEL: func @local(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
// CHECK: %[[RET:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[2,3,?],f32>
func @local(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
%1 = torch.aten.tanh %0 : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
%2 = torch.copy.tensor %1 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
%3 = torch.tensor_static_info_cast %2 : !torch.tensor<[2,3,?],f32> to !torch.tensor
%4 = torch.copy.tensor %2 : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
return %4 : !torch.vtensor<[2,3,?],f32>
// CHECK-LABEL: func @torch.copy.tensor$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.vtensor, !torch.vtensor
func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.copy.to_vtensor %0 : !torch.vtensor
%2 = torch.copy.to_vtensor %0 : !torch.vtensor
return %1, %2 : !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @one_mutation_in_a_block(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor
func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
%equal_to_arg1 = torch.copy.to_vtensor %0 : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @multiple_mutations_in_a_block(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, %[[ARG1:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
// CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG2]] : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
// The mutable tensor we are overwriting.
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
// The original value.
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
// Overwrite with %arg1
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
%equal_to_arg1_again = torch.copy.to_vtensor %tensor : !torch.vtensor
// Overwrite with %arg2
torch.overwrite.tensor %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
%equal_to_arg2 = torch.copy.to_vtensor %tensor : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @unmodeled_mutation(
// CHECK: torch.overwrite.tensor
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
"some.op"(%0) : (!torch.tensor) -> ()
%result = torch.copy.to_vtensor %0 : !torch.vtensor
return %result : !torch.vtensor
}
// We don't yet handle nontrivial cases involving control flow.
// CHECK-LABEL: func @unimplemented_control_flow(
// CHECK: torch.copy.to_vtensor
func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %cond: !torch.bool) -> (!torch.vtensor, !torch.vtensor) {
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
torch.prim.If %cond -> () {
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
}

View File

@ -2,9 +2,9 @@
// CHECK-LABEL: func @convert_to_value_semantic_tensors(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.tensor %[[ARG]] : !torch.tensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.to_vtensor %[[ARG]] : !torch.vtensor<[],f32>
// CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.copy.tensor %[[RESULT_TENSOR]] : !torch.vtensor<[],f32> -> !torch.tensor<[],f32>
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[RESULT_TENSOR]] : !torch.tensor<[],f32>
// CHECK: return %[[RET]] : !torch.tensor<[],f32>
func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
%0 = torch.aten.tanh %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
@ -16,14 +16,14 @@ func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR0:.*]] = torch.copy.tensor %[[ARG0]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
// CHECK: %[[TENSOR1:.*]] = torch.copy.tensor %[[ARG1]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
// CHECK: %[[TENSOR0:.*]] = torch.copy.to_vtensor %[[ARG0]] : !torch.vtensor<[2,2],f32>
// CHECK: %[[TENSOR1:.*]] = torch.copy.to_vtensor %[[ARG1]] : !torch.vtensor<[2,2],f32>
// CHECK: %[[TENSOR_RESULT:.*]] = torch.aten.add.Tensor %[[TENSOR0]], %[[TENSOR1]], %[[C1]] : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, !torch.int -> !torch.vtensor<[2,2],f32>
// Note: This somewhat redundant conversion back and forth
// (which is cleaned up by canonicalization) is an artifact of two patterns
// being applied in sequence.
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.tensor %[[TENSOR_RESULT]] : !torch.vtensor<[2,2],f32> -> !torch.tensor<[2,2],f32>
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.tensor %[[ARRAY_RESULT]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
@ -35,7 +35,7 @@ func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>
// CHECK-LABEL: func @torch.tensor.literal() -> !torch.tensor {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
// CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor
// CHECK: %[[TENSOR:.*]] = torch.copy.tensor %[[SIZES_ERASED]] : !torch.vtensor -> !torch.tensor
// CHECK: %[[TENSOR:.*]] = torch.copy.to_tensor %[[SIZES_ERASED]] : !torch.tensor
// CHECK: return %[[TENSOR]] : !torch.tensor
func @torch.tensor.literal() -> !torch.tensor {
%0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor

View File

@ -2,12 +2,11 @@
// CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
// CHECK: %[[COPIED_NONVAL:.*]] = torch.copy.tensor %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
// CHECK: %[[COPIED_VALUE:.*]] = torch.copy.tensor %[[COPIED_NONVAL]] : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
// CHECK: %[[COPIED_NONVAL:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32>
// CHECK: %[[COPIED_VALUE:.*]] = torch.copy.to_vtensor %[[COPIED_NONVAL]] : !torch.vtensor<[2,3,?],f32>
// CHECK: return %[[COPIED_VALUE]] : !torch.vtensor<[2,3,?],f32>
// CHECK: }
func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%1 = torch.copy.tensor %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
%1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
return %2 : !torch.tensor
}
@ -15,11 +14,11 @@ func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// No conversion on private function.
// CHECK-LABEL: func private @basic_private(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[COPIED:.*]] = torch.copy.tensor %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
// CHECK: %[[COPIED:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32>
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[COPIED]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
// CHECK: return %[[CASTED]] : !torch.tensor
func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%1 = torch.copy.tensor %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
%1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
return %2 : !torch.tensor
}

View File

@ -15,12 +15,12 @@ func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.tensor %[[CASTED]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[CASTED]] : !torch.tensor<[2,3,?],f32>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[NONVAL_TENSOR]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
// CHECK: return %[[ERASED]] : !torch.tensor
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
%1 = torch.copy.tensor %0 : !torch.vtensor -> !torch.tensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
return %1 : !torch.tensor
}