mirror of https://github.com/llvm/torch-mlir
Replace RefineTypes with dtype functions (#2105)
This commit adds dtype functions for all the torch ops that did not previously have one and removes the pass `RefineTypes`, since the abstract interpretation library now takes care of all the dtype propagation. All dtype functions added are tested except for - `aten.embedding` - `aten._embedding_bag` - `aten.embedding_bag` These functions need a change to the testing framework to allow specifying the actual data inside the tensor used for testing. I will fix this in a follow up patch. Co-authored-by: Jiahao Li <liplus17@163.com>pull/2124/head
parent
28bb866260
commit
de02b56e17
|
@ -242,8 +242,7 @@ The `torchscript-module-to-torch-backend-pipeline` contains the set of simplific
|
|||
1. LowerToBackendContract: This pass iteratively applies a simplification
|
||||
pipeline until the backend contract is reached. The simplification pipeline consists of:
|
||||
- Standard canonicalization.
|
||||
- Shape refinement. See [shape_lib.md](https://github.com/llvm/torch-mlir/blob/main/docs/shape_lib.md) for detail
|
||||
- DType refinement. See `RefineTypes`.
|
||||
- Shape and Dtype refinement. See [abstract_interp_lib.md](https://github.com/llvm/torch-mlir/blob/main/docs/abstract_interp_lib.md) for detail
|
||||
- Decomposing ops into more primitive ops. See `DecomposeComplexOps`.
|
||||
|
||||
### Layering of the PyTorch Dependency
|
||||
|
@ -414,8 +413,6 @@ DON'T use a unit test if your lowering pattern could be described as a trivial
|
|||
like your unit test is just rewriting `b.create<...>(...)` into `CHECK: ...`
|
||||
then it is probably not a useful unit test.
|
||||
|
||||
DON'T add a unit test for trivial changes to RefineTypes.
|
||||
|
||||
With the exceptions above, all changes should include appropriate unit tests, as
|
||||
is standard in the LLVM and MLIR community. This includes full coverage of all
|
||||
canonicalizations, pretty printing, passes, errors, and diagnostics.
|
||||
|
|
|
@ -92,8 +92,6 @@ void createTorchDtypeRefinementPipeline(
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
|
|
|
@ -126,15 +126,6 @@ def AdjustCallingConventions
|
|||
}];
|
||||
}
|
||||
|
||||
def RefineTypes : Pass<"torch-refine-types", "func::FuncOp"> {
|
||||
let summary = "Refine types";
|
||||
let constructor = "mlir::torch::Torch::createRefineTypesPass()";
|
||||
let description = [{
|
||||
Refines types of the program. Currently, this means shapes and dtypes of
|
||||
tensors/arrays.
|
||||
}];
|
||||
}
|
||||
|
||||
def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
||||
let summary = "Inlines torch.global_slot ops.";
|
||||
let constructor = "mlir::torch::Torch::createInlineGlobalSlotsPass()";
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -12,7 +12,6 @@ add_mlir_library(TorchMLIRTorchPasses
|
|||
RecomposeComplexOps.cpp
|
||||
ReduceOpVariants.cpp
|
||||
RefinePublicReturn.cpp
|
||||
RefineTypes.cpp
|
||||
ReifyShapeCalculations.cpp
|
||||
ReifyDtypeCalculations.cpp
|
||||
ReifyAbstractInterpCalculationsUtils.cpp
|
||||
|
|
|
@ -103,7 +103,8 @@ static LogicalResult checkType(Operation *op, Type type,
|
|||
->emitError(
|
||||
"unsupported by backend contract: tensor with unknown dtype")
|
||||
.attachNote()
|
||||
.append("this is likely due to a missing case in RefineTypes");
|
||||
.append("this is likely due to a missing transfer function in "
|
||||
"abstract_interp_lib_gen.py");
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -119,20 +119,14 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
|||
// Update the return op to return value tensors.
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// Do shape refinement.
|
||||
// This should be run before RefineTypes (which primarily does dtype
|
||||
// inference), because Torch type promotion rules actually depend on the shape
|
||||
// of the operand.
|
||||
// Do shape and dtype refinement.
|
||||
// Shape refinement should be run before dtype refinement because Torch type
|
||||
// promotion rules actually depend on the shape of the operand.
|
||||
createTorchShapeRefinementPipeline(pm, options);
|
||||
createTorchDtypeRefinementPipeline(pm, options);
|
||||
// Refine types in the program, which mainly means inferring dtypes of ops.
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
|
||||
// Propagate to ABI return types the shape/dtype information discovered by
|
||||
// the previous pass. Doing this is ABI-compatible for our backends.
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
// This can fold away some branches given the information got from
|
||||
// RefineTypes before doing maximize value sematics which only works with
|
||||
// basic blocks.
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
if (options.decompose) {
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -176,20 +176,23 @@ FailureOr<Value> Torch::adjustFunctionArg(
|
|||
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||
}
|
||||
|
||||
// !torch.union<int, float> is the type used for `Scalar` inputs. At
|
||||
// compile time, such inputs will usually be resolved to an `int` or a `float`
|
||||
// so we need to derefine to match the library function signature.
|
||||
// !torch.union<int, float> or !torch.union<int, float, none> is the type used
|
||||
// for (optional) `Scalar` inputs. At compile time, such inputs will usually
|
||||
// be resolved to an `int` or a `float` so we need to derefine to match the
|
||||
// library function signature.
|
||||
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
|
||||
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
|
||||
return containedType.isa<Torch::IntType, Torch::FloatType>();
|
||||
return containedType
|
||||
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
|
||||
}))
|
||||
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||
}
|
||||
|
||||
// If the operand is NoneType, then we just need to derefine it to the
|
||||
// optional type in the function signature.
|
||||
// Operands with type `!torch.none` correspond to library function inputs with
|
||||
// types like `!torch.optional<...>` or `!torch.union<..., none>`, so here the
|
||||
// type is derefined to match the expected type of the library function.
|
||||
if (operandType.isa<Torch::NoneType>()) {
|
||||
assert(desiredType.isa<Torch::OptionalType>() &&
|
||||
assert(!desiredType.isa<Torch::NoneType>() &&
|
||||
"Don't expect library functions to have NoneType parameters");
|
||||
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||
}
|
||||
|
|
|
@ -8,11 +8,248 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "SimplifyAbstractInterpCalculationsUtils.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
class FoldPrimUncheckedCastOp : public OpRewritePattern<PrimUncheckedCastOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimUncheckedCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor type is not a valid subtype of result type");
|
||||
}
|
||||
rewriter.replaceOp(op, op.getX());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// TODO: Only unroll inside the shape calculation region.
|
||||
// Maybe do this by only applying patterns and folding greedily on the ops
|
||||
// inside the region + the shape.calculate op itself?
|
||||
class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimLoopOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
if (!op.isForLike())
|
||||
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
|
||||
int64_t maxTripCount;
|
||||
if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `maxTripCount` to be a constant int");
|
||||
;
|
||||
SmallVector<Value> indices;
|
||||
for (int64_t i = 0; i < maxTripCount; i++) {
|
||||
// TODO: Add convenience builder.
|
||||
indices.push_back(rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i)));
|
||||
}
|
||||
Block *beforeBlock = op->getBlock();
|
||||
Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator());
|
||||
|
||||
SmallVector<Block *> blocksToMerge;
|
||||
IRMapping bvm;
|
||||
// TODO: Helper for region().front()
|
||||
auto condition =
|
||||
cast<PrimLoopConditionOp>(op.getRegion().front().getTerminator());
|
||||
for (int64_t i = 0; i < maxTripCount; i++) {
|
||||
SmallVector<Value> iterArgs;
|
||||
if (i == 0) {
|
||||
llvm::append_range(iterArgs, op.getIterArgsInit());
|
||||
} else {
|
||||
llvm::append_range(
|
||||
iterArgs, llvm::map_range(condition.getIterArgs(),
|
||||
[&](Value v) { return bvm.lookup(v); }));
|
||||
}
|
||||
bvm.clear();
|
||||
bvm.map(op.getRegion().front().getArgument(0), indices[i]);
|
||||
bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs);
|
||||
|
||||
op.getRegion().cloneInto(afterBlock->getParent(),
|
||||
afterBlock->getIterator(), bvm);
|
||||
Block *clonedBlock = bvm.lookup(&op.getRegion().front());
|
||||
rewriter.eraseOp(clonedBlock->getTerminator());
|
||||
blocksToMerge.push_back(clonedBlock);
|
||||
}
|
||||
|
||||
blocksToMerge.push_back(afterBlock);
|
||||
for (Block *block : blocksToMerge)
|
||||
rewriter.mergeBlocks(block, beforeBlock);
|
||||
if (maxTripCount == 0) {
|
||||
rewriter.replaceOp(op, op.getIterArgsInit());
|
||||
} else {
|
||||
rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range(
|
||||
condition.getIterArgs(),
|
||||
[&](Value v) { return bvm.lookup(v); })));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class AbstractlyInterpretListOpsWithinABlock
|
||||
: public OpRewritePattern<PrimListConstructOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimListConstructOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Block *block = op->getBlock();
|
||||
auto allUsers = llvm::to_vector<6>(op->getUsers());
|
||||
|
||||
// Sort the users into program order.
|
||||
auto getParentInBlock = [&](Operation *op) {
|
||||
while (op->getBlock() != block)
|
||||
op = op->getParentOp();
|
||||
return op;
|
||||
};
|
||||
// Use a stable sort for deterministic results when users are nested in two
|
||||
// regions of the same parent op.
|
||||
llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) {
|
||||
return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs));
|
||||
});
|
||||
|
||||
// We cannot interpret all ops. So first do a check to see up until which
|
||||
// point we can interpret.
|
||||
int numUsersToInterpret = 0;
|
||||
for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) {
|
||||
Operation *user = allUsers[i];
|
||||
// If a user potentially mutates the list, then we require it to be in the
|
||||
// same block for our simple abstract interpretation to work (we can't,
|
||||
// for example, handle an "append" operation in a loop or other region).
|
||||
// However, if the op is read-only, then from the purpose of our abstract
|
||||
// interpretation, we can handle it effectively as though it was at the
|
||||
// same position as the corresponding parent op in the block under
|
||||
// consideration.
|
||||
if (potentiallyMutatesListOperands(user)) {
|
||||
if (user->getBlock() != block)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate the list of users to the number of users we're going to
|
||||
// interpret.
|
||||
allUsers.resize(numUsersToInterpret);
|
||||
auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret);
|
||||
|
||||
// For each mutating op (which must be in the same block), we save the
|
||||
// current state of the list as a vector of Value's. These will then
|
||||
// be converted to PrimListConstructOp's at the correct program points.
|
||||
SmallVector<SmallVector<Value>> listLiterals;
|
||||
SmallVector<Value> runningList;
|
||||
llvm::append_range(runningList, op->getOperands());
|
||||
bool generatedNewLiteral = false;
|
||||
for (Operation *user : usersToInterpret) {
|
||||
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
|
||||
if (!append.use_empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `AtenAppendTOp` to not have users");
|
||||
if (append.getSelf() == op) {
|
||||
runningList.push_back(append.getEl());
|
||||
generatedNewLiteral = true;
|
||||
}
|
||||
listLiterals.push_back(runningList);
|
||||
continue;
|
||||
}
|
||||
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
|
||||
if (!insert.use_empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `AtenInsertTOp` to not have users");
|
||||
int64_t index;
|
||||
if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `idx` of `AtenInsertTOp` to be a constant int");
|
||||
// The index might be statically out of bounds.
|
||||
if (index < 0 || index > static_cast<int64_t>(runningList.size()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Index in `AtenInsertTOp` is out of bounds");
|
||||
if (insert.getSelf() == op) {
|
||||
runningList.insert(runningList.begin() + index, insert.getEl());
|
||||
generatedNewLiteral = true;
|
||||
}
|
||||
listLiterals.push_back(runningList);
|
||||
continue;
|
||||
}
|
||||
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
|
||||
if (!setItem.use_empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `Aten_SetItemTOp` to not have users");
|
||||
std::optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
|
||||
setItem.getIdx(), runningList.size());
|
||||
// The index might be statically out of bounds.
|
||||
if (!indexOpt)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Index in `Aten_SetItemTOp` is out of bounds");
|
||||
if (setItem.getL() == op) {
|
||||
runningList[*indexOpt] = setItem.getEl();
|
||||
generatedNewLiteral = true;
|
||||
}
|
||||
listLiterals.push_back(runningList);
|
||||
continue;
|
||||
}
|
||||
// If this user potentially mutates the list and isn't handled above, then
|
||||
// we can't abstractly interpret any further.
|
||||
if (potentiallyMutatesListOperands(user))
|
||||
break;
|
||||
}
|
||||
|
||||
if (!generatedNewLiteral)
|
||||
return rewriter.notifyMatchFailure(op, "No new literal created");
|
||||
|
||||
// Rewrite all users to use the appropriate list literals.
|
||||
Value latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||
op->getLoc(), op.getType(), op->getOperands());
|
||||
int nextLiteral = 0;
|
||||
for (Operation *user : usersToInterpret) {
|
||||
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
|
||||
rewriter.setInsertionPoint(append);
|
||||
latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||
append->getLoc(), op.getType(), listLiterals[nextLiteral++]);
|
||||
if (append.getSelf() == op)
|
||||
rewriter.eraseOp(append);
|
||||
continue;
|
||||
}
|
||||
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
|
||||
rewriter.setInsertionPoint(insert);
|
||||
latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||
insert->getLoc(), op.getType(), listLiterals[nextLiteral++]);
|
||||
if (insert.getSelf() == op)
|
||||
rewriter.eraseOp(insert);
|
||||
continue;
|
||||
}
|
||||
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
|
||||
rewriter.setInsertionPoint(setItem);
|
||||
latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||
setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]);
|
||||
if (setItem.getL() == op)
|
||||
rewriter.eraseOp(setItem);
|
||||
continue;
|
||||
}
|
||||
for (OpOperand &opOperand : user->getOpOperands()) {
|
||||
if (opOperand.get() == op.getResult()) {
|
||||
opOperand.set(latestLiteral);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Any remaining uses should use the updated value of the latest literal.
|
||||
rewriter.replaceOp(op, latestLiteral);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
|
||||
int resultNum,
|
||||
Type newResultType,
|
||||
|
@ -97,3 +334,18 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
|
|||
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::Torch::populateFoldPrimUncheckedCastOpPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.insert<FoldPrimUncheckedCastOp>(context);
|
||||
}
|
||||
|
||||
void mlir::torch::Torch::populateFullyUnrollPrimLoopOpPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.insert<FullyUnrollPrimLoopOp>(context);
|
||||
}
|
||||
|
||||
void mlir::torch::Torch::populateAbstractlyInterpretListOpsWithinABlockPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.insert<AbstractlyInterpretListOpsWithinABlock>(context);
|
||||
}
|
||||
|
|
|
@ -23,6 +23,13 @@ LogicalResult updateCalculateOpResultTypes(Operation *calculateOp,
|
|||
int resultNum, Type newResultType,
|
||||
PatternRewriter &rewriter);
|
||||
|
||||
void populateFoldPrimUncheckedCastOpPattern(RewritePatternSet &patterns,
|
||||
MLIRContext *context);
|
||||
void populateFullyUnrollPrimLoopOpPattern(RewritePatternSet &patterns,
|
||||
MLIRContext *context);
|
||||
void populateAbstractlyInterpretListOpsWithinABlockPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *context);
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -191,10 +191,17 @@ class SimplifyDtypeCalculationsPass
|
|||
MLIRContext *context = &getContext();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
populateFullyUnrollPrimLoopOpPattern(patterns, context);
|
||||
populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context);
|
||||
populateFoldPrimUncheckedCastOpPattern(patterns, context);
|
||||
patterns.insert<RefineDtypeCalculateOp>(context);
|
||||
patterns.insert<DecomposePromoteDtypesOp>(context);
|
||||
patterns.insert<RefineNumToTensorScalarOpType>(context);
|
||||
|
||||
PrimIfOp::getCanonicalizationPatterns(patterns, context);
|
||||
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
|
||||
PrimTupleUnpackOp::getCanonicalizationPatterns(patterns, context);
|
||||
|
||||
// TODO: Debug visitation order to make this more efficient.
|
||||
// A single linear scan should suffice.
|
||||
GreedyRewriteConfig config;
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "SimplifyAbstractInterpCalculationsUtils.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
@ -19,225 +18,6 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
// TODO: Only unroll inside the shape calculation region.
|
||||
// Maybe do this by only applying patterns and folding greedily on the ops
|
||||
// inside the region + the shape.calculate op itself?
|
||||
class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimLoopOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
if (!op.isForLike())
|
||||
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
|
||||
int64_t maxTripCount;
|
||||
if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `maxTripCount` to be a constant int");
|
||||
;
|
||||
SmallVector<Value> indices;
|
||||
for (int64_t i = 0; i < maxTripCount; i++) {
|
||||
// TODO: Add convenience builder.
|
||||
indices.push_back(rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i)));
|
||||
}
|
||||
Block *beforeBlock = op->getBlock();
|
||||
Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator());
|
||||
|
||||
SmallVector<Block *> blocksToMerge;
|
||||
IRMapping bvm;
|
||||
// TODO: Helper for region().front()
|
||||
auto condition =
|
||||
cast<PrimLoopConditionOp>(op.getRegion().front().getTerminator());
|
||||
for (int64_t i = 0; i < maxTripCount; i++) {
|
||||
SmallVector<Value> iterArgs;
|
||||
if (i == 0) {
|
||||
llvm::append_range(iterArgs, op.getIterArgsInit());
|
||||
} else {
|
||||
llvm::append_range(
|
||||
iterArgs, llvm::map_range(condition.getIterArgs(),
|
||||
[&](Value v) { return bvm.lookup(v); }));
|
||||
}
|
||||
bvm.clear();
|
||||
bvm.map(op.getRegion().front().getArgument(0), indices[i]);
|
||||
bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs);
|
||||
|
||||
op.getRegion().cloneInto(afterBlock->getParent(), afterBlock->getIterator(),
|
||||
bvm);
|
||||
Block *clonedBlock = bvm.lookup(&op.getRegion().front());
|
||||
rewriter.eraseOp(clonedBlock->getTerminator());
|
||||
blocksToMerge.push_back(clonedBlock);
|
||||
}
|
||||
|
||||
blocksToMerge.push_back(afterBlock);
|
||||
for (Block *block : blocksToMerge)
|
||||
rewriter.mergeBlocks(block, beforeBlock);
|
||||
if (maxTripCount == 0) {
|
||||
rewriter.replaceOp(op, op.getIterArgsInit());
|
||||
} else {
|
||||
rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range(
|
||||
condition.getIterArgs(),
|
||||
[&](Value v) { return bvm.lookup(v); })));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class AbstractlyInterpretListOpsWithinABlock
|
||||
: public OpRewritePattern<PrimListConstructOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimListConstructOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Block *block = op->getBlock();
|
||||
auto allUsers = llvm::to_vector<6>(op->getUsers());
|
||||
|
||||
// Sort the users into program order.
|
||||
auto getParentInBlock = [&](Operation *op) {
|
||||
while (op->getBlock() != block)
|
||||
op = op->getParentOp();
|
||||
return op;
|
||||
};
|
||||
// Use a stable sort for deterministic results when users are nested in two
|
||||
// regions of the same parent op.
|
||||
llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) {
|
||||
return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs));
|
||||
});
|
||||
|
||||
// We cannot interpret all ops. So first do a check to see up until which
|
||||
// point we can interpret.
|
||||
int numUsersToInterpret = 0;
|
||||
for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) {
|
||||
Operation *user = allUsers[i];
|
||||
// If a user potentially mutates the list, then we require it to be in the
|
||||
// same block for our simple abstract interpretation to work (we can't,
|
||||
// for example, handle an "append" operation in a loop or other region).
|
||||
// However, if the op is read-only, then from the purpose of our abstract
|
||||
// interpretation, we can handle it effectively as though it was at the
|
||||
// same position as the corresponding parent op in the block under
|
||||
// consideration.
|
||||
if (potentiallyMutatesListOperands(user)) {
|
||||
if (user->getBlock() != block)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate the list of users to the number of users we're going to
|
||||
// interpret.
|
||||
allUsers.resize(numUsersToInterpret);
|
||||
auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret);
|
||||
|
||||
// For each mutating op (which must be in the same block), we save the
|
||||
// current state of the list as a vector of Value's. These will then
|
||||
// be converted to PrimListConstructOp's at the correct program points.
|
||||
SmallVector<SmallVector<Value>> listLiterals;
|
||||
SmallVector<Value> runningList;
|
||||
llvm::append_range(runningList, op->getOperands());
|
||||
bool generatedNewLiteral = false;
|
||||
for (Operation *user : usersToInterpret) {
|
||||
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
|
||||
if (!append.use_empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `AtenAppendTOp` to not have users");
|
||||
if (append.getSelf() == op) {
|
||||
runningList.push_back(append.getEl());
|
||||
generatedNewLiteral = true;
|
||||
}
|
||||
listLiterals.push_back(runningList);
|
||||
continue;
|
||||
}
|
||||
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
|
||||
if (!insert.use_empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `AtenInsertTOp` to not have users");
|
||||
int64_t index;
|
||||
if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `idx` of `AtenInsertTOp` to be a constant int");
|
||||
// The index might be statically out of bounds.
|
||||
if (index < 0 || index > static_cast<int64_t>(runningList.size()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Index in `AtenInsertTOp` is out of bounds");
|
||||
if (insert.getSelf() == op) {
|
||||
runningList.insert(runningList.begin() + index, insert.getEl());
|
||||
generatedNewLiteral = true;
|
||||
}
|
||||
listLiterals.push_back(runningList);
|
||||
continue;
|
||||
}
|
||||
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
|
||||
if (!setItem.use_empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `Aten_SetItemTOp` to not have users");
|
||||
std::optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
|
||||
setItem.getIdx(), runningList.size());
|
||||
// The index might be statically out of bounds.
|
||||
if (!indexOpt)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Index in `Aten_SetItemTOp` is out of bounds");
|
||||
if (setItem.getL() == op) {
|
||||
runningList[*indexOpt] = setItem.getEl();
|
||||
generatedNewLiteral = true;
|
||||
}
|
||||
listLiterals.push_back(runningList);
|
||||
continue;
|
||||
}
|
||||
// If this user potentially mutates the list and isn't handled above, then
|
||||
// we can't abstractly interpret any further.
|
||||
if (potentiallyMutatesListOperands(user))
|
||||
break;
|
||||
}
|
||||
|
||||
if (!generatedNewLiteral)
|
||||
return rewriter.notifyMatchFailure(op, "No new literal created");
|
||||
|
||||
// Rewrite all users to use the appropriate list literals.
|
||||
Value latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||
op->getLoc(), op.getType(), op->getOperands());
|
||||
int nextLiteral = 0;
|
||||
for (Operation *user : usersToInterpret) {
|
||||
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
|
||||
rewriter.setInsertionPoint(append);
|
||||
latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||
append->getLoc(), op.getType(), listLiterals[nextLiteral++]);
|
||||
if (append.getSelf() == op)
|
||||
rewriter.eraseOp(append);
|
||||
continue;
|
||||
}
|
||||
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
|
||||
rewriter.setInsertionPoint(insert);
|
||||
latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||
insert->getLoc(), op.getType(), listLiterals[nextLiteral++]);
|
||||
if (insert.getSelf() == op)
|
||||
rewriter.eraseOp(insert);
|
||||
continue;
|
||||
}
|
||||
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
|
||||
rewriter.setInsertionPoint(setItem);
|
||||
latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||
setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]);
|
||||
if (setItem.getL() == op)
|
||||
rewriter.eraseOp(setItem);
|
||||
continue;
|
||||
}
|
||||
for (OpOperand &opOperand : user->getOpOperands()) {
|
||||
if (opOperand.get() == op.getResult()) {
|
||||
opOperand.set(latestLiteral);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Any remaining uses should use the updated value of the latest literal.
|
||||
rewriter.replaceOp(op, latestLiteral);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
||||
public:
|
||||
|
@ -266,22 +46,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class FoldPrimUncheckedCastOp : public OpRewritePattern<PrimUncheckedCastOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimUncheckedCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor type is not a valid subtype of result type");
|
||||
}
|
||||
rewriter.replaceOp(op, op.getX());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
||||
int resultNum,
|
||||
PatternRewriter &rewriter) {
|
||||
|
@ -367,11 +131,11 @@ class SimplifyShapeCalculationsPass
|
|||
MLIRContext *context = &getContext();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.insert<FullyUnrollPrimLoopOp>(context);
|
||||
patterns.insert<AbstractlyInterpretListOpsWithinABlock>(context);
|
||||
populateFullyUnrollPrimLoopOpPattern(patterns, context);
|
||||
populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context);
|
||||
populateFoldPrimUncheckedCastOpPattern(patterns, context);
|
||||
patterns.insert<DecomposeAtenSizeOp>(context);
|
||||
patterns.insert<RefineShapeCalculateOp>(context);
|
||||
patterns.insert<FoldPrimUncheckedCastOp>(context);
|
||||
|
||||
PrimIfOp::getCanonicalizationPatterns(patterns, context);
|
||||
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -14,6 +14,55 @@ from torch_mlir.passmanager import PassManager
|
|||
|
||||
from .registry import Registry
|
||||
|
||||
def all_integer_dtypes() -> List[int]:
|
||||
return [torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
||||
|
||||
def is_integer_dtype(dtype: int) -> bool:
|
||||
return dtype in all_integer_dtypes()
|
||||
|
||||
def all_complex_dtypes() -> List[int]:
|
||||
return [torch.complex64, torch.complex128]
|
||||
|
||||
def is_complex_dtype(dtype: int) -> bool:
|
||||
return dtype in all_complex_dtypes()
|
||||
|
||||
def all_float_dtypes() -> List[int]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32, torch.float64]
|
||||
|
||||
def is_float_dtype(dtype: int) -> bool:
|
||||
return dtype in all_float_dtypes()
|
||||
|
||||
def get_priority_of_dtype(dtype: int) -> int:
|
||||
# If a loop is used to iterate over a list of sorted dtypes, TorchScript
|
||||
# produces a loop with INT64_MAX max trip count, which causes problems
|
||||
# during the loop unrolling that takes place when simplifying the dtype
|
||||
# functions. Therefore, here we resort to `if`s.
|
||||
if dtype == torch.bool:
|
||||
return 0
|
||||
if dtype == torch.uint8:
|
||||
return 1
|
||||
if dtype == torch.int8:
|
||||
return 2
|
||||
if dtype == torch.int16:
|
||||
return 3
|
||||
if dtype == torch.int32:
|
||||
return 4
|
||||
if dtype == torch.int64:
|
||||
return 5
|
||||
if dtype == torch.bfloat16:
|
||||
return 6
|
||||
if dtype == torch.float16:
|
||||
return 7
|
||||
if dtype == torch.float32:
|
||||
return 8
|
||||
if dtype == torch.float64:
|
||||
return 9
|
||||
if dtype == torch.complex64:
|
||||
return 10
|
||||
if dtype == torch.complex128:
|
||||
return 11
|
||||
assert False, "Cannot determine priority of dtype"
|
||||
|
||||
def get_dtype_of_scalar(scalar: Union[int, float]) -> int:
|
||||
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
|
||||
# that when `jit.script`ed converts a float scalar to a tensor
|
||||
|
|
|
@ -60,33 +60,32 @@ class TensorOfShape:
|
|||
This class also tracks a dtype of the tensor, since some ops require a
|
||||
specific dtype.
|
||||
"""
|
||||
def __init__(self, *shape: int, dtype: torch.dtype = torch.float32):
|
||||
def __init__(self, *shape: int, dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None):
|
||||
self.shape = list(shape)
|
||||
self.dtype = dtype
|
||||
self.device = "meta" if device is None else device
|
||||
def __repr__(self):
|
||||
args_str = ", ".join(repr(x) for x in self.shape)
|
||||
if self.dtype is torch.float32:
|
||||
return f"TensorOfShape({args_str})"
|
||||
else:
|
||||
return f"TensorOfShape({args_str}, dtype={self.dtype})"
|
||||
return f"TensorOfShape({args_str}, dtype={self.dtype}, device={self.device})"
|
||||
|
||||
def LongTensorOfShape(*args, **kwargs):
|
||||
"""Helper for indicating a TensorOfShape with integer type."""
|
||||
return TensorOfShape(*args, **kwargs, dtype=torch.long)
|
||||
|
||||
def NonZeroDTensorWithDtype(dtype):
|
||||
def NonZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None):
|
||||
"""Helper for indicating a non-zero dim tensor with custom type."""
|
||||
return TensorOfShape(1, dtype=dtype)
|
||||
return TensorOfShape(1, dtype=dtype, device=device)
|
||||
|
||||
def ZeroDTensorWithDtype(dtype):
|
||||
def ZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None):
|
||||
"""Helper for indicating a zero dim tensor with custom type."""
|
||||
return TensorOfShape(dtype=dtype)
|
||||
return TensorOfShape(dtype=dtype, device=device)
|
||||
|
||||
def _recursively_transform_tensor_args(
|
||||
o: Any,
|
||||
tensor_transformer: Callable[[TensorOfShape], Any]) -> Any:
|
||||
"""Replace `TensorOfShape` with the result of `tensor_transformer`"""
|
||||
if o is None or isinstance(o, (float, int)):
|
||||
if o is None or isinstance(o, (float, int, str)):
|
||||
return o
|
||||
if isinstance(o, TensorOfShape):
|
||||
return tensor_transformer(o)
|
||||
|
@ -146,7 +145,7 @@ class Invocation:
|
|||
|
||||
def to_real_op_args(self):
|
||||
"""Gets positional arguments appropriate for the real op."""
|
||||
tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype)
|
||||
tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype).to(o.device)
|
||||
return _recursively_transform_tensor_args(self.args, tensor_transformer)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -258,6 +257,15 @@ def check_shape_function(invocations: List[Invocation]):
|
|||
return f
|
||||
return decorator
|
||||
|
||||
@torch.jit.script
|
||||
def _convert_dtype_to_int(dtype: torch.dtype) -> int:
|
||||
"""Convert a PyTorch `dtype` into its underlying `int` representation.
|
||||
|
||||
This works because in TorchScript there is no special type for `dtypes`;
|
||||
they are simply `int`s.
|
||||
"""
|
||||
return dtype
|
||||
|
||||
def check_dtype_function(invocations: List[Invocation]):
|
||||
"""Decorator that automatically tests a dtype function.
|
||||
|
||||
|
@ -281,7 +289,12 @@ def check_dtype_function(invocations: List[Invocation]):
|
|||
golden_dtype = torch.tensor([]).to(type(golden_result)).dtype
|
||||
else:
|
||||
raise ValueError(f"Unhandled return type {type(golden_result)}")
|
||||
if result_dtype != golden_dtype:
|
||||
# Some dtype funtions have default `dtype` parameters, which are
|
||||
# represented as `int` values in the registry. In order to
|
||||
# support returning the default `int` value, the comparisons of
|
||||
# the result and golden dtypes are done using their underlying
|
||||
# `int` representation.
|
||||
if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(golden_dtype):
|
||||
_report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}")
|
||||
return f
|
||||
return decorator
|
||||
|
|
|
@ -1,153 +0,0 @@
|
|||
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @prim.if$branch_merge_type_tensor(
|
||||
// CHECK-SAME: %[[PRED:.*]]: !torch.bool,
|
||||
// CHECK-SAME: %[[T1:.*]]: !torch.tensor,
|
||||
// CHECK-SAME: %[[T2:.*]]: !torch.tensor) -> !torch.bool {
|
||||
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<tensor>) {
|
||||
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T1]] : !torch.tensor to !torch.optional<tensor>
|
||||
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T2]] : !torch.tensor to !torch.optional<tensor>
|
||||
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
|
||||
// CHECK: }
|
||||
// CHECK: %[[REFINED:.*]] = torch.prim.unchecked_cast %[[MERGED:.*]] : !torch.optional<tensor> -> !torch.tensor
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[RET:.*]] = torch.aten.__isnot__ %[[REFINED]], %[[NONE]] : !torch.tensor, !torch.none -> !torch.bool
|
||||
// CHECK: return %[[RET]] : !torch.bool
|
||||
|
||||
func.func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %t1: !torch.tensor) -> !torch.bool {
|
||||
%res = torch.prim.If %pred -> (!torch.optional<tensor>) {
|
||||
%optional0 = torch.derefine %t0: !torch.tensor to !torch.optional<tensor>
|
||||
torch.prim.If.yield %optional0: !torch.optional<tensor>
|
||||
} else {
|
||||
%optional1 = torch.derefine %t1: !torch.tensor to !torch.optional<tensor>
|
||||
torch.prim.If.yield %optional1: !torch.optional<tensor>
|
||||
}
|
||||
%none = torch.constant.none
|
||||
%cmp = torch.aten.__isnot__ %res, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
|
||||
return %cmp : !torch.bool
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @prim.if$branch_merge_type_optional(
|
||||
// CHECK-SAME: %[[PRED:.*]]: !torch.bool,
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.optional<tensor> {
|
||||
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<tensor>) {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
|
||||
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T]] : !torch.tensor to !torch.optional<tensor>
|
||||
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
|
||||
// CHECK: }
|
||||
// CHECK: return %[[MERGED:.*]] : !torch.optional<tensor>
|
||||
|
||||
func.func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor) -> !torch.optional<tensor> {
|
||||
%res = torch.prim.If %pred -> (!torch.optional<tensor>) {
|
||||
%none = torch.constant.none
|
||||
%optional0 = torch.derefine %none: !torch.none to !torch.optional<tensor>
|
||||
torch.prim.If.yield %optional0: !torch.optional<tensor>
|
||||
} else {
|
||||
%optional1 = torch.derefine %t1: !torch.tensor to !torch.optional<tensor>
|
||||
torch.prim.If.yield %optional1: !torch.optional<tensor>
|
||||
}
|
||||
return %res: !torch.optional<tensor>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @prim.if$refined_type_conflicting(
|
||||
// CHECK-SAME: %[[NONE:.*]]: !torch.none) -> !torch.tensor {
|
||||
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
|
||||
// CHECK: %[[NOT_NONE:.*]] = torch.aten.__isnot__ %[[NONE]], %[[NONE]] : !torch.none, !torch.none -> !torch.bool
|
||||
// CHECK: %[[PRED:.*]] = torch.prim.If %[[NOT_NONE]] -> (!torch.tensor) {
|
||||
// CHECK: %[[T:.*]] = torch.prim.unchecked_cast %[[OPTIONAL]] : !torch.optional<tensor> -> !torch.tensor
|
||||
// CHECK: torch.prim.If.yield %[[T]] : !torch.tensor
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[LITERAL:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<3x5xf32>) : !torch.tensor
|
||||
// CHECK: torch.prim.If.yield %[[LITERAL]] : !torch.tensor
|
||||
// CHECK: }
|
||||
// CHECK: return %[[PRED:.*]] : !torch.tensor
|
||||
|
||||
func.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
|
||||
%optional = torch.derefine %none: !torch.none to !torch.optional<tensor>
|
||||
%pred = torch.aten.__isnot__ %optional, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
|
||||
%res = torch.prim.If %pred -> (!torch.tensor) {
|
||||
%t = torch.prim.unchecked_cast %optional: !torch.optional<tensor> -> !torch.tensor
|
||||
torch.prim.If.yield %t: !torch.tensor
|
||||
} else {
|
||||
%t_cst = torch.tensor.literal(dense<0.0> : tensor<3x5xf32>) : !torch.tensor
|
||||
torch.prim.If.yield %t_cst: !torch.tensor
|
||||
}
|
||||
return %res: !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @prim.loop$region_arg_to_internal(
|
||||
// CHECK-SAME: %[[ARG_NONE:.*]]: !torch.none) -> !torch.optional<tensor> {
|
||||
// CHECK: %[[INT10:.*]] = torch.constant.int 10
|
||||
// CHECK: %[[INDV:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[ARG_NONE]] : !torch.none to !torch.optional<tensor>
|
||||
// CHECK: %[[LOOP_RET:.*]] = torch.prim.Loop %[[INT10]], %[[TRUE]], init(%[[OPTIONAL]]) {
|
||||
// CHECK: ^bb0(%[[INDV:.*]]: !torch.int, %[[IT:.*]]: !torch.optional<tensor>):
|
||||
// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[IT]] : !torch.optional<tensor> -> !torch.none
|
||||
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
|
||||
// CHECK: %[[COND:.*]] = torch.aten.__isnot__ %[[NONE]], %[[ARG_NONE]] : !torch.none, !torch.none -> !torch.bool
|
||||
// CHECK: torch.prim.Loop.condition %[[COND]], iter(%[[OPTIONAL]] : !torch.optional<tensor>)
|
||||
// CHECK: } : (!torch.int, !torch.bool, !torch.optional<tensor>) -> !torch.optional<tensor>
|
||||
// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[LOOP_RET:.*]] : !torch.optional<tensor> -> !torch.none
|
||||
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
|
||||
// CHECK: return %[[OPTIONAL]] : !torch.optional<tensor>
|
||||
|
||||
func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional<tensor> {
|
||||
%int10 = torch.constant.int 10
|
||||
%int0 = torch.constant.int 0
|
||||
%true = torch.constant.bool true
|
||||
%optional = torch.derefine %none: !torch.none to !torch.optional<tensor>
|
||||
%ret = torch.prim.Loop %int10, %true, init(%optional) {
|
||||
^bb0(%arg2: !torch.int, %arg3: !torch.optional<tensor>): // no predecessors
|
||||
%cond = torch.aten.__isnot__ %arg3, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
|
||||
torch.prim.Loop.condition %cond, iter(%arg3: !torch.optional<tensor>)
|
||||
} : (!torch.int, !torch.bool, !torch.optional<tensor>) -> (!torch.optional<tensor>)
|
||||
return %ret: !torch.optional<tensor>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @f
|
||||
// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
cf.br ^bb1(%cast: !torch.vtensor)
|
||||
^bb1(%arg1: !torch.vtensor):
|
||||
%1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @f
|
||||
// CHECK: func.func private @callee
|
||||
// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
|
||||
func.func @f() {
|
||||
builtin.module {
|
||||
func.func private @callee(%arg0: !torch.vtensor) {
|
||||
%1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor
|
||||
return
|
||||
}
|
||||
func.func @caller(%arg0: !torch.vtensor<*,f32>) {
|
||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
call @callee(%cast) : (!torch.vtensor) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,364 +0,0 @@
|
|||
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
||||
|
||||
// This file is for tests for individual ops that require a new transfer
|
||||
// function (i.e. new code called from visitOperation).
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @aten.arange.start$int64_dtype(
|
||||
// CHECK-SAME: %[[START:.*]]: !torch.int,
|
||||
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[T:.*]] = torch.aten.arange.start
|
||||
// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none
|
||||
// CHECK-SAME: -> !torch.vtensor<*,si64>
|
||||
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,si64> to !torch.vtensor
|
||||
// CHECK: return %[[RET]] : !torch.vtensor
|
||||
func.func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor {
|
||||
%none = torch.constant.none
|
||||
%ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
|
||||
return %ret : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @aten.arange.start$float32_dtype(
|
||||
// CHECK-SAME: %[[START:.*]]: !torch.float,
|
||||
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[T:.*]] = torch.aten.arange.start
|
||||
// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none
|
||||
// CHECK-SAME: -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: return %[[RET]] : !torch.vtensor
|
||||
func.func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor {
|
||||
%none = torch.constant.none
|
||||
%ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
|
||||
return %ret : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @aten.arange.start$specified_dtype(
|
||||
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[T:.*]] = torch.aten.arange
|
||||
// CHECK-SAME: %[[END]], %[[CST6]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none
|
||||
// CHECK-SAME: -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: return %[[RET]] : !torch.vtensor
|
||||
func.func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
|
||||
%int6 = torch.constant.int 6
|
||||
%none = torch.constant.none
|
||||
%ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor
|
||||
return %ret : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.linear(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[5,3],f32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[ARG0]], %[[ARG1]], %[[ARG2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor
|
||||
func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @aten.sum.dim_IntList(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,si64>) -> !torch.vtensor {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[DIMLIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT_NEG1]]
|
||||
// CHECK-SAME: : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.sum.dim_IntList %[[T]], %[[DIMLIST]], %[[FALSE]], %[[NONE]]
|
||||
// CHECK-SAME: : !torch.vtensor<*,si64>, !torch.list<int>, !torch.bool, !torch.none
|
||||
// CHECK-SAME: -> !torch.vtensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,si64> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func.func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor {
|
||||
%false = torch.constant.bool false
|
||||
%none = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
%int-1 = torch.constant.int -1
|
||||
%dimList = torch.prim.ListConstruct %int0, %int-1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%ret = torch.aten.sum.dim_IntList %t, %dimList, %false, %none : !torch.vtensor<*,si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor
|
||||
return %ret : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @aten.any.dim(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[INT_NEG1]], %[[FALSE]] : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor<*,i1>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func.func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
|
||||
%false = torch.constant.bool false
|
||||
%int-1 = torch.constant.int -1
|
||||
%ret = torch.aten.any.dim %t, %int-1, %false : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor
|
||||
return %ret : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @aten.any(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor {
|
||||
// CHECK: %[[RET:.*]] = torch.aten.any %[[T]] : !torch.vtensor<*,i1> -> !torch.vtensor<*,i1>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func.func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
|
||||
%ret = torch.aten.any %t: !torch.vtensor<*,i1> -> !torch.vtensor
|
||||
return %ret : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.zeros(
|
||||
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
|
||||
%none = torch.constant.none
|
||||
%int2 = torch.constant.int 2
|
||||
%sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%ret = torch.aten.zeros %sizesList, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.type_as(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>,
|
||||
// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor {
|
||||
%ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor
|
||||
return %ret: !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.cat(
|
||||
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
|
||||
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[?,1,4],f32>, !torch.tensor<[2,3,4],f32>) -> !torch.list<tensor>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor {
|
||||
%int1 = torch.constant.int 1
|
||||
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<tensor>
|
||||
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.cat$promote_type(
|
||||
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>,
|
||||
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list<tensor>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor {
|
||||
%int1 = torch.constant.int 1
|
||||
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list<tensor>
|
||||
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor {
|
||||
%ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor$unknown_input_shape(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor {
|
||||
%ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.embedding(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
|
||||
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[PADDING_IDX:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RET:.*]] = torch.aten.embedding %[[INPUT]], %[[INDEXES]], %[[PADDING_IDX]], %[[FALSE]], %[[FALSE]] : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor {
|
||||
%false = torch.constant.bool false
|
||||
%int1 = torch.constant.int 1
|
||||
%ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor
|
||||
return %ret: !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.tensor.float(
|
||||
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%ret = torch.aten.tensor.float %t, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.tensor.float$specified_dtype(
|
||||
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[CST11:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,i1>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,i1> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor {
|
||||
%none = torch.constant.none
|
||||
%int11 = torch.constant.int 11
|
||||
%false = torch.constant.bool false
|
||||
%ret = torch.aten.tensor.float %t, %int11, %none, %false : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.softmax.int(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
||||
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[RET]] : !torch.tensor
|
||||
func.func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
|
||||
%none = torch.constant.none
|
||||
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.softmax.int$specified_dtype(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
||||
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[RET]] : !torch.tensor
|
||||
func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
|
||||
%int4 = torch.constant.int 4
|
||||
%ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Matrix(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Vector(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<*,f32>) -> !torch.tensor {
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.to.dtype(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
|
||||
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
|
||||
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
|
||||
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
|
||||
// CHECK-SAME: -> !torch.tensor<*,si64>
|
||||
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK-NEXT: return %[[RES]] : !torch.tensor
|
||||
func.func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%int4 = torch.constant.int 4
|
||||
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
|
||||
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
|
||||
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
|
||||
return %0: !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.tensor(
|
||||
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[NONE]], %[[NONE]], %[[FALSE]]
|
||||
// CHECK-SAME: : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool
|
||||
// CHECK-SAME: -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%ret = torch.aten.tensor %t, %none, %none, %false : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.tensor$specified_dtype(
|
||||
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.tensor$specified_dtype(%t: !torch.list<list<float>>) -> !torch.tensor {
|
||||
%none = torch.constant.none
|
||||
%int4 = torch.constant.int 4
|
||||
%false = torch.constant.bool false
|
||||
%ret = torch.aten.tensor %t, %int4, %none, %false : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
|
@ -1,238 +0,0 @@
|
|||
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
||||
|
||||
// This file tests the structural logic of the pass. This is for testing logic
|
||||
// that does not scale with the number of ops supported, such as the core
|
||||
// propagation logic, rewriting, etc.
|
||||
// Code for testing transfer functions for new ops (which is most changes)
|
||||
// should go in refine-types-ops.mlir.
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[COS]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor
|
||||
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @keep_existing_shape_information(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
||||
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
|
||||
// CHECK: return %[[COS]] : !torch.vtensor<[2],f32>
|
||||
func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32>
|
||||
return %1 : !torch.vtensor<[2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @propagate_through_multiple_ops(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[COS2:.*]] = torch.aten.cos %[[COS1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[COS3:.*]] = torch.tensor_static_info_cast %[[COS2]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: return %[[COS3]] : !torch.vtensor
|
||||
func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||
%2 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
|
||||
%3 = torch.aten.cos %2 : !torch.vtensor -> !torch.vtensor
|
||||
return %3 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// Check rewriting logic in case of mixes of users that do/don't allow type
|
||||
// refinement.
|
||||
// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
|
||||
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[COS0]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor
|
||||
func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
|
||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||
%3 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
|
||||
return %1, %1 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
||||
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
|
||||
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32>
|
||||
// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST:.*]] : !torch.vtensor<*,f32> to !torch.vtensor<*,f32>
|
||||
// CHECK: torch.overwrite.tensor.contents %[[CAST2]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32>
|
||||
func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
||||
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
|
||||
%static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor
|
||||
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||
torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor
|
||||
%static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor
|
||||
%result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32>
|
||||
return %result : !torch.vtensor<[2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[2],f32> to !torch.vtensor<*,f32>
|
||||
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32>
|
||||
// CHECK: %[[MUTABLE_COPY:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor<*,f32>
|
||||
// CHECK: torch.overwrite.tensor.contents %[[ARG0_ERASED]] overwrites %[[MUTABLE_COPY]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32>
|
||||
func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
|
||||
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||
%dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor
|
||||
torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor
|
||||
%dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor
|
||||
%result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32>
|
||||
return %result : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @bf16_result_type(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
|
||||
// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16>
|
||||
// CHECK: return %[[SQRT]] : !torch.vtensor<[2],bf16>
|
||||
func.func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
|
||||
%1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16>
|
||||
return %1 : !torch.vtensor<[2],bf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @propagate_scalar_type(
|
||||
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
|
||||
// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number
|
||||
// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int
|
||||
// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number
|
||||
// CHECK: return %[[RET]] : !torch.number
|
||||
func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number {
|
||||
%num = torch.derefine %arg0 : !torch.int to !torch.number
|
||||
%1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number
|
||||
return %1 : !torch.number
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @prim.dtype(
|
||||
// CHECK-SAME: %[[arg:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor {
|
||||
|
||||
// CHECK: %[[zero:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[false:.*]] = torch.constant.bool false
|
||||
|
||||
// CHECK: %[[neg:.*]] = torch.aten.neg %[[arg]] : !torch.vtensor<*,bf16> -> !torch.vtensor<*,bf16>
|
||||
// CHECK: %[[dtype0:.*]] = torch.prim.dtype %[[neg]] : !torch.vtensor<*,bf16> -> !torch.int
|
||||
// CHECK: %[[device0:.*]] = torch.prim.device %[[neg]] : !torch.vtensor<*,bf16> -> !torch.Device
|
||||
// CHECK: %[[tensor:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype0]], %[[device0]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16>
|
||||
|
||||
// CHECK: %[[dtype1:.*]] = torch.prim.dtype %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.int
|
||||
// CHECK: %[[device1:.*]] = torch.prim.device %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.Device
|
||||
// CHECK: %[[result:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype1]], %[[device1]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16>
|
||||
|
||||
// CHECK: %[[cast:.*]] = torch.tensor_static_info_cast %[[result]] : !torch.vtensor<*,bf16> to !torch.vtensor
|
||||
// CHECK: return %[[cast]] : !torch.vtensor
|
||||
// CHECK: }
|
||||
|
||||
func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> {
|
||||
%zero = torch.constant.int 0
|
||||
%false = torch.constant.bool false
|
||||
|
||||
// Op that requires type refinement
|
||||
%neg = torch.aten.neg %arg : !torch.vtensor<*,bf16> -> !torch.vtensor<*,unk>
|
||||
|
||||
// Op whose processing requires type refinement on its source argument.
|
||||
%dtype = torch.prim.dtype %neg : !torch.vtensor<*,unk> -> !torch.int
|
||||
%device = torch.prim.device %neg : !torch.vtensor<*,unk> -> !torch.Device
|
||||
|
||||
// Another op that requires type refinement
|
||||
%result = torch.aten.tensor.int %zero, %dtype, %device, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk>
|
||||
|
||||
// Repeat the above three ops a second time to ensure that the type refinement
|
||||
// code works regardless of the number of alternating refinement+prim.dtype
|
||||
// sequences.
|
||||
%dtype2 = torch.prim.dtype %result : !torch.vtensor<*,unk> -> !torch.int
|
||||
%device2 = torch.prim.device %result : !torch.vtensor<*,unk> -> !torch.Device
|
||||
%result2 = torch.aten.tensor.int %zero, %dtype2, %device2, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk>
|
||||
|
||||
return %result2 : !torch.vtensor<*,unk>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we don't crash on this input.
|
||||
|
||||
// CHECK-LABEL: func.func @forward
|
||||
func.func @forward() -> !torch.vtensor {
|
||||
%false = torch.constant.bool false
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct : () -> !torch.list<tensor>
|
||||
// CHECK: torch.aten.tensor
|
||||
%1 = torch.aten.tensor %0, %none, %none, %false : !torch.list<tensor>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we don't crash on this input.
|
||||
// TODO: This appears to result in aten.mul.Tensor not being visited.
|
||||
// We should investigate why that happens.
|
||||
|
||||
// CHECK-LABEL: func.func @forward
|
||||
func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) {
|
||||
%0 = torch.prim.If %arg0 -> (!torch.tensor) {
|
||||
torch.prim.If.yield %arg1 : !torch.tensor
|
||||
} else {
|
||||
torch.prim.If.yield %arg1 : !torch.tensor
|
||||
}
|
||||
%1 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||
%2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.zeros_like(
|
||||
// CHECK-SAME: %[[arg:.*]]: !torch.vtensor) {
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[CPU:.*]] = torch.constant.device "cpu"
|
||||
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros_like %[[arg]], %[[INT6]], %[[INT0]], %[[CPU]], %[[FALSE]], %[[INT1]] : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor<*,f32>
|
||||
// CHECK: return
|
||||
func.func @torch.aten.zeros_like(%arg: !torch.vtensor) {
|
||||
%int6 = torch.constant.int 6
|
||||
%false = torch.constant.bool false
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%cpu = torch.constant.device "cpu"
|
||||
%2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// The data-flow analysis does not always propagate information to the entire graph.
|
||||
// This results in some lattice elements being uninitialized, which must be properly
|
||||
// handled when using the lattice elements to rewrite the graph.
|
||||
// In this particular case, the presence of the loop causes `torch.copy.to_vtensor`
|
||||
// to end up with an uninitialized lattice element. This is the simplest graph I was
|
||||
// able to come up with that reproduces such behavior.
|
||||
|
||||
// CHECK-LABEL: func.func @uninitialized_lattice_elements(
|
||||
// CHECK: %{{.*}} = torch.copy.to_vtensor %{{.*}} : !torch.vtensor<*,f32>
|
||||
|
||||
func.func @uninitialized_lattice_elements(%arg0: !torch.vtensor<*,f32>, %arg3: !torch.tensor) -> !torch.vtensor<*,f32> {
|
||||
%true = torch.constant.bool true
|
||||
%1 = torch.constant.int 0
|
||||
%2 = torch.prim.Loop %1, %true, init(%arg3) {
|
||||
^bb0(%arg1: !torch.int, %arg2: !torch.tensor):
|
||||
torch.prim.Loop.condition %true, iter(%arg2 : !torch.tensor)
|
||||
} : (!torch.int, !torch.bool, !torch.tensor) -> !torch.tensor
|
||||
%3 = torch.tensor_static_info_cast %2 : !torch.tensor to !torch.tensor<*,f32>
|
||||
%4 = torch.copy.to_vtensor %3 : !torch.vtensor<*,f32>
|
||||
return %4 : !torch.vtensor<*,f32>
|
||||
}
|
Loading…
Reference in New Issue