[MLIR][TORCH] Add E2E support for aten.empty_strided decomposition op (redo PR) (#2459)

Making the same PR with #2457, as I accidentally thought the review was already made and merged it (reverted).

Add decompose empty_strided op.
Referring to #1776, this decomposition op only supports default stride values, because accessing the tensor or indexing over that, the indices are determined by the strides.
In MLIR, this is not implicitly supported but assumes that the strides are default while iterating over the tensor.
sogartar-patch-1 snapshot-20230914.961
Bruce Kim 2023-09-13 10:04:31 -07:00 committed by GitHub
parent 4b4c38da46
commit 40913a36c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 128 additions and 1 deletions

View File

@ -748,6 +748,7 @@ STABLEHLO_PASS_SET = {
"NewEmptyModuleNonDefaultFloatDtype_basic",
"NewEmptyModuleNonDefaultIntDtype_basic",
"NewEmptyStridedModuleDefaultDtype_basic",
"EmptyStridedModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
@ -1421,4 +1422,5 @@ LTC_XFAIL_SET = {
"ScatterValueIntModule_basic",
"UniformStaticShapeModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"EmptyStridedModule_basic",
}

View File

@ -8335,6 +8335,34 @@ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
}];
}
def Torch_AtenEmptyStridedOp : Torch_Op<"aten.empty_strided", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
AnyTorchListOfTorchIntType:$size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenEmptyStridedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenEmptyStridedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenExpandOp : Torch_Op<"aten.expand", [
AllowsTypeRefinement,
ReadOnly

View File

@ -7220,6 +7220,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.empty.memory_format\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.empty_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.full\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
@ -10533,6 +10536,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.empty_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>, %arg6: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -4416,6 +4416,53 @@ public:
};
} // namespace
namespace {
class DecomposeAtenEmptyStridedOp
: public OpRewritePattern<AtenEmptyStridedOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> sizeListInts, strideListInts;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
return rewriter.notifyMatchFailure(
op, "all size list elements must be constant ints");
if (!matchPattern(op.getStride(),
m_TorchListOfConstantInts(strideListInts)))
return rewriter.notifyMatchFailure(
op, "all stride list elements must be constant ints");
// We only support the cases with default stride values.
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
// stride[2] == 1.
bool isDefaultStride = true;
for (unsigned i = 0; i < strideListInts.size(); i++) {
int64_t defaultStride = 1;
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
defaultStride *= sizeListInts[j];
if (defaultStride != strideListInts[i]) {
isDefaultStride = false;
break;
}
}
if (!isDefaultStride)
return rewriter.notifyMatchFailure(
op, "only default strides supported for new_empty_strided op");
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(),
op.getPinMemory(), /*memoryFormat=*/noneVal);
return success();
}
};
} // namespace
namespace {
class DecomposePrimsSqueezeOp : public OpRewritePattern<PrimsSqueezeOp> {
public:
@ -5251,6 +5298,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsSqueezeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);

View File

@ -480,6 +480,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRandnLikeOp>();
target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>();
target.addIllegalOp<PrimsSqueezeOp>();
target.addIllegalOp<AtenMovedimIntOp>();

View File

@ -643,7 +643,8 @@ def atenones〡shape(size: List[int], dtype: Optional[int] = None, layout: Op
def atenemptymemory_format〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return size
def atenempty_strided〡shape(size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
def atenfull〡shape(size: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
@ -3237,6 +3238,13 @@ def atenempty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[
self_rank, self_dtype = self_rank_dtype
return self_dtype if dtype is None else dtype
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1]) +
_check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.float16) +
_check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.int32) +
_check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.complex64))
def atenempty_strided〡dtype(size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
return torch.float32 if dtype is None else dtype
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0) +

View File

@ -542,6 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True)

View File

@ -1628,3 +1628,27 @@ class NewEmptyStridedModuleDefaultDtype(torch.nn.Module):
@register_test_case(module_factory=lambda: NewEmptyStridedModuleDefaultDtype())
def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class EmptyStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, 4], torch.float32, True),
])
def forward(self, a):
x = torch.ops.aten.empty_strided(a.size(), stride=[12, 4, 1])
y = x.copy_(a)
return y
@register_test_case(module_factory=lambda: EmptyStridedModule())
def EmptyStridedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))