Change dtype functions interface to take ints tuple for each tensor (#1965)

The original design for the dtype functions outlined in
https://github.com/llvm/torch-mlir/issues/1462 was unable to properly
handle ops that take optional tensors as an input when the optional
tensor has a value of None. By the time the op gets imported into
torch-mlir, if an optional value is None, all information about the
original type is lost from the op type signature, preventing
torch-mlir from knowing if a value of None was from an optional tensor
or not, which was crucial in the original design since each tensor
argument must be turned into two separate arguments for the dtype
function.

This commit changes the interface to dtype functions such that each
tensor turns into a tuple of two ints, the first representing the rank
of the tensor and the second the dtype of the tensor. Since now there
is a one-to-one correspondence between the operands of an op and the
operands of its dtype function, there is no ambiguity about which
operand of the op corresponds with which operand of the dtype
function.

To test the implementation, this commit defines dtype function for
convolution op, which takes one optional tensor as an argument.
pull/1984/head
Ramiro Leal-Cavazos 2023-03-23 11:05:39 -07:00 committed by GitHub
parent f2a05f2dc0
commit eae3ff7f1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 180 additions and 163 deletions

View File

@ -21,7 +21,7 @@ We will use the example of adding support for the `torch.aten.tanh` op.
function signatures are: function signatures are:
- `def atentanh〡shape(self: List[int]) -> List[int]:` - `def atentanh〡shape(self: List[int]) -> List[int]:`
- `def atentanh〡dtype(self_rank: int, self_dtype: int) -> int:` - `def atentanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:`
Note the use of `` as a separator since `.` or `::` aren't legal Note the use of `` as a separator since `.` or `::` aren't legal
in a Python identifier. in a Python identifier.

View File

@ -5979,31 +5979,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n" " %int6 = torch.constant.int 6\n"
" %int5 = torch.constant.int 5\n" " %int5 = torch.constant.int 5\n"
" %int15 = torch.constant.int 15\n" " %int15 = torch.constant.int 15\n"
" %true = torch.constant.bool true\n" " %true = torch.constant.bool true\n"
" %int7 = torch.constant.int 7\n" " %int7 = torch.constant.int 7\n"
" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n" " %1 = torch.aten.eq.int %0#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n" " %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n" " torch.prim.If.yield %true : !torch.bool\n"
" } else {\n" " } else {\n"
" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" " %5 = torch.aten.eq.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n" " torch.prim.If.yield %5 : !torch.bool\n"
" }\n" " }\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n" " %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" torch.prim.If.yield %arg1 : !torch.int\n" " torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %5 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %5 : !torch.bool\n"
" }\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %0#1 : !torch.int\n"
" } else {\n" " } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n" " torch.prim.If.yield %int6 : !torch.int\n"
" }\n" " }\n"
" return %3 : !torch.int\n" " return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
@ -6236,13 +6237,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %0 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = torch.prim.ListConstruct %arg1, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n" " %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" return %3 : !torch.int\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list<optional<int>>, %arg1: !torch.list<int>) -> !torch.int {\n" " func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list<optional<int>>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n" " %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
@ -6972,11 +6974,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list<int>\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" return %2 : !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
@ -7162,6 +7166,36 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n" " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n"
" %int10 = torch.constant.int 10\n"
" %int9 = torch.constant.int 9\n"
" %int5 = torch.constant.int 5\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.ListConstruct %int11, %int5, %int9, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %7 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %8 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
@ -7639,7 +7673,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.int, %arg4: !torch.optional<str>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.int {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
" %int4 = torch.constant.int 4\n" " %int4 = torch.constant.int 4\n"
@ -7654,68 +7688,69 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %true = torch.constant.bool true\n" " %true = torch.constant.bool true\n"
" %int9 = torch.constant.int 9\n" " %int9 = torch.constant.int 9\n"
" %0 = torch.prim.Uninitialized : !torch.int\n" " %0 = torch.prim.Uninitialized : !torch.int\n"
" %1 = torch.aten.eq.int %arg1, %int9 : !torch.int, !torch.int -> !torch.bool\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n" " %2 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n" " torch.prim.If.yield %true : !torch.bool\n"
" } else {\n" " } else {\n"
" %4 = torch.aten.eq.int %arg1, %int10 : !torch.int, !torch.int -> !torch.bool\n" " %5 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n" " torch.prim.If.yield %5 : !torch.bool\n"
" }\n" " }\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %arg1 : !torch.int\n" " torch.prim.If.yield %1#1 : !torch.int\n"
" } else {\n" " } else {\n"
" %4 = torch.aten.eq.int %arg1, %int6 : !torch.int, !torch.int -> !torch.bool\n" " %5 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.int) {\n" " %6 = torch.prim.If %5 -> (!torch.int) {\n"
" torch.prim.If.yield %int9 : !torch.int\n" " torch.prim.If.yield %int9 : !torch.int\n"
" } else {\n" " } else {\n"
" %6 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" " %7 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.int) {\n" " %8 = torch.prim.If %7 -> (!torch.int) {\n"
" torch.prim.If.yield %int10 : !torch.int\n" " torch.prim.If.yield %int10 : !torch.int\n"
" } else {\n" " } else {\n"
" %8 = torch.aten.eq.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n" " %9 = torch.aten.eq.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %15 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" }\n"
" %10 = torch.prim.If %9 -> (!torch.bool) {\n" " %10 = torch.prim.If %9 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n" " torch.prim.If.yield %true : !torch.bool\n"
" } else {\n" " } else {\n"
" %15 = torch.aten.eq.int %arg1, %int1 : !torch.int, !torch.int -> !torch.bool\n" " %16 = torch.aten.eq.int %1#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n" " torch.prim.If.yield %16 : !torch.bool\n"
" }\n" " }\n"
" %11 = torch.prim.If %10 -> (!torch.bool) {\n" " %11 = torch.prim.If %10 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n" " torch.prim.If.yield %true : !torch.bool\n"
" } else {\n" " } else {\n"
" %15 = torch.aten.eq.int %arg1, %int2 : !torch.int, !torch.int -> !torch.bool\n" " %16 = torch.aten.eq.int %1#1, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n" " torch.prim.If.yield %16 : !torch.bool\n"
" }\n" " }\n"
" %12 = torch.prim.If %11 -> (!torch.bool) {\n" " %12 = torch.prim.If %11 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n" " torch.prim.If.yield %true : !torch.bool\n"
" } else {\n" " } else {\n"
" %15 = torch.aten.eq.int %arg1, %int3 : !torch.int, !torch.int -> !torch.bool\n" " %16 = torch.aten.eq.int %1#1, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n" " torch.prim.If.yield %16 : !torch.bool\n"
" }\n" " }\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n" " %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n" " torch.prim.If.yield %true : !torch.bool\n"
" } else {\n" " } else {\n"
" %15 = torch.aten.eq.int %arg1, %int4 : !torch.int, !torch.int -> !torch.bool\n" " %16 = torch.aten.eq.int %1#1, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n" " torch.prim.If.yield %16 : !torch.bool\n"
" }\n" " }\n"
" %14 = torch.prim.If %13 -> (!torch.int) {\n" " %14 = torch.prim.If %13 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %16 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %16 : !torch.bool\n"
" }\n"
" %15 = torch.prim.If %14 -> (!torch.int) {\n"
" torch.prim.If.yield %int9 : !torch.int\n" " torch.prim.If.yield %int9 : !torch.int\n"
" } else {\n" " } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %0 : !torch.int\n" " torch.prim.If.yield %0 : !torch.int\n"
" }\n" " }\n"
" torch.prim.If.yield %14 : !torch.int\n" " torch.prim.If.yield %15 : !torch.int\n"
" }\n" " }\n"
" torch.prim.If.yield %7 : !torch.int\n" " torch.prim.If.yield %8 : !torch.int\n"
" }\n" " }\n"
" torch.prim.If.yield %5 : !torch.int\n" " torch.prim.If.yield %6 : !torch.int\n"
" }\n" " }\n"
" return %3 : !torch.int\n" " return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"

View File

@ -714,9 +714,8 @@ void TypeAnalysis::visitOperation(Operation *op,
// Promote the two dtypes assuming non-zero rank. // Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp, if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp, Aten_ConvolutionOp, AtenMvOp, AtenConvolutionOverrideableOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp, AtenConvTranspose2dInputOp, AtenMseLossOp>(op)) {
AtenMseLossOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(

View File

@ -19,55 +19,25 @@ using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
static bool isTensorTypeOrWrappedTensorType(Type type) {
// Allowing tuples as arguments to dtype calculation functions can cause
// issues. For example, if an argument is a tuple of tensors and ints, there
// would be no way of differentiating the original ints from the ints created
// to represent the dtype and rank of the tensors. Therefore, to avoid this
// and keep things simple, the tuple type is not allowed. This works well in
// practice, since PyTorch op signatures don't seem to take tuples as inputs.
assert(!type.isa<Torch::TupleType>() &&
"dtype calculation functions are expected to not have tuples of "
"tensors as arguments");
if (type.isa<Torch::BaseTensorType>())
return true;
if (auto optionalType = type.dyn_cast<Torch::OptionalType>()) {
return isTensorTypeOrWrappedTensorType(optionalType.getContainedType());
} else if (auto listType = type.dyn_cast<Torch::ListType>()) {
return isTensorTypeOrWrappedTensorType(listType.getContainedType());
} else {
return false;
}
}
// Massage the op operands to match the dtype function signature. // Massage the op operands to match the dtype function signature.
// The dtype function generally takes the same operands as the op, with a few // The dtype function generally takes the same operands as the op, with a few
// systematic modifications, such as replacing tensors with a rank and dtype // systematic modifications, such as replacing each tensor with a tuple of
// argument. // its rank and dtype.
static FailureOr<SmallVector<Value>> static FailureOr<SmallVector<Value>>
dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
ValueRange originalOperands, func::FuncOp dtypeFunc) { ValueRange originalOperands, func::FuncOp dtypeFunc) {
// Turns a tensor operand into an operand representing the rank of the tensor // Turn every tensor into a tuple of (tensor_rank, tensor_dtype)
auto rankArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
Type desiredType) -> Value {
if (desiredType.isa<Torch::IntType>() &&
operand.getType().isa<Torch::BaseTensorType>()) {
auto sizeListType =
Torch::ListType::get(Torch::IntType::get(b.getContext()));
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
return b.create<AtenLenTOp>(loc, desiredType, size);
}
return operand;
};
// Turns a tensor operand into an operand representing the dtype of the tensor
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand, auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
Type desiredType) -> Value { Type desiredType) -> Value {
if (desiredType.isa<Torch::IntType>() && if (desiredType.isa<Torch::TupleType>() &&
operand.getType().isa<Torch::BaseTensorType>()) { operand.getType().isa<Torch::BaseTensorType>()) {
return b.create<PrimDtypeOp>(loc, desiredType, operand); Type intType = Torch::IntType::get(b.getContext());
Type sizeListType = Torch::ListType::get(intType);
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
Value rank = b.create<AtenLenTOp>(loc, intType, size);
Value dtype = b.create<PrimDtypeOp>(loc, intType, operand);
return b.create<PrimTupleConstructOp>(loc, desiredType,
ArrayRef{rank, dtype});
} }
return operand; return operand;
}; };
@ -79,26 +49,11 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
"`dtypeFunc` should have at least one argument for each argument in " "`dtypeFunc` should have at least one argument for each argument in "
"`originalOperands`"); "`originalOperands`");
Type desiredType = desiredTypes.front(); Type desiredType = desiredTypes.front();
if (isTensorTypeOrWrappedTensorType(operand.getType())) { FailureOr<Value> otherArg;
assert(desiredTypes.size() >= 2 && if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType,
"`dtypeFunc` should have two arguments for each tensor argument "
"in `originalOperands`");
FailureOr<Value> rankArg, dtypeArg;
if (failed(rankArg = adjustFunctionArg(b, loc, operand, desiredType,
rankArgAdjuster)))
return failure();
desiredTypes = desiredTypes.drop_front();
desiredType = desiredTypes.front();
if (failed(dtypeArg = adjustFunctionArg(b, loc, operand, desiredType,
dtypeArgAdjuster))) dtypeArgAdjuster)))
return failure(); return failure();
dtypeFuncArgs.append({*rankArg, *dtypeArg});
} else {
FailureOr<Value> otherArg;
if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType)))
return failure();
dtypeFuncArgs.push_back(*otherArg); dtypeFuncArgs.push_back(*otherArg);
}
desiredTypes = desiredTypes.drop_front(); desiredTypes = desiredTypes.drop_front();
} }

View File

@ -89,7 +89,8 @@ def atenexpm1〡shape(self: List[int]) -> List[int]:
Invocation(ZeroDTensorWithDtype(torch.int32)), Invocation(ZeroDTensorWithDtype(torch.int32)),
Invocation(ZeroDTensorWithDtype(torch.bool)), Invocation(ZeroDTensorWithDtype(torch.bool)),
]) ])
def atenexpm1〡dtype(self_rank: int, self_dtype: int) -> int: def atenexpm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16: if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16:
return self_dtype return self_dtype
else: else:
@ -280,7 +281,8 @@ def atenrsubScalar〡shape(self: List[int], other: float, alpha: float = 1
Invocation(ZeroDTensorWithDtype(torch.int64), other=0.0), Invocation(ZeroDTensorWithDtype(torch.int64), other=0.0),
Invocation(ZeroDTensorWithDtype(torch.float16), other=0.0) Invocation(ZeroDTensorWithDtype(torch.float16), other=0.0)
]) ])
def atenrsubScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int: def atenrsubScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
def atenleaky_relu〡shape(self: List[int], negative_slope: float = 0.01) -> List[int]: def atenleaky_relu〡shape(self: List[int], negative_slope: float = 0.01) -> List[int]:
@ -679,7 +681,9 @@ def atenfloor_divide〡shape(self: List[int], other: List[int]) -> List[int]:
Invocation(ZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float64)), Invocation(ZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float64)),
Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)), Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)),
]) ])
def atenfloor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: def atenfloor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
other_rank, other_dtype = other_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank] ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype] dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes) return promote_dtypes(ranks, dtypes)
@ -819,6 +823,40 @@ def aten_convolution〡shape(input: List[int], weight: List[int], bias: Optio
def aten_convolutiondeprecated〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]: def aten_convolutiondeprecated〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]:
return atenconvolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) return atenconvolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
_convolution_deprecated_kwargs = {
"stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0],
"groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False}
@check_dtype_function(
[Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Same type
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.int32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different type
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different width
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.int32), # Different type and width
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.complex64), TensorOfShape(1, 1, 1, 1, dtype=torch.float32),
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.complex128),
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32),
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool),
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32),
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16),
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs)
])
def aten_convolutiondeprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int:
input_rank, input_dtype = input_rank_dtype
weight_rank, weight_dtype = weight_rank_dtype
assert input_dtype == weight_dtype
assert input_dtype not in [torch.bool, torch.float16, torch.complex64, torch.complex128]
ranks: List[Optional[int]] = [input_rank, weight_rank]
dtypes = [input_dtype, weight_dtype]
return promote_dtypes(ranks, dtypes)
def atenflip〡shape(self: List[int], dims: List[int]) -> List[int]: def atenflip〡shape(self: List[int], dims: List[int]) -> List[int]:
return self return self
@ -1035,7 +1073,8 @@ def atenfft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int =
ErrorInvocation(NonZeroDTensorWithDtype(torch.float16)), ErrorInvocation(NonZeroDTensorWithDtype(torch.float16)),
ErrorInvocation(NonZeroDTensorWithDtype(torch.bfloat16)), ErrorInvocation(NonZeroDTensorWithDtype(torch.bfloat16)),
]) ])
def atenfft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: def atenfft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int:
self_rank, self_dtype = self_rank_dtype
if self_dtype == torch.complex64 or self_dtype == torch.complex128: if self_dtype == torch.complex64 or self_dtype == torch.complex128:
return self_dtype return self_dtype
elif self_dtype == torch.float: elif self_dtype == torch.float:

View File

@ -86,7 +86,7 @@ def _pytype_to_dtype_fn_pytype(pytype: str) -> str:
""" """
# Dtype functions only care about the rank and dtype of tensors. # Dtype functions only care about the rank and dtype of tensors.
if "Tensor" in pytype: if "Tensor" in pytype:
return pytype.replace("Tensor", "int") return pytype.replace("Tensor", "Tuple[int, int]")
return _pytype_to_fn_pytype_common(pytype) return _pytype_to_fn_pytype_common(pytype)
def _pytype_to_decomposition_fn_pytype(pytype: str) -> str: def _pytype_to_decomposition_fn_pytype(pytype: str) -> str:
@ -232,8 +232,7 @@ class JitOperator:
default = _get_default_value(arg) default = _get_default_value(arg)
parameter_name = _rename_python_keyword_parameter_name(arg["name"]) parameter_name = _rename_python_keyword_parameter_name(arg["name"])
if "Tensor" in arg["pytype"]: if "Tensor" in arg["pytype"]:
return ", ".join([f"{parameter_name}_rank: {pytype}{default}", return f"{parameter_name}_rank_dtype: {pytype}{default}"
f"{parameter_name}_dtype: {pytype}{default}"])
return f"{parameter_name}: {pytype}{default}" return f"{parameter_name}: {pytype}{default}"
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
@ -241,7 +240,7 @@ class JitOperator:
# results of type `number`. Here we handle this case because # results of type `number`. Here we handle this case because
# `_pytype_to_dtype_fn_pytype` will replace `number` with # `_pytype_to_dtype_fn_pytype` will replace `number` with
# `Union[int, float]`. # `Union[int, float]`.
if arg["pytype"] == "number": if arg["pytype"] in ["number", "Tensor"]:
return "int" return "int"
return _pytype_to_dtype_fn_pytype(arg["pytype"]) return _pytype_to_dtype_fn_pytype(arg["pytype"])

View File

@ -96,36 +96,6 @@ def _recursively_transform_tensor_args(
return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o) return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o)
raise Exception(f"Unhandled type {type(o)}") raise Exception(f"Unhandled type {type(o)}")
def _convert_to_dtype_function_args(arguments: Iterable[Any]) -> List[Any]:
"""Converts an Invocation argument to a dtype function argument.
TensorOfShape is replaced with two ints representing the rank
and dtype of the tensor, respectively.
"""
def contains_tensor(o: Any) -> bool:
if o is None or isinstance(o, (float, int)):
return False
if isinstance(o, TensorOfShape):
return True
if isinstance(o, (list, tuple)):
for elem in o:
if contains_tensor(elem):
return True
return False
raise Exception(f"Unhandled type {type(o)}")
result = []
for arg in arguments:
if contains_tensor(arg):
rank_arg = _recursively_transform_tensor_args(
arg, lambda x: len(x.shape))
dtype_arg = _recursively_transform_tensor_args(
arg, lambda x: x.dtype)
result += [rank_arg, dtype_arg]
else:
result.append(arg)
return result
class Invocation: class Invocation:
"""Representation of a single op invocation (i.e. list of args to the op). """Representation of a single op invocation (i.e. list of args to the op).
@ -135,8 +105,8 @@ class Invocation:
Specifically, this class has special knowledge of `TensorOfShape` and Specifically, this class has special knowledge of `TensorOfShape` and
translates it appropriately to either a tensor (for the real op), a translates it appropriately to either a tensor (for the real op), a
`List[int]` for the shape function, and two `int`s representing `List[int]` for the shape function, and a tuple with two `int`s
the tensor rank and dtype in the case of a dtype function. representing the tensor rank and dtype in the case of a dtype function.
This class also tracks whether the invocation is expected to raise an This class also tracks whether the invocation is expected to raise an
exception for greater precision when interpreting errors raised during exception for greater precision when interpreting errors raised during
@ -170,7 +140,9 @@ class Invocation:
def to_dtype_function_args(self): def to_dtype_function_args(self):
"""Gets positional arguments appropriate for a dtype function.""" """Gets positional arguments appropriate for a dtype function."""
return _convert_to_dtype_function_args(self.args) tensor_transformer = lambda o: (len(o.shape), o.dtype)
return _recursively_transform_tensor_args(
self.args, tensor_transformer)
def to_real_op_args(self): def to_real_op_args(self):
"""Gets positional arguments appropriate for the real op.""" """Gets positional arguments appropriate for the real op."""

