[TORCH][MLIR] Add E2E support for `aten._unsafe_view` op.

This commit adds decomposition of `aten._unsafe_view` op into
`aten.view` op.

Signed-Off-By: Prateek Gupta<prateek@nod-labs.com>
pull/572/head
Prateek Gupta 2022-02-10 08:11:05 +00:00
parent 9b89f8eb3f
commit 318946a650
7 changed files with 203 additions and 9 deletions

View File

@ -126,6 +126,122 @@ def View1DFoldModule_basic(module, tu: TestUtils):
# ==============================================================================
class UnsafeViewExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([6, 4], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._unsafe_view(a, [2, 3, 4])
@register_test_case(module_factory=lambda: UnsafeViewExpandModule())
def UnsafeViewExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4))
# ==============================================================================
class UnsafeViewDynamicExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, 30, 384], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._unsafe_view(a,[2, 4, 5, 6, 12, 32])
@register_test_case(module_factory=lambda: UnsafeViewDynamicExpandModule())
def UnsafeViewDynamicExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 30, 384))
# ==============================================================================
class UnsafeViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._unsafe_view(a, [a.size(0), a.size(1), 12, 32])
@register_test_case(module_factory=lambda: UnsafeViewDynamicExpandWithAtenSizeIntModule())
def UnsafeViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 384))
# ==============================================================================
class UnsafeViewCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._unsafe_view(a,[8])
@register_test_case(module_factory=lambda: UnsafeViewCollapseModule())
def UnsafeViewCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))
# ==============================================================================
class UnsafeViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1, -1, -1], torch.float32, True),
([], torch.int64, True),
([], torch.int64, True),
])
def forward(self, a, b, c):
return torch.ops.aten._unsafe_view(a, [a.size(0), int(b), int(c), a.size(3), 384])
@register_test_case(module_factory=lambda: UnsafeViewCollapseDynamicWithAtenSizeIntModule())
def UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))
# ==============================================================================
class UnsafeView1DFoldModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._unsafe_view(a, [-1])
@register_test_case(module_factory=lambda: UnsafeView1DFoldModule())
def UnsafeView1DFoldModule_basic(module, tu: TestUtils):
module.forward(tu.rand(32))
# ==============================================================================
class ReshapeExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -47,6 +47,7 @@ TOSA_PASS_SET = {
"TModuleRank0_basic",
"ElementwiseToDtypeIdentityModule_basic",
"View1DFoldModule_basic",
"UnsafeView1DFoldModule_basic",
"SqueezeDimModule_static",
"SqueezeDimModule_identity",
"SqueezeDimModule_unitDim",

View File

@ -2751,6 +2751,21 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
let hasFolder = 1;
}
def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::_unsafe_view : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchIntListType:$size
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $size attr-dict `:` qualified(type($self)) `,` qualified(type($size)) `->` qualified(type($result))";
}
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -943,6 +943,31 @@ class DecomposeAtenNativeBatchNormOp
};
} // namespace
// Decompse `Aten_UnsafeViewOp` into `AtenViewOp`. _unsafe_view() differs from
// view() in that the returned tensor isn't treated as a view for the purposes
// of automatic differentiation. It's only safe to use if the `self` tensor is
// temporary. For example, the viewed tensor here (a + b) is discarded
// immediately after viewing:
//
// res = _unsafe_view(a + b, size);
//
// This is a hack because in-place operations on tensors treated like views
// can be much more expensive than the same operations on non-view tensors.
// Refer to
// https://github.com/pytorch/pytorch/blob/364055b2771ecf9b54f1d67a8bf44bb5496476d4/aten/src/ATen/native/TensorShape.cpp#L2072
namespace {
class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_UnsafeViewOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.self(),
op.size());
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -1014,6 +1039,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenVarOp>();
patterns.add<DecomposeAtenStdOp>(context);
target.addIllegalOp<AtenStdOp>();
patterns.add<DecomposeAten_UnsafeViewOp>(context);
target.addIllegalOp<Aten_UnsafeViewOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -366,6 +366,8 @@ public:
secondResDtype, operands, /*resNum=*/1);
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
return visitReshapeLikeOp(view, operands, view.size());
} else if (auto unsafeView = dyn_cast<Aten_UnsafeViewOp>(op)) {
return visitReshapeLikeOp(unsafeView, operands, unsafeView.size());
} else if (auto reshape = dyn_cast<AtenReshapeOp>(op)) {
return visitReshapeLikeOp(reshape, operands, reshape.shape());
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) {

View File

@ -608,6 +608,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::len.Tensor : (Tensor) -> (int)")

View File

@ -25,7 +25,7 @@ func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
return %0 : !torch.tensor
}
// ----
// -----
// CHECK-LABEL: func @torch.aten.softmax.int(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> {
@ -52,7 +52,7 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
}
// ----
// -----
// CHECK-LABEL: func @torch.aten.softmax.int$cst_dim(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -79,7 +79,7 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten
return %ret : !torch.tensor<[2,3],f32>
}
// ----
// -----
// CHECK-LABEL: func @torch.aten.softmax.int$dyn_shape(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -106,7 +106,7 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t
return %ret : !torch.tensor<[?,?],f32>
}
// ----
// -----
// CHECK-LABEL: func @torch.aten.softmax.int$unknown_shape(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -133,7 +133,7 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t
return %ret : !torch.tensor<*,f32>
}
// ----
// -----
// CHECK-LABEL: func @torch.aten.size(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
@ -147,7 +147,7 @@ func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.in
return %0 : !torch.list<!torch.int>
}
// ----
// -----
// CHECK-LABEL: func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST5:.*]] = torch.constant.int 5
// CHECK: %[[CSTN:.*]] = torch.constant.none
@ -163,7 +163,7 @@ func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
return %0 : !torch.vtensor<[?],si64>
}
// ----
// -----
// CHECK-LABEL: func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST10:.*]] = torch.constant.int 10
// CHECK: %[[CST0:.*]] = torch.constant.int 0
@ -180,7 +180,7 @@ func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
return %0 : !torch.vtensor<[?],si64>
}
// ----
// -----
// CHECK-LABEL: func @torch.aten.argmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
@ -194,7 +194,7 @@ func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?
return %0 : !torch.vtensor<[1,?],si64>
}
// ----
// -----
// CHECK-LABEL: func @torch.aten.argmax$reduceall(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -310,3 +310,35 @@ func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtenso
%0 = torch.aten.std %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
// CHECK-LABEL: func @torch.aten._unsafe_view$static
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,512,32],f32>)
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._unsafe_view
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
// CHECK-NEXT: return
func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !torch.vtensor<[1,2,256,32],f32> {
%c1 = torch.constant.int 1
%c2 = torch.constant.int 2
%c256 = torch.constant.int 256
%c32 = torch.constant.int 32
%0 = torch.prim.ListConstruct %c1, %c2, %c256, %c32 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[1,512,32],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,2,256,32],f32>
return %1 : !torch.vtensor<[1,2,256,32],f32>
}
// -----
// CHECK-LABEL: func @torch.aten._unsafe_view$dynamic
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._unsafe_view
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
// CHECK-NEXT: return
func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[512,32],f32> {
%c256 = torch.constant.int 512
%c32 = torch.constant.int 32
%0 = torch.prim.ListConstruct %c256, %c32 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[512,32],f32>
return %1 : !torch.vtensor<[512,32],f32>
}