diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 530f3154d..72cdfd61b 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -363,10 +363,14 @@ public: Value hDimSize = inputShape[hDim]; Value vDimSize = inputShape[vDim]; - assert(getHPadArgument(LEFT) < hDimSize && "Left padding too large"); - assert(getHPadArgument(RIGHT) < hDimSize && "Right padding too large"); - assert(getVPadArgument(TOP) < vDimSize && "Top padding too large"); - assert(getVPadArgument(BOTTOM) < vDimSize && "Bottom padding too large"); + assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && + "Left padding too large"); + assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && + "Right padding too large"); + assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && + "Top padding too large"); + assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && + "Bottom padding too large"); Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 820e352c3..55b9638dd 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8368,12 +8368,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" " %int-1 = torch.constant.int -1\n" " %int-2 = torch.constant.int -2\n" " %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" " %int2 = torch.constant.int 2\n" " %int1 = torch.constant.int 1\n" +" %int4 = torch.constant.int 4\n" " %int0 = torch.constant.int 0\n" " %int3 = torch.constant.int 3\n" " %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" @@ -8381,43 +8383,51 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" " %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" " %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %5 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %6 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %7 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %8 = torch.aten.lt.int %4, %3 : !torch.int, !torch.int -> !torch.bool\n" -" %9 = torch.prim.If %8 -> (!torch.bool) {\n" -" %13 = torch.aten.lt.int %5, %3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %9 -> () {\n" +" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %10 = torch.aten.lt.int %6, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.lt.int %6, %3 : !torch.int, !torch.int -> !torch.bool\n" " %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" %13 = torch.aten.lt.int %7, %2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %15 = torch.aten.lt.int %7, %3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" " torch.prim.If %11 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %12 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" -" return %12 : !torch.list\n" +" %12 = torch.aten.lt.int %8, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.aten.lt.int %9, %2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %14 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" " %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list\n" @@ -9056,18 +9066,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" -" %int4 = torch.constant.int 4\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %2 = torch.aten.eq.int %1, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index ae190ea20..a16d778c7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1290,6 +1290,8 @@ def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List # Padding size must be smaller than corresponding dimension @check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,2,1,1]), ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,3]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1]), Invocation(TensorOfShape(2, 2, 2), padding=[1,1,1,1]), ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,1]), ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])]) @@ -1853,13 +1855,7 @@ def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: L assert len(padding) == 2, 'padding size expected to be 2' return self_dtype -@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1), - ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]), - ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]), - ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2,1]), - ErrorInvocation(TensorOfShape(2, 3, 4), padding=[3,2,1]), - Invocation(TensorOfShape(5, 5, 4), padding=[1,2,3,4]), - ErrorInvocation(TensorOfShape(2, 3, 4), padding=[5,4,3,2,1])]) +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4, 2, 2)], padding=[1,1,1,1])) def aten〇reflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype