Make MaximizeValueSemantics a bit smarter.

This adds a pattern to MaximizeValueSemantics which does a simple
abstract interpretation within a block, which handles simple cases of
`torch.overwrite_tensor`, enough to remove all the unnecessary uses of
non-value tensors in ResNet right now.

Before/after IR:
[gist](https://gist.github.com/silvasean/a3e1ef625b19dfc63579f73cd3b543b6)

Also,
- Split `torch.copy.tensor` into `torch.copy.to_tensor` and
  `torch.copy.to_vtensor` which convert between value and non-value
  semantic tensors. This is a much cleaner factorization as they have
  very separate use cases and properties (e.g. different side effects)
- Remove the various canonicalization patterns they had, which were
  confusing because they resulted in limited forms of maximizing value
  semantics throughout the pipeline. We should structure our compilation
  pipeline such that only MaximizeValueSemantics should be maximizing
  value semantics.
- Adjust pass pipeline to only run MaximizeValueSemantics once.
- Make OverwriteTensorOp `$value` always be a value tensor and
  `$overwritten` be a non-value tensor.
pull/235/head
Sean Silva 2021-06-18 13:47:47 -07:00
parent 6dddb4d4fe
commit 79aade33da
13 changed files with 269 additions and 157 deletions

View File

@ -905,11 +905,11 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
This op *cannot* be used to add/remove value semantics from a tensor. This op *cannot* be used to add/remove value semantics from a tensor.
For converting between the value-semantic and non-value-semantic domains, For converting between the value-semantic and non-value-semantic domains,
use `torch.copy.tensor`. The two ops are kept separate to prevent use `torch.copy.to_tensor` and `torch.copy.from_tensor`. This op is kept
canonicalizations from accidentally dropping static information. In separate to prevent canonicalizations from accidentally dropping static
most cases, after running the `torch-refine-types` pass, this op becomes information. In most cases, after running the `torch-refine-types` pass,
a no-op (the pass will incorporate the static information into other ops this op becomes a no-op (the pass will incorporate the static information
that allow type refinement). into other ops that allow type refinement).
}]; }];
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$operand AnyTorchTensorType:$operand
@ -922,34 +922,66 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
}]; }];
} }
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", [ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.vtensor",
"result", "operand",
"$_self.cast<NonValueTensorType>().getWithValueSemantics()">,
]> { ]> {
let summary = "Makes a copy of a tensor."; let summary = "Create a !torch.tensor with the same contents as the operand";
let description = [{ let description = [{
Changes to the original tensor will not be reflected in the copy. This op is used to convert from !torch.vtensor to !torch.tensor.
It does so by allocating a new !torch.tensor and filling it with
the contents of the operand.
This op can be used to interconvert between value-semantic and However, this op *does not* allow adding/removing static information about
non-value-semantic tensors. However, this op *does not* allow sizes/dtype. For that, use `torch.tensor_static_info_cast`.
adding/removing static information about sizes/dtype. For that, use
`torch.tensor_static_info_cast`.
This op does not have the AllowsTypeRefinement trait because the operand This op does not have the AllowsTypeRefinement trait because the operand
and result types are coupled. Only places that know how to simultaneously and result types are coupled. Only places that know how to simultaneously
update both types should be changing the type of this op. update both types should be changing the type of this op.
}]; }];
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$operand Torch_ValueTensorType:$operand
); );
let results = (outs let results = (outs
AnyTorchTensorType:$result Torch_NonValueTensorType:$result
); );
let assemblyFormat = [{ let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result) $operand attr-dict `:` type($result)
}];
let verifier = "return ::verify(*this);";
}
def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.tensor",
"result", "operand",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">,
]> {
let summary = "Create a !torch.vtensor with the same contents as the operand";
let description = [{
This op is used to convert from !torch.tensor to !torch.vtensor.
However, this op *does not* allow adding/removing static information about
sizes/dtype. For that, use `torch.tensor_static_info_cast`.
This op does not have the AllowsTypeRefinement trait because the operand
and result types are coupled. Only places that know how to simultaneously
update both types should be changing the type of this op.
}];
let arguments = (ins
Torch_NonValueTensorType:$operand
);
let results = (outs
Torch_ValueTensorType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($result)
}]; }];
let verifier = "return ::verify(*this);"; let verifier = "return ::verify(*this);";
let hasFolder = 1;
let hasCanonicalizer = 1;
} }
def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
@ -961,7 +993,7 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
`value`. `value`.
Immediately after this op has completed, indexing `overwritten` will result Immediately after this op has completed, indexing `overwritten` will result
in identical values as indexing into `tensor`. Of course, later ops in identical values as indexing into `value`. Of course, later ops
might mutate `overwritten`, so this relationship need not hold for the might mutate `overwritten`, so this relationship need not hold for the
entire program. entire program.
@ -969,8 +1001,8 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
shapes or dtypes. shapes or dtypes.
}]; }];
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$value, Torch_ValueTensorType:$value,
AnyTorchTensorType:$overwritten Torch_NonValueTensorType:$overwritten
); );
let results = (outs let results = (outs
); );

View File

@ -33,11 +33,17 @@ Value mlir::NPCOMP::Torch::copyTensorToType(OpBuilder &builder, Location loc,
tensor = builder.create<TensorStaticInfoCastOp>( tensor = builder.create<TensorStaticInfoCastOp>(
loc, originalType.getWithSizesAndDtypeFrom(newType), tensor); loc, originalType.getWithSizesAndDtypeFrom(newType), tensor);
} }
// If both the original and new types already have value semantics, a copy is
// pointless. // Unless both the original and new types are both value tensors, we end
if (originalType.isa<ValueTensorType>() && newType.isa<ValueTensorType>()) // up creating one op that converts between the value and non-value tensor
return tensor; // domains. If both the original and new types are both non-value tensors,
return builder.create<CopyTensorOp>(loc, newType, tensor); // then we do the copy by going to a value tensor and back.
if (tensor.getType().isa<NonValueTensorType>())
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
if (newType.isa<NonValueTensorType>())
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
return tensor;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -504,10 +510,10 @@ bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// CopyTensorOp // CopyToNonValueTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(CopyTensorOp op) { static LogicalResult verify(CopyToNonValueTensorOp op) {
auto resultType = op.getResult().getType().cast<BaseTensorType>(); auto resultType = op.getResult().getType().cast<BaseTensorType>();
auto operandType = op.getOperand().getType().cast<BaseTensorType>(); auto operandType = op.getOperand().getType().cast<BaseTensorType>();
if (!resultType.hasSameSizesAndDtype(operandType)) { if (!resultType.hasSameSizesAndDtype(operandType)) {
@ -517,50 +523,48 @@ static LogicalResult verify(CopyTensorOp op) {
return success(); return success();
} }
OpFoldResult CopyTensorOp::fold(ArrayRef<Attribute> operands) { LogicalResult CopyToNonValueTensorOp::inferReturnTypes(
// A copy between value semantic tensors is a no-op. MLIRContext *context, Optional<Location> location, ValueRange operands,
if (getType().isa<ValueTensorType>() && DictionaryAttr attributes, RegionRange regions,
getOperand().getType().isa<ValueTensorType>()) { SmallVectorImpl<Type> &inferredReturnTypes) {
return getOperand(); auto resultType = operands[0].getType().cast<ValueTensorType>();
} inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
return nullptr; return success();
} }
void CopyTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, void CopyToNonValueTensorOp::getEffects(
MLIRContext *context) { SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
// y = torch.copy.tensor(torch.copy.tensor(x)) -> x
// Only safe when `y` and `x` have value semantics, and
// all users of the intermediate tensor op treat the tensor as if it
// had value semantics (even if it is a NonValueTensorType).
patterns.add(+[](CopyTensorOp op, PatternRewriter &rewriter) {
auto otherCopy = op.getOperand().getDefiningOp<CopyTensorOp>();
if (!otherCopy)
return failure();
if (!otherCopy.getOperand().getType().isa<ValueTensorType>() ||
!op.getResult().getType().isa<ValueTensorType>())
return failure();
// TODO: Use a proper interface here.
// MemoryEffectOpInterface is not powerful enough because it cannot model
// aliasing. We don't just care that the user is readonly -- we care also
// whether it creates an alias. Basically, we care if the user "treats the
// tensor as if it has value semantics".
// For now, just hardcode the important case of multiple CopyTensorOp users.
if (llvm::all_of(op.getOperand().getUsers(),
[](Operation *op) { return isa<CopyTensorOp>(op); })) {
rewriter.replaceOp(op, {otherCopy.getOperand()});
return success();
}
return failure();
});
}
void CopyTensorOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>>
&effects) { &effects) {
if (getResult().getType().isa<NonValueTensorType>()) effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
effects.emplace_back(MemoryEffects::Allocate::get(), getResult()); }
if (getOperand().getType().isa<NonValueTensorType>())
effects.emplace_back(MemoryEffects::Read::get(), getOperand()); //===----------------------------------------------------------------------===//
// CopyToValueTensorOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(CopyToValueTensorOp op) {
auto resultType = op.getResult().getType().cast<BaseTensorType>();
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
if (!resultType.hasSameSizesAndDtype(operandType)) {
return op.emitError()
<< "operand and result must have same sizes and dtype";
}
return success();
}
LogicalResult CopyToValueTensorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto resultType = operands[0].getType().cast<NonValueTensorType>();
inferredReturnTypes.push_back(resultType.getWithValueSemantics());
return success();
}
void CopyToValueTensorOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -216,7 +216,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) { target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
return !opsInOriginalProgram.contains(op.getOperation()); return !opsInOriginalProgram.contains(op.getOperation());
}); });
target.addLegalOp<CopyTensorOp>(); target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
target.addLegalOp<TensorStaticInfoCastOp>(); target.addLegalOp<TensorStaticInfoCastOp>();
target.addLegalOp<ConstantNoneOp>(); target.addLegalOp<ConstantNoneOp>();
// We don't know how to rewrite it, so mark it as illegal. // We don't know how to rewrite it, so mark it as illegal.

View File

@ -19,6 +19,65 @@ using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch; using namespace mlir::NPCOMP::Torch;
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
: public OpRewritePattern<CopyToNonValueTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
PatternRewriter &rewriter) const override {
SmallVector<Operation *> users;
// See if our limited form of analysis is even applicatble.
for (Operation *user : copy.getResult().getUsers()) {
// We can only analyze within a single basic block.
if (user->getBlock() != copy->getBlock())
return failure();
// We can only analyze these ops.
if (!isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
return failure();
users.push_back(user);
}
// Sort by order in the block, so we can abstractly interpret the ops.
llvm::sort(users, [](Operation *lhs, Operation *rhs) {
return lhs->isBeforeInBlock(rhs);
});
// Do an abstract interpretation within the block.
// We track the current value tensor that holds the same contents as the
// non-value tensor at each program point as we walk forward.
Value currentlyHeldValueTensor = copy.getOperand();
for (Operation *user : users) {
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) {
currentlyHeldValueTensor = overwriteTensor.value();
rewriter.eraseOp(overwriteTensor);
} else {
llvm_unreachable("only those ops supported!");
}
}
rewriter.eraseOp(copy);
return success();
}
};
class RewriteNonValueTensorNeverMutatedOrAliased
: public OpRewritePattern<CopyToNonValueTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
PatternRewriter &rewriter) const override {
SmallVector<Operation *> users;
// See if our limited form of analysis is even applicatble.
for (Operation *user : copy.getResult().getUsers()) {
if (!isa<CopyToValueTensorOp>(user))
return failure();
users.push_back(user);
}
for (Operation *user : users)
rewriter.replaceOp(user, copy.getOperand());
return success();
}
};
namespace { namespace {
class MaximizeValueSemanticsPass class MaximizeValueSemanticsPass
@ -28,8 +87,8 @@ class MaximizeValueSemanticsPass
auto func = getOperation(); auto func = getOperation();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
CopyTensorOp::getCanonicalizationPatterns(patterns, context); patterns.insert<AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock,
TensorStaticInfoCastOp::getCanonicalizationPatterns(patterns, context); RewriteNonValueTensorNeverMutatedOrAliased>(context);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns)); (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
} }
}; };

