[TORCH][MLIR]Add lowering for control flow operations.

1. This commit adds lowering of "while-like" prim loop to scf.while
operation.
2. Adds lowering of "for-like" prim loops to scf.for operation.

Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
pull/816/head snapshot-20220429.421
Prateek Gupta 2022-03-31 14:24:44 +00:00
parent ef546e1137
commit e1db318a3c
6 changed files with 473 additions and 3 deletions

View File

@ -10,10 +10,13 @@
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
@ -61,18 +64,272 @@ public:
};
} // namespace
namespace {
// Converts the Torch::PrimLoopOp which is ``While-like`` into scf::WhileOp.
class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern<PrimLoopOp> {
public:
using OpConversionPattern<PrimLoopOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PrimLoopOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Return failure on for-like loops.
if (op.isForLike())
return failure();
TypeConverter *typeConverter = getTypeConverter();
SmallVector<Type, 1> newResultTypes;
if (failed(
typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
return rewriter.notifyMatchFailure(
op, "could not convert PrimLoopOp outputs");
// Create scf.while operation using the operands of torch::primloop. The
// first argument of the primloop correspond to `maxTripCount` which
// can be omitted in the `scf.while` operation.
Value condition = adaptor.initialCondition();
ValueRange iterArgsInit = adaptor.iterArgsInit();
SmallVector<Value> scfWhileOpOperands{condition};
scfWhileOpOperands.append(iterArgsInit.begin(), iterArgsInit.end());
auto scfWhileOp = rewriter.create<scf::WhileOp>(
op->getLoc(), newResultTypes, scfWhileOpOperands);
// Populate the before region of the scf.while operation. The `before`
// region will have only one block and the arguments of the block must match
// the arguments of `scf.while` operation.
SmallVector<Type> beforeRegionArgTypes;
SmallVector<Location> beforeRegionArgLocs;
for (Value value : scfWhileOp->getOperands()) {
beforeRegionArgTypes.push_back(value.getType());
beforeRegionArgLocs.push_back(value.getLoc());
}
auto *beforeBlock = rewriter.createBlock(
&scfWhileOp.getBefore(), scfWhileOp.getBefore().begin(),
beforeRegionArgTypes, beforeRegionArgLocs);
rewriter.setInsertionPointToEnd(beforeBlock);
// Fetch the condition passed as the iter argument. Pass rest of the
// arguments to the after block.
auto scfConditionOp = rewriter.create<scf::ConditionOp>(
op.getLoc(), beforeBlock->getArgument(0),
beforeBlock->getArguments().drop_front());
// Populate the after region.
if (!scfWhileOp.getAfter().empty())
rewriter.eraseBlock(&scfWhileOp.getAfter().back());
SmallVector<Type> afterRegionArgTypes;
SmallVector<Location> afterRegionArgLocs;
for (Value value : scfConditionOp.getArgs()) {
afterRegionArgTypes.push_back(value.getType());
afterRegionArgLocs.push_back(value.getLoc());
}
auto *afterBlock = rewriter.createBlock(
&scfWhileOp.getAfter(), scfWhileOp.getAfter().begin(),
afterRegionArgTypes, afterRegionArgLocs);
// Rewrite uses of the torch loop block arguments to the new while-loop
// "after" arguments. Leave the induction variable of prim loop(first
// argument) because while like prim loops does not use the induction
// variable.
for (const auto &barg :
enumerate(op.region().front().getArguments().drop_front())) {
Value to = afterBlock->getArgument(barg.index());
Type targetType = to.getType();
Value torchArg = to;
// If the target type is non-torch type, then use TypeConverter to convert
// the type of the source.
if (targetType.isa<mlir::FloatType>()) {
targetType = Torch::FloatType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfWhileOp.getLoc(), targetType, {to});
} else if (targetType.isa<mlir::IntegerType>()) {
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
if (bitWidth == 1)
targetType = Torch::BoolType::get(op->getContext());
else
targetType = Torch::IntType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfWhileOp.getLoc(), targetType, {to});
}
if (!torchArg)
return rewriter.notifyMatchFailure(op,
"unsupported type of the operand");
barg.value().replaceAllUsesWith(torchArg);
}
// Inline torch loop body operations into 'after' region.
PatternRewriter::InsertionGuard guard(rewriter);
for (auto &operation :
llvm::make_early_inc_range(op.region().front().getOperations())) {
if (auto primLoopConditionOp = dyn_cast<PrimLoopConditionOp>(operation)) {
// Fix up the terminator.
SmallVector<Value> loopConditionIterArgs;
Value torchShouldContinue = primLoopConditionOp.shouldContinue();
Value shouldContinue = typeConverter->materializeTargetConversion(
rewriter, scfWhileOp->getLoc(),
typeConverter->convertType(torchShouldContinue.getType()),
{torchShouldContinue});
if (!shouldContinue)
return rewriter.notifyMatchFailure(op,
"unsupported type of the operand");
loopConditionIterArgs.push_back(shouldContinue);
for (auto torchArg : primLoopConditionOp.iterArgs()) {
Type torchType = torchArg.getType();
// If the argument is a torch tensor, directly add it in the list of
// iter args.
if (torchType.isa<Torch::BaseTensorType>()) {
loopConditionIterArgs.push_back(torchArg);
continue;
}
Value arg = typeConverter->materializeTargetConversion(
rewriter, scfWhileOp->getLoc(),
typeConverter->convertType(torchArg.getType()), {torchArg});
if (!arg)
return rewriter.notifyMatchFailure(
op, "unsupported type of the operand");
loopConditionIterArgs.push_back(arg);
}
rewriter.create<scf::YieldOp>(scfWhileOp.getLoc(),
loopConditionIterArgs);
} else {
operation.moveBefore(afterBlock, afterBlock->end());
}
}
rewriter.replaceOp(op, scfWhileOp->getResults());
return success();
}
};
} // namespace
namespace {
// Converts the Torch::PrimLoopOp which is ``For-like`` into scf::ForOp.
class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern<PrimLoopOp> {
public:
using OpConversionPattern<PrimLoopOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PrimLoopOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Return failure on while-like loops.
if (!op.isForLike())
return failure();
TypeConverter *typeConverter = getTypeConverter();
SmallVector<Type, 1> newResultTypes;
if (failed(
typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
return rewriter.notifyMatchFailure(
op, "could not convert PrimLoopOp outputs");
// Calculate the lower bound, upper bound and step indices. Currently only
// lower-bound = 0 and step = 1 is supported.
Location loc = op.getLoc();
Value lowerBoundIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value upperBoundIndex = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), adaptor.maxTripCount());
auto scfForOp =
rewriter.create<scf::ForOp>(loc, lowerBoundIndex, upperBoundIndex,
stepIndex, adaptor.iterArgsInit());
SmallVector<Type> regionArgTypes;
SmallVector<Location> regionArgLocs;
for (Value value : scfForOp.getLoopBody().front().getArguments()) {
regionArgTypes.push_back(value.getType());
regionArgLocs.push_back(value.getLoc());
}
// Populate the loop body region.
if (!scfForOp.getLoopBody().empty())
rewriter.eraseBlock(&scfForOp.getLoopBody().back());
auto *block = rewriter.createBlock(&scfForOp.getLoopBody(),
scfForOp.getLoopBody().begin(),
regionArgTypes, regionArgLocs);
// Rewrite uses of the torch loop block arguments to the new for-loop
// "block" arguments
for (const auto &barg : enumerate(op.region().front().getArguments())) {
Value to = block->getArgument(barg.index());
if (to.getType().isa<mlir::IndexType>())
to =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(), to);
Type targetType = to.getType();
Value torchArg = to;
// If the target type is non-torch type, then use TypeConverter to convert
// the type of the source.
if (targetType.isa<mlir::FloatType>()) {
targetType = Torch::FloatType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfForOp.getLoc(), targetType, {to});
} else if (targetType.isa<mlir::IntegerType>()) {
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
if (bitWidth == 1)
targetType = Torch::BoolType::get(op->getContext());
else
targetType = Torch::IntType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfForOp.getLoc(), targetType, {to});
}
if (!torchArg)
return rewriter.notifyMatchFailure(op,
"unsupported type of the operand");
barg.value().replaceAllUsesWith(torchArg);
}
// Inline torch loop body operations into 'after' region.
PatternRewriter::InsertionGuard guard(rewriter);
for (auto &operation :
llvm::make_early_inc_range(op.region().front().getOperations())) {
if (auto primLoopConditionOp = dyn_cast<PrimLoopConditionOp>(operation)) {
// Fix up the terminator.
SmallVector<Value> loopConditionIterArgs;
for (auto torchArg : primLoopConditionOp.iterArgs()) {
Type torchType = torchArg.getType();
// If the argument is a torch tensor, directly add it in the list of
// iter args.
if (torchType.isa<Torch::BaseTensorType>()) {
loopConditionIterArgs.push_back(torchArg);
continue;
}
Value arg = typeConverter->materializeTargetConversion(
rewriter, scfForOp.getLoc(),
typeConverter->convertType(torchArg.getType()), {torchArg});
if (!arg)
return rewriter.notifyMatchFailure(
op, "unsupported type of the operand");
loopConditionIterArgs.push_back(arg);
}
rewriter.create<scf::YieldOp>(scfForOp.getLoc(), loopConditionIterArgs);
} else {
operation.moveBefore(block, block->end());
}
}
rewriter.replaceOp(op, scfForOp->getResults());
return success();
}
};
} // namespace
namespace {
class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect>();
registry.insert<scf::SCFDialect, arith::ArithmeticDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect>();
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect,
arith::ArithmeticDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
@ -83,6 +340,9 @@ public:
patterns.add<ConvertTorchPrimIfOp>(typeConverter, context);
target.addIllegalOp<PrimIfYieldOp>();
patterns.add<ConvertTorchPrimIfYieldOp>(typeConverter, context);
target.addIllegalOp<PrimLoopOp>();
patterns.add<ConvertTorchPrimLoopWhileLikeOp>(typeConverter, context);
patterns.add<ConvertTorchPrimLoopForLikeOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -61,8 +61,8 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
// and those constants get somewhat obscured by TorchToStd.
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
if (options.optimize) {

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
@ -79,6 +80,7 @@ class VerifyLinalgOnTensorsBackendContractPass
target.addDynamicallyLegalDialect<AffineDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<TMTensorDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
// ConstantOp is used for tensors and for scalars.
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);

View File

@ -52,3 +52,4 @@ def register_all_tests():
from . import index_put
from . import pooling
from . import return_types
from . import control_flow

View File

@ -0,0 +1,57 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from numpy import int64
import torch
import random
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class TorchPrimLoopForLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True)
])
def forward(self, x):
x_val = x.size(0)
sum = 0
for i in range(x_val):
sum += i
return sum
@register_test_case(module_factory=lambda: TorchPrimLoopForLikeModule())
def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils):
module.forward(torch.randint(0, 10, (6, 8)))
# ==============================================================================
class TorchPrimLoopWhileLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True)
])
def forward(self, x):
x_val = x.size(0)
sum = 0
while(x_val > sum):
sum += 1
return sum
@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeModule())
def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils):
module.forward(torch.randint(0, 10, (6, 8)))

View File

@ -64,3 +64,153 @@ func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int
}
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.prim.loop$while
// CHECK-SAME: (%[[ARG0:.*]]: !torch.int) -> !torch.float {
// CHECK: %[[TORCH_FLOAT_VAL:.*]] = torch.constant.float
// CHECK-NEXT: %[[FLOAT_VAL:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL]]
// CHECK-NEXT: %[[MAX_TRIP_COUNT:.*]] = torch.constant.int 9223372036854775807
// CHECK-NEXT: %[[TORCH_CONDITION:.*]] = torch.aten.lt.float_int %[[TORCH_FLOAT_VAL]], %[[ARG0]]
// CHECK-NEXT: %[[CONDITION:.*]] = torch_c.to_i1 %[[TORCH_CONDITION]]
// CHECK-NEXT: %[[LOOP:.*]] = scf.while
// CHECK-SAME: (%[[LOOP_CONDITION:.*]] = %[[CONDITION]], %[[LOOP_ARG:.*]] = %[[FLOAT_VAL]]) : (i1, f64) -> f64 {
// CHECK-NEXT: scf.condition(%[[LOOP_CONDITION]]) %[[LOOP_ARG]]
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[BLOCK_ARG:.*]]: f64):
// CHECK-NEXT: %[[TORCH_BLOCK_ARG:.*]] = torch_c.from_f64 %[[BLOCK_ARG]]
// CHECK-NEXT: %[[TORCH_VAL:.*]] = torch.aten.mul.float %[[TORCH_BLOCK_ARG]], %[[TORCH_BLOCK_ARG]]
// CHECK-NEXT: %[[TORCH_BLOCK_CONDITION:.*]] = torch.aten.lt.float_int %[[TORCH_VAL]], %[[ARG0]]
// CHECK-NEXT: %[[BLOCK_CONDITION:.*]] = torch_c.to_i1 %[[TORCH_BLOCK_CONDITION]]
// CHECK-NEXT: %[[VAL:.*]] = torch_c.to_f64 %[[TORCH_VAL]]
// CHECK-NEXT: scf.yield %[[BLOCK_CONDITION]], %[[VAL]] : i1, f64
// CHECK-NEXT: }
// CHECK-NEXT: %[[TORCH_LOOP:.*]] = torch_c.from_f64 %[[LOOP]]
// CHECK-NEXT: return %[[TORCH_LOOP]] : !torch.float
func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float {
%float3.200000e00 = torch.constant.float 3.200000e+00
%int9223372036854775807 = torch.constant.int 9223372036854775807
%0 = torch.aten.lt.float_int %float3.200000e00, %arg0 : !torch.float, !torch.int -> !torch.bool
%1 = torch.prim.Loop %int9223372036854775807, %0, init(%float3.200000e00) {
^bb0(%arg1: !torch.int, %arg2: !torch.float):
%2 = torch.aten.mul.float %arg2, %arg2 : !torch.float, !torch.float -> !torch.float
%3 = torch.aten.lt.float_int %2, %arg0 : !torch.float, !torch.int -> !torch.bool
torch.prim.Loop.condition %3, iter(%2 : !torch.float)
} : (!torch.int, !torch.bool, !torch.float) -> !torch.float
return %1 : !torch.float
}
// CHECK-LABEL: func @torch.prim.loop$while_with_multiple_values
// CHECK-SAME: () -> (!torch.float, !torch.float) {
// CHECK: %[[TORCH_FLOAT_VAL_0:.*]] = torch.constant.float
// CHECK-NEXT: %[[FLOAT_VAL_0:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL_0]]
// CHECK-NEXT: %[[MAX_TRIP_COUNT:.*]] = torch.constant.int 9223372036854775807
// CHECK-NEXT: %[[TORCH_FLOAT_VAL_1:.*]] = torch.constant.float
// CHECK-NEXT: %[[FLOAT_VAL_1:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL_1]]
// CHECK-NEXT: %[[TORCH_CONDITION:.*]] = torch.aten.lt.float %[[TORCH_FLOAT_VAL_0]], %[[TORCH_FLOAT_VAL_1]]
// CHECK-NEXT: %[[CONDITION:.*]] = torch_c.to_i1 %[[TORCH_CONDITION]]
// CHECK-NEXT: %[[LOOP:.*]]:2 = scf.while
// CHECK-SAME: (%[[LOOP_CONDITION:.*]] = %[[CONDITION]], %[[LOOP_ARG_0:.*]] = %[[FLOAT_VAL_0]], %[[LOOP_ARG_1:.*]] = %[[FLOAT_VAL_1]]) : (i1, f64, f64) -> (f64, f64) {
// CHECK-NEXT: scf.condition(%[[LOOP_CONDITION]]) %[[LOOP_ARG_0]], %[[LOOP_ARG_1]]
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[BLOCK_ARG_0:.*]]: f64, %[[BLOCK_ARG_1:.*]]: f64):
// CHECK-NEXT: %[[TORCH_BLOCK_ARG_0:.*]] = torch_c.from_f64 %[[BLOCK_ARG_0]]
// CHECK-NEXT: %[[TORCH_BLOCK_ARG_1:.*]] = torch_c.from_f64 %[[BLOCK_ARG_1]]
// CHECK-NEXT: %[[TORCH_VAL_0:.*]] = torch.aten.mul.float %[[TORCH_BLOCK_ARG_0]], %[[TORCH_BLOCK_ARG_0]]
// CHECK-NEXT: %[[TORCH_BLOCK_CONDITION:.*]] = torch.aten.lt.float %[[TORCH_VAL_0]], %[[TORCH_BLOCK_ARG_1]]
// CHECK-NEXT: %[[CONSTANT:.*]] = torch.constant.int -2
// CHECK-NEXT: %[[TORCH_VAL_1:.*]] = torch.aten.add.float_int %[[TORCH_BLOCK_ARG_1]], %[[CONSTANT]]
// CHECK-NEXT: %[[BLOCK_CONDITION:.*]] = torch_c.to_i1 %[[TORCH_BLOCK_CONDITION]]
// CHECK-NEXT: %[[VAL_0:.*]] = torch_c.to_f64 %[[TORCH_VAL_0]]
// CHECK-NEXT: %[[VAL_1:.*]] = torch_c.to_f64 %[[TORCH_VAL_1]]
// CHECK-NEXT: scf.yield %[[BLOCK_CONDITION]], %[[VAL_0]], %[[VAL_1]] : i1, f64, f64
// CHECK-NEXT: }
// CHECK-NEXT: %[[TORCH_LOOP_0:.*]] = torch_c.from_f64 %[[LOOP]]#0
// CHECK-NEXT: %[[TORCH_LOOP_1:.*]] = torch_c.from_f64 %[[LOOP]]#1
// CHECK-NEXT: return %[[TORCH_LOOP_0]], %[[TORCH_LOOP_1]] : !torch.float, !torch.float
func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.float) {
%float3.200000e00 = torch.constant.float 3.200000e+00
%int9223372036854775807 = torch.constant.int 9223372036854775807
%float9.0 = torch.constant.float 9.0
%0 = torch.aten.lt.float %float3.200000e00, %float9.0 : !torch.float, !torch.float -> !torch.bool
%1:2 = torch.prim.Loop %int9223372036854775807, %0, init(%float3.200000e00, %float9.0) {
^bb0(%arg1: !torch.int, %arg2: !torch.float, %arg3: !torch.float):
%2 = torch.aten.mul.float %arg2, %arg2 : !torch.float, !torch.float -> !torch.float
%3 = torch.aten.lt.float %2, %arg3 : !torch.float, !torch.float -> !torch.bool
%4 = torch.constant.int -2
%5 = torch.aten.add.float_int %arg3, %4 : !torch.float, !torch.int -> !torch.float
torch.prim.Loop.condition %3, iter(%2, %5 : !torch.float, !torch.float)
} : (!torch.int, !torch.bool, !torch.float, !torch.float) -> (!torch.float, !torch.float)
return %1#0, %1#1 : !torch.float, !torch.float
}
// CHECK-LABEL: func @torch.prim.Loop$for
// CHECK-SAME: (%[[TORCH_ARG0:.*]]: !torch.int) -> !torch.float {
// CHECK: %[[ARG0:.*]] = torch_c.to_i64 %[[TORCH_ARG0]]
// CHECK-NEXT: %{{.*}} = torch.constant.bool true
// CHECK-NEXT: %[[TORCH_FLOAT:.*]] = torch.constant.float 0.000000e+00
// CHECK-NEXT: %[[FLOAT:.*]] = torch_c.to_f64 %[[TORCH_FLOAT]]
// CHECK-NEXT: %[[LOWER_BOUND:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[STEP:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[UPPER_BOUND:.*]] = arith.index_cast %[[ARG0]] : i64 to index
// CHECK-NEXT: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[LOWER_BOUND]] to %[[UPPER_BOUND]] step %[[STEP]]
// CHECK-SAME: iter_args(%[[ITER_ARG:.*]] = %[[FLOAT]]) -> (f64) {
// CHECK-NEXT: %[[IV_I64:.*]] = arith.index_cast %[[IV]] : index to i64
// CHECK-NEXT: %[[TORCH_IV:.*]] = torch_c.from_i64 %[[IV_I64]]
// CHECK-NEXT: %[[TORCH_ITER_ARG:.*]] = torch_c.from_f64 %[[ITER_ARG]]
// CHECK-NEXT: %[[TORCH_VAL:.*]] = torch.aten.add.float_int %[[TORCH_ITER_ARG]], %[[TORCH_IV]]
// CHECK-NEXT: %[[VAL:.*]] = torch_c.to_f64 %[[TORCH_VAL]]
// CHECK-NEXT: scf.yield %[[VAL]] : f64
// CHECK-NEXT: }
// CHECK-NEXT: %[[RETURN:.*]] = torch_c.from_f64 %[[LOOP]]
// CHECK-NEXT: return %[[RETURN]] : !torch.float
// CHECK-NEXT: }
func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float {
%true = torch.constant.bool true
%float0.000000e00 = torch.constant.float 0.000000e+00
%0 = torch.prim.Loop %arg0, %true, init(%float0.000000e00) {
^bb0(%arg1: !torch.int, %arg2: !torch.float):
%1 = torch.aten.add.float_int %arg2, %arg1 : !torch.float, !torch.int -> !torch.float
torch.prim.Loop.condition %true, iter(%1 : !torch.float)
} : (!torch.int, !torch.bool, !torch.float) -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func @torch.prim.Loop$for_with_multiple_results
// CHECK-SAME: (%[[TORCH_ARG0:.*]]: !torch.int) -> (!torch.float, !torch.float) {
// CHECK: %[[ARG0:.*]] = torch_c.to_i64 %[[TORCH_ARG0]]
// CHECK-NEXT: %{{.*}} = torch.constant.bool true
// CHECK-NEXT: %[[TORCH_FLOAT_0:.*]] = torch.constant.float 0.000000e+00
// CHECK-NEXT: %[[FLOAT_0:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_0]]
// CHECK-NEXT: %[[TORCH_FLOAT_1:.*]] = torch.constant.float 9.000000e+00
// CHECK-NEXT: %[[FLOAT_1:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_1]]
// CHECK-NEXT: %[[LOWER_BOUND:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[STEP:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[UPPER_BOUND:.*]] = arith.index_cast %[[ARG0]] : i64 to index
// CHECK-NEXT: %[[LOOP:.*]]:2 = scf.for %[[IV:.*]] = %[[LOWER_BOUND]] to %[[UPPER_BOUND]] step %[[STEP]]
// CHECK-SAME: iter_args(%[[ITER_ARG_0:.*]] = %[[FLOAT_0]], %[[ITER_ARG_1:.*]] = %[[FLOAT_1]]) -> (f64, f64) {
// CHECK-NEXT: %[[IV_I64:.*]] = arith.index_cast %[[IV]] : index to i64
// CHECK-NEXT: %[[TORCH_IV:.*]] = torch_c.from_i64 %[[IV_I64]]
// CHECK-NEXT: %[[TORCH_ITER_ARG_0:.*]] = torch_c.from_f64 %[[ITER_ARG_0]]
// CHECK-NEXT: %[[TORCH_ITER_ARG_1:.*]] = torch_c.from_f64 %[[ITER_ARG_1]]
// CHECK-NEXT: %[[TORCH_VAL_0:.*]] = torch.aten.add.float_int %[[TORCH_ITER_ARG_0]], %[[TORCH_IV]]
// CHECK-NEXT: %[[TORCH_VAL_1:.*]] = torch.aten.mul.float %[[TORCH_ITER_ARG_1]], %[[TORCH_VAL_0]]
// CHECK-NEXT: %[[VAL_0:.*]] = torch_c.to_f64 %[[TORCH_VAL_0]]
// CHECK-NEXT: %[[VAL_1:.*]] = torch_c.to_f64 %[[TORCH_VAL_1]]
// CHECK-NEXT: scf.yield %[[VAL_0]], %[[VAL_1]] : f64, f64
// CHECK-NEXT: }
// CHECK-NEXT: %[[RETURN_0:.*]] = torch_c.from_f64 %[[LOOP]]#0
// CHECK-NEXT: %[[RETURN_1:.*]] = torch_c.from_f64 %[[LOOP]]#1
// CHECK-NEXT: return %[[RETURN_0]], %[[RETURN_1]] : !torch.float, !torch.float
// CHECK-NEXT: }
func @torch.prim.Loop$for_with_multiple_results(%arg0: !torch.int) -> (!torch.float, !torch.float) {
%true = torch.constant.bool true
%float0.000000e00 = torch.constant.float 0.000000e+00
%float9.0 = torch.constant.float 9.0
%0:2 = torch.prim.Loop %arg0, %true, init(%float0.000000e00, %float9.0) {
^bb0(%arg1: !torch.int, %arg2: !torch.float, %arg3: !torch.float):
%1 = torch.aten.add.float_int %arg2, %arg1 : !torch.float, !torch.int -> !torch.float
%2 = torch.aten.mul.float %arg3, %1 : !torch.float, !torch.float -> !torch.float
torch.prim.Loop.condition %true, iter(%1, %2 : !torch.float, !torch.float)
} : (!torch.int, !torch.bool, !torch.float, !torch.float) -> (!torch.float, !torch.float)
return %0#0, %0#1 : !torch.float, !torch.float
}