fixup! Review changes

pre_fixup_20240110
Frederik Harwath 2024-01-10 01:25:52 -08:00 committed by Frederik Harwath
parent 3f603470a9
commit de85278f1f
3 changed files with 41 additions and 42 deletions

View File

@ -363,10 +363,14 @@ public:
Value hDimSize = inputShape[hDim];
Value vDimSize = inputShape[vDim];
assert(getHPadArgument(LEFT) < hDimSize && "Left padding too large");
assert(getHPadArgument(RIGHT) < hDimSize && "Right padding too large");
assert(getVPadArgument(TOP) < vDimSize && "Top padding too large");
assert(getVPadArgument(BOTTOM) < vDimSize && "Bottom padding too large");
assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] &&
"Left padding too large");
assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] &&
"Right padding too large");
assert(getVPadArgument(TOP) < inputType.getShape()[vDim] &&
"Top padding too large");
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
"Bottom padding too large");
Type indexType = rewriter.getIndexType();
Value zero = getConstant(rewriter, loc, 0, indexType);

View File

@ -8368,12 +8368,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n"
" %int-1 = torch.constant.int -1\n"
" %int-2 = torch.constant.int -2\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %str_0 = torch.constant.str \"AssertionError: \"\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %int4 = torch.constant.int 4\n"
" %int0 = torch.constant.int 0\n"
" %int3 = torch.constant.int 3\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
@ -8381,43 +8383,51 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %5 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %6 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %8 = torch.aten.lt.int %4, %3 : !torch.int, !torch.int -> !torch.bool\n"
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
" %13 = torch.aten.lt.int %5, %3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %9 -> () {\n"
" %4 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %5 = torch.aten.eq.int %4, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %10 = torch.aten.lt.int %6, %2 : !torch.int, !torch.int -> !torch.bool\n"
" %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %8 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %9 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %10 = torch.aten.lt.int %6, %3 : !torch.int, !torch.int -> !torch.bool\n"
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
" %13 = torch.aten.lt.int %7, %2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" %15 = torch.aten.lt.int %7, %3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %11 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %12 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %12 : !torch.list<int>\n"
" %12 = torch.aten.lt.int %8, %2 : !torch.int, !torch.int -> !torch.bool\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" %15 = torch.aten.lt.int %9, %2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %13 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %14 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %14 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {\n"
" %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>\n"
@ -9056,18 +9066,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n"
" %int4 = torch.constant.int 4\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %2 = torch.aten.eq.int %1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"

View File

@ -1290,6 +1290,8 @@ def atenreflection_pad1d〡shape(self: List[int], padding: List[int]) -> List
# Padding size must be smaller than corresponding dimension
@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,2,1,1]),
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,1]),
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,3]),
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1]),
Invocation(TensorOfShape(2, 2, 2), padding=[1,1,1,1]),
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,1]),
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])])
@ -1853,13 +1855,7 @@ def atenreflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: L
assert len(padding) == 2, 'padding size expected to be 2'
return self_dtype
@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2,1]),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[3,2,1]),
Invocation(TensorOfShape(5, 5, 4), padding=[1,2,3,4]),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[5,4,3,2,1])])
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4, 2, 2)], padding=[1,1,1,1]))
def atenreflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype