mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Fold trivial cases of `aten.to.dtype` and `aten.view` op
- It folds `aten.to.dtype` when the input tensor type and result type are exactly same. - It folds `aten.view` when the rank of both the input tensor type and result type is unity. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/502/head snapshot-20211224.163
parent
9e1ecf2c0b
commit
a83004c806
|
@ -865,6 +865,22 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
|
|||
def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
class ElementwiseToDtypeIdentityModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float32, False, False)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseToDtypeIdentityModule())
|
||||
def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
class ElementwiseLog2Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
|||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewExpandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -45,8 +46,8 @@ class ViewDynamicExpandModule(torch.nn.Module):
|
|||
def ViewDynamicExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 30, 384))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -65,6 +66,7 @@ def ViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(2, 4, 384))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewCollapseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -82,8 +84,8 @@ class ViewCollapseModule(torch.nn.Module):
|
|||
def ViewCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -102,3 +104,22 @@ class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule())
|
||||
def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class View1DFoldModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(-1)
|
||||
|
||||
@register_test_case(module_factory=lambda: View1DFoldModule())
|
||||
def View1DFoldModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(32))
|
||||
|
|
|
@ -43,4 +43,6 @@ TOSA_PASS_SET = {
|
|||
"SqueezeModule_allUnitDim",
|
||||
"TModuleRank1_basic",
|
||||
"TModuleRank0_basic",
|
||||
"ElementwiseToDtypeIdentityModule_basic",
|
||||
"View1DFoldModule_basic",
|
||||
}
|
||||
|
|
|
@ -2399,6 +2399,7 @@ def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
|
|||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $dtype `,` $non_blocking `,` $copy `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($non_blocking) `,` type($copy) `,` type($memory_format) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [
|
||||
|
@ -2462,6 +2463,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
|
|||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
|
||||
|
|
|
@ -474,6 +474,48 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenToDtypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
|
||||
bool nonBlocking, copyArg;
|
||||
// The non_blocking arg must be `False`.
|
||||
if (!matchPattern(non_blocking(), m_TorchConstantBool(&nonBlocking)) ||
|
||||
nonBlocking)
|
||||
return nullptr;
|
||||
// The copy arg must be `False`.
|
||||
if (!matchPattern(copy(), m_TorchConstantBool(©Arg)) || copyArg)
|
||||
return nullptr;
|
||||
// The memory_format arg must be `none`.
|
||||
if (!memory_format().getType().isa<Torch::NoneType>())
|
||||
return nullptr;
|
||||
|
||||
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||
if (!inputType || !inputType.hasSizes())
|
||||
return nullptr;
|
||||
auto resType = getType().dyn_cast<BaseTensorType>();
|
||||
if (!resType || !resType.hasSizes() || inputType != resType)
|
||||
return nullptr;
|
||||
// Fold when both the input tensor and result are of the same type.
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
|
||||
return nullptr;
|
||||
auto resType = getType().dyn_cast<BaseTensorType>();
|
||||
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
|
||||
return nullptr;
|
||||
// Fold when both the input tensor and result are unity rank tensors.
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenDimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -580,11 +580,11 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||
emit("aten::sum : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
|
||||
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)")
|
||||
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
|
||||
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
|
||||
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
||||
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::view : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::len.Tensor : (Tensor) -> (int)")
|
||||
|
|
|
@ -622,3 +622,24 @@ func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.t
|
|||
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32>
|
||||
return %0 : !torch.tensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.to.dtype$same_dtype(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32>
|
||||
func @torch.aten.to.dtype$same_dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%int6 = torch.constant.int 6
|
||||
%0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
|
||||
return %0 : !torch.tensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.view$1D(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?],f32>
|
||||
func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32> {
|
||||
%int-1 = torch.constant.int -1
|
||||
%0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<!torch.int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.tensor<[?],f32>, !torch.list<!torch.int> -> !torch.tensor<[?],f32>
|
||||
return %1 : !torch.tensor<[?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue