mirror of https://github.com/llvm/torch-mlir
Make MaximizeValueSemantics a bit smarter.
This adds a pattern to MaximizeValueSemantics which does a simple abstract interpretation within a block, which handles simple cases of `torch.overwrite_tensor`, enough to remove all the unnecessary uses of non-value tensors in ResNet right now. Before/after IR: [gist](https://gist.github.com/silvasean/a3e1ef625b19dfc63579f73cd3b543b6) Also, - Split `torch.copy.tensor` into `torch.copy.to_tensor` and `torch.copy.to_vtensor` which convert between value and non-value semantic tensors. This is a much cleaner factorization as they have very separate use cases and properties (e.g. different side effects) - Remove the various canonicalization patterns they had, which were confusing because they resulted in limited forms of maximizing value semantics throughout the pipeline. We should structure our compilation pipeline such that only MaximizeValueSemantics should be maximizing value semantics. - Adjust pass pipeline to only run MaximizeValueSemantics once. - Make OverwriteTensorOp `$value` always be a value tensor and `$overwritten` be a non-value tensor.pull/235/head
parent
6dddb4d4fe
commit
79aade33da
|
@ -905,11 +905,11 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
|
|||
|
||||
This op *cannot* be used to add/remove value semantics from a tensor.
|
||||
For converting between the value-semantic and non-value-semantic domains,
|
||||
use `torch.copy.tensor`. The two ops are kept separate to prevent
|
||||
canonicalizations from accidentally dropping static information. In
|
||||
most cases, after running the `torch-refine-types` pass, this op becomes
|
||||
a no-op (the pass will incorporate the static information into other ops
|
||||
that allow type refinement).
|
||||
use `torch.copy.to_tensor` and `torch.copy.from_tensor`. This op is kept
|
||||
separate to prevent canonicalizations from accidentally dropping static
|
||||
information. In most cases, after running the `torch-refine-types` pass,
|
||||
this op becomes a no-op (the pass will incorporate the static information
|
||||
into other ops that allow type refinement).
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$operand
|
||||
|
@ -922,34 +922,66 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", [
|
||||
def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"operand is corresponding !torch.vtensor",
|
||||
"result", "operand",
|
||||
"$_self.cast<NonValueTensorType>().getWithValueSemantics()">,
|
||||
]> {
|
||||
let summary = "Makes a copy of a tensor.";
|
||||
let summary = "Create a !torch.tensor with the same contents as the operand";
|
||||
let description = [{
|
||||
Changes to the original tensor will not be reflected in the copy.
|
||||
This op is used to convert from !torch.vtensor to !torch.tensor.
|
||||
It does so by allocating a new !torch.tensor and filling it with
|
||||
the contents of the operand.
|
||||
|
||||
This op can be used to interconvert between value-semantic and
|
||||
non-value-semantic tensors. However, this op *does not* allow
|
||||
adding/removing static information about sizes/dtype. For that, use
|
||||
`torch.tensor_static_info_cast`.
|
||||
However, this op *does not* allow adding/removing static information about
|
||||
sizes/dtype. For that, use `torch.tensor_static_info_cast`.
|
||||
|
||||
This op does not have the AllowsTypeRefinement trait because the operand
|
||||
and result types are coupled. Only places that know how to simultaneously
|
||||
update both types should be changing the type of this op.
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$operand
|
||||
Torch_ValueTensorType:$operand
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
Torch_NonValueTensorType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
$operand attr-dict `:` type($result)
|
||||
}];
|
||||
let verifier = "return ::verify(*this);";
|
||||
}
|
||||
|
||||
def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"operand is corresponding !torch.tensor",
|
||||
"result", "operand",
|
||||
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">,
|
||||
]> {
|
||||
let summary = "Create a !torch.vtensor with the same contents as the operand";
|
||||
let description = [{
|
||||
This op is used to convert from !torch.tensor to !torch.vtensor.
|
||||
|
||||
However, this op *does not* allow adding/removing static information about
|
||||
sizes/dtype. For that, use `torch.tensor_static_info_cast`.
|
||||
|
||||
This op does not have the AllowsTypeRefinement trait because the operand
|
||||
and result types are coupled. Only places that know how to simultaneously
|
||||
update both types should be changing the type of this op.
|
||||
}];
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$operand
|
||||
);
|
||||
let results = (outs
|
||||
Torch_ValueTensorType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($result)
|
||||
}];
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
||||
|
@ -961,7 +993,7 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
|||
`value`.
|
||||
|
||||
Immediately after this op has completed, indexing `overwritten` will result
|
||||
in identical values as indexing into `tensor`. Of course, later ops
|
||||
in identical values as indexing into `value`. Of course, later ops
|
||||
might mutate `overwritten`, so this relationship need not hold for the
|
||||
entire program.
|
||||
|
||||
|
@ -969,8 +1001,8 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
|||
shapes or dtypes.
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$value,
|
||||
AnyTorchTensorType:$overwritten
|
||||
Torch_ValueTensorType:$value,
|
||||
Torch_NonValueTensorType:$overwritten
|
||||
);
|
||||
let results = (outs
|
||||
);
|
||||
|
|
|
@ -33,11 +33,17 @@ Value mlir::NPCOMP::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
|||
tensor = builder.create<TensorStaticInfoCastOp>(
|
||||
loc, originalType.getWithSizesAndDtypeFrom(newType), tensor);
|
||||
}
|
||||
// If both the original and new types already have value semantics, a copy is
|
||||
// pointless.
|
||||
if (originalType.isa<ValueTensorType>() && newType.isa<ValueTensorType>())
|
||||
return tensor;
|
||||
return builder.create<CopyTensorOp>(loc, newType, tensor);
|
||||
|
||||
// Unless both the original and new types are both value tensors, we end
|
||||
// up creating one op that converts between the value and non-value tensor
|
||||
// domains. If both the original and new types are both non-value tensors,
|
||||
// then we do the copy by going to a value tensor and back.
|
||||
if (tensor.getType().isa<NonValueTensorType>())
|
||||
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
|
||||
if (newType.isa<NonValueTensorType>())
|
||||
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -504,10 +510,10 @@ bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CopyTensorOp
|
||||
// CopyToNonValueTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CopyTensorOp op) {
|
||||
static LogicalResult verify(CopyToNonValueTensorOp op) {
|
||||
auto resultType = op.getResult().getType().cast<BaseTensorType>();
|
||||
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
|
||||
if (!resultType.hasSameSizesAndDtype(operandType)) {
|
||||
|
@ -517,50 +523,48 @@ static LogicalResult verify(CopyTensorOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult CopyTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||
// A copy between value semantic tensors is a no-op.
|
||||
if (getType().isa<ValueTensorType>() &&
|
||||
getOperand().getType().isa<ValueTensorType>()) {
|
||||
return getOperand();
|
||||
}
|
||||
return nullptr;
|
||||
LogicalResult CopyToNonValueTensorOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto resultType = operands[0].getType().cast<ValueTensorType>();
|
||||
inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
|
||||
return success();
|
||||
}
|
||||
|
||||
void CopyTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
// y = torch.copy.tensor(torch.copy.tensor(x)) -> x
|
||||
// Only safe when `y` and `x` have value semantics, and
|
||||
// all users of the intermediate tensor op treat the tensor as if it
|
||||
// had value semantics (even if it is a NonValueTensorType).
|
||||
patterns.add(+[](CopyTensorOp op, PatternRewriter &rewriter) {
|
||||
auto otherCopy = op.getOperand().getDefiningOp<CopyTensorOp>();
|
||||
if (!otherCopy)
|
||||
return failure();
|
||||
if (!otherCopy.getOperand().getType().isa<ValueTensorType>() ||
|
||||
!op.getResult().getType().isa<ValueTensorType>())
|
||||
return failure();
|
||||
// TODO: Use a proper interface here.
|
||||
// MemoryEffectOpInterface is not powerful enough because it cannot model
|
||||
// aliasing. We don't just care that the user is readonly -- we care also
|
||||
// whether it creates an alias. Basically, we care if the user "treats the
|
||||
// tensor as if it has value semantics".
|
||||
// For now, just hardcode the important case of multiple CopyTensorOp users.
|
||||
if (llvm::all_of(op.getOperand().getUsers(),
|
||||
[](Operation *op) { return isa<CopyTensorOp>(op); })) {
|
||||
rewriter.replaceOp(op, {otherCopy.getOperand()});
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
void CopyTensorOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>>
|
||||
void CopyToNonValueTensorOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
if (getResult().getType().isa<NonValueTensorType>())
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
|
||||
if (getOperand().getType().isa<NonValueTensorType>())
|
||||
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CopyToValueTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CopyToValueTensorOp op) {
|
||||
auto resultType = op.getResult().getType().cast<BaseTensorType>();
|
||||
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
|
||||
if (!resultType.hasSameSizesAndDtype(operandType)) {
|
||||
return op.emitError()
|
||||
<< "operand and result must have same sizes and dtype";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult CopyToValueTensorOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto resultType = operands[0].getType().cast<NonValueTensorType>();
|
||||
inferredReturnTypes.push_back(resultType.getWithValueSemantics());
|
||||
return success();
|
||||
}
|
||||
|
||||
void CopyToValueTensorOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -216,7 +216,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
|||
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
|
||||
return !opsInOriginalProgram.contains(op.getOperation());
|
||||
});
|
||||
target.addLegalOp<CopyTensorOp>();
|
||||
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
|
||||
target.addLegalOp<TensorStaticInfoCastOp>();
|
||||
target.addLegalOp<ConstantNoneOp>();
|
||||
// We don't know how to rewrite it, so mark it as illegal.
|
||||
|
|
|
@ -19,6 +19,65 @@ using namespace mlir;
|
|||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
||||
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Operation *> users;
|
||||
// See if our limited form of analysis is even applicatble.
|
||||
for (Operation *user : copy.getResult().getUsers()) {
|
||||
// We can only analyze within a single basic block.
|
||||
if (user->getBlock() != copy->getBlock())
|
||||
return failure();
|
||||
// We can only analyze these ops.
|
||||
if (!isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
|
||||
return failure();
|
||||
users.push_back(user);
|
||||
}
|
||||
// Sort by order in the block, so we can abstractly interpret the ops.
|
||||
llvm::sort(users, [](Operation *lhs, Operation *rhs) {
|
||||
return lhs->isBeforeInBlock(rhs);
|
||||
});
|
||||
// Do an abstract interpretation within the block.
|
||||
// We track the current value tensor that holds the same contents as the
|
||||
// non-value tensor at each program point as we walk forward.
|
||||
Value currentlyHeldValueTensor = copy.getOperand();
|
||||
for (Operation *user : users) {
|
||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
||||
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
|
||||
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) {
|
||||
currentlyHeldValueTensor = overwriteTensor.value();
|
||||
rewriter.eraseOp(overwriteTensor);
|
||||
} else {
|
||||
llvm_unreachable("only those ops supported!");
|
||||
}
|
||||
}
|
||||
rewriter.eraseOp(copy);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class RewriteNonValueTensorNeverMutatedOrAliased
|
||||
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Operation *> users;
|
||||
// See if our limited form of analysis is even applicatble.
|
||||
for (Operation *user : copy.getResult().getUsers()) {
|
||||
if (!isa<CopyToValueTensorOp>(user))
|
||||
return failure();
|
||||
users.push_back(user);
|
||||
}
|
||||
for (Operation *user : users)
|
||||
rewriter.replaceOp(user, copy.getOperand());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
class MaximizeValueSemanticsPass
|
||||
|
@ -28,8 +87,8 @@ class MaximizeValueSemanticsPass
|
|||
auto func = getOperation();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
CopyTensorOp::getCanonicalizationPatterns(patterns, context);
|
||||
TensorStaticInfoCastOp::getCanonicalizationPatterns(patterns, context);
|
||||
patterns.insert<AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock,
|
||||
RewriteNonValueTensorNeverMutatedOrAliased>(context);
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
|
|
@ -120,14 +120,12 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
|||
// Lowering to ranked !torch.vtensors of known dtype.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Convert the bulk of non-ABI-visible arrays to tensors.
|
||||
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||
// Do shape and dtype refinement.
|
||||
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
|
||||
// Propagate to ABI return types the shape/dtype information discovered by
|
||||
// the previous pass. Doing this is ABI-compatible for our backends.
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
// Clean up a few stray conversion remnants.
|
||||
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
|
||||
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||
|
||||
if (options.optimize) {
|
||||
|
|
|
@ -37,8 +37,8 @@ public:
|
|||
opOperand.get().getType().dyn_cast<NonValueTensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
opOperand.set(rewriter.create<CopyTensorOp>(
|
||||
op->getLoc(), tensorType.getWithValueSemantics(), opOperand.get()));
|
||||
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
||||
opOperand.get()));
|
||||
}
|
||||
// Convert all results.
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
|
@ -46,10 +46,10 @@ public:
|
|||
auto tensorType = result.getType().dyn_cast<NonValueTensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
auto createArray = rewriter.create<CopyTensorOp>(
|
||||
op->getLoc(), result.getType(), result);
|
||||
result.replaceAllUsesExcept(createArray, createArray);
|
||||
result.setType(tensorType.getWithValueSemantics());
|
||||
auto nonValueTensor =
|
||||
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
|
||||
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
|
||||
}
|
||||
});
|
||||
return success();
|
||||
|
@ -85,12 +85,8 @@ public:
|
|||
"Torch JIT operators shouldn't have regions or successors");
|
||||
|
||||
Operation *newOp = rewriter.createOperation(state);
|
||||
auto tensor = rewriter.create<CopyTensorOp>(op->getLoc(),
|
||||
newOp->getResult(0)
|
||||
.getType()
|
||||
.cast<NonValueTensorType>()
|
||||
.getWithValueSemantics(),
|
||||
newOp->getResult(0));
|
||||
auto tensor =
|
||||
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
||||
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
|
||||
|
|
|
@ -141,8 +141,8 @@ public:
|
|||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
|
||||
if (isa<TensorStaticInfoCastOp, CopyTensorOp, AtenTanhOp, AtenBatchNormOp,
|
||||
AtenReluOp>(op)) {
|
||||
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
|
||||
AtenTanhOp, AtenBatchNormOp, AtenReluOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
if (isa<AtenMmOp>(op)) {
|
||||
|
@ -303,12 +303,13 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
|
|||
// which allows arbitrary refinements. But some other cases are safe too,
|
||||
// such as when an op has two types that are coupled, but we know that our
|
||||
// analysis and updating logic will correctly maintain the invariants of the op.
|
||||
// The `torch.copy.tensor` is an example of the latter case, since its
|
||||
// operand and result types must have the same shape and dtype -- we know
|
||||
// that our transfer functions and updating logic will do the right thing
|
||||
// for that op.
|
||||
// The `torch.copy.to_tensor` / `torch.copy.to_vtensor` are examples of the
|
||||
// latter case, since their operand and result types must have the same shape
|
||||
// and dtype -- we know that our transfer functions and updating logic will do
|
||||
// the right thing forthose ops.
|
||||
static bool allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(Operation *op) {
|
||||
return allowsTypeRefinement(op) || isa<CopyTensorOp>(op);
|
||||
return allowsTypeRefinement(op) ||
|
||||
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
|
||||
}
|
||||
|
||||
void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// CHECK-LABEL: func @basic(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.tensor %[[ERASED]] : !torch.vtensor -> !torch.tensor
|
||||
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[ERASED]] : !torch.tensor
|
||||
// CHECK: return %[[NONVAL_TENSOR]] : !torch.tensor
|
||||
func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
|
||||
return %arg0 : !torch.tensor
|
||||
|
@ -19,9 +19,9 @@ func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
|
|||
// CHECK-LABEL: func @call(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
// CHECK: %[[ARG_NONVAL:.*]] = torch.copy.tensor %[[ARG_ERASED]] : !torch.vtensor -> !torch.tensor
|
||||
// CHECK: %[[ARG_NONVAL:.*]] = torch.copy.to_tensor %[[ARG_ERASED]] : !torch.tensor
|
||||
// CHECK: %[[INFO_ADDED:.*]] = torch.tensor_static_info_cast %[[ARG_NONVAL]] : !torch.tensor to !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[CALL_ARG:.*]] = torch.copy.tensor %[[INFO_ADDED]] : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: %[[CALL_ARG:.*]] = torch.copy.to_vtensor %[[INFO_ADDED]] : !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: %[[CALL_RES:.*]] = call @call(%[[CALL_ARG]]) : (!torch.vtensor<[2,3,?],f32>) -> !torch.tensor
|
||||
// CHECK: return %[[ARG_NONVAL]] : !torch.tensor
|
||||
func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
|
||||
|
|
|
@ -78,23 +78,6 @@ func @torch.aten.len.t$of_build_list(%arg0: !torch.int) -> !torch.int {
|
|||
return %1 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.copy.tensor$value_copy_is_noop(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
|
||||
// CHECK: return %[[ARG]] : !torch.vtensor
|
||||
func @torch.copy.tensor$value_copy_is_noop(%arg0: !torch.vtensor) -> !torch.vtensor {
|
||||
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
|
||||
// CHECK: return %[[ARG]] : !torch.vtensor
|
||||
func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(%arg0: !torch.vtensor) -> !torch.vtensor {
|
||||
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
|
||||
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.__getitem__.t(
|
||||
// CHECK: %[[C5:.*]] = torch.constant.int 5
|
||||
// CHECK: return %[[C5]] : !torch.int
|
||||
|
@ -179,13 +162,3 @@ func @torch.prim.If$erase_dead_branch(%arg0: !torch.int) -> !torch.int {
|
|||
}
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.copy.tensor$untouched_nonval(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
// CHECK-NEXT: return %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor
|
||||
func @torch.copy.tensor$untouched_nonval(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
|
||||
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
||||
%2 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
||||
return %1, %2 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
|
|
@ -1,22 +1,72 @@
|
|||
// RUN: npcomp-opt -split-input-file %s -torch-maximize-value-semantics | FileCheck %s
|
||||
// RUN: npcomp-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
|
||||
|
||||
// Basic case that can be resolved with local reasoning.
|
||||
// This pass will eventually need to learn about aliasing relationships.
|
||||
//
|
||||
// This is taken from a test case from an e2e spike, and isn't intended to be
|
||||
// particularly minimal or specifically test one thing, since the pass is
|
||||
// currently just a handful of canonicalization patterns that are already
|
||||
// tested elsewhere.
|
||||
|
||||
// CHECK-LABEL: func @local(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[2,3,?],f32>
|
||||
func @local(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
|
||||
%1 = torch.aten.tanh %0 : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
%2 = torch.copy.tensor %1 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
%3 = torch.tensor_static_info_cast %2 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||
%4 = torch.copy.tensor %2 : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
return %4 : !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK-LABEL: func @torch.copy.tensor$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.vtensor, !torch.vtensor
|
||||
func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||
%1 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||
%2 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||
return %1, %2 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @one_mutation_in_a_block(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
// CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor
|
||||
func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||
%equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
|
||||
%equal_to_arg1 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_mutations_in_a_block(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, %[[ARG1:.*]]: !torch.vtensor,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
|
||||
// CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG2]] : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
|
||||
func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
|
||||
// The mutable tensor we are overwriting.
|
||||
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||
|
||||
// The original value.
|
||||
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
|
||||
// Overwrite with %arg1
|
||||
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
%equal_to_arg1_again = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
|
||||
// Overwrite with %arg2
|
||||
torch.overwrite.tensor %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||
%equal_to_arg2 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
|
||||
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unmodeled_mutation(
|
||||
// CHECK: torch.overwrite.tensor
|
||||
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
|
||||
"some.op"(%0) : (!torch.tensor) -> ()
|
||||
%result = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||
return %result : !torch.vtensor
|
||||
}
|
||||
|
||||
// We don't yet handle nontrivial cases involving control flow.
|
||||
// CHECK-LABEL: func @unimplemented_control_flow(
|
||||
// CHECK: torch.copy.to_vtensor
|
||||
func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %cond: !torch.bool) -> (!torch.vtensor, !torch.vtensor) {
|
||||
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
torch.prim.If %cond -> () {
|
||||
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
// CHECK-LABEL: func @convert_to_value_semantic_tensors(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
// CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.tensor %[[ARG]] : !torch.tensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.to_vtensor %[[ARG]] : !torch.vtensor<[],f32>
|
||||
// CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[RET:.*]] = torch.copy.tensor %[[RESULT_TENSOR]] : !torch.vtensor<[],f32> -> !torch.tensor<[],f32>
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[RESULT_TENSOR]] : !torch.tensor<[],f32>
|
||||
// CHECK: return %[[RET]] : !torch.tensor<[],f32>
|
||||
func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
%0 = torch.aten.tanh %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
|
||||
|
@ -16,14 +16,14 @@ func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.
|
|||
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[TENSOR0:.*]] = torch.copy.tensor %[[ARG0]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR1:.*]] = torch.copy.tensor %[[ARG1]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR0:.*]] = torch.copy.to_vtensor %[[ARG0]] : !torch.vtensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR1:.*]] = torch.copy.to_vtensor %[[ARG1]] : !torch.vtensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR_RESULT:.*]] = torch.aten.add.Tensor %[[TENSOR0]], %[[TENSOR1]], %[[C1]] : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, !torch.int -> !torch.vtensor<[2,2],f32>
|
||||
// Note: This somewhat redundant conversion back and forth
|
||||
// (which is cleaned up by canonicalization) is an artifact of two patterns
|
||||
// being applied in sequence.
|
||||
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.tensor %[[TENSOR_RESULT]] : !torch.vtensor<[2,2],f32> -> !torch.tensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.tensor %[[ARRAY_RESULT]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
||||
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
|
||||
// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
||||
|
@ -35,7 +35,7 @@ func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>
|
|||
// CHECK-LABEL: func @torch.tensor.literal() -> !torch.tensor {
|
||||
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
|
||||
// CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor
|
||||
// CHECK: %[[TENSOR:.*]] = torch.copy.tensor %[[SIZES_ERASED]] : !torch.vtensor -> !torch.tensor
|
||||
// CHECK: %[[TENSOR:.*]] = torch.copy.to_tensor %[[SIZES_ERASED]] : !torch.tensor
|
||||
// CHECK: return %[[TENSOR]] : !torch.tensor
|
||||
func @torch.tensor.literal() -> !torch.tensor {
|
||||
%0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor
|
||||
|
|
|
@ -2,12 +2,11 @@
|
|||
|
||||
// CHECK-LABEL: func @basic(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
||||
// CHECK: %[[COPIED_NONVAL:.*]] = torch.copy.tensor %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[COPIED_VALUE:.*]] = torch.copy.tensor %[[COPIED_NONVAL]] : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: %[[COPIED_NONVAL:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[COPIED_VALUE:.*]] = torch.copy.to_vtensor %[[COPIED_NONVAL]] : !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: return %[[COPIED_VALUE]] : !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: }
|
||||
func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
%1 = torch.copy.tensor %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
%1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
|
||||
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||
return %2 : !torch.tensor
|
||||
}
|
||||
|
@ -15,11 +14,11 @@ func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
|||
// No conversion on private function.
|
||||
// CHECK-LABEL: func private @basic_private(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[COPIED:.*]] = torch.copy.tensor %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[COPIED:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[COPIED]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[CASTED]] : !torch.tensor
|
||||
func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
%1 = torch.copy.tensor %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
%1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
|
||||
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||
return %2 : !torch.tensor
|
||||
}
|
||||
|
|
|
@ -15,12 +15,12 @@ func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
|||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.tensor %[[CASTED]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[CASTED]] : !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[NONVAL_TENSOR]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[ERASED]] : !torch.tensor
|
||||
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
%1 = torch.copy.tensor %0 : !torch.vtensor -> !torch.tensor
|
||||
%1 = torch.copy.to_tensor %0 : !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue