mirror of https://github.com/llvm/torch-mlir
[Torch] Emit and decompose prims.iota op (#3132)
parent
a60e84e5ee
commit
e5bdd71baf
|
@ -15909,6 +15909,34 @@ def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_PrimsIotaOp : Torch_Op<"prims.iota", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `prims::iota : (int, int, int, int, Device, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_IntType:$length,
|
||||||
|
Torch_IntType:$start,
|
||||||
|
Torch_IntType:$step,
|
||||||
|
Torch_IntType:$dtype,
|
||||||
|
Torch_DeviceType:$device,
|
||||||
|
Torch_BoolType:$requires_grad
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult PrimsIotaOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||||
|
}
|
||||||
|
void PrimsIotaOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 6, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
|
|
|
@ -8653,6 +8653,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.int {\n"
|
||||||
|
" return %arg3 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list<int> {\n"
|
||||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
|
|
@ -4789,6 +4789,35 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// The `prims.iota` op is converted to `aten.arange.startStep` op.
|
||||||
|
class DecomposePrimsIotaOp : public OpRewritePattern<PrimsIotaOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(PrimsIotaOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
int64_t length, start, step;
|
||||||
|
if (!matchPattern(op.getLength(), m_TorchConstantInt(&length)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: low must be a constant integer");
|
||||||
|
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: low must be a constant integer");
|
||||||
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: low must be a constant integer");
|
||||||
|
auto endVal = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(start + length * step));
|
||||||
|
auto none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
||||||
|
op, op.getType(), op.getStart(), endVal, op.getStep(), op.getDtype(),
|
||||||
|
none, op.getDevice(), none);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Decompose constant tensor full like ops.
|
// Decompose constant tensor full like ops.
|
||||||
template <typename OpTy, int fillVal>
|
template <typename OpTy, int fillVal>
|
||||||
|
@ -7605,6 +7634,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<
|
addPatternIfTargetOpIsIllegal<
|
||||||
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
||||||
|
|
|
@ -1228,6 +1228,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"PrimMinIntDynamicModule_basic",
|
"PrimMinIntDynamicModule_basic",
|
||||||
"PrimMinIntModule_basic",
|
"PrimMinIntModule_basic",
|
||||||
"PrimsConvertElementTypeModule_basic",
|
"PrimsConvertElementTypeModule_basic",
|
||||||
|
"PrimsIotaModule_basic",
|
||||||
"PrimsSqueezeEmptyDimensionsModule_basic",
|
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||||
"PrimsViewOfModule_basic",
|
"PrimsViewOfModule_basic",
|
||||||
"PrimsViewOfZeroRankModule_basic",
|
"PrimsViewOfZeroRankModule_basic",
|
||||||
|
@ -1789,6 +1790,7 @@ TOSA_PASS_SET = {
|
||||||
"PermuteModule_basic",
|
"PermuteModule_basic",
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
"PrimListUnpackNumMismatchModule_basic",
|
"PrimListUnpackNumMismatchModule_basic",
|
||||||
|
"PrimsIotaModule_basic",
|
||||||
"PrimsSqueezeEmptyDimensionsModule_basic",
|
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||||
"PrimsSqueezeModule_basic",
|
"PrimsSqueezeModule_basic",
|
||||||
"PrimsViewOfModule_basic",
|
"PrimsViewOfModule_basic",
|
||||||
|
@ -2684,6 +2686,9 @@ ONNX_XFAIL_SET = {
|
||||||
"SqueezeModule_broadcast",
|
"SqueezeModule_broadcast",
|
||||||
"SqueezeModule_static",
|
"SqueezeModule_static",
|
||||||
|
|
||||||
|
# RuntimeError: unsupported input type: Device
|
||||||
|
"PrimsIotaModule_basic",
|
||||||
|
|
||||||
# Failure - unknown
|
# Failure - unknown
|
||||||
"BernoulliModule_basic",
|
"BernoulliModule_basic",
|
||||||
"BucketizeTensorFloatModule_basic",
|
"BucketizeTensorFloatModule_basic",
|
||||||
|
|
|
@ -1319,6 +1319,12 @@ def prims〇view_of〡dtype(a_rank_dtype: Tuple[int, int]) -> int:
|
||||||
_, a_dtype = a_rank_dtype
|
_, a_dtype = a_rank_dtype
|
||||||
return a_dtype
|
return a_dtype
|
||||||
|
|
||||||
|
def prims〇iota〡shape(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> List[int]:
|
||||||
|
return [length]
|
||||||
|
|
||||||
|
def prims〇iota〡dtype(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> int:
|
||||||
|
return dtype
|
||||||
|
|
||||||
def prim〇NumToTensor〇Scalar〡shape(a: float) -> List[int]:
|
def prim〇NumToTensor〇Scalar〡shape(a: float) -> List[int]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
|
@ -897,6 +897,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("prims::split_dim : (Tensor, int, int) -> (Tensor)")
|
emit("prims::split_dim : (Tensor, int, int) -> (Tensor)")
|
||||||
emit("prims::squeeze : (Tensor, int[]) -> (Tensor)")
|
emit("prims::squeeze : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True)
|
emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True)
|
||||||
|
emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)")
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
# `quantized::` namespace.
|
# `quantized::` namespace.
|
||||||
|
|
|
@ -380,3 +380,20 @@ class LinspaceTwoSizeModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
|
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
|
||||||
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
|
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class PrimsIotaModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device='cpu',
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: PrimsIotaModule())
|
||||||
|
def PrimsIotaModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
Loading…
Reference in New Issue