View File

@ -120,14 +120,12 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
// Lowering to ranked !torch.vtensors of known dtype. // Lowering to ranked !torch.vtensors of known dtype.
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Convert the bulk of non-ABI-visible arrays to tensors.
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
// Do shape and dtype refinement. // Do shape and dtype refinement.
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass()); pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by // Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends. // the previous pass. Doing this is ABI-compatible for our backends.
pm.addPass(Torch::createRefinePublicReturnPass()); pm.addPass(Torch::createRefinePublicReturnPass());
// Clean up a few stray conversion remnants. // Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass()); pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
if (options.optimize) { if (options.optimize) {

View File

@ -37,8 +37,8 @@ public:
opOperand.get().getType().dyn_cast<NonValueTensorType>(); opOperand.get().getType().dyn_cast<NonValueTensorType>();
if (!tensorType) if (!tensorType)
continue; continue;
opOperand.set(rewriter.create<CopyTensorOp>( opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
op->getLoc(), tensorType.getWithValueSemantics(), opOperand.get())); opOperand.get()));
} }
// Convert all results. // Convert all results.
rewriter.setInsertionPointAfter(op); rewriter.setInsertionPointAfter(op);
@ -46,10 +46,10 @@ public:
auto tensorType = result.getType().dyn_cast<NonValueTensorType>(); auto tensorType = result.getType().dyn_cast<NonValueTensorType>();
if (!tensorType) if (!tensorType)
continue; continue;
auto createArray = rewriter.create<CopyTensorOp>(
op->getLoc(), result.getType(), result);
result.replaceAllUsesExcept(createArray, createArray);
result.setType(tensorType.getWithValueSemantics()); result.setType(tensorType.getWithValueSemantics());
auto nonValueTensor =
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
} }
}); });
return success(); return success();
@ -85,12 +85,8 @@ public:
"Torch JIT operators shouldn't have regions or successors"); "Torch JIT operators shouldn't have regions or successors");
Operation *newOp = rewriter.createOperation(state); Operation *newOp = rewriter.createOperation(state);
auto tensor = rewriter.create<CopyTensorOp>(op->getLoc(), auto tensor =
newOp->getResult(0) rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
.getType()
.cast<NonValueTensorType>()
.getWithValueSemantics(),
newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0)); rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0));

View File

