[Torch] Emit and decompose prims.iota op (#3132)

pull/3203/head
penguin_wwy 2024-04-22 10:45:01 +08:00 committed by GitHub
parent a60e84e5ee
commit e5bdd71baf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 94 additions and 0 deletions

View File

@ -15909,6 +15909,34 @@ def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [
let hasFolder = 1;
}
def Torch_PrimsIotaOp : Torch_Op<"prims.iota", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prims::iota : (int, int, int, int, Device, bool) -> (Tensor)`";
let arguments = (ins
Torch_IntType:$length,
Torch_IntType:$start,
Torch_IntType:$step,
Torch_IntType:$dtype,
Torch_DeviceType:$device,
Torch_BoolType:$requires_grad
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult PrimsIotaOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void PrimsIotaOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
HasValueSemantics,
AllowsTypeRefinement,

View File

@ -8653,6 +8653,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.int {\n"
" return %arg3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"

View File

@ -4789,6 +4789,35 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
};
} // namespace
namespace {
// The `prims.iota` op is converted to `aten.arange.startStep` op.
class DecomposePrimsIotaOp : public OpRewritePattern<PrimsIotaOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimsIotaOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
int64_t length, start, step;
if (!matchPattern(op.getLength(), m_TorchConstantInt(&length)))
return rewriter.notifyMatchFailure(
op, "unimplemented: low must be a constant integer");
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(
op, "unimplemented: low must be a constant integer");
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
return rewriter.notifyMatchFailure(
op, "unimplemented: low must be a constant integer");
auto endVal = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(start + length * step));
auto none = rewriter.create<ConstantNoneOp>(loc);
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
op, op.getType(), op.getStart(), endVal, op.getStep(), op.getDtype(),
none, op.getDevice(), none);
return success();
}
};
} // namespace
namespace {
// Decompose constant tensor full like ops.
template <typename OpTy, int fillVal>
@ -7605,6 +7634,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);

View File

@ -1228,6 +1228,7 @@ STABLEHLO_PASS_SET = {
"PrimMinIntDynamicModule_basic",
"PrimMinIntModule_basic",
"PrimsConvertElementTypeModule_basic",
"PrimsIotaModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
@ -1789,6 +1790,7 @@ TOSA_PASS_SET = {
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"PrimListUnpackNumMismatchModule_basic",
"PrimsIotaModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
@ -2684,6 +2686,9 @@ ONNX_XFAIL_SET = {
"SqueezeModule_broadcast",
"SqueezeModule_static",
# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",
# Failure - unknown
"BernoulliModule_basic",
"BucketizeTensorFloatModule_basic",

View File

@ -1319,6 +1319,12 @@ def primsview_of〡dtype(a_rank_dtype: Tuple[int, int]) -> int:
_, a_dtype = a_rank_dtype
return a_dtype
def primsiota〡shape(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> List[int]:
return [length]
def primsiota〡dtype(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> int:
return dtype
def primNumToTensorScalar〡shape(a: float) -> List[int]:
return []

View File

@ -897,6 +897,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("prims::split_dim : (Tensor, int, int) -> (Tensor)")
emit("prims::squeeze : (Tensor, int[]) -> (Tensor)")
emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True)
emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)")
# ==========================================================================
# `quantized::` namespace.

View File

@ -380,3 +380,20 @@ class LinspaceTwoSizeModule(torch.nn.Module):
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
module.forward()
class PrimsIotaModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device='cpu',
requires_grad=False)
@register_test_case(module_factory=lambda: PrimsIotaModule())
def PrimsIotaModule_basic(module, tu: TestUtils):
module.forward()