View File

@ -12,7 +12,8 @@
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RANK:.*]] = torch.aten.len.t %[[SIZE]] : !torch.list<int> -> !torch.int // CHECK: %[[RANK:.*]] = torch.aten.len.t %[[SIZE]] : !torch.list<int> -> !torch.int
// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG]] : !torch.vtensor -> !torch.int // CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG]] : !torch.vtensor -> !torch.int
// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK]], %[[DTYPE]]) : (!torch.int, !torch.int) -> !torch.int // CHECK: %[[RANK_DTYPE:.*]] = torch.prim.TupleConstruct %[[RANK]], %[[DTYPE]] : !torch.int, !torch.int -> !torch.tuple<int, int>
// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK_DTYPE]]) : (!torch.tuple<int, int>) -> !torch.int
// CHECK: torch.dtype.calculate.yield.dtypes %[[RESULT_DTYPE]] : !torch.int // CHECK: torch.dtype.calculate.yield.dtypes %[[RESULT_DTYPE]] : !torch.int
// CHECK: } : !torch.vtensor // CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor
@ -38,6 +39,21 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor)
// ----- // -----
// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten._convolution.deprecated(
// CHECK-LABEL: func.func @op_with_optional_tensor_arg$none(
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[OPTIONAL_TUPLE:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tuple<int, int>>
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten._convolution.deprecated({{.*}}, %[[OPTIONAL_TUPLE]], {{.*}}) : ({{.*}}, !torch.optional<tuple<int, int>>, {{.*}}) -> !torch.int
func.func @op_with_optional_tensor_arg$none(%input: !torch.vtensor, %weight: !torch.vtensor, %stride: !torch.list<int>, %padding: !torch.list<int>, %dilation: !torch.list<int>, %transposed: !torch.bool, %output_padding: !torch.list<int>, %groups: !torch.int) -> !torch.vtensor {
%bias_none = torch.constant.none
%false = torch.constant.bool false
%0 = torch.aten._convolution.deprecated %input, %weight, %bias_none, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %false, %false, %false : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor
return %0 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide( // CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide(
// CHECK-LABEL: func.func @turn_tensors_into_rank_and_dtype_args( // CHECK-LABEL: func.func @turn_tensors_into_rank_and_dtype_args(
@ -46,10 +62,12 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor)
// CHECK: %[[SIZE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[SIZE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RANK0:.*]] = torch.aten.len.t %[[SIZE0]] : !torch.list<int> -> !torch.int // CHECK: %[[RANK0:.*]] = torch.aten.len.t %[[SIZE0]] : !torch.list<int> -> !torch.int
// CHECK: %[[DTYPE0:.*]] = torch.prim.dtype %[[ARG0]] : !torch.vtensor -> !torch.int // CHECK: %[[DTYPE0:.*]] = torch.prim.dtype %[[ARG0]] : !torch.vtensor -> !torch.int
// CHECK: %[[RANK_DTYPE0:.*]] = torch.prim.TupleConstruct %[[RANK0]], %[[DTYPE0]] : !torch.int, !torch.int -> !torch.tuple<int, int>
// CHECK: %[[SIZE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[SIZE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RANK1:.*]] = torch.aten.len.t %[[SIZE1]] : !torch.list<int> -> !torch.int // CHECK: %[[RANK1:.*]] = torch.aten.len.t %[[SIZE1]] : !torch.list<int> -> !torch.int
// CHECK: %[[DTYPE1:.*]] = torch.prim.dtype %[[ARG1]] : !torch.vtensor -> !torch.int // CHECK: %[[DTYPE1:.*]] = torch.prim.dtype %[[ARG1]] : !torch.vtensor -> !torch.int
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK0]], %[[DTYPE0]], %[[RANK1]], %[[DTYPE1]]) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.int // CHECK: %[[RANK_DTYPE1:.*]] = torch.prim.TupleConstruct %[[RANK1]], %[[DTYPE1]] : !torch.int, !torch.int -> !torch.tuple<int, int>
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK_DTYPE0]], %[[RANK_DTYPE1]]) : (!torch.tuple<int, int>, !torch.tuple<int, int>) -> !torch.int
func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor %0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor