mirror of https://github.com/llvm/torch-mlir
Add folder for torch.aten.Int.Tensor
This is to fold the common pattern from Bert inference like: ``` %111 = torch.prim.NumToTensor.Scalar %110 : !torch.int -> !torch.vtensor<[],si64> %112 = torch.aten.Int.Tensor %111 : !torch.vtensor<[],si64> -> !torch.int ```pull/441/head
parent
36afa4a4d3
commit
5d28549c2c
|
@ -2197,6 +2197,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
|
||||||
Torch_IntType:$result
|
Torch_IntType:$result
|
||||||
);
|
);
|
||||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
|
def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
|
||||||
|
|
|
@ -1076,5 +1076,14 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
// If an scalar number is converted to a 0-d tensor and passed on to
|
||||||
|
// aten.Int.Tensor, fold to the scalar number.
|
||||||
|
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||||
|
return numToTensorScalar.a();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||||
|
|
|
@ -568,7 +568,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
||||||
emit("aten::IntImplicit : (Tensor) -> (int)")
|
emit("aten::IntImplicit : (Tensor) -> (int)")
|
||||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
||||||
emit("aten::Int.Tensor : (Tensor) -> (int)")
|
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
||||||
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||||
|
|
||||||
# Dict ops.
|
# Dict ops.
|
||||||
|
|
|
@ -594,3 +594,14 @@ func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor,
|
||||||
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||||
return %1 : !torch.tensor
|
return %1 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.Int.Tensor(
|
||||||
|
// CHECK-SAME: %[[NUM:.*]]: !torch.int) -> !torch.int {
|
||||||
|
// CHECK: %[[T:.*]] = torch.prim.NumToTensor.Scalar %[[NUM]] : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: return %[[NUM]] : !torch.int
|
||||||
|
func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
|
||||||
|
%tensor = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%scalar = torch.aten.Int.Tensor %tensor : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
return %scalar : !torch.int
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue