mirror of https://github.com/llvm/torch-mlir
Change dtype functions interface to take ints tuple for each tensor (#1965)
The original design for the dtype functions outlined in https://github.com/llvm/torch-mlir/issues/1462 was unable to properly handle ops that take optional tensors as an input when the optional tensor has a value of None. By the time the op gets imported into torch-mlir, if an optional value is None, all information about the original type is lost from the op type signature, preventing torch-mlir from knowing if a value of None was from an optional tensor or not, which was crucial in the original design since each tensor argument must be turned into two separate arguments for the dtype function. This commit changes the interface to dtype functions such that each tensor turns into a tuple of two ints, the first representing the rank of the tensor and the second the dtype of the tensor. Since now there is a one-to-one correspondence between the operands of an op and the operands of its dtype function, there is no ambiguity about which operand of the op corresponds with which operand of the dtype function. To test the implementation, this commit defines dtype function for convolution op, which takes one optional tensor as an argument.pull/1984/head
parent
f2a05f2dc0
commit
eae3ff7f1c
|
@ -21,7 +21,7 @@ We will use the example of adding support for the `torch.aten.tanh` op.
|
|||
function signatures are:
|
||||
|
||||
- `def aten〇tanh〡shape(self: List[int]) -> List[int]:`
|
||||
- `def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:`
|
||||
- `def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:`
|
||||
|
||||
Note the use of `〇` as a separator since `.` or `::` aren't legal
|
||||
in a Python identifier.
|
||||
|
|
|
@ -5979,31 +5979,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||
" %5 = torch.aten.eq.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %5 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %5 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %5 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %0#1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
|
@ -6236,13 +6237,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %arg1, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list<optional<int>>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
|
@ -6972,11 +6974,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
|
@ -7162,6 +7166,36 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n"
|
||||
" %int10 = torch.constant.int 10\n"
|
||||
" %int9 = torch.constant.int 9\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.ListConstruct %int11, %int5, %int9, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %5 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %6 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %7 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %8 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -7639,7 +7673,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.int, %arg4: !torch.optional<str>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
|
@ -7654,68 +7688,69 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %true = torch.constant.bool true\n"
|
||||
" %int9 = torch.constant.int 9\n"
|
||||
" %0 = torch.prim.Uninitialized : !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %arg1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %4 = torch.aten.eq.int %arg1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||
" %5 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %5 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %1#1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %4 = torch.aten.eq.int %arg1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %5 = torch.prim.If %4 -> (!torch.int) {\n"
|
||||
" %5 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %6 = torch.prim.If %5 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int9 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %6 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
|
||||
" %7 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %8 = torch.prim.If %7 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int10 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %8 = torch.aten.eq.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %9 = torch.aten.eq.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %10 = torch.prim.If %9 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.eq.int %arg1, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" %16 = torch.aten.eq.int %1#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %16 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.eq.int %arg1, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" %16 = torch.aten.eq.int %1#1, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %16 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %12 = torch.prim.If %11 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.eq.int %arg1, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" %16 = torch.aten.eq.int %1#1, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %16 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.eq.int %arg1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" %16 = torch.aten.eq.int %1#1, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %16 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %14 = torch.prim.If %13 -> (!torch.int) {\n"
|
||||
" %14 = torch.prim.If %13 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %16 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %16 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %15 = torch.prim.If %14 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int9 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %14 : !torch.int\n"
|
||||
" torch.prim.If.yield %15 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %7 : !torch.int\n"
|
||||
" torch.prim.If.yield %8 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %5 : !torch.int\n"
|
||||
" torch.prim.If.yield %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
|
||||
|
|
|
@ -714,9 +714,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
|
||||
// Promote the two dtypes assuming non-zero rank.
|
||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
|
||||
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp,
|
||||
AtenMseLossOp>(op)) {
|
||||
Aten_ConvolutionOp, AtenMvOp, AtenConvolutionOverrideableOp,
|
||||
AtenConvTranspose2dInputOp, AtenMseLossOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
|
|
|
@ -19,55 +19,25 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
static bool isTensorTypeOrWrappedTensorType(Type type) {
|
||||
// Allowing tuples as arguments to dtype calculation functions can cause
|
||||
// issues. For example, if an argument is a tuple of tensors and ints, there
|
||||
// would be no way of differentiating the original ints from the ints created
|
||||
// to represent the dtype and rank of the tensors. Therefore, to avoid this
|
||||
// and keep things simple, the tuple type is not allowed. This works well in
|
||||
// practice, since PyTorch op signatures don't seem to take tuples as inputs.
|
||||
assert(!type.isa<Torch::TupleType>() &&
|
||||
"dtype calculation functions are expected to not have tuples of "
|
||||
"tensors as arguments");
|
||||
|
||||
if (type.isa<Torch::BaseTensorType>())
|
||||
return true;
|
||||
|
||||
if (auto optionalType = type.dyn_cast<Torch::OptionalType>()) {
|
||||
return isTensorTypeOrWrappedTensorType(optionalType.getContainedType());
|
||||
} else if (auto listType = type.dyn_cast<Torch::ListType>()) {
|
||||
return isTensorTypeOrWrappedTensorType(listType.getContainedType());
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Massage the op operands to match the dtype function signature.
|
||||
// The dtype function generally takes the same operands as the op, with a few
|
||||
// systematic modifications, such as replacing tensors with a rank and dtype
|
||||
// argument.
|
||||
// systematic modifications, such as replacing each tensor with a tuple of
|
||||
// its rank and dtype.
|
||||
static FailureOr<SmallVector<Value>>
|
||||
dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
||||
ValueRange originalOperands, func::FuncOp dtypeFunc) {
|
||||
// Turns a tensor operand into an operand representing the rank of the tensor
|
||||
auto rankArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
|
||||
Type desiredType) -> Value {
|
||||
if (desiredType.isa<Torch::IntType>() &&
|
||||
operand.getType().isa<Torch::BaseTensorType>()) {
|
||||
auto sizeListType =
|
||||
Torch::ListType::get(Torch::IntType::get(b.getContext()));
|
||||
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
|
||||
return b.create<AtenLenTOp>(loc, desiredType, size);
|
||||
}
|
||||
return operand;
|
||||
};
|
||||
|
||||
// Turns a tensor operand into an operand representing the dtype of the tensor
|
||||
// Turn every tensor into a tuple of (tensor_rank, tensor_dtype)
|
||||
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
|
||||
Type desiredType) -> Value {
|
||||
if (desiredType.isa<Torch::IntType>() &&
|
||||
if (desiredType.isa<Torch::TupleType>() &&
|
||||
operand.getType().isa<Torch::BaseTensorType>()) {
|
||||
return b.create<PrimDtypeOp>(loc, desiredType, operand);
|
||||
Type intType = Torch::IntType::get(b.getContext());
|
||||
Type sizeListType = Torch::ListType::get(intType);
|
||||
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
|
||||
Value rank = b.create<AtenLenTOp>(loc, intType, size);
|
||||
Value dtype = b.create<PrimDtypeOp>(loc, intType, operand);
|
||||
return b.create<PrimTupleConstructOp>(loc, desiredType,
|
||||
ArrayRef{rank, dtype});
|
||||
}
|
||||
return operand;
|
||||
};
|
||||
|
@ -79,26 +49,11 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
|||
"`dtypeFunc` should have at least one argument for each argument in "
|
||||
"`originalOperands`");
|
||||
Type desiredType = desiredTypes.front();
|
||||
if (isTensorTypeOrWrappedTensorType(operand.getType())) {
|
||||
assert(desiredTypes.size() >= 2 &&
|
||||
"`dtypeFunc` should have two arguments for each tensor argument "
|
||||
"in `originalOperands`");
|
||||
FailureOr<Value> rankArg, dtypeArg;
|
||||
if (failed(rankArg = adjustFunctionArg(b, loc, operand, desiredType,
|
||||
rankArgAdjuster)))
|
||||
return failure();
|
||||
desiredTypes = desiredTypes.drop_front();
|
||||
desiredType = desiredTypes.front();
|
||||
if (failed(dtypeArg = adjustFunctionArg(b, loc, operand, desiredType,
|
||||
dtypeArgAdjuster)))
|
||||
return failure();
|
||||
dtypeFuncArgs.append({*rankArg, *dtypeArg});
|
||||
} else {
|
||||
FailureOr<Value> otherArg;
|
||||
if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType)))
|
||||
return failure();
|
||||
dtypeFuncArgs.push_back(*otherArg);
|
||||
}
|
||||
FailureOr<Value> otherArg;
|
||||
if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType,
|
||||
dtypeArgAdjuster)))
|
||||
return failure();
|
||||
dtypeFuncArgs.push_back(*otherArg);
|
||||
desiredTypes = desiredTypes.drop_front();
|
||||
}
|
||||
|
||||
|
|
|
@ -89,7 +89,8 @@ def aten〇expm1〡shape(self: List[int]) -> List[int]:
|
|||
Invocation(ZeroDTensorWithDtype(torch.int32)),
|
||||
Invocation(ZeroDTensorWithDtype(torch.bool)),
|
||||
])
|
||||
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16:
|
||||
return self_dtype
|
||||
else:
|
||||
|
@ -280,7 +281,8 @@ def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1
|
|||
Invocation(ZeroDTensorWithDtype(torch.int64), other=0.0),
|
||||
Invocation(ZeroDTensorWithDtype(torch.float16), other=0.0)
|
||||
])
|
||||
def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
|
||||
def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
|
||||
|
||||
def aten〇leaky_relu〡shape(self: List[int], negative_slope: float = 0.01) -> List[int]:
|
||||
|
@ -679,7 +681,9 @@ def aten〇floor_divide〡shape(self: List[int], other: List[int]) -> List[int]:
|
|||
Invocation(ZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float64)),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)),
|
||||
])
|
||||
def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
|
||||
def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
other_rank, other_dtype = other_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, other_rank]
|
||||
dtypes = [self_dtype, other_dtype]
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
@ -819,6 +823,40 @@ def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optio
|
|||
def aten〇_convolution〇deprecated〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]:
|
||||
return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
|
||||
|
||||
_convolution_deprecated_kwargs = {
|
||||
"stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0],
|
||||
"groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False}
|
||||
@check_dtype_function(
|
||||
[Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Same type
|
||||
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
|
||||
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.int32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different type
|
||||
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
|
||||
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different width
|
||||
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
|
||||
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.int32), # Different type and width
|
||||
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
|
||||
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.complex64), TensorOfShape(1, 1, 1, 1, dtype=torch.float32),
|
||||
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
|
||||
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.complex128),
|
||||
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
|
||||
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32),
|
||||
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
|
||||
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool),
|
||||
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
|
||||
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32),
|
||||
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs),
|
||||
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16),
|
||||
TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs)
|
||||
])
|
||||
def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int:
|
||||
input_rank, input_dtype = input_rank_dtype
|
||||
weight_rank, weight_dtype = weight_rank_dtype
|
||||
assert input_dtype == weight_dtype
|
||||
assert input_dtype not in [torch.bool, torch.float16, torch.complex64, torch.complex128]
|
||||
ranks: List[Optional[int]] = [input_rank, weight_rank]
|
||||
dtypes = [input_dtype, weight_dtype]
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
def aten〇flip〡shape(self: List[int], dims: List[int]) -> List[int]:
|
||||
return self
|
||||
|
||||
|
@ -1035,7 +1073,8 @@ def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int =
|
|||
ErrorInvocation(NonZeroDTensorWithDtype(torch.float16)),
|
||||
ErrorInvocation(NonZeroDTensorWithDtype(torch.bfloat16)),
|
||||
])
|
||||
def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int:
|
||||
def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if self_dtype == torch.complex64 or self_dtype == torch.complex128:
|
||||
return self_dtype
|
||||
elif self_dtype == torch.float:
|
||||
|
|
|
@ -86,7 +86,7 @@ def _pytype_to_dtype_fn_pytype(pytype: str) -> str:
|
|||
"""
|
||||
# Dtype functions only care about the rank and dtype of tensors.
|
||||
if "Tensor" in pytype:
|
||||
return pytype.replace("Tensor", "int")
|
||||
return pytype.replace("Tensor", "Tuple[int, int]")
|
||||
return _pytype_to_fn_pytype_common(pytype)
|
||||
|
||||
def _pytype_to_decomposition_fn_pytype(pytype: str) -> str:
|
||||
|
@ -232,8 +232,7 @@ class JitOperator:
|
|||
default = _get_default_value(arg)
|
||||
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
|
||||
if "Tensor" in arg["pytype"]:
|
||||
return ", ".join([f"{parameter_name}_rank: {pytype}{default}",
|
||||
f"{parameter_name}_dtype: {pytype}{default}"])
|
||||
return f"{parameter_name}_rank_dtype: {pytype}{default}"
|
||||
return f"{parameter_name}: {pytype}{default}"
|
||||
|
||||
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
|
@ -241,7 +240,7 @@ class JitOperator:
|
|||
# results of type `number`. Here we handle this case because
|
||||
# `_pytype_to_dtype_fn_pytype` will replace `number` with
|
||||
# `Union[int, float]`.
|
||||
if arg["pytype"] == "number":
|
||||
if arg["pytype"] in ["number", "Tensor"]:
|
||||
return "int"
|
||||
return _pytype_to_dtype_fn_pytype(arg["pytype"])
|
||||
|
||||
|
|
|
@ -96,36 +96,6 @@ def _recursively_transform_tensor_args(
|
|||
return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o)
|
||||
raise Exception(f"Unhandled type {type(o)}")
|
||||
|
||||
def _convert_to_dtype_function_args(arguments: Iterable[Any]) -> List[Any]:
|
||||
"""Converts an Invocation argument to a dtype function argument.
|
||||
|
||||
TensorOfShape is replaced with two ints representing the rank
|
||||
and dtype of the tensor, respectively.
|
||||
"""
|
||||
def contains_tensor(o: Any) -> bool:
|
||||
if o is None or isinstance(o, (float, int)):
|
||||
return False
|
||||
if isinstance(o, TensorOfShape):
|
||||
return True
|
||||
if isinstance(o, (list, tuple)):
|
||||
for elem in o:
|
||||
if contains_tensor(elem):
|
||||
return True
|
||||
return False
|
||||
raise Exception(f"Unhandled type {type(o)}")
|
||||
|
||||
result = []
|
||||
for arg in arguments:
|
||||
if contains_tensor(arg):
|
||||
rank_arg = _recursively_transform_tensor_args(
|
||||
arg, lambda x: len(x.shape))
|
||||
dtype_arg = _recursively_transform_tensor_args(
|
||||
arg, lambda x: x.dtype)
|
||||
result += [rank_arg, dtype_arg]
|
||||
else:
|
||||
result.append(arg)
|
||||
return result
|
||||
|
||||
class Invocation:
|
||||
"""Representation of a single op invocation (i.e. list of args to the op).
|
||||
|
||||
|
@ -135,8 +105,8 @@ class Invocation:
|
|||
|
||||
Specifically, this class has special knowledge of `TensorOfShape` and
|
||||
translates it appropriately to either a tensor (for the real op), a
|
||||
`List[int]` for the shape function, and two `int`s representing
|
||||
the tensor rank and dtype in the case of a dtype function.
|
||||
`List[int]` for the shape function, and a tuple with two `int`s
|
||||
representing the tensor rank and dtype in the case of a dtype function.
|
||||
|
||||
This class also tracks whether the invocation is expected to raise an
|
||||
exception for greater precision when interpreting errors raised during
|
||||
|
@ -170,7 +140,9 @@ class Invocation:
|
|||
|
||||
def to_dtype_function_args(self):
|
||||
"""Gets positional arguments appropriate for a dtype function."""
|
||||
return _convert_to_dtype_function_args(self.args)
|
||||
tensor_transformer = lambda o: (len(o.shape), o.dtype)
|
||||
return _recursively_transform_tensor_args(
|
||||
self.args, tensor_transformer)
|
||||
|
||||
def to_real_op_args(self):
|
||||
"""Gets positional arguments appropriate for the real op."""
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<int>
|
||||
// CHECK: %[[RANK:.*]] = torch.aten.len.t %[[SIZE]] : !torch.list<int> -> !torch.int
|
||||
// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG]] : !torch.vtensor -> !torch.int
|
||||
// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK]], %[[DTYPE]]) : (!torch.int, !torch.int) -> !torch.int
|
||||
// CHECK: %[[RANK_DTYPE:.*]] = torch.prim.TupleConstruct %[[RANK]], %[[DTYPE]] : !torch.int, !torch.int -> !torch.tuple<int, int>
|
||||
// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK_DTYPE]]) : (!torch.tuple<int, int>) -> !torch.int
|
||||
// CHECK: torch.dtype.calculate.yield.dtypes %[[RESULT_DTYPE]] : !torch.int
|
||||
// CHECK: } : !torch.vtensor
|
||||
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
|
||||
|
@ -38,6 +39,21 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor)
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten._convolution.deprecated(
|
||||
|
||||
// CHECK-LABEL: func.func @op_with_optional_tensor_arg$none(
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[OPTIONAL_TUPLE:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tuple<int, int>>
|
||||
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten._convolution.deprecated({{.*}}, %[[OPTIONAL_TUPLE]], {{.*}}) : ({{.*}}, !torch.optional<tuple<int, int>>, {{.*}}) -> !torch.int
|
||||
func.func @op_with_optional_tensor_arg$none(%input: !torch.vtensor, %weight: !torch.vtensor, %stride: !torch.list<int>, %padding: !torch.list<int>, %dilation: !torch.list<int>, %transposed: !torch.bool, %output_padding: !torch.list<int>, %groups: !torch.int) -> !torch.vtensor {
|
||||
%bias_none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten._convolution.deprecated %input, %weight, %bias_none, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %false, %false, %false : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide(
|
||||
|
||||
// CHECK-LABEL: func.func @turn_tensors_into_rank_and_dtype_args(
|
||||
|
@ -46,10 +62,12 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor)
|
|||
// CHECK: %[[SIZE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
|
||||
// CHECK: %[[RANK0:.*]] = torch.aten.len.t %[[SIZE0]] : !torch.list<int> -> !torch.int
|
||||
// CHECK: %[[DTYPE0:.*]] = torch.prim.dtype %[[ARG0]] : !torch.vtensor -> !torch.int
|
||||
// CHECK: %[[RANK_DTYPE0:.*]] = torch.prim.TupleConstruct %[[RANK0]], %[[DTYPE0]] : !torch.int, !torch.int -> !torch.tuple<int, int>
|
||||
// CHECK: %[[SIZE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int>
|
||||
// CHECK: %[[RANK1:.*]] = torch.aten.len.t %[[SIZE1]] : !torch.list<int> -> !torch.int
|
||||
// CHECK: %[[DTYPE1:.*]] = torch.prim.dtype %[[ARG1]] : !torch.vtensor -> !torch.int
|
||||
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK0]], %[[DTYPE0]], %[[RANK1]], %[[DTYPE1]]) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.int
|
||||
// CHECK: %[[RANK_DTYPE1:.*]] = torch.prim.TupleConstruct %[[RANK1]], %[[DTYPE1]] : !torch.int, !torch.int -> !torch.tuple<int, int>
|
||||
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK_DTYPE0]], %[[RANK_DTYPE1]]) : (!torch.tuple<int, int>, !torch.tuple<int, int>) -> !torch.int
|
||||
func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
|
|
Loading…
Reference in New Issue