mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.empty_strided decomposition op (redo PR) (#2459)
Making the same PR with #2457, as I accidentally thought the review was already made and merged it (reverted). Add decompose empty_strided op. Referring to #1776, this decomposition op only supports default stride values, because accessing the tensor or indexing over that, the indices are determined by the strides. In MLIR, this is not implicitly supported but assumes that the strides are default while iterating over the tensor.sogartar-patch-1 snapshot-20230914.961
parent
4b4c38da46
commit
40913a36c2
|
@ -748,6 +748,7 @@ STABLEHLO_PASS_SET = {
|
|||
"NewEmptyModuleNonDefaultFloatDtype_basic",
|
||||
"NewEmptyModuleNonDefaultIntDtype_basic",
|
||||
"NewEmptyStridedModuleDefaultDtype_basic",
|
||||
"EmptyStridedModule_basic",
|
||||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
|
@ -1421,4 +1422,5 @@ LTC_XFAIL_SET = {
|
|||
"ScatterValueIntModule_basic",
|
||||
"UniformStaticShapeModule_basic",
|
||||
"AtenEmbeddingBagStaticModule_basic",
|
||||
"EmptyStridedModule_basic",
|
||||
}
|
||||
|
|
|
@ -8335,6 +8335,34 @@ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenEmptyStridedOp : Torch_Op<"aten.empty_strided", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchListOfTorchIntType:$size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchOptionalIntType:$dtype,
|
||||
AnyTorchOptionalIntType:$layout,
|
||||
AnyTorchOptionalDeviceType:$device,
|
||||
AnyTorchOptionalBoolType:$pin_memory
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenEmptyStridedOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenEmptyStridedOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenExpandOp : Torch_Op<"aten.expand", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -7220,6 +7220,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.empty.memory_format\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.empty_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.full\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -10533,6 +10536,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.empty_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>, %arg6: !torch.optional<int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -4416,6 +4416,53 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenEmptyStridedOp
|
||||
: public OpRewritePattern<AtenEmptyStridedOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<int64_t> sizeListInts, strideListInts;
|
||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all size list elements must be constant ints");
|
||||
if (!matchPattern(op.getStride(),
|
||||
m_TorchListOfConstantInts(strideListInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all stride list elements must be constant ints");
|
||||
|
||||
// We only support the cases with default stride values.
|
||||
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
|
||||
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
|
||||
// stride[2] == 1.
|
||||
bool isDefaultStride = true;
|
||||
for (unsigned i = 0; i < strideListInts.size(); i++) {
|
||||
int64_t defaultStride = 1;
|
||||
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
|
||||
defaultStride *= sizeListInts[j];
|
||||
if (defaultStride != strideListInts[i]) {
|
||||
isDefaultStride = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!isDefaultStride)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only default strides supported for new_empty_strided op");
|
||||
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
||||
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(),
|
||||
op.getPinMemory(), /*memoryFormat=*/noneVal);
|
||||
|
||||
return success();
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposePrimsSqueezeOp : public OpRewritePattern<PrimsSqueezeOp> {
|
||||
public:
|
||||
|
@ -5251,6 +5298,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposePrimsSqueezeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
|
||||
|
|
|
@ -480,6 +480,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenRandnLikeOp>();
|
||||
target.addIllegalOp<AtenVarMeanOp>();
|
||||
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
||||
target.addIllegalOp<AtenEmptyStridedOp>();
|
||||
target.addIllegalOp<AtenBucketizeTensorOp>();
|
||||
target.addIllegalOp<PrimsSqueezeOp>();
|
||||
target.addIllegalOp<AtenMovedimIntOp>();
|
||||
|
|
|
@ -643,7 +643,8 @@ def aten〇ones〡shape(size: List[int], dtype: Optional[int] = None, layout: Op
|
|||
|
||||
def aten〇empty〇memory_format〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
||||
return size
|
||||
|
||||
def aten〇empty_strided〡shape(size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return size
|
||||
def aten〇full〡shape(size: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return size
|
||||
|
||||
|
@ -3237,6 +3238,13 @@ def aten〇empty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype if dtype is None else dtype
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1]) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.float16) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.int32) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.complex64))
|
||||
def aten〇empty_strided〡dtype(size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
return torch.float32 if dtype is None else dtype
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0) +
|
||||
|
|
|
@ -542,6 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||
|
|
|
@ -1628,3 +1628,27 @@ class NewEmptyStridedModuleDefaultDtype(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NewEmptyStridedModuleDefaultDtype())
|
||||
def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class EmptyStridedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
x = torch.ops.aten.empty_strided(a.size(), stride=[12, 4, 1])
|
||||
y = x.copy_(a)
|
||||
return y
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: EmptyStridedModule())
|
||||
def EmptyStridedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
|
Loading…
Reference in New Issue