[Torch-Dialect] emit aten.narrow.Tensor op and decompose it to aten.narrow op (#2297)

pull/2326/head snapshot-20230720.905
Jiawei Wu 2023-07-20 16:46:44 +08:00 committed by GitHub
parent 64d7626a52
commit 9535be7903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 130 additions and 2 deletions

View File

@ -314,6 +314,7 @@ TORCHDYNAMO_CRASHING_SET = {
STABLEHLO_PASS_SET = {
"AliasModule_basic",
"TensorIntModule_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
"AnyBoolFalseModule_basic",
@ -751,6 +752,8 @@ STABLEHLO_PASS_SET = {
"NarrowHorizontalTest_basic",
"NarrowVerticalTest2_basic",
"NarrowVerticalTest_basic",
"NarrowTensorHorizontalModule_basic",
"NarrowTensorVerticalModule_basic",
"NumToTensorIntModule_basic",
"NumpyTRank0Module_basic",
"NumpyTRank1Module_basic",

View File

@ -11314,6 +11314,31 @@ def Torch_AtenNarrowOp : Torch_Op<"aten.narrow", [
}];
}
def Torch_AtenNarrowTensorOp : Torch_Op<"aten.narrow.Tensor", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$start,
Torch_IntType:$length
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNarrowTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenNarrowTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -762,6 +762,22 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
return success();
}
// AtenTensorIntOp
template <>
LogicalResult ConvertAtenOp<AtenTensorIntOp>::matchAndRewrite(
AtenTensorIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type outElementType = resultType.getElementType();
Value innerValue = adaptor.getT();
Value stablehloTensor =
hlo::scalarToStablehloTensor(rewriter, op, innerValue, outElementType);
rewriter.replaceOp(op, stablehloTensor);
return success();
}
// AtenReciprocalOp
// Reciprocal(x) = Div(1, x)
template <>
@ -1699,6 +1715,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenTensorIntOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);

View File

@ -156,6 +156,8 @@ static Value getScalarIntValue(Value input, Location loc,
} else if (auto primNumToTensorScalarOp =
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
return primNumToTensorScalarOp.getA();
} else if (auto tensorIntOp = input.getDefiningOp<AtenTensorIntOp>()) {
return tensorIntOp.getT();
}
return nullptr;
}
@ -2557,6 +2559,8 @@ OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
// aten.Int.Tensor, fold to the scalar number.
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
return numToTensorScalar.getA();
if (auto tensorIntOp = getA().getDefiningOp<AtenTensorIntOp>())
return tensorIntOp.getT();
return nullptr;
}

View File

@ -7543,6 +7543,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %1, %2, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
" return %3 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.narrow.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.int) -> !torch.list<int> {\n"
" %0 = torch.aten._set_item.t %arg0, %arg1, %arg3 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.slice_scatter\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.int) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
@ -8430,6 +8434,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.narrow.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.neg\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"

View File

@ -340,6 +340,27 @@ public:
};
} // namespace
namespace {
// Decompose `aten.narrow.Tensor` to `aten.narrow` op
class DecomposeAtenNarrowTensorOp
: public OpRewritePattern<AtenNarrowTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNarrowTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto *context = op.getContext();
// PyTorch makes sure that `start` param is an 0-dim integral tensor.
// REF: https://pytorch.org/docs/stable/generated/torch.narrow.html.
auto start = rewriter.create<Torch::AtenScalarImplicitOp>(
loc, Torch::IntType::get(context), op.getStart());
rewriter.replaceOpWithNewOp<Torch::AtenNarrowOp>(
op, op.getType(), op.getSelf(), op.getDim(), start, op.getLength());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenZeroOp
: public OpRewritePattern<AtenZeroOp> {
@ -4753,6 +4774,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorHackedTwinOp>(

View File

@ -459,6 +459,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenStdDimOp>();
target.addIllegalOp<AtenStdCorrectionOp>();
target.addIllegalOp<AtenNarrowOp>();
target.addIllegalOp<AtenNarrowTensorOp>();
target.addIllegalOp<Aten_EmbeddingBagOp>();
target.addIllegalOp<AtenLiftFreshCopyOp>();
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();

View File

@ -204,8 +204,9 @@ bool Torch::isViewLikeOp(Operation *op) {
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp,
PrimsViewOfOp, AtenRealOp, AtenImagOp, AtenViewAsComplexOp>(op);
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
AtenViewAsComplexOp>(op);
}
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,

View File

@ -900,6 +900,11 @@ def atensort〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1, descend
def atennarrow〡shape(self: List[int], dim: int, start: int, length: int) -> List[int]:
return upstream_shape_functions.slice(self, dim, start, start + length, 1)
# This shape function is a little hacky, because we don't know the start index which is determined by a tensor param.
def atennarrowTensor〡shape(self: List[int], dim: int, start: List[int], length: int) -> List[int]:
self[dim] = length
return self
def atenslice_scatter〡shape(self: List[int], src: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return self
@ -1659,6 +1664,11 @@ def atennarrow〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start: int
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function([Invocation(TensorOfShape(3, 4, dtype=dtype, device=torch.device("cpu")), 0, ZeroDTensorWithDtype(dtype=torch.int64, device=torch.device("cpu")), 1) for dtype in _SORTED_TORCH_TYPES])
def atennarrowTensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start_rank_dtype: Tuple[int, int], length: int) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
def atenneg〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -668,6 +668,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::eq.device : (Device, Device) -> (bool)")
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
emit("aten::narrow : (Tensor, int, int, int) -> (Tensor)")
emit("aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)")
emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True)
# backprop ops

View File

@ -523,6 +523,42 @@ def NarrowVerticalTest2_basic(module, tu: TestUtils):
# ==============================================================================
class NarrowTensorHorizontalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True)
])
def forward(self, x):
return torch.narrow(x, dim=1, start=torch.tensor(0), length=2)
@register_test_case(module_factory=lambda: NarrowTensorHorizontalModule())
def NarrowTensorHorizontalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4))
# ==============================================================================
class NarrowTensorVerticalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True)
])
def forward(self, x):
return torch.narrow(x, dim=1, start=torch.tensor(1), length=2)
@register_test_case(module_factory=lambda: NarrowTensorVerticalModule())
def NarrowTensorVerticalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4))
# ==============================================================================
class SliceCopy_Module(torch.nn.Module):
def __init__(self):
super().__init__()