Add folder for torch.aten.Int.Tensor

This is to fold the common pattern from Bert inference like:
```
%111 = torch.prim.NumToTensor.Scalar %110 : !torch.int ->
    !torch.vtensor<[],si64>
%112 = torch.aten.Int.Tensor %111 : !torch.vtensor<[],si64> ->
    !torch.int
```
pull/441/head
Yi Zhang 2021-11-29 13:39:37 -05:00 committed by Prashant Kumar
parent 36afa4a4d3
commit 5d28549c2c
4 changed files with 22 additions and 1 deletions

View File

@ -2197,6 +2197,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
Torch_IntType:$result Torch_IntType:$result
); );
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)"; let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
let hasFolder = 1;
} }
def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [ def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [

View File

@ -1076,5 +1076,14 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
} }
return nullptr; return nullptr;
} }
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
// If an scalar number is converted to a 0-d tensor and passed on to
// aten.Int.Tensor, fold to the scalar number.
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
return numToTensorScalar.a();
return nullptr;
}
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc" #include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"

View File

@ -568,7 +568,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
emit("aten::IntImplicit : (Tensor) -> (int)") emit("aten::IntImplicit : (Tensor) -> (int)")
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)") emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)") emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
# Dict ops. # Dict ops.

View File

@ -594,3 +594,14 @@ func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor,
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor %1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
return %1 : !torch.tensor return %1 : !torch.tensor
} }
// CHECK-LABEL: func @torch.aten.Int.Tensor(
// CHECK-SAME: %[[NUM:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[T:.*]] = torch.prim.NumToTensor.Scalar %[[NUM]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[NUM]] : !torch.int
func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
%tensor = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.vtensor<[],si64>
%scalar = torch.aten.Int.Tensor %tensor : !torch.vtensor<[],si64> -> !torch.int
return %scalar : !torch.int
}