mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for `aten._unsafe_view` op.
This commit adds decomposition of `aten._unsafe_view` op into `aten.view` op. Signed-Off-By: Prateek Gupta<prateek@nod-labs.com>pull/572/head
parent
9b89f8eb3f
commit
318946a650
|
@ -126,6 +126,122 @@ def View1DFoldModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class UnsafeViewExpandModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([6, 4], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten._unsafe_view(a, [2, 3, 4])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: UnsafeViewExpandModule())
|
||||||
|
def UnsafeViewExpandModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(6, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class UnsafeViewDynamicExpandModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, 30, 384], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten._unsafe_view(a,[2, 4, 5, 6, 12, 32])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: UnsafeViewDynamicExpandModule())
|
||||||
|
def UnsafeViewDynamicExpandModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4, 30, 384))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class UnsafeViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten._unsafe_view(a, [a.size(0), a.size(1), 12, 32])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: UnsafeViewDynamicExpandWithAtenSizeIntModule())
|
||||||
|
def UnsafeViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4, 384))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class UnsafeViewCollapseModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten._unsafe_view(a,[8])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: UnsafeViewCollapseModule())
|
||||||
|
def UnsafeViewCollapseModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class UnsafeViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
([], torch.int64, True),
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a, b, c):
|
||||||
|
return torch.ops.aten._unsafe_view(a, [a.size(0), int(b), int(c), a.size(3), 384])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: UnsafeViewCollapseDynamicWithAtenSizeIntModule())
|
||||||
|
def UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class UnsafeView1DFoldModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten._unsafe_view(a, [-1])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: UnsafeView1DFoldModule())
|
||||||
|
def UnsafeView1DFoldModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ReshapeExpandModule(torch.nn.Module):
|
class ReshapeExpandModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -47,6 +47,7 @@ TOSA_PASS_SET = {
|
||||||
"TModuleRank0_basic",
|
"TModuleRank0_basic",
|
||||||
"ElementwiseToDtypeIdentityModule_basic",
|
"ElementwiseToDtypeIdentityModule_basic",
|
||||||
"View1DFoldModule_basic",
|
"View1DFoldModule_basic",
|
||||||
|
"UnsafeView1DFoldModule_basic",
|
||||||
"SqueezeDimModule_static",
|
"SqueezeDimModule_static",
|
||||||
"SqueezeDimModule_identity",
|
"SqueezeDimModule_identity",
|
||||||
"SqueezeDimModule_unitDim",
|
"SqueezeDimModule_unitDim",
|
||||||
|
|
|
@ -2751,6 +2751,21 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::_unsafe_view : (Tensor, int[]) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
TorchIntListType:$size
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $size attr-dict `:` qualified(type($self)) `,` qualified(type($size)) `->` qualified(type($result))";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
|
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics
|
HasValueSemantics
|
||||||
|
|
|
@ -943,6 +943,31 @@ class DecomposeAtenNativeBatchNormOp
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompse `Aten_UnsafeViewOp` into `AtenViewOp`. _unsafe_view() differs from
|
||||||
|
// view() in that the returned tensor isn't treated as a view for the purposes
|
||||||
|
// of automatic differentiation. It's only safe to use if the `self` tensor is
|
||||||
|
// temporary. For example, the viewed tensor here (a + b) is discarded
|
||||||
|
// immediately after viewing:
|
||||||
|
//
|
||||||
|
// res = _unsafe_view(a + b, size);
|
||||||
|
//
|
||||||
|
// This is a hack because in-place operations on tensors treated like views
|
||||||
|
// can be much more expensive than the same operations on non-view tensors.
|
||||||
|
|
||||||
|
// Refer to
|
||||||
|
// https://github.com/pytorch/pytorch/blob/364055b2771ecf9b54f1d67a8bf44bb5496476d4/aten/src/ATen/native/TensorShape.cpp#L2072
|
||||||
|
namespace {
|
||||||
|
class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(Aten_UnsafeViewOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.self(),
|
||||||
|
op.size());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -1014,6 +1039,8 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<AtenVarOp>();
|
target.addIllegalOp<AtenVarOp>();
|
||||||
patterns.add<DecomposeAtenStdOp>(context);
|
patterns.add<DecomposeAtenStdOp>(context);
|
||||||
target.addIllegalOp<AtenStdOp>();
|
target.addIllegalOp<AtenStdOp>();
|
||||||
|
patterns.add<DecomposeAten_UnsafeViewOp>(context);
|
||||||
|
target.addIllegalOp<Aten_UnsafeViewOp>();
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
|
|
@ -366,6 +366,8 @@ public:
|
||||||
secondResDtype, operands, /*resNum=*/1);
|
secondResDtype, operands, /*resNum=*/1);
|
||||||
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
|
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
|
||||||
return visitReshapeLikeOp(view, operands, view.size());
|
return visitReshapeLikeOp(view, operands, view.size());
|
||||||
|
} else if (auto unsafeView = dyn_cast<Aten_UnsafeViewOp>(op)) {
|
||||||
|
return visitReshapeLikeOp(unsafeView, operands, unsafeView.size());
|
||||||
} else if (auto reshape = dyn_cast<AtenReshapeOp>(op)) {
|
} else if (auto reshape = dyn_cast<AtenReshapeOp>(op)) {
|
||||||
return visitReshapeLikeOp(reshape, operands, reshape.shape());
|
return visitReshapeLikeOp(reshape, operands, reshape.shape());
|
||||||
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) {
|
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) {
|
||||||
|
|
|
@ -608,6 +608,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
||||||
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||||
|
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
|
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
|
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
|
||||||
emit("aten::len.Tensor : (Tensor) -> (int)")
|
emit("aten::len.Tensor : (Tensor) -> (int)")
|
||||||
|
|
|
@ -25,7 +25,7 @@ func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
|
||||||
return %0 : !torch.tensor
|
return %0 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.softmax.int(
|
// CHECK-LABEL: func @torch.aten.softmax.int(
|
||||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
||||||
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> {
|
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> {
|
||||||
|
@ -52,7 +52,7 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.softmax.int$cst_dim(
|
// CHECK-LABEL: func @torch.aten.softmax.int$cst_dim(
|
||||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
|
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
|
||||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||||
|
@ -79,7 +79,7 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten
|
||||||
return %ret : !torch.tensor<[2,3],f32>
|
return %ret : !torch.tensor<[2,3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.softmax.int$dyn_shape(
|
// CHECK-LABEL: func @torch.aten.softmax.int$dyn_shape(
|
||||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
||||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||||
|
@ -106,7 +106,7 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t
|
||||||
return %ret : !torch.tensor<[?,?],f32>
|
return %ret : !torch.tensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.softmax.int$unknown_shape(
|
// CHECK-LABEL: func @torch.aten.softmax.int$unknown_shape(
|
||||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
|
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
|
||||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||||
|
@ -133,7 +133,7 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t
|
||||||
return %ret : !torch.tensor<*,f32>
|
return %ret : !torch.tensor<*,f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.size(
|
// CHECK-LABEL: func @torch.aten.size(
|
||||||
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
|
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
|
||||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||||
|
@ -147,7 +147,7 @@ func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.in
|
||||||
return %0 : !torch.list<!torch.int>
|
return %0 : !torch.list<!torch.int>
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
|
// CHECK-LABEL: func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
|
||||||
// CHECK: %[[CST5:.*]] = torch.constant.int 5
|
// CHECK: %[[CST5:.*]] = torch.constant.int 5
|
||||||
// CHECK: %[[CSTN:.*]] = torch.constant.none
|
// CHECK: %[[CSTN:.*]] = torch.constant.none
|
||||||
|
@ -163,7 +163,7 @@ func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
|
||||||
return %0 : !torch.vtensor<[?],si64>
|
return %0 : !torch.vtensor<[?],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
|
// CHECK-LABEL: func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
|
||||||
// CHECK: %[[CST10:.*]] = torch.constant.int 10
|
// CHECK: %[[CST10:.*]] = torch.constant.int 10
|
||||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||||
|
@ -180,7 +180,7 @@ func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
|
||||||
return %0 : !torch.vtensor<[?],si64>
|
return %0 : !torch.vtensor<[?],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.argmax(
|
// CHECK-LABEL: func @torch.aten.argmax(
|
||||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
|
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
|
||||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||||
|
@ -194,7 +194,7 @@ func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?
|
||||||
return %0 : !torch.vtensor<[1,?],si64>
|
return %0 : !torch.vtensor<[1,?],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.argmax$reduceall(
|
// CHECK-LABEL: func @torch.aten.argmax$reduceall(
|
||||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
|
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
|
||||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
@ -310,3 +310,35 @@ func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtenso
|
||||||
%0 = torch.aten.std %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
|
%0 = torch.aten.std %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
|
||||||
return %0 : !torch.vtensor<[],f32>
|
return %0 : !torch.vtensor<[],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @torch.aten._unsafe_view$static
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,512,32],f32>)
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
|
||||||
|
// CHECK-NOT: torch.aten._unsafe_view
|
||||||
|
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !torch.vtensor<[1,2,256,32],f32> {
|
||||||
|
%c1 = torch.constant.int 1
|
||||||
|
%c2 = torch.constant.int 2
|
||||||
|
%c256 = torch.constant.int 256
|
||||||
|
%c32 = torch.constant.int 32
|
||||||
|
%0 = torch.prim.ListConstruct %c1, %c2, %c256, %c32 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||||
|
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[1,512,32],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,2,256,32],f32>
|
||||||
|
return %1 : !torch.vtensor<[1,2,256,32],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @torch.aten._unsafe_view$dynamic
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>)
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
|
||||||
|
// CHECK-NOT: torch.aten._unsafe_view
|
||||||
|
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[512,32],f32> {
|
||||||
|
%c256 = torch.constant.int 512
|
||||||
|
%c32 = torch.constant.int 32
|
||||||
|
%0 = torch.prim.ListConstruct %c256, %c32 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||||
|
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[512,32],f32>
|
||||||
|
return %1 : !torch.vtensor<[512,32],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue