[TORCH][MLIR] Fold trivial cases of `aten.to.dtype` and `aten.view` op

- It folds `aten.to.dtype` when the input tensor type and result type
  are exactly same.
- It folds `aten.view` when the rank of both the input tensor type and
  result type is unity.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/502/head snapshot-20211224.163
Gaurav Shukla 2021-12-23 17:34:29 +05:30
parent 9e1ecf2c0b
commit a83004c806
7 changed files with 108 additions and 4 deletions

View File

@ -865,6 +865,22 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ElementwiseToDtypeIdentityModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True)
])
def forward(self, x):
return x.to(torch.float32, False, False)
@register_test_case(module_factory=lambda: ElementwiseToDtypeIdentityModule())
def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ElementwiseLog2Module(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -9,6 +9,7 @@ from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class ViewExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -45,8 +46,8 @@ class ViewDynamicExpandModule(torch.nn.Module):
def ViewDynamicExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 30, 384))
# ==============================================================================
class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -65,6 +66,7 @@ def ViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 384))
# ==============================================================================
class ViewCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -82,8 +84,8 @@ class ViewCollapseModule(torch.nn.Module):
def ViewCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))
# ==============================================================================
class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -102,3 +104,22 @@ class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule())
def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))
# ==============================================================================
class View1DFoldModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return a.view(-1)
@register_test_case(module_factory=lambda: View1DFoldModule())
def View1DFoldModule_basic(module, tu: TestUtils):
module.forward(tu.rand(32))

View File

@ -43,4 +43,6 @@ TOSA_PASS_SET = {
"SqueezeModule_allUnitDim",
"TModuleRank1_basic",
"TModuleRank0_basic",
"ElementwiseToDtypeIdentityModule_basic",
"View1DFoldModule_basic",
}

View File

@ -2399,6 +2399,7 @@ def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dtype `,` $non_blocking `,` $copy `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($non_blocking) `,` type($copy) `,` type($memory_format) `->` type($result)";
let hasFolder = 1;
}
def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [
@ -2462,6 +2463,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)";
let hasFolder = 1;
}
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [

View File

@ -474,6 +474,48 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenToDtypeOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
bool nonBlocking, copyArg;
// The non_blocking arg must be `False`.
if (!matchPattern(non_blocking(), m_TorchConstantBool(&nonBlocking)) ||
nonBlocking)
return nullptr;
// The copy arg must be `False`.
if (!matchPattern(copy(), m_TorchConstantBool(&copyArg)) || copyArg)
return nullptr;
// The memory_format arg must be `none`.
if (!memory_format().getType().isa<Torch::NoneType>())
return nullptr;
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
if (!inputType || !inputType.hasSizes())
return nullptr;
auto resType = getType().dyn_cast<BaseTensorType>();
if (!resType || !resType.hasSizes() || inputType != resType)
return nullptr;
// Fold when both the input tensor and result are of the same type.
return getOperand(0);
}
//===----------------------------------------------------------------------===//
// AtenViewOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
return nullptr;
auto resType = getType().dyn_cast<BaseTensorType>();
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
return nullptr;
// Fold when both the input tensor and result are unity rank tensors.
return getOperand(0);
}
//===----------------------------------------------------------------------===//
// AtenDimOp
//===----------------------------------------------------------------------===//

View File

@ -580,11 +580,11 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::sum : (Tensor, int?) -> (Tensor)")
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)")
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::view : (Tensor, int[]) -> (Tensor)")
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::len.Tensor : (Tensor) -> (int)")

View File

@ -622,3 +622,24 @@ func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.t
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32>
return %0 : !torch.tensor<[],f32>
}
// CHECK-LABEL: func @torch.aten.to.dtype$same_dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32>
func @torch.aten.to.dtype$same_dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%int6 = torch.constant.int 6
%0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
return %0 : !torch.tensor<[?,?],f32>
}
// CHECK-LABEL: func @torch.aten.view$1D(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?],f32>
func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32> {
%int-1 = torch.constant.int -1
%0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.view %arg0, %0 : !torch.tensor<[?],f32>, !torch.list<!torch.int> -> !torch.tensor<[?],f32>
return %1 : !torch.tensor<[?],f32>
}