[MLIR][TORCH] Add E2E support for torch.arange op

This commit adds lowering of `aten.arange.start_step` op.
This commit decomposes `aten.arange` and `aten.arange.start` into
`aten.arange.start_step` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/502/head snapshot-20211227.170
Vivek Khandelwal 2021-12-23 13:22:45 +00:00
parent a83004c806
commit 4486de5ef3
8 changed files with 465 additions and 6 deletions

View File

@ -0,0 +1,250 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class ArangeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(5)
@register_test_case(module_factory=lambda: ArangeIntModule())
def ArangeIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(5.0)
@register_test_case(module_factory=lambda: ArangeFloatModule())
def ArangeFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeZeroElementOutputModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(0)
@register_test_case(module_factory=lambda: ArangeZeroElementOutputModule())
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(0, 5)
@register_test_case(module_factory=lambda: ArangeStartIntModule())
def ArangeStartIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(0.0, 5.0)
@register_test_case(module_factory=lambda: ArangeStartFloatModule())
def ArangeStartFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeNegativeStartIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(-10, 5)
@register_test_case(module_factory=lambda: ArangeNegativeStartIntModule())
def ArangeNegativeStartIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeNegativeStartFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1.4, 5.7)
@register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule())
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartStepIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(0, 5, 1)
@register_test_case(module_factory=lambda: ArangeStartStepIntModule())
def ArangeStartStepIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartStepFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1, 5, 1.3)
@register_test_case(module_factory=lambda: ArangeStartStepFloatModule())
def ArangeStartStepFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartNegativeStepIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(10, 1, -2)
@register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule())
def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartNegativeStepFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1, -15, -3.4)
@register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule())
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeDtypeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1, 15, dtype=torch.float32)
@register_test_case(module_factory=lambda: ArangeDtypeFloatModule())
def ArangeDtypeFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeDtypeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(0.2, 5.0, dtype=torch.int64)
@register_test_case(module_factory=lambda: ArangeDtypeIntModule())
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(5.0, dtype=torch.int64, pin_memory=False)
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
module.forward()

View File

@ -45,6 +45,7 @@ from . import squeeze
from . import slice_like
from . import nll_loss
from . import index_select
from . import arange
def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -2028,6 +2028,26 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [
let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
}
def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
AnyTorchScalarType:$start,
AnyTorchScalarType:$end,
AnyTorchScalarType:$step,
TorchOptionalIntType:$dtype,
TorchOptionalIntType:$layout,
TorchOptionalDeviceType:$device,
TorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$start `,` $end `,` $step `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($step) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
}
def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -4022,6 +4022,99 @@ public:
};
} // namespace
namespace {
// Let's say the result of the `aten.arange.start_step` is `output` which is a
// 1-d output tensor. The approach used for generating the output tensor is as
// follows:
// for i in range(ceil((end-start)/step))
// output[i] = start + (i * step)
class ConvertAtenArangeStartStepOp
: public OpConversionPattern<AtenArangeStartStepOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenArangeStartStepOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Add support for layout, pin_memory features.
// Only `none` layout is supported.
if (!op.layout().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default layout is supported");
// The pin_memory should be either `False` or `none`.
bool pinMemory;
if (!op.pin_memory().getType().isa<Torch::NoneType>() &&
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: pin_memory must be either None or false");
}
Location loc = op.getLoc();
TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type dtype = resultType.getElementType();
Value start = convertScalarToDtype(rewriter, loc, adaptor.start(), dtype);
Value end = convertScalarToDtype(rewriter, loc, adaptor.end(), dtype);
Value step = convertScalarToDtype(rewriter, loc, adaptor.step(), dtype);
// The result will always be a 1-d tensor.
// The size of the result is calculated as follows:
// ceil((end - start)/step)
Value resultShape;
if (dtype.isa<mlir::IntegerType>()) {
Value subOut = rewriter.create<arith::SubIOp>(loc, end, start);
resultShape = rewriter.create<arith::CeilDivSIOp>(loc, subOut, step);
} else {
Value subOut = rewriter.create<arith::SubFOp>(loc, end, start);
Value divOut = rewriter.create<arith::DivFOp>(loc, subOut, step);
Value ceilOut = rewriter.create<math::CeilOp>(loc, divOut);
resultShape =
rewriter.create<arith::FPToUIOp>(loc, rewriter.getI64Type(), ceilOut);
}
resultShape = castIntToIndex(rewriter, loc, resultShape);
Value resultTensor =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, dtype);
StringRef iteratorType = getParallelIteratorTypeName();
AffineMap indexingMap =
AffineMap::getMultiDimIdentityMap(1, op->getContext());
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/resultTensor.getType(),
/*inputs=*/ValueRange({}),
/*outputs=*/resultTensor,
/*indexingMaps=*/indexingMap,
/*iteratorTypes=*/iteratorType,
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value index = b.create<linalg::IndexOp>(loc, 0);
index = castIndexToInt(b, loc, index);
index = convertScalarToDtype(b, loc, index, dtype);
Value mulOut, result;
if (dtype.isa<mlir::FloatType>()) {
mulOut = b.create<arith::MulFOp>(loc, step, index);
result = b.create<arith::AddFOp>(loc, start, mulOut);
} else {
mulOut = b.create<arith::MulIOp>(loc, step, index);
result = b.create<arith::AddIOp>(loc, start, mulOut);
}
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -4134,6 +4227,8 @@ public:
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
target.addIllegalOp<AtenTensorIntOp, AtenTensorFloatOp>();
patterns.add<ConvertAtenArangeStartStepOp>(typeConverter, context);
target.addIllegalOp<AtenArangeStartStepOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -563,6 +563,48 @@ public:
};
} // namespace
namespace {
// The `aten.arange` op is converted to `aten.arange.start_step` op.
class DecomposeAtenArangeOp : public OpRewritePattern<AtenArangeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenArangeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// The AtenArangeOp doesn't have a start and step value. Therefore we set
// them as default values 0 and 1, respectively.
Value start, step;
start = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
step = rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
op, op.getType(), start, op.end(), step, op.dtype(), op.layout(),
op.device(), op.pin_memory());
return success();
}
};
} // namespace
namespace {
// The `aten.arange.start` op is converted to `aten.arange.start_step` op.
class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenArangeStartOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// The AtenArangeStartOp doesn't have a step value. Therefore we set it as
// default value 1.
Value step;
step = rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
op, op.getType(), op.start(), op.end(), step, op.dtype(), op.layout(),
op.device(), op.pin_memory());
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -612,6 +654,10 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<DecomposeAtenLayerNormOp>(context);
patterns.add<DecomposeAtenArangeOp>(context);
target.addIllegalOp<AtenArangeOp>();
patterns.add<DecomposeAtenArangeStartOp>(context);
target.addIllegalOp<AtenArangeStartOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -338,6 +338,8 @@ public:
return visitAtenArangeOp(arange);
} else if (auto arangeStart = dyn_cast<AtenArangeStartOp>(op)) {
return visitAtenArangeStartOp(arangeStart);
} else if (auto arangeStartStep = dyn_cast<AtenArangeStartStepOp>(op)) {
return visitAtenArangeStartStepOp(arangeStartStep);
} else if (auto sum = dyn_cast<AtenSumOp>(op)) {
Type defaultDtype = operands[0]->getValue().dtype;
Type dtype =
@ -532,7 +534,10 @@ private:
ChangeResult visitAtenArangeLikeOpHelper(Operation *op,
llvm::Optional<Value> start,
Value end, Value dtype);
Value end,
llvm::Optional<Value> step,
Value dtype);
ChangeResult visitAtenArangeStartStepOp(AtenArangeStartStepOp op);
ChangeResult visitAtenArangeStartOp(AtenArangeStartOp op);
ChangeResult visitAtenArangeOp(AtenArangeOp op);
ChangeResult visitReductionAlongAllDimsOp(
@ -1092,7 +1097,8 @@ ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
// Arange like ops returns a 1-D tensor of size ceil(end - start).
ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
Operation *op, llvm::Optional<Value> start, Value end, Value dtype) {
Operation *op, llvm::Optional<Value> start, Value end,
llvm::Optional<Value> step, Value dtype) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.sizes.resize(1, kUnknownSize);
@ -1103,12 +1109,13 @@ ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
} else if (dtype.getType().isa<Torch::NoneType>()) {
// From torch/_torch_docs.py:
// If `dtype` is not given, infer the data type from the other input
// arguments. If any of `start`, `end`, or `stop` are floating-point, the
// arguments. If any of `start`, `end`, or `step` are floating-point, the
// `dtype` is inferred to be the default dtype, see
// `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to
// be `torch.int64`
if ((start.hasValue() && (*start).getType().isa<Torch::FloatType>()) ||
end.getType().isa<Torch::FloatType>()) {
end.getType().isa<Torch::FloatType>() ||
(step.hasValue() && (*step).getType().isa<Torch::FloatType>())) {
// TODO: Should get the dtype from torch.get_default_dtype().
// For now, use float32 which is the initial default dtype.
knowledge.dtype = Float32Type::get(op->getContext());
@ -1119,12 +1126,18 @@ ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult
TypeAnalyzer::visitAtenArangeStartStepOp(AtenArangeStartStepOp op) {
return visitAtenArangeLikeOpHelper(op, op.start(), op.end(), op.step(),
op.dtype());
}
ChangeResult TypeAnalyzer::visitAtenArangeStartOp(AtenArangeStartOp op) {
return visitAtenArangeLikeOpHelper(op, op.start(), op.end(), op.dtype());
return visitAtenArangeLikeOpHelper(op, op.start(), op.end(), {}, op.dtype());
}
ChangeResult TypeAnalyzer::visitAtenArangeOp(AtenArangeOp op) {
return visitAtenArangeLikeOpHelper(op, {}, op.end(), op.dtype());
return visitAtenArangeLikeOpHelper(op, {}, op.end(), {}, op.dtype());
}
ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp(

View File

@ -557,6 +557,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)")

View File

@ -121,3 +121,36 @@ func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.in
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<!torch.int>
return %0 : !torch.list<!torch.int>
}
// ----
// CHECK-LABEL: func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST5:.*]] = torch.constant.int 5
// CHECK: %[[CSTN:.*]] = torch.constant.none
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RESULT:.*]] = torch.aten.arange.start_step %[[CST0]], %[[CST5]], %[[CST1]], %[[CSTN]], %[[CSTN]], %[[CSTN]], %[[CSTN]] :
// CHECK-SAME: !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],si64>
func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
%int5 = torch.constant.int 5
%none = torch.constant.none
%0 = torch.aten.arange %int5, %none, %none, %none, %none : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
return %0 : !torch.vtensor<[?],si64>
}
// ----
// CHECK-LABEL: func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST10:.*]] = torch.constant.int 10
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CSTN:.*]] = torch.constant.none
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RESULT:.*]] = torch.aten.arange.start_step %[[CST0]], %[[CST10]], %[[CST1]], %[[CSTN]], %[[CSTN]], %[[CSTN]], %[[CSTN]] :
// CHECK-SAME: !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],si64>
func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
%int10 = torch.constant.int 10
%int0 = torch.constant.int 0
%none = torch.constant.none
%0 = torch.aten.arange.start %int0, %int10, %none, %none, %none, %none : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
return %0 : !torch.vtensor<[?],si64>
}