mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for torch.arange op
This commit adds lowering of `aten.arange.start_step` op. This commit decomposes `aten.arange` and `aten.arange.start` into `aten.arange.start_step` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/502/head snapshot-20211227.170
parent
a83004c806
commit
4486de5ef3
|
@ -0,0 +1,250 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ArangeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(5)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeIntModule())
|
||||
def ArangeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(5.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeFloatModule())
|
||||
def ArangeFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeZeroElementOutputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(0)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeZeroElementOutputModule())
|
||||
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(0, 5)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartIntModule())
|
||||
def ArangeStartIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(0.0, 5.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartFloatModule())
|
||||
def ArangeStartFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeNegativeStartIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(-10, 5)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeNegativeStartIntModule())
|
||||
def ArangeNegativeStartIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeNegativeStartFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(-1.4, 5.7)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule())
|
||||
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartStepIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(0, 5, 1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartStepIntModule())
|
||||
def ArangeStartStepIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartStepFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(-1, 5, 1.3)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartStepFloatModule())
|
||||
def ArangeStartStepFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartNegativeStepIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(10, 1, -2)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule())
|
||||
def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartNegativeStepFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(-1, -15, -3.4)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule())
|
||||
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeDtypeFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(-1, 15, dtype=torch.float32)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeDtypeFloatModule())
|
||||
def ArangeDtypeFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeDtypeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(0.2, 5.0, dtype=torch.int64)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeDtypeIntModule())
|
||||
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeFalsePinMemoryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(5.0, dtype=torch.int64, pin_memory=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
|
||||
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
|
@ -45,6 +45,7 @@ from . import squeeze
|
|||
from . import slice_like
|
||||
from . import nll_loss
|
||||
from . import index_select
|
||||
from . import arange
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
|
|
|
@ -2028,6 +2028,26 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [
|
|||
let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchScalarType:$start,
|
||||
AnyTorchScalarType:$end,
|
||||
AnyTorchScalarType:$step,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
TorchOptionalDeviceType:$device,
|
||||
TorchOptionalBoolType:$pin_memory
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$start `,` $end `,` $step `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($step) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -4022,6 +4022,99 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Let's say the result of the `aten.arange.start_step` is `output` which is a
|
||||
// 1-d output tensor. The approach used for generating the output tensor is as
|
||||
// follows:
|
||||
// for i in range(ceil((end-start)/step))
|
||||
// output[i] = start + (i * step)
|
||||
class ConvertAtenArangeStartStepOp
|
||||
: public OpConversionPattern<AtenArangeStartStepOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenArangeStartStepOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
// TODO: Add support for layout, pin_memory features.
|
||||
// Only `none` layout is supported.
|
||||
if (!op.layout().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default layout is supported");
|
||||
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
if (!op.pin_memory().getType().isa<Torch::NoneType>() &&
|
||||
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: pin_memory must be either None or false");
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = this->getTypeConverter();
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Type dtype = resultType.getElementType();
|
||||
Value start = convertScalarToDtype(rewriter, loc, adaptor.start(), dtype);
|
||||
Value end = convertScalarToDtype(rewriter, loc, adaptor.end(), dtype);
|
||||
Value step = convertScalarToDtype(rewriter, loc, adaptor.step(), dtype);
|
||||
|
||||
// The result will always be a 1-d tensor.
|
||||
// The size of the result is calculated as follows:
|
||||
// ceil((end - start)/step)
|
||||
Value resultShape;
|
||||
if (dtype.isa<mlir::IntegerType>()) {
|
||||
Value subOut = rewriter.create<arith::SubIOp>(loc, end, start);
|
||||
resultShape = rewriter.create<arith::CeilDivSIOp>(loc, subOut, step);
|
||||
} else {
|
||||
Value subOut = rewriter.create<arith::SubFOp>(loc, end, start);
|
||||
Value divOut = rewriter.create<arith::DivFOp>(loc, subOut, step);
|
||||
Value ceilOut = rewriter.create<math::CeilOp>(loc, divOut);
|
||||
resultShape =
|
||||
rewriter.create<arith::FPToUIOp>(loc, rewriter.getI64Type(), ceilOut);
|
||||
}
|
||||
resultShape = castIntToIndex(rewriter, loc, resultShape);
|
||||
|
||||
Value resultTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, resultShape, dtype);
|
||||
|
||||
StringRef iteratorType = getParallelIteratorTypeName();
|
||||
AffineMap indexingMap =
|
||||
AffineMap::getMultiDimIdentityMap(1, op->getContext());
|
||||
|
||||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, /*resultTensorTypes=*/resultTensor.getType(),
|
||||
/*inputs=*/ValueRange({}),
|
||||
/*outputs=*/resultTensor,
|
||||
/*indexingMaps=*/indexingMap,
|
||||
/*iteratorTypes=*/iteratorType,
|
||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
Value index = b.create<linalg::IndexOp>(loc, 0);
|
||||
index = castIndexToInt(b, loc, index);
|
||||
index = convertScalarToDtype(b, loc, index, dtype);
|
||||
Value mulOut, result;
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
mulOut = b.create<arith::MulFOp>(loc, step, index);
|
||||
result = b.create<arith::AddFOp>(loc, start, mulOut);
|
||||
} else {
|
||||
mulOut = b.create<arith::MulIOp>(loc, step, index);
|
||||
result = b.create<arith::AddIOp>(loc, start, mulOut);
|
||||
}
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -4134,6 +4227,8 @@ public:
|
|||
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
|
||||
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
|
||||
target.addIllegalOp<AtenTensorIntOp, AtenTensorFloatOp>();
|
||||
patterns.add<ConvertAtenArangeStartStepOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenArangeStartStepOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -563,6 +563,48 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// The `aten.arange` op is converted to `aten.arange.start_step` op.
|
||||
class DecomposeAtenArangeOp : public OpRewritePattern<AtenArangeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenArangeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
// The AtenArangeOp doesn't have a start and step value. Therefore we set
|
||||
// them as default values 0 and 1, respectively.
|
||||
Value start, step;
|
||||
start = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
||||
op, op.getType(), start, op.end(), step, op.dtype(), op.layout(),
|
||||
op.device(), op.pin_memory());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// The `aten.arange.start` op is converted to `aten.arange.start_step` op.
|
||||
class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenArangeStartOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
// The AtenArangeStartOp doesn't have a step value. Therefore we set it as
|
||||
// default value 1.
|
||||
Value step;
|
||||
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
||||
op, op.getType(), op.start(), op.end(), step, op.dtype(), op.layout(),
|
||||
op.device(), op.pin_memory());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -612,6 +654,10 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenAddcdivOp>();
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
patterns.add<DecomposeAtenLayerNormOp>(context);
|
||||
patterns.add<DecomposeAtenArangeOp>(context);
|
||||
target.addIllegalOp<AtenArangeOp>();
|
||||
patterns.add<DecomposeAtenArangeStartOp>(context);
|
||||
target.addIllegalOp<AtenArangeStartOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -338,6 +338,8 @@ public:
|
|||
return visitAtenArangeOp(arange);
|
||||
} else if (auto arangeStart = dyn_cast<AtenArangeStartOp>(op)) {
|
||||
return visitAtenArangeStartOp(arangeStart);
|
||||
} else if (auto arangeStartStep = dyn_cast<AtenArangeStartStepOp>(op)) {
|
||||
return visitAtenArangeStartStepOp(arangeStartStep);
|
||||
} else if (auto sum = dyn_cast<AtenSumOp>(op)) {
|
||||
Type defaultDtype = operands[0]->getValue().dtype;
|
||||
Type dtype =
|
||||
|
@ -532,7 +534,10 @@ private:
|
|||
|
||||
ChangeResult visitAtenArangeLikeOpHelper(Operation *op,
|
||||
llvm::Optional<Value> start,
|
||||
Value end, Value dtype);
|
||||
Value end,
|
||||
llvm::Optional<Value> step,
|
||||
Value dtype);
|
||||
ChangeResult visitAtenArangeStartStepOp(AtenArangeStartStepOp op);
|
||||
ChangeResult visitAtenArangeStartOp(AtenArangeStartOp op);
|
||||
ChangeResult visitAtenArangeOp(AtenArangeOp op);
|
||||
ChangeResult visitReductionAlongAllDimsOp(
|
||||
|
@ -1092,7 +1097,8 @@ ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
|
|||
|
||||
// Arange like ops returns a 1-D tensor of size ceil(end - start).
|
||||
ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
|
||||
Operation *op, llvm::Optional<Value> start, Value end, Value dtype) {
|
||||
Operation *op, llvm::Optional<Value> start, Value end,
|
||||
llvm::Optional<Value> step, Value dtype) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.sizes.resize(1, kUnknownSize);
|
||||
|
@ -1103,12 +1109,13 @@ ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
|
|||
} else if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
// From torch/_torch_docs.py:
|
||||
// If `dtype` is not given, infer the data type from the other input
|
||||
// arguments. If any of `start`, `end`, or `stop` are floating-point, the
|
||||
// arguments. If any of `start`, `end`, or `step` are floating-point, the
|
||||
// `dtype` is inferred to be the default dtype, see
|
||||
// `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to
|
||||
// be `torch.int64`
|
||||
if ((start.hasValue() && (*start).getType().isa<Torch::FloatType>()) ||
|
||||
end.getType().isa<Torch::FloatType>()) {
|
||||
end.getType().isa<Torch::FloatType>() ||
|
||||
(step.hasValue() && (*step).getType().isa<Torch::FloatType>())) {
|
||||
// TODO: Should get the dtype from torch.get_default_dtype().
|
||||
// For now, use float32 which is the initial default dtype.
|
||||
knowledge.dtype = Float32Type::get(op->getContext());
|
||||
|
@ -1119,12 +1126,18 @@ ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
|
|||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult
|
||||
TypeAnalyzer::visitAtenArangeStartStepOp(AtenArangeStartStepOp op) {
|
||||
return visitAtenArangeLikeOpHelper(op, op.start(), op.end(), op.step(),
|
||||
op.dtype());
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenArangeStartOp(AtenArangeStartOp op) {
|
||||
return visitAtenArangeLikeOpHelper(op, op.start(), op.end(), op.dtype());
|
||||
return visitAtenArangeLikeOpHelper(op, op.start(), op.end(), {}, op.dtype());
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenArangeOp(AtenArangeOp op) {
|
||||
return visitAtenArangeLikeOpHelper(op, {}, op.end(), op.dtype());
|
||||
return visitAtenArangeLikeOpHelper(op, {}, op.end(), {}, op.dtype());
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp(
|
||||
|
|
|
@ -557,6 +557,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
|
||||
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
|
||||
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
|
||||
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
|
||||
|
|
|
@ -121,3 +121,36 @@ func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.in
|
|||
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<!torch.int>
|
||||
return %0 : !torch.list<!torch.int>
|
||||
}
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
|
||||
// CHECK: %[[CST5:.*]] = torch.constant.int 5
|
||||
// CHECK: %[[CSTN:.*]] = torch.constant.none
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.arange.start_step %[[CST0]], %[[CST5]], %[[CST1]], %[[CSTN]], %[[CSTN]], %[[CSTN]], %[[CSTN]] :
|
||||
// CHECK-SAME: !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],si64>
|
||||
func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
|
||||
%int5 = torch.constant.int 5
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.arange %int5, %none, %none, %none, %none : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
|
||||
return %0 : !torch.vtensor<[?],si64>
|
||||
}
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
|
||||
// CHECK: %[[CST10:.*]] = torch.constant.int 10
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[CSTN:.*]] = torch.constant.none
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.arange.start_step %[[CST0]], %[[CST10]], %[[CST1]], %[[CSTN]], %[[CSTN]], %[[CSTN]], %[[CSTN]] :
|
||||
// CHECK-SAME: !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],si64>
|
||||
func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
|
||||
%int10 = torch.constant.int 10
|
||||
%int0 = torch.constant.int 0
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.arange.start %int0, %int10, %none, %none, %none, %none : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
|
||||
return %0 : !torch.vtensor<[?],si64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue