This commit decomposes aten._reshape_alias op into aten.view op. (#690)

pull/711/head snapshot-20220329.353
Maksim Levental 2022-03-28 23:54:28 -05:00 committed by GitHub
parent eecbf0bab6
commit 25ba51b2af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 117 additions and 6 deletions

View File

@ -4549,6 +4549,30 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
}];
}
def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchIntListType:$size,
TorchIntListType:$stride
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_ReshapeAliasOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void Aten_ReshapeAliasOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [
AllowsTypeRefinement
]> {

View File

@ -1418,6 +1418,23 @@ class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
};
} // namespace
// In PyTorch, _reshape_alias just uses an already computed stride.
// See
// https://github.com/pytorch/pytorch/blob/d8c31a819d4a65e732b5901e3b994e1869851f1a/aten/src/ATen/native/TensorShape.cpp#L1153
// Note that this is the same decomposition as in AOTAutograd
// https://github.com/pytorch/functorch/blob/a3042d94e616d4143813668b1372d9d4545be14e/functorch/_src/aot_autograd.py#L104
namespace {
class DecomposeAten_ReshapeAliasOp : public OpRewritePattern<Aten_ReshapeAliasOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ReshapeAliasOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.self(),
op.size());
return success();
}
};
} // namespace
namespace {
// Decompose constant tensor like ops.
template <typename OpTy, typename NewOpTy>
@ -1596,6 +1613,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenStdOp>();
patterns.add<DecomposeAten_UnsafeViewOp>(context);
target.addIllegalOp<Aten_UnsafeViewOp>();
patterns.add<DecomposeAten_ReshapeAliasOp>(context);
target.addIllegalOp<Aten_ReshapeAliasOp>();
patterns.add<DecomposeAtenBernoulliOp>(context);
target.addIllegalOp<AtenBernoulliOp>();
patterns.add<DecomposeValsemVariantAtenBernoulliFloatOp>(context);

View File

@ -33,7 +33,7 @@ static bool isViewLikeOp(Operation *op) {
// that it does not return a view and treat those as having value
// semantics.
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenExpandAsOp, AtenExpandOp,
AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp,
AtenSelectIntOp, AtenSliceTensorOp, AtenSqueezeDimOp,
AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp,
AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp>(op);

View File

@ -507,13 +507,14 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
Aten_UnsafeViewOp, AtenReshapeOp, AtenResize_Op, AtenTransposeIntOp,
AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenZero_Op,
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp,
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
AtenConstantPadNdOp, AtenZero_Op,
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp>(op)) {
ValueKnowledge knowledge =

View File

@ -1205,6 +1205,10 @@ module {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten._reshape_alias"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten._unsafe_view"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
return %arg1 : !torch.list<int>
}

View File

@ -551,6 +551,9 @@ def atenview(self: List[int], size: List[int]) -> List[int]:
def atenreshape(self: List[int], shape: List[int]) -> List[int]:
return upstream_shape_helpers.view(self, shape)
def aten_reshape_alias(self: List[int], size: List[int], stride: List[int]) -> List[int]:
return upstream_shape_helpers.view(self, size)
def aten_unsafe_view(self: List[int], size: List[int]) -> List[int]:
return size

View File

@ -397,6 +397,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::numel : (Tensor) -> (int)")
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)

View File

@ -388,3 +388,44 @@ class ViewNoChangeStaticModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ViewNoChangeStaticModule())
def ViewNoChangeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6))
# ==============================================================================
class ReshapeAliasExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.reshape_alias = torch.ops.aten._reshape_alias
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._reshape_alias(a, size=(12, 32), stride=(32, 1))
@register_test_case(module_factory=lambda: ReshapeAliasExpandModule())
def ReshapeAliasExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(384))
# ==============================================================================
class ReshapeAliasCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._reshape_alias(a, (8,), (1,))
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))

View File

@ -332,6 +332,24 @@ func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !
return %1 : !torch.vtensor<[1,2,256,32],f32>
}
// -----
// CHECK-LABEL: func @torch.aten._reshape_alias$static
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1],f32>)
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._reshape_alias
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST1]]
// CHECK-NEXT: return
func @torch.aten._reshape_alias$static(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[12,32],f32> {
%int1 = torch.constant.int 1
%int32 = torch.constant.int 32
%int12 = torch.constant.int 12
%0 = torch.prim.ListConstruct %int12, %int32 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int32, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten._reshape_alias %arg0, %0, %1 : !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[12,32],f32>
return %2 : !torch.vtensor<[12,32],f32>
}
// -----
// CHECK-LABEL: func @torch.aten._unsafe_view$dynamic
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>)