mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR]Add lowering for control flow operations.
1. This commit adds lowering of "while-like" prim loop to scf.while operation. 2. Adds lowering of "for-like" prim loops to scf.for operation. Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>pull/816/head snapshot-20220429.421
parent
ef546e1137
commit
e1db318a3c
|
@ -10,10 +10,13 @@
|
|||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
||||
|
@ -61,18 +64,272 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
// Converts the Torch::PrimLoopOp which is ``While-like`` into scf::WhileOp.
|
||||
class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern<PrimLoopOp> {
|
||||
public:
|
||||
using OpConversionPattern<PrimLoopOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PrimLoopOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Return failure on for-like loops.
|
||||
if (op.isForLike())
|
||||
return failure();
|
||||
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
SmallVector<Type, 1> newResultTypes;
|
||||
if (failed(
|
||||
typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "could not convert PrimLoopOp outputs");
|
||||
|
||||
// Create scf.while operation using the operands of torch::primloop. The
|
||||
// first argument of the primloop correspond to `maxTripCount` which
|
||||
// can be omitted in the `scf.while` operation.
|
||||
Value condition = adaptor.initialCondition();
|
||||
ValueRange iterArgsInit = adaptor.iterArgsInit();
|
||||
SmallVector<Value> scfWhileOpOperands{condition};
|
||||
scfWhileOpOperands.append(iterArgsInit.begin(), iterArgsInit.end());
|
||||
auto scfWhileOp = rewriter.create<scf::WhileOp>(
|
||||
op->getLoc(), newResultTypes, scfWhileOpOperands);
|
||||
|
||||
// Populate the before region of the scf.while operation. The `before`
|
||||
// region will have only one block and the arguments of the block must match
|
||||
// the arguments of `scf.while` operation.
|
||||
SmallVector<Type> beforeRegionArgTypes;
|
||||
SmallVector<Location> beforeRegionArgLocs;
|
||||
for (Value value : scfWhileOp->getOperands()) {
|
||||
beforeRegionArgTypes.push_back(value.getType());
|
||||
beforeRegionArgLocs.push_back(value.getLoc());
|
||||
}
|
||||
auto *beforeBlock = rewriter.createBlock(
|
||||
&scfWhileOp.getBefore(), scfWhileOp.getBefore().begin(),
|
||||
beforeRegionArgTypes, beforeRegionArgLocs);
|
||||
|
||||
rewriter.setInsertionPointToEnd(beforeBlock);
|
||||
// Fetch the condition passed as the iter argument. Pass rest of the
|
||||
// arguments to the after block.
|
||||
auto scfConditionOp = rewriter.create<scf::ConditionOp>(
|
||||
op.getLoc(), beforeBlock->getArgument(0),
|
||||
beforeBlock->getArguments().drop_front());
|
||||
|
||||
// Populate the after region.
|
||||
if (!scfWhileOp.getAfter().empty())
|
||||
rewriter.eraseBlock(&scfWhileOp.getAfter().back());
|
||||
|
||||
SmallVector<Type> afterRegionArgTypes;
|
||||
SmallVector<Location> afterRegionArgLocs;
|
||||
for (Value value : scfConditionOp.getArgs()) {
|
||||
afterRegionArgTypes.push_back(value.getType());
|
||||
afterRegionArgLocs.push_back(value.getLoc());
|
||||
}
|
||||
auto *afterBlock = rewriter.createBlock(
|
||||
&scfWhileOp.getAfter(), scfWhileOp.getAfter().begin(),
|
||||
afterRegionArgTypes, afterRegionArgLocs);
|
||||
|
||||
// Rewrite uses of the torch loop block arguments to the new while-loop
|
||||
// "after" arguments. Leave the induction variable of prim loop(first
|
||||
// argument) because while like prim loops does not use the induction
|
||||
// variable.
|
||||
for (const auto &barg :
|
||||
enumerate(op.region().front().getArguments().drop_front())) {
|
||||
Value to = afterBlock->getArgument(barg.index());
|
||||
Type targetType = to.getType();
|
||||
Value torchArg = to;
|
||||
|
||||
// If the target type is non-torch type, then use TypeConverter to convert
|
||||
// the type of the source.
|
||||
if (targetType.isa<mlir::FloatType>()) {
|
||||
targetType = Torch::FloatType::get(op->getContext());
|
||||
torchArg = typeConverter->materializeSourceConversion(
|
||||
rewriter, scfWhileOp.getLoc(), targetType, {to});
|
||||
} else if (targetType.isa<mlir::IntegerType>()) {
|
||||
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
|
||||
if (bitWidth == 1)
|
||||
targetType = Torch::BoolType::get(op->getContext());
|
||||
else
|
||||
targetType = Torch::IntType::get(op->getContext());
|
||||
torchArg = typeConverter->materializeSourceConversion(
|
||||
rewriter, scfWhileOp.getLoc(), targetType, {to});
|
||||
}
|
||||
if (!torchArg)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unsupported type of the operand");
|
||||
barg.value().replaceAllUsesWith(torchArg);
|
||||
}
|
||||
// Inline torch loop body operations into 'after' region.
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
for (auto &operation :
|
||||
llvm::make_early_inc_range(op.region().front().getOperations())) {
|
||||
if (auto primLoopConditionOp = dyn_cast<PrimLoopConditionOp>(operation)) {
|
||||
// Fix up the terminator.
|
||||
SmallVector<Value> loopConditionIterArgs;
|
||||
Value torchShouldContinue = primLoopConditionOp.shouldContinue();
|
||||
Value shouldContinue = typeConverter->materializeTargetConversion(
|
||||
rewriter, scfWhileOp->getLoc(),
|
||||
typeConverter->convertType(torchShouldContinue.getType()),
|
||||
{torchShouldContinue});
|
||||
if (!shouldContinue)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unsupported type of the operand");
|
||||
loopConditionIterArgs.push_back(shouldContinue);
|
||||
for (auto torchArg : primLoopConditionOp.iterArgs()) {
|
||||
Type torchType = torchArg.getType();
|
||||
|
||||
// If the argument is a torch tensor, directly add it in the list of
|
||||
// iter args.
|
||||
if (torchType.isa<Torch::BaseTensorType>()) {
|
||||
loopConditionIterArgs.push_back(torchArg);
|
||||
continue;
|
||||
}
|
||||
Value arg = typeConverter->materializeTargetConversion(
|
||||
rewriter, scfWhileOp->getLoc(),
|
||||
typeConverter->convertType(torchArg.getType()), {torchArg});
|
||||
if (!arg)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported type of the operand");
|
||||
loopConditionIterArgs.push_back(arg);
|
||||
}
|
||||
rewriter.create<scf::YieldOp>(scfWhileOp.getLoc(),
|
||||
loopConditionIterArgs);
|
||||
|
||||
} else {
|
||||
operation.moveBefore(afterBlock, afterBlock->end());
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, scfWhileOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Converts the Torch::PrimLoopOp which is ``For-like`` into scf::ForOp.
|
||||
class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern<PrimLoopOp> {
|
||||
public:
|
||||
using OpConversionPattern<PrimLoopOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PrimLoopOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// Return failure on while-like loops.
|
||||
if (!op.isForLike())
|
||||
return failure();
|
||||
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
SmallVector<Type, 1> newResultTypes;
|
||||
if (failed(
|
||||
typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "could not convert PrimLoopOp outputs");
|
||||
|
||||
// Calculate the lower bound, upper bound and step indices. Currently only
|
||||
// lower-bound = 0 and step = 1 is supported.
|
||||
Location loc = op.getLoc();
|
||||
Value lowerBoundIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value upperBoundIndex = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), adaptor.maxTripCount());
|
||||
auto scfForOp =
|
||||
rewriter.create<scf::ForOp>(loc, lowerBoundIndex, upperBoundIndex,
|
||||
stepIndex, adaptor.iterArgsInit());
|
||||
|
||||
SmallVector<Type> regionArgTypes;
|
||||
SmallVector<Location> regionArgLocs;
|
||||
for (Value value : scfForOp.getLoopBody().front().getArguments()) {
|
||||
regionArgTypes.push_back(value.getType());
|
||||
regionArgLocs.push_back(value.getLoc());
|
||||
}
|
||||
|
||||
// Populate the loop body region.
|
||||
if (!scfForOp.getLoopBody().empty())
|
||||
rewriter.eraseBlock(&scfForOp.getLoopBody().back());
|
||||
|
||||
auto *block = rewriter.createBlock(&scfForOp.getLoopBody(),
|
||||
scfForOp.getLoopBody().begin(),
|
||||
regionArgTypes, regionArgLocs);
|
||||
|
||||
// Rewrite uses of the torch loop block arguments to the new for-loop
|
||||
// "block" arguments
|
||||
for (const auto &barg : enumerate(op.region().front().getArguments())) {
|
||||
Value to = block->getArgument(barg.index());
|
||||
if (to.getType().isa<mlir::IndexType>())
|
||||
to =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(), to);
|
||||
Type targetType = to.getType();
|
||||
Value torchArg = to;
|
||||
|
||||
// If the target type is non-torch type, then use TypeConverter to convert
|
||||
// the type of the source.
|
||||
if (targetType.isa<mlir::FloatType>()) {
|
||||
targetType = Torch::FloatType::get(op->getContext());
|
||||
torchArg = typeConverter->materializeSourceConversion(
|
||||
rewriter, scfForOp.getLoc(), targetType, {to});
|
||||
} else if (targetType.isa<mlir::IntegerType>()) {
|
||||
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
|
||||
if (bitWidth == 1)
|
||||
targetType = Torch::BoolType::get(op->getContext());
|
||||
else
|
||||
targetType = Torch::IntType::get(op->getContext());
|
||||
torchArg = typeConverter->materializeSourceConversion(
|
||||
rewriter, scfForOp.getLoc(), targetType, {to});
|
||||
}
|
||||
if (!torchArg)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unsupported type of the operand");
|
||||
barg.value().replaceAllUsesWith(torchArg);
|
||||
}
|
||||
|
||||
// Inline torch loop body operations into 'after' region.
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
for (auto &operation :
|
||||
llvm::make_early_inc_range(op.region().front().getOperations())) {
|
||||
if (auto primLoopConditionOp = dyn_cast<PrimLoopConditionOp>(operation)) {
|
||||
// Fix up the terminator.
|
||||
SmallVector<Value> loopConditionIterArgs;
|
||||
for (auto torchArg : primLoopConditionOp.iterArgs()) {
|
||||
Type torchType = torchArg.getType();
|
||||
|
||||
// If the argument is a torch tensor, directly add it in the list of
|
||||
// iter args.
|
||||
if (torchType.isa<Torch::BaseTensorType>()) {
|
||||
loopConditionIterArgs.push_back(torchArg);
|
||||
continue;
|
||||
}
|
||||
Value arg = typeConverter->materializeTargetConversion(
|
||||
rewriter, scfForOp.getLoc(),
|
||||
typeConverter->convertType(torchArg.getType()), {torchArg});
|
||||
if (!arg)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported type of the operand");
|
||||
loopConditionIterArgs.push_back(arg);
|
||||
}
|
||||
rewriter.create<scf::YieldOp>(scfForOp.getLoc(), loopConditionIterArgs);
|
||||
} else {
|
||||
operation.moveBefore(block, block->end());
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, scfForOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<scf::SCFDialect>();
|
||||
registry.insert<scf::SCFDialect, arith::ArithmeticDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect>();
|
||||
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect,
|
||||
arith::ArithmeticDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
@ -83,6 +340,9 @@ public:
|
|||
patterns.add<ConvertTorchPrimIfOp>(typeConverter, context);
|
||||
target.addIllegalOp<PrimIfYieldOp>();
|
||||
patterns.add<ConvertTorchPrimIfYieldOp>(typeConverter, context);
|
||||
target.addIllegalOp<PrimLoopOp>();
|
||||
patterns.add<ConvertTorchPrimLoopWhileLikeOp>(typeConverter, context);
|
||||
patterns.add<ConvertTorchPrimLoopForLikeOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -61,8 +61,8 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
// and those constants get somewhat obscured by TorchToStd.
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToStdPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToStdPass());
|
||||
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
||||
|
||||
if (options.optimize) {
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -79,6 +80,7 @@ class VerifyLinalgOnTensorsBackendContractPass
|
|||
target.addDynamicallyLegalDialect<AffineDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<TMTensorDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
|
||||
|
||||
// ConstantOp is used for tensors and for scalars.
|
||||
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
||||
|
|
|
@ -52,3 +52,4 @@ def register_all_tests():
|
|||
from . import index_put
|
||||
from . import pooling
|
||||
from . import return_types
|
||||
from . import control_flow
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from numpy import int64
|
||||
import torch
|
||||
import random
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TorchPrimLoopForLikeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
x_val = x.size(0)
|
||||
sum = 0
|
||||
for i in range(x_val):
|
||||
sum += i
|
||||
return sum
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TorchPrimLoopForLikeModule())
|
||||
def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(0, 10, (6, 8)))
|
||||
|
||||
# ==============================================================================
|
||||
class TorchPrimLoopWhileLikeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
x_val = x.size(0)
|
||||
sum = 0
|
||||
while(x_val > sum):
|
||||
sum += 1
|
||||
return sum
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeModule())
|
||||
def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(0, 10, (6, 8)))
|
|
@ -64,3 +64,153 @@ func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int
|
|||
}
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.loop$while
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !torch.int) -> !torch.float {
|
||||
// CHECK: %[[TORCH_FLOAT_VAL:.*]] = torch.constant.float
|
||||
// CHECK-NEXT: %[[FLOAT_VAL:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL]]
|
||||
// CHECK-NEXT: %[[MAX_TRIP_COUNT:.*]] = torch.constant.int 9223372036854775807
|
||||
// CHECK-NEXT: %[[TORCH_CONDITION:.*]] = torch.aten.lt.float_int %[[TORCH_FLOAT_VAL]], %[[ARG0]]
|
||||
// CHECK-NEXT: %[[CONDITION:.*]] = torch_c.to_i1 %[[TORCH_CONDITION]]
|
||||
// CHECK-NEXT: %[[LOOP:.*]] = scf.while
|
||||
// CHECK-SAME: (%[[LOOP_CONDITION:.*]] = %[[CONDITION]], %[[LOOP_ARG:.*]] = %[[FLOAT_VAL]]) : (i1, f64) -> f64 {
|
||||
// CHECK-NEXT: scf.condition(%[[LOOP_CONDITION]]) %[[LOOP_ARG]]
|
||||
// CHECK-NEXT: } do {
|
||||
// CHECK-NEXT: ^bb0(%[[BLOCK_ARG:.*]]: f64):
|
||||
// CHECK-NEXT: %[[TORCH_BLOCK_ARG:.*]] = torch_c.from_f64 %[[BLOCK_ARG]]
|
||||
// CHECK-NEXT: %[[TORCH_VAL:.*]] = torch.aten.mul.float %[[TORCH_BLOCK_ARG]], %[[TORCH_BLOCK_ARG]]
|
||||
// CHECK-NEXT: %[[TORCH_BLOCK_CONDITION:.*]] = torch.aten.lt.float_int %[[TORCH_VAL]], %[[ARG0]]
|
||||
// CHECK-NEXT: %[[BLOCK_CONDITION:.*]] = torch_c.to_i1 %[[TORCH_BLOCK_CONDITION]]
|
||||
// CHECK-NEXT: %[[VAL:.*]] = torch_c.to_f64 %[[TORCH_VAL]]
|
||||
// CHECK-NEXT: scf.yield %[[BLOCK_CONDITION]], %[[VAL]] : i1, f64
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[TORCH_LOOP:.*]] = torch_c.from_f64 %[[LOOP]]
|
||||
// CHECK-NEXT: return %[[TORCH_LOOP]] : !torch.float
|
||||
func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float {
|
||||
%float3.200000e00 = torch.constant.float 3.200000e+00
|
||||
%int9223372036854775807 = torch.constant.int 9223372036854775807
|
||||
%0 = torch.aten.lt.float_int %float3.200000e00, %arg0 : !torch.float, !torch.int -> !torch.bool
|
||||
%1 = torch.prim.Loop %int9223372036854775807, %0, init(%float3.200000e00) {
|
||||
^bb0(%arg1: !torch.int, %arg2: !torch.float):
|
||||
%2 = torch.aten.mul.float %arg2, %arg2 : !torch.float, !torch.float -> !torch.float
|
||||
%3 = torch.aten.lt.float_int %2, %arg0 : !torch.float, !torch.int -> !torch.bool
|
||||
torch.prim.Loop.condition %3, iter(%2 : !torch.float)
|
||||
} : (!torch.int, !torch.bool, !torch.float) -> !torch.float
|
||||
return %1 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.loop$while_with_multiple_values
|
||||
// CHECK-SAME: () -> (!torch.float, !torch.float) {
|
||||
// CHECK: %[[TORCH_FLOAT_VAL_0:.*]] = torch.constant.float
|
||||
// CHECK-NEXT: %[[FLOAT_VAL_0:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL_0]]
|
||||
// CHECK-NEXT: %[[MAX_TRIP_COUNT:.*]] = torch.constant.int 9223372036854775807
|
||||
// CHECK-NEXT: %[[TORCH_FLOAT_VAL_1:.*]] = torch.constant.float
|
||||
// CHECK-NEXT: %[[FLOAT_VAL_1:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL_1]]
|
||||
// CHECK-NEXT: %[[TORCH_CONDITION:.*]] = torch.aten.lt.float %[[TORCH_FLOAT_VAL_0]], %[[TORCH_FLOAT_VAL_1]]
|
||||
// CHECK-NEXT: %[[CONDITION:.*]] = torch_c.to_i1 %[[TORCH_CONDITION]]
|
||||
// CHECK-NEXT: %[[LOOP:.*]]:2 = scf.while
|
||||
// CHECK-SAME: (%[[LOOP_CONDITION:.*]] = %[[CONDITION]], %[[LOOP_ARG_0:.*]] = %[[FLOAT_VAL_0]], %[[LOOP_ARG_1:.*]] = %[[FLOAT_VAL_1]]) : (i1, f64, f64) -> (f64, f64) {
|
||||
// CHECK-NEXT: scf.condition(%[[LOOP_CONDITION]]) %[[LOOP_ARG_0]], %[[LOOP_ARG_1]]
|
||||
// CHECK-NEXT: } do {
|
||||
// CHECK-NEXT: ^bb0(%[[BLOCK_ARG_0:.*]]: f64, %[[BLOCK_ARG_1:.*]]: f64):
|
||||
// CHECK-NEXT: %[[TORCH_BLOCK_ARG_0:.*]] = torch_c.from_f64 %[[BLOCK_ARG_0]]
|
||||
// CHECK-NEXT: %[[TORCH_BLOCK_ARG_1:.*]] = torch_c.from_f64 %[[BLOCK_ARG_1]]
|
||||
// CHECK-NEXT: %[[TORCH_VAL_0:.*]] = torch.aten.mul.float %[[TORCH_BLOCK_ARG_0]], %[[TORCH_BLOCK_ARG_0]]
|
||||
// CHECK-NEXT: %[[TORCH_BLOCK_CONDITION:.*]] = torch.aten.lt.float %[[TORCH_VAL_0]], %[[TORCH_BLOCK_ARG_1]]
|
||||
// CHECK-NEXT: %[[CONSTANT:.*]] = torch.constant.int -2
|
||||
// CHECK-NEXT: %[[TORCH_VAL_1:.*]] = torch.aten.add.float_int %[[TORCH_BLOCK_ARG_1]], %[[CONSTANT]]
|
||||
// CHECK-NEXT: %[[BLOCK_CONDITION:.*]] = torch_c.to_i1 %[[TORCH_BLOCK_CONDITION]]
|
||||
// CHECK-NEXT: %[[VAL_0:.*]] = torch_c.to_f64 %[[TORCH_VAL_0]]
|
||||
// CHECK-NEXT: %[[VAL_1:.*]] = torch_c.to_f64 %[[TORCH_VAL_1]]
|
||||
// CHECK-NEXT: scf.yield %[[BLOCK_CONDITION]], %[[VAL_0]], %[[VAL_1]] : i1, f64, f64
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[TORCH_LOOP_0:.*]] = torch_c.from_f64 %[[LOOP]]#0
|
||||
// CHECK-NEXT: %[[TORCH_LOOP_1:.*]] = torch_c.from_f64 %[[LOOP]]#1
|
||||
// CHECK-NEXT: return %[[TORCH_LOOP_0]], %[[TORCH_LOOP_1]] : !torch.float, !torch.float
|
||||
func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.float) {
|
||||
%float3.200000e00 = torch.constant.float 3.200000e+00
|
||||
%int9223372036854775807 = torch.constant.int 9223372036854775807
|
||||
%float9.0 = torch.constant.float 9.0
|
||||
%0 = torch.aten.lt.float %float3.200000e00, %float9.0 : !torch.float, !torch.float -> !torch.bool
|
||||
%1:2 = torch.prim.Loop %int9223372036854775807, %0, init(%float3.200000e00, %float9.0) {
|
||||
^bb0(%arg1: !torch.int, %arg2: !torch.float, %arg3: !torch.float):
|
||||
%2 = torch.aten.mul.float %arg2, %arg2 : !torch.float, !torch.float -> !torch.float
|
||||
%3 = torch.aten.lt.float %2, %arg3 : !torch.float, !torch.float -> !torch.bool
|
||||
%4 = torch.constant.int -2
|
||||
%5 = torch.aten.add.float_int %arg3, %4 : !torch.float, !torch.int -> !torch.float
|
||||
torch.prim.Loop.condition %3, iter(%2, %5 : !torch.float, !torch.float)
|
||||
} : (!torch.int, !torch.bool, !torch.float, !torch.float) -> (!torch.float, !torch.float)
|
||||
return %1#0, %1#1 : !torch.float, !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.Loop$for
|
||||
// CHECK-SAME: (%[[TORCH_ARG0:.*]]: !torch.int) -> !torch.float {
|
||||
// CHECK: %[[ARG0:.*]] = torch_c.to_i64 %[[TORCH_ARG0]]
|
||||
// CHECK-NEXT: %{{.*}} = torch.constant.bool true
|
||||
// CHECK-NEXT: %[[TORCH_FLOAT:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-NEXT: %[[FLOAT:.*]] = torch_c.to_f64 %[[TORCH_FLOAT]]
|
||||
// CHECK-NEXT: %[[LOWER_BOUND:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[STEP:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[UPPER_BOUND:.*]] = arith.index_cast %[[ARG0]] : i64 to index
|
||||
// CHECK-NEXT: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[LOWER_BOUND]] to %[[UPPER_BOUND]] step %[[STEP]]
|
||||
// CHECK-SAME: iter_args(%[[ITER_ARG:.*]] = %[[FLOAT]]) -> (f64) {
|
||||
// CHECK-NEXT: %[[IV_I64:.*]] = arith.index_cast %[[IV]] : index to i64
|
||||
// CHECK-NEXT: %[[TORCH_IV:.*]] = torch_c.from_i64 %[[IV_I64]]
|
||||
// CHECK-NEXT: %[[TORCH_ITER_ARG:.*]] = torch_c.from_f64 %[[ITER_ARG]]
|
||||
// CHECK-NEXT: %[[TORCH_VAL:.*]] = torch.aten.add.float_int %[[TORCH_ITER_ARG]], %[[TORCH_IV]]
|
||||
// CHECK-NEXT: %[[VAL:.*]] = torch_c.to_f64 %[[TORCH_VAL]]
|
||||
// CHECK-NEXT: scf.yield %[[VAL]] : f64
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[RETURN:.*]] = torch_c.from_f64 %[[LOOP]]
|
||||
// CHECK-NEXT: return %[[RETURN]] : !torch.float
|
||||
// CHECK-NEXT: }
|
||||
func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float {
|
||||
%true = torch.constant.bool true
|
||||
%float0.000000e00 = torch.constant.float 0.000000e+00
|
||||
%0 = torch.prim.Loop %arg0, %true, init(%float0.000000e00) {
|
||||
^bb0(%arg1: !torch.int, %arg2: !torch.float):
|
||||
%1 = torch.aten.add.float_int %arg2, %arg1 : !torch.float, !torch.int -> !torch.float
|
||||
torch.prim.Loop.condition %true, iter(%1 : !torch.float)
|
||||
} : (!torch.int, !torch.bool, !torch.float) -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.Loop$for_with_multiple_results
|
||||
// CHECK-SAME: (%[[TORCH_ARG0:.*]]: !torch.int) -> (!torch.float, !torch.float) {
|
||||
// CHECK: %[[ARG0:.*]] = torch_c.to_i64 %[[TORCH_ARG0]]
|
||||
// CHECK-NEXT: %{{.*}} = torch.constant.bool true
|
||||
// CHECK-NEXT: %[[TORCH_FLOAT_0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-NEXT: %[[FLOAT_0:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_0]]
|
||||
// CHECK-NEXT: %[[TORCH_FLOAT_1:.*]] = torch.constant.float 9.000000e+00
|
||||
// CHECK-NEXT: %[[FLOAT_1:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_1]]
|
||||
// CHECK-NEXT: %[[LOWER_BOUND:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[STEP:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[UPPER_BOUND:.*]] = arith.index_cast %[[ARG0]] : i64 to index
|
||||
// CHECK-NEXT: %[[LOOP:.*]]:2 = scf.for %[[IV:.*]] = %[[LOWER_BOUND]] to %[[UPPER_BOUND]] step %[[STEP]]
|
||||
// CHECK-SAME: iter_args(%[[ITER_ARG_0:.*]] = %[[FLOAT_0]], %[[ITER_ARG_1:.*]] = %[[FLOAT_1]]) -> (f64, f64) {
|
||||
// CHECK-NEXT: %[[IV_I64:.*]] = arith.index_cast %[[IV]] : index to i64
|
||||
// CHECK-NEXT: %[[TORCH_IV:.*]] = torch_c.from_i64 %[[IV_I64]]
|
||||
// CHECK-NEXT: %[[TORCH_ITER_ARG_0:.*]] = torch_c.from_f64 %[[ITER_ARG_0]]
|
||||
// CHECK-NEXT: %[[TORCH_ITER_ARG_1:.*]] = torch_c.from_f64 %[[ITER_ARG_1]]
|
||||
// CHECK-NEXT: %[[TORCH_VAL_0:.*]] = torch.aten.add.float_int %[[TORCH_ITER_ARG_0]], %[[TORCH_IV]]
|
||||
// CHECK-NEXT: %[[TORCH_VAL_1:.*]] = torch.aten.mul.float %[[TORCH_ITER_ARG_1]], %[[TORCH_VAL_0]]
|
||||
// CHECK-NEXT: %[[VAL_0:.*]] = torch_c.to_f64 %[[TORCH_VAL_0]]
|
||||
// CHECK-NEXT: %[[VAL_1:.*]] = torch_c.to_f64 %[[TORCH_VAL_1]]
|
||||
// CHECK-NEXT: scf.yield %[[VAL_0]], %[[VAL_1]] : f64, f64
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[RETURN_0:.*]] = torch_c.from_f64 %[[LOOP]]#0
|
||||
// CHECK-NEXT: %[[RETURN_1:.*]] = torch_c.from_f64 %[[LOOP]]#1
|
||||
// CHECK-NEXT: return %[[RETURN_0]], %[[RETURN_1]] : !torch.float, !torch.float
|
||||
// CHECK-NEXT: }
|
||||
func @torch.prim.Loop$for_with_multiple_results(%arg0: !torch.int) -> (!torch.float, !torch.float) {
|
||||
%true = torch.constant.bool true
|
||||
%float0.000000e00 = torch.constant.float 0.000000e+00
|
||||
%float9.0 = torch.constant.float 9.0
|
||||
%0:2 = torch.prim.Loop %arg0, %true, init(%float0.000000e00, %float9.0) {
|
||||
^bb0(%arg1: !torch.int, %arg2: !torch.float, %arg3: !torch.float):
|
||||
%1 = torch.aten.add.float_int %arg2, %arg1 : !torch.float, !torch.int -> !torch.float
|
||||
%2 = torch.aten.mul.float %arg3, %1 : !torch.float, !torch.float -> !torch.float
|
||||
torch.prim.Loop.condition %true, iter(%1, %2 : !torch.float, !torch.float)
|
||||
} : (!torch.int, !torch.bool, !torch.float, !torch.float) -> (!torch.float, !torch.float)
|
||||
return %0#0, %0#1 : !torch.float, !torch.float
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue