mirror of https://github.com/llvm/torch-mlir
Add folder for torch.aten.Int.Tensor
This is to fold the common pattern from Bert inference like: ``` %111 = torch.prim.NumToTensor.Scalar %110 : !torch.int -> !torch.vtensor<[],si64> %112 = torch.aten.Int.Tensor %111 : !torch.vtensor<[],si64> -> !torch.int ```pull/441/head
parent
36afa4a4d3
commit
5d28549c2c
|
@ -2197,6 +2197,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
|
|||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
|
||||
|
|
|
@ -1076,5 +1076,14 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||
// If an scalar number is converted to a 0-d tensor and passed on to
|
||||
// aten.Int.Tensor, fold to the scalar number.
|
||||
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||
return numToTensorScalar.a();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||
|
|
|
@ -568,7 +568,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
||||
emit("aten::IntImplicit : (Tensor) -> (int)")
|
||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::Int.Tensor : (Tensor) -> (int)")
|
||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
||||
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||
|
||||
# Dict ops.
|
||||
|
|
|
@ -594,3 +594,14 @@ func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor,
|
|||
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.Int.Tensor(
|
||||
// CHECK-SAME: %[[NUM:.*]]: !torch.int) -> !torch.int {
|
||||
// CHECK: %[[T:.*]] = torch.prim.NumToTensor.Scalar %[[NUM]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[NUM]] : !torch.int
|
||||
func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
|
||||
%tensor = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.vtensor<[],si64>
|
||||
%scalar = torch.aten.Int.Tensor %tensor : !torch.vtensor<[],si64> -> !torch.int
|
||||
return %scalar : !torch.int
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue