mirror of https://github.com/llvm/torch-mlir
[Torch] Emit and decompose prims.iota op (#3132)
parent
a60e84e5ee
commit
e5bdd71baf
|
@ -15909,6 +15909,34 @@ def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimsIotaOp : Torch_Op<"prims.iota", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prims::iota : (int, int, int, int, Device, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_IntType:$length,
|
||||
Torch_IntType:$start,
|
||||
Torch_IntType:$step,
|
||||
Torch_IntType:$dtype,
|
||||
Torch_DeviceType:$device,
|
||||
Torch_BoolType:$requires_grad
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult PrimsIotaOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void PrimsIotaOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
||||
HasValueSemantics,
|
||||
AllowsTypeRefinement,
|
||||
|
|
|
@ -8653,6 +8653,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.int {\n"
|
||||
" return %arg3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
|
|
@ -4789,6 +4789,35 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// The `prims.iota` op is converted to `aten.arange.startStep` op.
|
||||
class DecomposePrimsIotaOp : public OpRewritePattern<PrimsIotaOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimsIotaOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
int64_t length, start, step;
|
||||
if (!matchPattern(op.getLength(), m_TorchConstantInt(&length)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: low must be a constant integer");
|
||||
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: low must be a constant integer");
|
||||
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: low must be a constant integer");
|
||||
auto endVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(start + length * step));
|
||||
auto none = rewriter.create<ConstantNoneOp>(loc);
|
||||
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
||||
op, op.getType(), op.getStart(), endVal, op.getStep(), op.getDtype(),
|
||||
none, op.getDevice(), none);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose constant tensor full like ops.
|
||||
template <typename OpTy, int fillVal>
|
||||
|
@ -7605,6 +7634,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
||||
|
|
|
@ -1228,6 +1228,7 @@ STABLEHLO_PASS_SET = {
|
|||
"PrimMinIntDynamicModule_basic",
|
||||
"PrimMinIntModule_basic",
|
||||
"PrimsConvertElementTypeModule_basic",
|
||||
"PrimsIotaModule_basic",
|
||||
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||
"PrimsViewOfModule_basic",
|
||||
"PrimsViewOfZeroRankModule_basic",
|
||||
|
@ -1789,6 +1790,7 @@ TOSA_PASS_SET = {
|
|||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"PrimListUnpackNumMismatchModule_basic",
|
||||
"PrimsIotaModule_basic",
|
||||
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||
"PrimsSqueezeModule_basic",
|
||||
"PrimsViewOfModule_basic",
|
||||
|
@ -2683,6 +2685,9 @@ ONNX_XFAIL_SET = {
|
|||
"SqueezeModule_allUnitDim",
|
||||
"SqueezeModule_broadcast",
|
||||
"SqueezeModule_static",
|
||||
|
||||
# RuntimeError: unsupported input type: Device
|
||||
"PrimsIotaModule_basic",
|
||||
|
||||
# Failure - unknown
|
||||
"BernoulliModule_basic",
|
||||
|
|
|
@ -1319,6 +1319,12 @@ def prims〇view_of〡dtype(a_rank_dtype: Tuple[int, int]) -> int:
|
|||
_, a_dtype = a_rank_dtype
|
||||
return a_dtype
|
||||
|
||||
def prims〇iota〡shape(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> List[int]:
|
||||
return [length]
|
||||
|
||||
def prims〇iota〡dtype(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> int:
|
||||
return dtype
|
||||
|
||||
def prim〇NumToTensor〇Scalar〡shape(a: float) -> List[int]:
|
||||
return []
|
||||
|
||||
|
|
|
@ -897,6 +897,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("prims::split_dim : (Tensor, int, int) -> (Tensor)")
|
||||
emit("prims::squeeze : (Tensor, int[]) -> (Tensor)")
|
||||
emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)")
|
||||
|
||||
# ==========================================================================
|
||||
# `quantized::` namespace.
|
||||
|
|
|
@ -380,3 +380,20 @@ class LinspaceTwoSizeModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
|
||||
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class PrimsIotaModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device='cpu',
|
||||
requires_grad=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: PrimsIotaModule())
|
||||
def PrimsIotaModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
|
Loading…
Reference in New Issue