mirror of https://github.com/llvm/torch-mlir
This commit decomposes aten._reshape_alias op into aten.view op. (#690)
parent
eecbf0bab6
commit
25ba51b2af
|
@ -4549,6 +4549,30 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size,
|
||||
TorchIntListType:$stride
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten_ReshapeAliasOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void Aten_ReshapeAliasOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
|
|
@ -1418,6 +1418,23 @@ class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// In PyTorch, _reshape_alias just uses an already computed stride.
|
||||
// See
|
||||
// https://github.com/pytorch/pytorch/blob/d8c31a819d4a65e732b5901e3b994e1869851f1a/aten/src/ATen/native/TensorShape.cpp#L1153
|
||||
// Note that this is the same decomposition as in AOTAutograd
|
||||
// https://github.com/pytorch/functorch/blob/a3042d94e616d4143813668b1372d9d4545be14e/functorch/_src/aot_autograd.py#L104
|
||||
namespace {
|
||||
class DecomposeAten_ReshapeAliasOp : public OpRewritePattern<Aten_ReshapeAliasOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_ReshapeAliasOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.self(),
|
||||
op.size());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose constant tensor like ops.
|
||||
template <typename OpTy, typename NewOpTy>
|
||||
|
@ -1596,6 +1613,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenStdOp>();
|
||||
patterns.add<DecomposeAten_UnsafeViewOp>(context);
|
||||
target.addIllegalOp<Aten_UnsafeViewOp>();
|
||||
patterns.add<DecomposeAten_ReshapeAliasOp>(context);
|
||||
target.addIllegalOp<Aten_ReshapeAliasOp>();
|
||||
patterns.add<DecomposeAtenBernoulliOp>(context);
|
||||
target.addIllegalOp<AtenBernoulliOp>();
|
||||
patterns.add<DecomposeValsemVariantAtenBernoulliFloatOp>(context);
|
||||
|
|
|
@ -33,7 +33,7 @@ static bool isViewLikeOp(Operation *op) {
|
|||
// that it does not return a view and treat those as having value
|
||||
// semantics.
|
||||
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenExpandAsOp, AtenExpandOp,
|
||||
AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
|
||||
AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp,
|
||||
AtenSelectIntOp, AtenSliceTensorOp, AtenSqueezeDimOp,
|
||||
AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp,
|
||||
AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp>(op);
|
||||
|
|
|
@ -507,13 +507,14 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
|
||||
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
|
||||
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
||||
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
|
||||
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
|
||||
AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
|
||||
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
|
||||
Aten_UnsafeViewOp, AtenReshapeOp, AtenResize_Op, AtenTransposeIntOp,
|
||||
AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
|
||||
AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
|
||||
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenZero_Op,
|
||||
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp,
|
||||
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
|
||||
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
|
||||
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
|
||||
AtenConstantPadNdOp, AtenZero_Op,
|
||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
|
|
|
@ -1205,6 +1205,10 @@ module {
|
|||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten._reshape_alias"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten._unsafe_view"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||
return %arg1 : !torch.list<int>
|
||||
}
|
||||
|
|
|
@ -551,6 +551,9 @@ def aten〇view(self: List[int], size: List[int]) -> List[int]:
|
|||
def aten〇reshape(self: List[int], shape: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.view(self, shape)
|
||||
|
||||
def aten〇_reshape_alias(self: List[int], size: List[int], stride: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.view(self, size)
|
||||
|
||||
def aten〇_unsafe_view(self: List[int], size: List[int]) -> List[int]:
|
||||
return size
|
||||
|
||||
|
|
|
@ -397,6 +397,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::numel : (Tensor) -> (int)")
|
||||
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
|
||||
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
|
||||
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
|
||||
|
|
|
@ -388,3 +388,44 @@ class ViewNoChangeStaticModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ViewNoChangeStaticModule())
|
||||
def ViewNoChangeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReshapeAliasExpandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.reshape_alias = torch.ops.aten._reshape_alias
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten._reshape_alias(a, size=(12, 32), stride=(32, 1))
|
||||
|
||||
@register_test_case(module_factory=lambda: ReshapeAliasExpandModule())
|
||||
def ReshapeAliasExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(384))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReshapeAliasCollapseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten._reshape_alias(a, (8,), (1,))
|
||||
|
||||
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
|
||||
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4))
|
||||
|
||||
|
|
|
@ -332,6 +332,24 @@ func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !
|
|||
return %1 : !torch.vtensor<[1,2,256,32],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten._reshape_alias$static
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1],f32>)
|
||||
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct
|
||||
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct
|
||||
// CHECK-NOT: torch.aten._reshape_alias
|
||||
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST1]]
|
||||
// CHECK-NEXT: return
|
||||
func @torch.aten._reshape_alias$static(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[12,32],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int32 = torch.constant.int 32
|
||||
%int12 = torch.constant.int 12
|
||||
%0 = torch.prim.ListConstruct %int12, %int32 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int32, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%2 = torch.aten._reshape_alias %arg0, %0, %1 : !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[12,32],f32>
|
||||
return %2 : !torch.vtensor<[12,32],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten._unsafe_view$dynamic
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>)
|
||||
|
|
Loading…
Reference in New Issue