mirror of https://github.com/llvm/torch-mlir
[Torch-Dialect] emit aten.narrow.Tensor op and decompose it to aten.narrow op (#2297)
parent
64d7626a52
commit
9535be7903
|
@ -314,6 +314,7 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
|
||||
STABLEHLO_PASS_SET = {
|
||||
"AliasModule_basic",
|
||||
"TensorIntModule_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
"AnyBoolFalseModule_basic",
|
||||
|
@ -751,6 +752,8 @@ STABLEHLO_PASS_SET = {
|
|||
"NarrowHorizontalTest_basic",
|
||||
"NarrowVerticalTest2_basic",
|
||||
"NarrowVerticalTest_basic",
|
||||
"NarrowTensorHorizontalModule_basic",
|
||||
"NarrowTensorVerticalModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"NumpyTRank0Module_basic",
|
||||
"NumpyTRank1Module_basic",
|
||||
|
|
|
@ -11314,6 +11314,31 @@ def Torch_AtenNarrowOp : Torch_Op<"aten.narrow", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNarrowTensorOp : Torch_Op<"aten.narrow.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$start,
|
||||
Torch_IntType:$length
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenNarrowTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenNarrowTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -762,6 +762,22 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenTensorIntOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenTensorIntOp>::matchAndRewrite(
|
||||
AtenTensorIntOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Type outElementType = resultType.getElementType();
|
||||
Value innerValue = adaptor.getT();
|
||||
Value stablehloTensor =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, innerValue, outElementType);
|
||||
rewriter.replaceOp(op, stablehloTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenReciprocalOp
|
||||
// Reciprocal(x) = Div(1, x)
|
||||
template <>
|
||||
|
@ -1699,6 +1715,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
INSERT_ATENOP_PATTERN(AtenTensorIntOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
|
|
|
@ -156,6 +156,8 @@ static Value getScalarIntValue(Value input, Location loc,
|
|||
} else if (auto primNumToTensorScalarOp =
|
||||
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
||||
return primNumToTensorScalarOp.getA();
|
||||
} else if (auto tensorIntOp = input.getDefiningOp<AtenTensorIntOp>()) {
|
||||
return tensorIntOp.getT();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -2557,6 +2559,8 @@ OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
|
|||
// aten.Int.Tensor, fold to the scalar number.
|
||||
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||
return numToTensorScalar.getA();
|
||||
if (auto tensorIntOp = getA().getDefiningOp<AtenTensorIntOp>())
|
||||
return tensorIntOp.getT();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -7543,6 +7543,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %3 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %1, %2, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %3 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.narrow.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = torch.aten._set_item.t %arg0, %arg1, %arg3 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.slice_scatter\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.int) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -8430,6 +8434,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.narrow.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.neg\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
|
|
|
@ -340,6 +340,27 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.narrow.Tensor` to `aten.narrow` op
|
||||
class DecomposeAtenNarrowTensorOp
|
||||
: public OpRewritePattern<AtenNarrowTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNarrowTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto *context = op.getContext();
|
||||
// PyTorch makes sure that `start` param is an 0-dim integral tensor.
|
||||
// REF: https://pytorch.org/docs/stable/generated/torch.narrow.html.
|
||||
auto start = rewriter.create<Torch::AtenScalarImplicitOp>(
|
||||
loc, Torch::IntType::get(context), op.getStart());
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenNarrowOp>(
|
||||
op, op.getType(), op.getSelf(), op.getDim(), start, op.getLength());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenZeroOp
|
||||
: public OpRewritePattern<AtenZeroOp> {
|
||||
|
@ -4753,6 +4774,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorHackedTwinOp>(
|
||||
|
|
|
@ -459,6 +459,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenStdDimOp>();
|
||||
target.addIllegalOp<AtenStdCorrectionOp>();
|
||||
target.addIllegalOp<AtenNarrowOp>();
|
||||
target.addIllegalOp<AtenNarrowTensorOp>();
|
||||
target.addIllegalOp<Aten_EmbeddingBagOp>();
|
||||
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
||||
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
|
||||
|
|
|
@ -204,8 +204,9 @@ bool Torch::isViewLikeOp(Operation *op) {
|
|||
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
|
||||
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
|
||||
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
|
||||
AtenNarrowOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp,
|
||||
PrimsViewOfOp, AtenRealOp, AtenImagOp, AtenViewAsComplexOp>(op);
|
||||
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
|
||||
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
|
||||
AtenViewAsComplexOp>(op);
|
||||
}
|
||||
|
||||
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
||||
|
|
|
@ -900,6 +900,11 @@ def aten〇sort〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1, descend
|
|||
def aten〇narrow〡shape(self: List[int], dim: int, start: int, length: int) -> List[int]:
|
||||
return upstream_shape_functions.slice(self, dim, start, start + length, 1)
|
||||
|
||||
# This shape function is a little hacky, because we don't know the start index which is determined by a tensor param.
|
||||
def aten〇narrow〇Tensor〡shape(self: List[int], dim: int, start: List[int], length: int) -> List[int]:
|
||||
self[dim] = length
|
||||
return self
|
||||
|
||||
def aten〇slice_scatter〡shape(self: List[int], src: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||
return self
|
||||
|
||||
|
@ -1659,6 +1664,11 @@ def aten〇narrow〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start: int
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function([Invocation(TensorOfShape(3, 4, dtype=dtype, device=torch.device("cpu")), 0, ZeroDTensorWithDtype(dtype=torch.int64, device=torch.device("cpu")), 1) for dtype in _SORTED_TORCH_TYPES])
|
||||
def aten〇narrow〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start_rank_dtype: Tuple[int, int], length: int) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
|
||||
def aten〇neg〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -668,6 +668,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::eq.device : (Device, Device) -> (bool)")
|
||||
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
|
||||
emit("aten::narrow : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)")
|
||||
emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True)
|
||||
|
||||
# backprop ops
|
||||
|
|
|
@ -523,6 +523,42 @@ def NarrowVerticalTest2_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class NarrowTensorHorizontalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.narrow(x, dim=1, start=torch.tensor(0), length=2)
|
||||
|
||||
@register_test_case(module_factory=lambda: NarrowTensorHorizontalModule())
|
||||
def NarrowTensorHorizontalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NarrowTensorVerticalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.narrow(x, dim=1, start=torch.tensor(1), length=2)
|
||||
|
||||
@register_test_case(module_factory=lambda: NarrowTensorVerticalModule())
|
||||
def NarrowTensorVerticalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SliceCopy_Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue