Replace RefineTypes with dtype functions (#2105)

This commit adds dtype functions for all the torch ops that did not
previously have one and removes the pass `RefineTypes`, since the
abstract interpretation library now takes care of all the dtype
propagation.

All dtype functions added are tested except for
- `aten.embedding`
- `aten._embedding_bag`
- `aten.embedding_bag`

These functions need a change to the testing framework to allow
specifying the actual data inside the tensor used for testing. I will
fix this in a follow up patch.

Co-authored-by: Jiahao Li <liplus17@163.com>
pull/2124/head
Ramiro Leal-Cavazos 2023-05-12 13:40:45 -07:00 committed by GitHub
parent 28bb866260
commit de02b56e17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 5104 additions and 3084 deletions

View File

@ -242,8 +242,7 @@ The `torchscript-module-to-torch-backend-pipeline` contains the set of simplific
1. LowerToBackendContract: This pass iteratively applies a simplification
pipeline until the backend contract is reached. The simplification pipeline consists of:
- Standard canonicalization.
- Shape refinement. See [shape_lib.md](https://github.com/llvm/torch-mlir/blob/main/docs/shape_lib.md) for detail
- DType refinement. See `RefineTypes`.
- Shape and Dtype refinement. See [abstract_interp_lib.md](https://github.com/llvm/torch-mlir/blob/main/docs/abstract_interp_lib.md) for detail
- Decomposing ops into more primitive ops. See `DecomposeComplexOps`.
### Layering of the PyTorch Dependency
@ -414,8 +413,6 @@ DON'T use a unit test if your lowering pattern could be described as a trivial
like your unit test is just rewriting `b.create<...>(...)` into `CHECK: ...`
then it is probably not a useful unit test.
DON'T add a unit test for trivial changes to RefineTypes.
With the exceptions above, all changes should include appropriate unit tests, as
is standard in the LLVM and MLIR community. This includes full coverage of all
canonicalizations, pretty printing, passes, errors, and diagnostics.

View File

@ -92,8 +92,6 @@ void createTorchDtypeRefinementPipeline(
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
std::unique_ptr<OperationPass<func::FuncOp>>

View File

@ -126,15 +126,6 @@ def AdjustCallingConventions
}];
}
def RefineTypes : Pass<"torch-refine-types", "func::FuncOp"> {
let summary = "Refine types";
let constructor = "mlir::torch::Torch::createRefineTypesPass()";
let description = [{
Refines types of the program. Currently, this means shapes and dtypes of
tensors/arrays.
}];
}
def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
let summary = "Inlines torch.global_slot ops.";
let constructor = "mlir::torch::Torch::createInlineGlobalSlotsPass()";

File diff suppressed because it is too large Load Diff

View File

@ -12,7 +12,6 @@ add_mlir_library(TorchMLIRTorchPasses
RecomposeComplexOps.cpp
ReduceOpVariants.cpp
RefinePublicReturn.cpp
RefineTypes.cpp
ReifyShapeCalculations.cpp
ReifyDtypeCalculations.cpp
ReifyAbstractInterpCalculationsUtils.cpp

View File

@ -103,7 +103,8 @@ static LogicalResult checkType(Operation *op, Type type,
->emitError(
"unsupported by backend contract: tensor with unknown dtype")
.attachNote()
.append("this is likely due to a missing case in RefineTypes");
.append("this is likely due to a missing transfer function in "
"abstract_interp_lib_gen.py");
} else {
return failure();
}

View File

@ -119,20 +119,14 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// Update the return op to return value tensors.
pm.addPass(Torch::createRefinePublicReturnPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Do shape refinement.
// This should be run before RefineTypes (which primarily does dtype
// inference), because Torch type promotion rules actually depend on the shape
// of the operand.
// Do shape and dtype refinement.
// Shape refinement should be run before dtype refinement because Torch type
// promotion rules actually depend on the shape of the operand.
createTorchShapeRefinementPipeline(pm, options);
createTorchDtypeRefinementPipeline(pm, options);
// Refine types in the program, which mainly means inferring dtypes of ops.
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends.
pm.addPass(Torch::createRefinePublicReturnPass());
// This can fold away some branches given the information got from
// RefineTypes before doing maximize value sematics which only works with
// basic blocks.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(

File diff suppressed because it is too large Load Diff

View File

@ -176,20 +176,23 @@ FailureOr<Value> Torch::adjustFunctionArg(
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// !torch.union<int, float> is the type used for `Scalar` inputs. At
// compile time, such inputs will usually be resolved to an `int` or a `float`
// so we need to derefine to match the library function signature.
// !torch.union<int, float> or !torch.union<int, float, none> is the type used
// for (optional) `Scalar` inputs. At compile time, such inputs will usually
// be resolved to an `int` or a `float` so we need to derefine to match the
// library function signature.
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType.isa<Torch::IntType, Torch::FloatType>();
return containedType
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
}))
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// If the operand is NoneType, then we just need to derefine it to the
// optional type in the function signature.
// Operands with type `!torch.none` correspond to library function inputs with
// types like `!torch.optional<...>` or `!torch.union<..., none>`, so here the
// type is derefined to match the expected type of the library function.
if (operandType.isa<Torch::NoneType>()) {
assert(desiredType.isa<Torch::OptionalType>() &&
assert(!desiredType.isa<Torch::NoneType>() &&
"Don't expect library functions to have NoneType parameters");
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}

View File

@ -8,11 +8,248 @@
//===----------------------------------------------------------------------===//
#include "SimplifyAbstractInterpCalculationsUtils.h"
#include "mlir/IR/IRMapping.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class FoldPrimUncheckedCastOp : public OpRewritePattern<PrimUncheckedCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimUncheckedCastOp op,
PatternRewriter &rewriter) const override {
if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) {
return rewriter.notifyMatchFailure(
op, "input tensor type is not a valid subtype of result type");
}
rewriter.replaceOp(op, op.getX());
return success();
}
};
} // namespace
namespace {
// TODO: Only unroll inside the shape calculation region.
// Maybe do this by only applying patterns and folding greedily on the ops
// inside the region + the shape.calculate op itself?
class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimLoopOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
if (!op.isForLike())
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
int64_t maxTripCount;
if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount)))
return rewriter.notifyMatchFailure(
op, "Expected `maxTripCount` to be a constant int");
;
SmallVector<Value> indices;
for (int64_t i = 0; i < maxTripCount; i++) {
// TODO: Add convenience builder.
indices.push_back(rewriter.create<ConstantIntOp>(
loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i)));
}
Block *beforeBlock = op->getBlock();
Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator());
SmallVector<Block *> blocksToMerge;
IRMapping bvm;
// TODO: Helper for region().front()
auto condition =
cast<PrimLoopConditionOp>(op.getRegion().front().getTerminator());
for (int64_t i = 0; i < maxTripCount; i++) {
SmallVector<Value> iterArgs;
if (i == 0) {
llvm::append_range(iterArgs, op.getIterArgsInit());
} else {
llvm::append_range(
iterArgs, llvm::map_range(condition.getIterArgs(),
[&](Value v) { return bvm.lookup(v); }));
}
bvm.clear();
bvm.map(op.getRegion().front().getArgument(0), indices[i]);
bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs);
op.getRegion().cloneInto(afterBlock->getParent(),
afterBlock->getIterator(), bvm);
Block *clonedBlock = bvm.lookup(&op.getRegion().front());
rewriter.eraseOp(clonedBlock->getTerminator());
blocksToMerge.push_back(clonedBlock);
}
blocksToMerge.push_back(afterBlock);
for (Block *block : blocksToMerge)
rewriter.mergeBlocks(block, beforeBlock);
if (maxTripCount == 0) {
rewriter.replaceOp(op, op.getIterArgsInit());
} else {
rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range(
condition.getIterArgs(),
[&](Value v) { return bvm.lookup(v); })));
}
return success();
}
};
} // namespace
namespace {
class AbstractlyInterpretListOpsWithinABlock
: public OpRewritePattern<PrimListConstructOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimListConstructOp op,
PatternRewriter &rewriter) const override {
Block *block = op->getBlock();
auto allUsers = llvm::to_vector<6>(op->getUsers());
// Sort the users into program order.
auto getParentInBlock = [&](Operation *op) {
while (op->getBlock() != block)
op = op->getParentOp();
return op;
};
// Use a stable sort for deterministic results when users are nested in two
// regions of the same parent op.
llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) {
return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs));
});
// We cannot interpret all ops. So first do a check to see up until which
// point we can interpret.
int numUsersToInterpret = 0;
for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) {
Operation *user = allUsers[i];
// If a user potentially mutates the list, then we require it to be in the
// same block for our simple abstract interpretation to work (we can't,
// for example, handle an "append" operation in a loop or other region).
// However, if the op is read-only, then from the purpose of our abstract
// interpretation, we can handle it effectively as though it was at the
// same position as the corresponding parent op in the block under
// consideration.
if (potentiallyMutatesListOperands(user)) {
if (user->getBlock() != block)
break;
}
}
// Truncate the list of users to the number of users we're going to
// interpret.
allUsers.resize(numUsersToInterpret);
auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret);
// For each mutating op (which must be in the same block), we save the
// current state of the list as a vector of Value's. These will then
// be converted to PrimListConstructOp's at the correct program points.
SmallVector<SmallVector<Value>> listLiterals;
SmallVector<Value> runningList;
llvm::append_range(runningList, op->getOperands());
bool generatedNewLiteral = false;
for (Operation *user : usersToInterpret) {
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
if (!append.use_empty())
return rewriter.notifyMatchFailure(
op, "Expected `AtenAppendTOp` to not have users");
if (append.getSelf() == op) {
runningList.push_back(append.getEl());
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
continue;
}
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
if (!insert.use_empty())
return rewriter.notifyMatchFailure(
op, "Expected `AtenInsertTOp` to not have users");
int64_t index;
if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index)))
return rewriter.notifyMatchFailure(
op, "Expected `idx` of `AtenInsertTOp` to be a constant int");
// The index might be statically out of bounds.
if (index < 0 || index > static_cast<int64_t>(runningList.size()))
return rewriter.notifyMatchFailure(
op, "Index in `AtenInsertTOp` is out of bounds");
if (insert.getSelf() == op) {
runningList.insert(runningList.begin() + index, insert.getEl());
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
continue;
}
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
if (!setItem.use_empty())
return rewriter.notifyMatchFailure(
op, "Expected `Aten_SetItemTOp` to not have users");
std::optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
setItem.getIdx(), runningList.size());
// The index might be statically out of bounds.
if (!indexOpt)
return rewriter.notifyMatchFailure(
op, "Index in `Aten_SetItemTOp` is out of bounds");
if (setItem.getL() == op) {
runningList[*indexOpt] = setItem.getEl();
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
continue;
}
// If this user potentially mutates the list and isn't handled above, then
// we can't abstractly interpret any further.
if (potentiallyMutatesListOperands(user))
break;
}
if (!generatedNewLiteral)
return rewriter.notifyMatchFailure(op, "No new literal created");
// Rewrite all users to use the appropriate list literals.
Value latestLiteral = rewriter.create<PrimListConstructOp>(
op->getLoc(), op.getType(), op->getOperands());
int nextLiteral = 0;
for (Operation *user : usersToInterpret) {
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
rewriter.setInsertionPoint(append);
latestLiteral = rewriter.create<PrimListConstructOp>(
append->getLoc(), op.getType(), listLiterals[nextLiteral++]);
if (append.getSelf() == op)
rewriter.eraseOp(append);
continue;
}
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
rewriter.setInsertionPoint(insert);
latestLiteral = rewriter.create<PrimListConstructOp>(
insert->getLoc(), op.getType(), listLiterals[nextLiteral++]);
if (insert.getSelf() == op)
rewriter.eraseOp(insert);
continue;
}
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
rewriter.setInsertionPoint(setItem);
latestLiteral = rewriter.create<PrimListConstructOp>(
setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]);
if (setItem.getL() == op)
rewriter.eraseOp(setItem);
continue;
}
for (OpOperand &opOperand : user->getOpOperands()) {
if (opOperand.get() == op.getResult()) {
opOperand.set(latestLiteral);
}
}
}
// Any remaining uses should use the updated value of the latest literal.
rewriter.replaceOp(op, latestLiteral);
return success();
}
};
} // namespace
LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
int resultNum,
Type newResultType,
@ -97,3 +334,18 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
return success();
}
void mlir::torch::Torch::populateFoldPrimUncheckedCastOpPattern(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<FoldPrimUncheckedCastOp>(context);
}
void mlir::torch::Torch::populateFullyUnrollPrimLoopOpPattern(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<FullyUnrollPrimLoopOp>(context);
}
void mlir::torch::Torch::populateAbstractlyInterpretListOpsWithinABlockPattern(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<AbstractlyInterpretListOpsWithinABlock>(context);
}

View File

@ -23,6 +23,13 @@ LogicalResult updateCalculateOpResultTypes(Operation *calculateOp,
int resultNum, Type newResultType,
PatternRewriter &rewriter);
void populateFoldPrimUncheckedCastOpPattern(RewritePatternSet &patterns,
MLIRContext *context);
void populateFullyUnrollPrimLoopOpPattern(RewritePatternSet &patterns,
MLIRContext *context);
void populateAbstractlyInterpretListOpsWithinABlockPattern(
RewritePatternSet &patterns, MLIRContext *context);
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -191,10 +191,17 @@ class SimplifyDtypeCalculationsPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
populateFullyUnrollPrimLoopOpPattern(patterns, context);
populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context);
populateFoldPrimUncheckedCastOpPattern(patterns, context);
patterns.insert<RefineDtypeCalculateOp>(context);
patterns.insert<DecomposePromoteDtypesOp>(context);
patterns.insert<RefineNumToTensorScalarOpType>(context);
PrimIfOp::getCanonicalizationPatterns(patterns, context);
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
PrimTupleUnpackOp::getCanonicalizationPatterns(patterns, context);
// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
GreedyRewriteConfig config;

View File

@ -10,7 +10,6 @@
#include "PassDetail.h"
#include "SimplifyAbstractInterpCalculationsUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@ -19,225 +18,6 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
// TODO: Only unroll inside the shape calculation region.
// Maybe do this by only applying patterns and folding greedily on the ops
// inside the region + the shape.calculate op itself?
class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimLoopOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
if (!op.isForLike())
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
int64_t maxTripCount;
if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount)))
return rewriter.notifyMatchFailure(
op, "Expected `maxTripCount` to be a constant int");
;
SmallVector<Value> indices;
for (int64_t i = 0; i < maxTripCount; i++) {
// TODO: Add convenience builder.
indices.push_back(rewriter.create<ConstantIntOp>(
loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i)));
}
Block *beforeBlock = op->getBlock();
Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator());
SmallVector<Block *> blocksToMerge;
IRMapping bvm;
// TODO: Helper for region().front()
auto condition =
cast<PrimLoopConditionOp>(op.getRegion().front().getTerminator());
for (int64_t i = 0; i < maxTripCount; i++) {
SmallVector<Value> iterArgs;
if (i == 0) {
llvm::append_range(iterArgs, op.getIterArgsInit());
} else {
llvm::append_range(
iterArgs, llvm::map_range(condition.getIterArgs(),
[&](Value v) { return bvm.lookup(v); }));
}
bvm.clear();
bvm.map(op.getRegion().front().getArgument(0), indices[i]);
bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs);
op.getRegion().cloneInto(afterBlock->getParent(), afterBlock->getIterator(),
bvm);
Block *clonedBlock = bvm.lookup(&op.getRegion().front());
rewriter.eraseOp(clonedBlock->getTerminator());
blocksToMerge.push_back(clonedBlock);
}
blocksToMerge.push_back(afterBlock);
for (Block *block : blocksToMerge)
rewriter.mergeBlocks(block, beforeBlock);
if (maxTripCount == 0) {
rewriter.replaceOp(op, op.getIterArgsInit());
} else {
rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range(
condition.getIterArgs(),
[&](Value v) { return bvm.lookup(v); })));
}
return success();
}
};
} // namespace
namespace {
class AbstractlyInterpretListOpsWithinABlock
: public OpRewritePattern<PrimListConstructOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimListConstructOp op,
PatternRewriter &rewriter) const override {
Block *block = op->getBlock();
auto allUsers = llvm::to_vector<6>(op->getUsers());
// Sort the users into program order.
auto getParentInBlock = [&](Operation *op) {
while (op->getBlock() != block)
op = op->getParentOp();
return op;
};
// Use a stable sort for deterministic results when users are nested in two
// regions of the same parent op.
llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) {
return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs));
});
// We cannot interpret all ops. So first do a check to see up until which
// point we can interpret.
int numUsersToInterpret = 0;
for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) {
Operation *user = allUsers[i];
// If a user potentially mutates the list, then we require it to be in the
// same block for our simple abstract interpretation to work (we can't,
// for example, handle an "append" operation in a loop or other region).
// However, if the op is read-only, then from the purpose of our abstract
// interpretation, we can handle it effectively as though it was at the
// same position as the corresponding parent op in the block under
// consideration.
if (potentiallyMutatesListOperands(user)) {
if (user->getBlock() != block)
break;
}
}
// Truncate the list of users to the number of users we're going to
// interpret.
allUsers.resize(numUsersToInterpret);
auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret);
// For each mutating op (which must be in the same block), we save the
// current state of the list as a vector of Value's. These will then
// be converted to PrimListConstructOp's at the correct program points.
SmallVector<SmallVector<Value>> listLiterals;
SmallVector<Value> runningList;
llvm::append_range(runningList, op->getOperands());
bool generatedNewLiteral = false;
for (Operation *user : usersToInterpret) {
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
if (!append.use_empty())
return rewriter.notifyMatchFailure(
op, "Expected `AtenAppendTOp` to not have users");
if (append.getSelf() == op) {
runningList.push_back(append.getEl());
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
continue;
}
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
if (!insert.use_empty())
return rewriter.notifyMatchFailure(
op, "Expected `AtenInsertTOp` to not have users");
int64_t index;
if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index)))
return rewriter.notifyMatchFailure(
op, "Expected `idx` of `AtenInsertTOp` to be a constant int");
// The index might be statically out of bounds.
if (index < 0 || index > static_cast<int64_t>(runningList.size()))
return rewriter.notifyMatchFailure(
op, "Index in `AtenInsertTOp` is out of bounds");
if (insert.getSelf() == op) {
runningList.insert(runningList.begin() + index, insert.getEl());
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
continue;
}
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
if (!setItem.use_empty())
return rewriter.notifyMatchFailure(
op, "Expected `Aten_SetItemTOp` to not have users");
std::optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
setItem.getIdx(), runningList.size());
// The index might be statically out of bounds.
if (!indexOpt)
return rewriter.notifyMatchFailure(
op, "Index in `Aten_SetItemTOp` is out of bounds");
if (setItem.getL() == op) {
runningList[*indexOpt] = setItem.getEl();
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
continue;
}
// If this user potentially mutates the list and isn't handled above, then
// we can't abstractly interpret any further.
if (potentiallyMutatesListOperands(user))
break;
}
if (!generatedNewLiteral)
return rewriter.notifyMatchFailure(op, "No new literal created");
// Rewrite all users to use the appropriate list literals.
Value latestLiteral = rewriter.create<PrimListConstructOp>(
op->getLoc(), op.getType(), op->getOperands());
int nextLiteral = 0;
for (Operation *user : usersToInterpret) {
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
rewriter.setInsertionPoint(append);
latestLiteral = rewriter.create<PrimListConstructOp>(
append->getLoc(), op.getType(), listLiterals[nextLiteral++]);
if (append.getSelf() == op)
rewriter.eraseOp(append);
continue;
}
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
rewriter.setInsertionPoint(insert);
latestLiteral = rewriter.create<PrimListConstructOp>(
insert->getLoc(), op.getType(), listLiterals[nextLiteral++]);
if (insert.getSelf() == op)
rewriter.eraseOp(insert);
continue;
}
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
rewriter.setInsertionPoint(setItem);
latestLiteral = rewriter.create<PrimListConstructOp>(
setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]);
if (setItem.getL() == op)
rewriter.eraseOp(setItem);
continue;
}
for (OpOperand &opOperand : user->getOpOperands()) {
if (opOperand.get() == op.getResult()) {
opOperand.set(latestLiteral);
}
}
}
// Any remaining uses should use the updated value of the latest literal.
rewriter.replaceOp(op, latestLiteral);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
public:
@ -266,22 +46,6 @@ public:
};
} // namespace
namespace {
class FoldPrimUncheckedCastOp : public OpRewritePattern<PrimUncheckedCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimUncheckedCastOp op,
PatternRewriter &rewriter) const override {
if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) {
return rewriter.notifyMatchFailure(
op, "input tensor type is not a valid subtype of result type");
}
rewriter.replaceOp(op, op.getX());
return success();
}
};
} // namespace
static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
int resultNum,
PatternRewriter &rewriter) {
@ -367,11 +131,11 @@ class SimplifyShapeCalculationsPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<FullyUnrollPrimLoopOp>(context);
patterns.insert<AbstractlyInterpretListOpsWithinABlock>(context);
populateFullyUnrollPrimLoopOpPattern(patterns, context);
populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context);
populateFoldPrimUncheckedCastOpPattern(patterns, context);
patterns.insert<DecomposeAtenSizeOp>(context);
patterns.insert<RefineShapeCalculateOp>(context);
patterns.insert<FoldPrimUncheckedCastOp>(context);
PrimIfOp::getCanonicalizationPatterns(patterns, context);
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);

View File

@ -14,6 +14,55 @@ from torch_mlir.passmanager import PassManager
from .registry import Registry
def all_integer_dtypes() -> List[int]:
return [torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
def is_integer_dtype(dtype: int) -> bool:
return dtype in all_integer_dtypes()
def all_complex_dtypes() -> List[int]:
return [torch.complex64, torch.complex128]
def is_complex_dtype(dtype: int) -> bool:
return dtype in all_complex_dtypes()
def all_float_dtypes() -> List[int]:
return [torch.float16, torch.bfloat16, torch.float32, torch.float64]
def is_float_dtype(dtype: int) -> bool:
return dtype in all_float_dtypes()
def get_priority_of_dtype(dtype: int) -> int:
# If a loop is used to iterate over a list of sorted dtypes, TorchScript
# produces a loop with INT64_MAX max trip count, which causes problems
# during the loop unrolling that takes place when simplifying the dtype
# functions. Therefore, here we resort to `if`s.
if dtype == torch.bool:
return 0
if dtype == torch.uint8:
return 1
if dtype == torch.int8:
return 2
if dtype == torch.int16:
return 3
if dtype == torch.int32:
return 4
if dtype == torch.int64:
return 5
if dtype == torch.bfloat16:
return 6
if dtype == torch.float16:
return 7
if dtype == torch.float32:
return 8
if dtype == torch.float64:
return 9
if dtype == torch.complex64:
return 10
if dtype == torch.complex128:
return 11
assert False, "Cannot determine priority of dtype"
def get_dtype_of_scalar(scalar: Union[int, float]) -> int:
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
# that when `jit.script`ed converts a float scalar to a tensor

View File

@ -60,33 +60,32 @@ class TensorOfShape:
This class also tracks a dtype of the tensor, since some ops require a
specific dtype.
"""
def __init__(self, *shape: int, dtype: torch.dtype = torch.float32):
def __init__(self, *shape: int, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None):
self.shape = list(shape)
self.dtype = dtype
self.device = "meta" if device is None else device
def __repr__(self):
args_str = ", ".join(repr(x) for x in self.shape)
if self.dtype is torch.float32:
return f"TensorOfShape({args_str})"
else:
return f"TensorOfShape({args_str}, dtype={self.dtype})"
return f"TensorOfShape({args_str}, dtype={self.dtype}, device={self.device})"
def LongTensorOfShape(*args, **kwargs):
"""Helper for indicating a TensorOfShape with integer type."""
return TensorOfShape(*args, **kwargs, dtype=torch.long)
def NonZeroDTensorWithDtype(dtype):
def NonZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None):
"""Helper for indicating a non-zero dim tensor with custom type."""
return TensorOfShape(1, dtype=dtype)
return TensorOfShape(1, dtype=dtype, device=device)
def ZeroDTensorWithDtype(dtype):
def ZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None):
"""Helper for indicating a zero dim tensor with custom type."""
return TensorOfShape(dtype=dtype)
return TensorOfShape(dtype=dtype, device=device)
def _recursively_transform_tensor_args(
o: Any,
tensor_transformer: Callable[[TensorOfShape], Any]) -> Any:
"""Replace `TensorOfShape` with the result of `tensor_transformer`"""
if o is None or isinstance(o, (float, int)):
if o is None or isinstance(o, (float, int, str)):
return o
if isinstance(o, TensorOfShape):
return tensor_transformer(o)
@ -146,7 +145,7 @@ class Invocation:
def to_real_op_args(self):
"""Gets positional arguments appropriate for the real op."""
tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype)
tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype).to(o.device)
return _recursively_transform_tensor_args(self.args, tensor_transformer)
def __repr__(self) -> str:
@ -258,6 +257,15 @@ def check_shape_function(invocations: List[Invocation]):
return f
return decorator
@torch.jit.script
def _convert_dtype_to_int(dtype: torch.dtype) -> int:
"""Convert a PyTorch `dtype` into its underlying `int` representation.
This works because in TorchScript there is no special type for `dtypes`;
they are simply `int`s.
"""
return dtype
def check_dtype_function(invocations: List[Invocation]):
"""Decorator that automatically tests a dtype function.
@ -281,7 +289,12 @@ def check_dtype_function(invocations: List[Invocation]):
golden_dtype = torch.tensor([]).to(type(golden_result)).dtype
else:
raise ValueError(f"Unhandled return type {type(golden_result)}")
if result_dtype != golden_dtype:
# Some dtype funtions have default `dtype` parameters, which are
# represented as `int` values in the registry. In order to
# support returning the default `int` value, the comparisons of
# the result and golden dtypes are done using their underlying
# `int` representation.
if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(golden_dtype):
_report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}")
return f
return decorator

View File

@ -1,153 +0,0 @@
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
// -----
// CHECK-LABEL: func.func @prim.if$branch_merge_type_tensor(
// CHECK-SAME: %[[PRED:.*]]: !torch.bool,
// CHECK-SAME: %[[T1:.*]]: !torch.tensor,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor) -> !torch.bool {
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<tensor>) {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T1]] : !torch.tensor to !torch.optional<tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
// CHECK: } else {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T2]] : !torch.tensor to !torch.optional<tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
// CHECK: }
// CHECK: %[[REFINED:.*]] = torch.prim.unchecked_cast %[[MERGED:.*]] : !torch.optional<tensor> -> !torch.tensor
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[RET:.*]] = torch.aten.__isnot__ %[[REFINED]], %[[NONE]] : !torch.tensor, !torch.none -> !torch.bool
// CHECK: return %[[RET]] : !torch.bool
func.func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %t1: !torch.tensor) -> !torch.bool {
%res = torch.prim.If %pred -> (!torch.optional<tensor>) {
%optional0 = torch.derefine %t0: !torch.tensor to !torch.optional<tensor>
torch.prim.If.yield %optional0: !torch.optional<tensor>
} else {
%optional1 = torch.derefine %t1: !torch.tensor to !torch.optional<tensor>
torch.prim.If.yield %optional1: !torch.optional<tensor>
}
%none = torch.constant.none
%cmp = torch.aten.__isnot__ %res, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
return %cmp : !torch.bool
}
// -----
// CHECK-LABEL: func.func @prim.if$branch_merge_type_optional(
// CHECK-SAME: %[[PRED:.*]]: !torch.bool,
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.optional<tensor> {
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<tensor>) {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
// CHECK: } else {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T]] : !torch.tensor to !torch.optional<tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
// CHECK: }
// CHECK: return %[[MERGED:.*]] : !torch.optional<tensor>
func.func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor) -> !torch.optional<tensor> {
%res = torch.prim.If %pred -> (!torch.optional<tensor>) {
%none = torch.constant.none
%optional0 = torch.derefine %none: !torch.none to !torch.optional<tensor>
torch.prim.If.yield %optional0: !torch.optional<tensor>
} else {
%optional1 = torch.derefine %t1: !torch.tensor to !torch.optional<tensor>
torch.prim.If.yield %optional1: !torch.optional<tensor>
}
return %res: !torch.optional<tensor>
}
// -----
// CHECK-LABEL: func.func @prim.if$refined_type_conflicting(
// CHECK-SAME: %[[NONE:.*]]: !torch.none) -> !torch.tensor {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: %[[NOT_NONE:.*]] = torch.aten.__isnot__ %[[NONE]], %[[NONE]] : !torch.none, !torch.none -> !torch.bool
// CHECK: %[[PRED:.*]] = torch.prim.If %[[NOT_NONE]] -> (!torch.tensor) {
// CHECK: %[[T:.*]] = torch.prim.unchecked_cast %[[OPTIONAL]] : !torch.optional<tensor> -> !torch.tensor
// CHECK: torch.prim.If.yield %[[T]] : !torch.tensor
// CHECK: } else {
// CHECK: %[[LITERAL:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<3x5xf32>) : !torch.tensor
// CHECK: torch.prim.If.yield %[[LITERAL]] : !torch.tensor
// CHECK: }
// CHECK: return %[[PRED:.*]] : !torch.tensor
func.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
%optional = torch.derefine %none: !torch.none to !torch.optional<tensor>
%pred = torch.aten.__isnot__ %optional, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
%res = torch.prim.If %pred -> (!torch.tensor) {
%t = torch.prim.unchecked_cast %optional: !torch.optional<tensor> -> !torch.tensor
torch.prim.If.yield %t: !torch.tensor
} else {
%t_cst = torch.tensor.literal(dense<0.0> : tensor<3x5xf32>) : !torch.tensor
torch.prim.If.yield %t_cst: !torch.tensor
}
return %res: !torch.tensor
}
// -----
// CHECK-LABEL: func.func @prim.loop$region_arg_to_internal(
// CHECK-SAME: %[[ARG_NONE:.*]]: !torch.none) -> !torch.optional<tensor> {
// CHECK: %[[INT10:.*]] = torch.constant.int 10
// CHECK: %[[INDV:.*]] = torch.constant.int 0
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[ARG_NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: %[[LOOP_RET:.*]] = torch.prim.Loop %[[INT10]], %[[TRUE]], init(%[[OPTIONAL]]) {
// CHECK: ^bb0(%[[INDV:.*]]: !torch.int, %[[IT:.*]]: !torch.optional<tensor>):
// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[IT]] : !torch.optional<tensor> -> !torch.none
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: %[[COND:.*]] = torch.aten.__isnot__ %[[NONE]], %[[ARG_NONE]] : !torch.none, !torch.none -> !torch.bool
// CHECK: torch.prim.Loop.condition %[[COND]], iter(%[[OPTIONAL]] : !torch.optional<tensor>)
// CHECK: } : (!torch.int, !torch.bool, !torch.optional<tensor>) -> !torch.optional<tensor>
// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[LOOP_RET:.*]] : !torch.optional<tensor> -> !torch.none
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: return %[[OPTIONAL]] : !torch.optional<tensor>
func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional<tensor> {
%int10 = torch.constant.int 10
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%optional = torch.derefine %none: !torch.none to !torch.optional<tensor>
%ret = torch.prim.Loop %int10, %true, init(%optional) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<tensor>): // no predecessors
%cond = torch.aten.__isnot__ %arg3, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
torch.prim.Loop.condition %cond, iter(%arg3: !torch.optional<tensor>)
} : (!torch.int, !torch.bool, !torch.optional<tensor>) -> (!torch.optional<tensor>)
return %ret: !torch.optional<tensor>
}
// -----
// CHECK-LABEL: func.func @f
// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
cf.br ^bb1(%cast: !torch.vtensor)
^bb1(%arg1: !torch.vtensor):
%1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @f
// CHECK: func.func private @callee
// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
func.func @f() {
builtin.module {
func.func private @callee(%arg0: !torch.vtensor) {
%1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor
return
}
func.func @caller(%arg0: !torch.vtensor<*,f32>) {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
call @callee(%cast) : (!torch.vtensor) -> ()
return
}
}
return
}

View File

@ -1,364 +0,0 @@
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
// This file is for tests for individual ops that require a new transfer
// function (i.e. new code called from visitOperation).
// -----
// CHECK-LABEL: func.func @aten.arange.start$int64_dtype(
// CHECK-SAME: %[[START:.*]]: !torch.int,
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[T:.*]] = torch.aten.arange.start
// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none
// CHECK-SAME: -> !torch.vtensor<*,si64>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,si64> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor
func.func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor {
%none = torch.constant.none
%ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @aten.arange.start$float32_dtype(
// CHECK-SAME: %[[START:.*]]: !torch.float,
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[T:.*]] = torch.aten.arange.start
// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none
// CHECK-SAME: -> !torch.vtensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor
func.func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor {
%none = torch.constant.none
%ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @aten.arange.start$specified_dtype(
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[T:.*]] = torch.aten.arange
// CHECK-SAME: %[[END]], %[[CST6]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none
// CHECK-SAME: -> !torch.vtensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor
func.func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
%int6 = torch.constant.int 6
%none = torch.constant.none
%ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.linear(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[5,3],f32>,
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor {
// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[ARG0]], %[[ARG1]], %[[ARG2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<*,f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @aten.sum.dim_IntList(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,si64>) -> !torch.vtensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
// CHECK: %[[DIMLIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT_NEG1]]
// CHECK-SAME: : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RET:.*]] = torch.aten.sum.dim_IntList %[[T]], %[[DIMLIST]], %[[FALSE]], %[[NONE]]
// CHECK-SAME: : !torch.vtensor<*,si64>, !torch.list<int>, !torch.bool, !torch.none
// CHECK-SAME: -> !torch.vtensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,si64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func.func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor {
%false = torch.constant.bool false
%none = torch.constant.none
%int0 = torch.constant.int 0
%int-1 = torch.constant.int -1
%dimList = torch.prim.ListConstruct %int0, %int-1 : (!torch.int, !torch.int) -> !torch.list<int>
%ret = torch.aten.sum.dim_IntList %t, %dimList, %false, %none : !torch.vtensor<*,si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @aten.any.dim(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
// CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[INT_NEG1]], %[[FALSE]] : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor<*,i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func.func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%ret = torch.aten.any.dim %t, %int-1, %false : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor
return %ret : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @aten.any(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor {
// CHECK: %[[RET:.*]] = torch.aten.any %[[T]] : !torch.vtensor<*,i1> -> !torch.vtensor<*,i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func.func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
%ret = torch.aten.any %t: !torch.vtensor<*,i1> -> !torch.vtensor
return %ret : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.zeros(
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
%none = torch.constant.none
%int2 = torch.constant.int 2
%sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%ret = torch.aten.zeros %sizesList, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.type_as(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>,
// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor {
%ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor
return %ret: !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.cat(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[?,1,4],f32>, !torch.tensor<[2,3,4],f32>) -> !torch.list<tensor>
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor {
%int1 = torch.constant.int 1
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<tensor>
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.cat$promote_type(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list<tensor>
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor {
%int1 = torch.constant.int 1
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list<tensor>
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor {
%ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor$unknown_input_shape(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor {
%ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.embedding(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[PADDING_IDX:.*]] = torch.constant.int 1
// CHECK: %[[RET:.*]] = torch.aten.embedding %[[INPUT]], %[[INDEXES]], %[[PADDING_IDX]], %[[FALSE]], %[[FALSE]] : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor
return %ret: !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
%none = torch.constant.none
%false = torch.constant.bool false
%ret = torch.aten.tensor.float %t, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.tensor.float$specified_dtype(
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CST11:.*]] = torch.constant.int 11
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,i1> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor {
%none = torch.constant.none
%int11 = torch.constant.int 11
%false = torch.constant.bool false
%ret = torch.aten.tensor.float %t, %int11, %none, %false : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.softmax.int(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func.func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
%none = torch.constant.none
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.softmax.int$specified_dtype(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[DTYPE:.*]] = torch.constant.int 4
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<*,si64>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
%int4 = torch.constant.int 4
%ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Matrix(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Vector(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<*,f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.to.dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
// CHECK-SAME: -> !torch.tensor<*,si64>
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK-NEXT: return %[[RES]] : !torch.tensor
func.func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
%none = torch.constant.none
%false = torch.constant.bool false
%int4 = torch.constant.int 4
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar(
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
return %0: !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.tensor(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[NONE]], %[[NONE]], %[[FALSE]]
// CHECK-SAME: : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool
// CHECK-SAME: -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
%none = torch.constant.none
%false = torch.constant.bool false
%ret = torch.aten.tensor %t, %none, %none, %false : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.tensor$specified_dtype(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.tensor$specified_dtype(%t: !torch.list<list<float>>) -> !torch.tensor {
%none = torch.constant.none
%int4 = torch.constant.int 4
%false = torch.constant.bool false
%ret = torch.aten.tensor %t, %int4, %none, %false : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}

View File

@ -1,238 +0,0 @@
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
// This file tests the structural logic of the pass. This is for testing logic
// that does not scale with the number of ops supported, such as the core
// propagation logic, rewriting, etc.
// Code for testing transfer functions for new ops (which is most changes)
// should go in refine-types-ops.mlir.
// -----
// CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[COS]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @keep_existing_shape_information(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
// CHECK: return %[[COS]] : !torch.vtensor<[2],f32>
func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32>
return %1 : !torch.vtensor<[2],f32>
}
// -----
// CHECK-LABEL: func.func @propagate_through_multiple_ops(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[COS2:.*]] = torch.aten.cos %[[COS1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[COS3:.*]] = torch.tensor_static_info_cast %[[COS2]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[COS3]] : !torch.vtensor
func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
%2 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
%3 = torch.aten.cos %2 : !torch.vtensor -> !torch.vtensor
return %3 : !torch.vtensor
}
// -----
// Check rewriting logic in case of mixes of users that do/don't allow type
// refinement.
// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[COS0]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor
func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
%3 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
return %1, %1 : !torch.vtensor, !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32>
// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST:.*]] : !torch.vtensor<*,f32> to !torch.vtensor<*,f32>
// CHECK: torch.overwrite.tensor.contents %[[CAST2]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32>
func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor
%static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor
%result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32>
return %result : !torch.vtensor<[2],f32>
}
// -----
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[2],f32> to !torch.vtensor<*,f32>
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32>
// CHECK: %[[MUTABLE_COPY:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor<*,f32>
// CHECK: torch.overwrite.tensor.contents %[[ARG0_ERASED]] overwrites %[[MUTABLE_COPY]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32>
func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
%dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor
torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor
%dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor
%result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32>
return %result : !torch.vtensor<[?],f32>
}
// -----
// CHECK-LABEL: func.func @bf16_result_type(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16>
// CHECK: return %[[SQRT]] : !torch.vtensor<[2],bf16>
func.func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
%1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16>
return %1 : !torch.vtensor<[2],bf16>
}
// -----
// CHECK-LABEL: func.func @propagate_scalar_type(
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number
// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int
// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number
// CHECK: return %[[RET]] : !torch.number
func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number {
%num = torch.derefine %arg0 : !torch.int to !torch.number
%1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number
return %1 : !torch.number
}
// -----
// CHECK-LABEL: func.func @prim.dtype(
// CHECK-SAME: %[[arg:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor {
// CHECK: %[[zero:.*]] = torch.constant.int 0
// CHECK: %[[false:.*]] = torch.constant.bool false
// CHECK: %[[neg:.*]] = torch.aten.neg %[[arg]] : !torch.vtensor<*,bf16> -> !torch.vtensor<*,bf16>
// CHECK: %[[dtype0:.*]] = torch.prim.dtype %[[neg]] : !torch.vtensor<*,bf16> -> !torch.int
// CHECK: %[[device0:.*]] = torch.prim.device %[[neg]] : !torch.vtensor<*,bf16> -> !torch.Device
// CHECK: %[[tensor:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype0]], %[[device0]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16>
// CHECK: %[[dtype1:.*]] = torch.prim.dtype %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.int
// CHECK: %[[device1:.*]] = torch.prim.device %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.Device
// CHECK: %[[result:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype1]], %[[device1]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16>
// CHECK: %[[cast:.*]] = torch.tensor_static_info_cast %[[result]] : !torch.vtensor<*,bf16> to !torch.vtensor
// CHECK: return %[[cast]] : !torch.vtensor
// CHECK: }
func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> {
%zero = torch.constant.int 0
%false = torch.constant.bool false
// Op that requires type refinement
%neg = torch.aten.neg %arg : !torch.vtensor<*,bf16> -> !torch.vtensor<*,unk>
// Op whose processing requires type refinement on its source argument.
%dtype = torch.prim.dtype %neg : !torch.vtensor<*,unk> -> !torch.int
%device = torch.prim.device %neg : !torch.vtensor<*,unk> -> !torch.Device
// Another op that requires type refinement
%result = torch.aten.tensor.int %zero, %dtype, %device, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk>
// Repeat the above three ops a second time to ensure that the type refinement
// code works regardless of the number of alternating refinement+prim.dtype
// sequences.
%dtype2 = torch.prim.dtype %result : !torch.vtensor<*,unk> -> !torch.int
%device2 = torch.prim.device %result : !torch.vtensor<*,unk> -> !torch.Device
%result2 = torch.aten.tensor.int %zero, %dtype2, %device2, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk>
return %result2 : !torch.vtensor<*,unk>
}
// -----
// Check that we don't crash on this input.
// CHECK-LABEL: func.func @forward
func.func @forward() -> !torch.vtensor {
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.prim.ListConstruct : () -> !torch.list<tensor>
// CHECK: torch.aten.tensor
%1 = torch.aten.tensor %0, %none, %none, %false : !torch.list<tensor>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// Check that we don't crash on this input.
// TODO: This appears to result in aten.mul.Tensor not being visited.
// We should investigate why that happens.
// CHECK-LABEL: func.func @forward
func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) {
%0 = torch.prim.If %arg0 -> (!torch.tensor) {
torch.prim.If.yield %arg1 : !torch.tensor
} else {
torch.prim.If.yield %arg1 : !torch.tensor
}
%1 = torch.copy.to_vtensor %0 : !torch.vtensor
%2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
return
}
// -----
// CHECK-LABEL: func.func @torch.aten.zeros_like(
// CHECK-SAME: %[[arg:.*]]: !torch.vtensor) {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[CPU:.*]] = torch.constant.device "cpu"
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros_like %[[arg]], %[[INT6]], %[[INT0]], %[[CPU]], %[[FALSE]], %[[INT1]] : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor<*,f32>
// CHECK: return
func.func @torch.aten.zeros_like(%arg: !torch.vtensor) {
%int6 = torch.constant.int 6
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%cpu = torch.constant.device "cpu"
%2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor
return
}
// -----
// The data-flow analysis does not always propagate information to the entire graph.
// This results in some lattice elements being uninitialized, which must be properly
// handled when using the lattice elements to rewrite the graph.
// In this particular case, the presence of the loop causes `torch.copy.to_vtensor`
// to end up with an uninitialized lattice element. This is the simplest graph I was
// able to come up with that reproduces such behavior.
// CHECK-LABEL: func.func @uninitialized_lattice_elements(
// CHECK: %{{.*}} = torch.copy.to_vtensor %{{.*}} : !torch.vtensor<*,f32>
func.func @uninitialized_lattice_elements(%arg0: !torch.vtensor<*,f32>, %arg3: !torch.tensor) -> !torch.vtensor<*,f32> {
%true = torch.constant.bool true
%1 = torch.constant.int 0
%2 = torch.prim.Loop %1, %true, init(%arg3) {
^bb0(%arg1: !torch.int, %arg2: !torch.tensor):
torch.prim.Loop.condition %true, iter(%arg2 : !torch.tensor)
} : (!torch.int, !torch.bool, !torch.tensor) -> !torch.tensor
%3 = torch.tensor_static_info_cast %2 : !torch.tensor to !torch.tensor<*,f32>
%4 = torch.copy.to_vtensor %3 : !torch.vtensor<*,f32>
return %4 : !torch.vtensor<*,f32>
}