@ -141,8 +141,8 @@ public:
ChangeResult ChangeResult
visitOperation(Operation *op, visitOperation(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final { ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
if (isa<TensorStaticInfoCastOp, CopyTensorOp, AtenTanhOp, AtenBatchNormOp, if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
AtenReluOp>(op)) { AtenTanhOp, AtenBatchNormOp, AtenReluOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }
if (isa<AtenMmOp>(op)) { if (isa<AtenMmOp>(op)) {
@ -303,12 +303,13 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
// which allows arbitrary refinements. But some other cases are safe too, // which allows arbitrary refinements. But some other cases are safe too,
// such as when an op has two types that are coupled, but we know that our // such as when an op has two types that are coupled, but we know that our
// analysis and updating logic will correctly maintain the invariants of the op. // analysis and updating logic will correctly maintain the invariants of the op.
// The `torch.copy.tensor` is an example of the latter case, since its // The `torch.copy.to_tensor` / `torch.copy.to_vtensor` are examples of the
// operand and result types must have the same shape and dtype -- we know // latter case, since their operand and result types must have the same shape
// that our transfer functions and updating logic will do the right thing // and dtype -- we know that our transfer functions and updating logic will do
// for that op. // the right thing forthose ops.
static bool allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(Operation *op) { static bool allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(Operation *op) {
return allowsTypeRefinement(op) || isa<CopyTensorOp>(op); return allowsTypeRefinement(op) ||
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
} }
void optimize(FuncOp func, TypeAnalyzer &analyzer) { void optimize(FuncOp func, TypeAnalyzer &analyzer) {

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @basic( // CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor // CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.tensor %[[ERASED]] : !torch.vtensor -> !torch.tensor // CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[ERASED]] : !torch.tensor
// CHECK: return %[[NONVAL_TENSOR]] : !torch.tensor // CHECK: return %[[NONVAL_TENSOR]] : !torch.tensor
func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor { func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
return %arg0 : !torch.tensor return %arg0 : !torch.tensor
@ -19,9 +19,9 @@ func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
// CHECK-LABEL: func @call( // CHECK-LABEL: func @call(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor // CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: %[[ARG_NONVAL:.*]] = torch.copy.tensor %[[ARG_ERASED]] : !torch.vtensor -> !torch.tensor // CHECK: %[[ARG_NONVAL:.*]] = torch.copy.to_tensor %[[ARG_ERASED]] : !torch.tensor
// CHECK: %[[INFO_ADDED:.*]] = torch.tensor_static_info_cast %[[ARG_NONVAL]] : !torch.tensor to !torch.tensor<[2,3,?],f32> // CHECK: %[[INFO_ADDED:.*]] = torch.tensor_static_info_cast %[[ARG_NONVAL]] : !torch.tensor to !torch.tensor<[2,3,?],f32>
// CHECK: %[[CALL_ARG:.*]] = torch.copy.tensor %[[INFO_ADDED]] : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32> // CHECK: %[[CALL_ARG:.*]] = torch.copy.to_vtensor %[[INFO_ADDED]] : !torch.vtensor<[2,3,?],f32>
// CHECK: %[[CALL_RES:.*]] = call @call(%[[CALL_ARG]]) : (!torch.vtensor<[2,3,?],f32>) -> !torch.tensor // CHECK: %[[CALL_RES:.*]] = call @call(%[[CALL_ARG]]) : (!torch.vtensor<[2,3,?],f32>) -> !torch.tensor
// CHECK: return %[[ARG_NONVAL]] : !torch.tensor // CHECK: return %[[ARG_NONVAL]] : !torch.tensor
func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor { func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {

View File

@ -78,23 +78,6 @@ func @torch.aten.len.t$of_build_list(%arg0: !torch.int) -> !torch.int {
return %1 : !torch.int return %1 : !torch.int
} }
// CHECK-LABEL: func @torch.copy.tensor$value_copy_is_noop(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: return %[[ARG]] : !torch.vtensor
func @torch.copy.tensor$value_copy_is_noop(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: return %[[ARG]] : !torch.vtensor
func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
return %1 : !torch.vtensor
}
// CHECK-LABEL: func @torch.aten.__getitem__.t( // CHECK-LABEL: func @torch.aten.__getitem__.t(
// CHECK: %[[C5:.*]] = torch.constant.int 5 // CHECK: %[[C5:.*]] = torch.constant.int 5
// CHECK: return %[[C5]] : !torch.int // CHECK: return %[[C5]] : !torch.int
@ -179,13 +162,3 @@ func @torch.prim.If$erase_dead_branch(%arg0: !torch.int) -> !torch.int {
} }
return %0 : !torch.int return %0 : !torch.int
} }
// CHECK-LABEL: func @torch.copy.tensor$untouched_nonval(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK-NEXT: return %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor
func @torch.copy.tensor$untouched_nonval(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
%2 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
return %1, %2 : !torch.vtensor, !torch.vtensor
}

View File

@ -1,22 +1,72 @@
// RUN: npcomp-opt -split-input-file %s -torch-maximize-value-semantics | FileCheck %s // RUN: npcomp-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
// Basic case that can be resolved with local reasoning. // CHECK-LABEL: func @torch.copy.tensor$basic(
// This pass will eventually need to learn about aliasing relationships. // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// // CHECK: return %[[ARG0]], %[[ARG0]] : !torch.vtensor, !torch.vtensor
// This is taken from a test case from an e2e spike, and isn't intended to be func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// particularly minimal or specifically test one thing, since the pass is %0 = torch.copy.to_tensor %arg0 : !torch.tensor
// currently just a handful of canonicalization patterns that are already %1 = torch.copy.to_vtensor %0 : !torch.vtensor
// tested elsewhere. %2 = torch.copy.to_vtensor %0 : !torch.vtensor
return %1, %2 : !torch.vtensor, !torch.vtensor
// CHECK-LABEL: func @local( }
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
// CHECK: %[[RET:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32> // CHECK-LABEL: func @one_mutation_in_a_block(
// CHECK: return %[[RET]] : !torch.vtensor<[2,3,?],f32> // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
func @local(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32> // CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor
%1 = torch.aten.tanh %0 : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32> func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%2 = torch.copy.tensor %1 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32> %0 = torch.copy.to_tensor %arg0 : !torch.tensor
%3 = torch.tensor_static_info_cast %2 : !torch.tensor<[2,3,?],f32> to !torch.tensor %equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor
%4 = torch.copy.tensor %2 : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32> torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
return %4 : !torch.vtensor<[2,3,?],f32> %equal_to_arg1 = torch.copy.to_vtensor %0 : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @multiple_mutations_in_a_block(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, %[[ARG1:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
// CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG2]] : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
// The mutable tensor we are overwriting.
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
// The original value.
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
// Overwrite with %arg1
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
%equal_to_arg1_again = torch.copy.to_vtensor %tensor : !torch.vtensor
// Overwrite with %arg2
torch.overwrite.tensor %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
%equal_to_arg2 = torch.copy.to_vtensor %tensor : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @unmodeled_mutation(
// CHECK: torch.overwrite.tensor
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
"some.op"(%0) : (!torch.tensor) -> ()
%result = torch.copy.to_vtensor %0 : !torch.vtensor
return %result : !torch.vtensor
}
// We don't yet handle nontrivial cases involving control flow.
// CHECK-LABEL: func @unimplemented_control_flow(
// CHECK: torch.copy.to_vtensor
func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %cond: !torch.bool) -> (!torch.vtensor, !torch.vtensor) {
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
torch.prim.If %cond -> () {
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
} }

View File

@ -2,9 +2,9 @@
// CHECK-LABEL: func @convert_to_value_semantic_tensors( // CHECK-LABEL: func @convert_to_value_semantic_tensors(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.tensor %[[ARG]] : !torch.tensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.to_vtensor %[[ARG]] : !torch.vtensor<[],f32>
// CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.copy.tensor %[[RESULT_TENSOR]] : !torch.vtensor<[],f32> -> !torch.tensor<[],f32> // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[RESULT_TENSOR]] : !torch.tensor<[],f32>
// CHECK: return %[[RET]] : !torch.tensor<[],f32> // CHECK: return %[[RET]] : !torch.tensor<[],f32>
func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
%0 = torch.aten.tanh %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32> %0 = torch.aten.tanh %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
@ -16,14 +16,14 @@ func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) { // CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
// CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR0:.*]] = torch.copy.tensor %[[ARG0]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> // CHECK: %[[TENSOR0:.*]] = torch.copy.to_vtensor %[[ARG0]] : !torch.vtensor<[2,2],f32>
// CHECK: %[[TENSOR1:.*]] = torch.copy.tensor %[[ARG1]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> // CHECK: %[[TENSOR1:.*]] = torch.copy.to_vtensor %[[ARG1]] : !torch.vtensor<[2,2],f32>
// CHECK: %[[TENSOR_RESULT:.*]] = torch.aten.add.Tensor %[[TENSOR0]], %[[TENSOR1]], %[[C1]] : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, !torch.int -> !torch.vtensor<[2,2],f32> // CHECK: %[[TENSOR_RESULT:.*]] = torch.aten.add.Tensor %[[TENSOR0]], %[[TENSOR1]], %[[C1]] : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, !torch.int -> !torch.vtensor<[2,2],f32>
// Note: This somewhat redundant conversion back and forth // Note: This somewhat redundant conversion back and forth
// (which is cleaned up by canonicalization) is an artifact of two patterns // (which is cleaned up by canonicalization) is an artifact of two patterns
// being applied in sequence. // being applied in sequence.
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.tensor %[[TENSOR_RESULT]] : !torch.vtensor<[2,2],f32> -> !torch.tensor<[2,2],f32> // CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.tensor %[[ARRAY_RESULT]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> // CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32> // CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32> // CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) { func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
@ -35,7 +35,7 @@ func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>
// CHECK-LABEL: func @torch.tensor.literal() -> !torch.tensor { // CHECK-LABEL: func @torch.tensor.literal() -> !torch.tensor {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32> // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
// CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor // CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor
// CHECK: %[[TENSOR:.*]] = torch.copy.tensor %[[SIZES_ERASED]] : !torch.vtensor -> !torch.tensor // CHECK: %[[TENSOR:.*]] = torch.copy.to_tensor %[[SIZES_ERASED]] : !torch.tensor
// CHECK: return %[[TENSOR]] : !torch.tensor // CHECK: return %[[TENSOR]] : !torch.tensor
func @torch.tensor.literal() -> !torch.tensor { func @torch.tensor.literal() -> !torch.tensor {
%0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor %0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor

View File

@ -2,12 +2,11 @@
// CHECK-LABEL: func @basic( // CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
// CHECK: %[[COPIED_NONVAL:.*]] = torch.copy.tensor %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32> // CHECK: %[[COPIED_NONVAL:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32>
// CHECK: %[[COPIED_VALUE:.*]] = torch.copy.tensor %[[COPIED_NONVAL]] : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32> // CHECK: %[[COPIED_VALUE:.*]] = torch.copy.to_vtensor %[[COPIED_NONVAL]] : !torch.vtensor<[2,3,?],f32>
// CHECK: return %[[COPIED_VALUE]] : !torch.vtensor<[2,3,?],f32> // CHECK: return %[[COPIED_VALUE]] : !torch.vtensor<[2,3,?],f32>
// CHECK: }
func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%1 = torch.copy.tensor %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32> %1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor %2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
return %2 : !torch.tensor return %2 : !torch.tensor
} }
@ -15,11 +14,11 @@ func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// No conversion on private function. // No conversion on private function.
// CHECK-LABEL: func private @basic_private( // CHECK-LABEL: func private @basic_private(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[COPIED:.*]] = torch.copy.tensor %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32> // CHECK: %[[COPIED:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32>
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[COPIED]] : !torch.tensor<[2,3,?],f32> to !torch.tensor // CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[COPIED]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
// CHECK: return %[[CASTED]] : !torch.tensor // CHECK: return %[[CASTED]] : !torch.tensor
func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%1 = torch.copy.tensor %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32> %1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor %2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
return %2 : !torch.tensor return %2 : !torch.tensor
} }

View File

@ -15,12 +15,12 @@ func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
// CHECK-LABEL: func @f( // CHECK-LABEL: func @f(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32> // CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.tensor %[[CASTED]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32> // CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[CASTED]] : !torch.tensor<[2,3,?],f32>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[NONVAL_TENSOR]] : !torch.tensor<[2,3,?],f32> to !torch.tensor // CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[NONVAL_TENSOR]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
// CHECK: return %[[ERASED]] : !torch.tensor // CHECK: return %[[ERASED]] : !torch.tensor
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
%1 = torch.copy.tensor %0 : !torch.vtensor -> !torch.tensor %1 = torch.copy.to_tensor %0 : !torch.tensor
return %1 : !torch.tensor return %1 : !torch.tensor
} }