mirror of https://github.com/llvm/torch-mlir
fixup! Review changes
parent
3f603470a9
commit
de85278f1f
|
@ -363,10 +363,14 @@ public:
|
||||||
Value hDimSize = inputShape[hDim];
|
Value hDimSize = inputShape[hDim];
|
||||||
Value vDimSize = inputShape[vDim];
|
Value vDimSize = inputShape[vDim];
|
||||||
|
|
||||||
assert(getHPadArgument(LEFT) < hDimSize && "Left padding too large");
|
assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] &&
|
||||||
assert(getHPadArgument(RIGHT) < hDimSize && "Right padding too large");
|
"Left padding too large");
|
||||||
assert(getVPadArgument(TOP) < vDimSize && "Top padding too large");
|
assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] &&
|
||||||
assert(getVPadArgument(BOTTOM) < vDimSize && "Bottom padding too large");
|
"Right padding too large");
|
||||||
|
assert(getVPadArgument(TOP) < inputType.getShape()[vDim] &&
|
||||||
|
"Top padding too large");
|
||||||
|
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
|
||||||
|
"Bottom padding too large");
|
||||||
|
|
||||||
Type indexType = rewriter.getIndexType();
|
Type indexType = rewriter.getIndexType();
|
||||||
Value zero = getConstant(rewriter, loc, 0, indexType);
|
Value zero = getConstant(rewriter, loc, 0, indexType);
|
||||||
|
|
|
@ -8368,12 +8368,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %false = torch.constant.bool false\n"
|
" %false = torch.constant.bool false\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n"
|
||||||
" %int-1 = torch.constant.int -1\n"
|
" %int-1 = torch.constant.int -1\n"
|
||||||
" %int-2 = torch.constant.int -2\n"
|
" %int-2 = torch.constant.int -2\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
" %str_0 = torch.constant.str \"AssertionError: \"\n"
|
||||||
" %int2 = torch.constant.int 2\n"
|
" %int2 = torch.constant.int 2\n"
|
||||||
" %int1 = torch.constant.int 1\n"
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
" %int0 = torch.constant.int 0\n"
|
" %int0 = torch.constant.int 0\n"
|
||||||
" %int3 = torch.constant.int 3\n"
|
" %int3 = torch.constant.int 3\n"
|
||||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||||
|
@ -8381,43 +8383,51 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" torch.prim.If %1 -> () {\n"
|
" torch.prim.If %1 -> () {\n"
|
||||||
" torch.prim.If.yield\n"
|
" torch.prim.If.yield\n"
|
||||||
" } else {\n"
|
" } else {\n"
|
||||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||||
" torch.prim.If.yield\n"
|
" torch.prim.If.yield\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
" %4 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||||
" %5 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
" %5 = torch.aten.eq.int %4, %int4 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
" %6 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
" torch.prim.If %5 -> () {\n"
|
||||||
" %7 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
|
||||||
" %8 = torch.aten.lt.int %4, %3 : !torch.int, !torch.int -> !torch.bool\n"
|
|
||||||
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
|
|
||||||
" %13 = torch.aten.lt.int %5, %3 : !torch.int, !torch.int -> !torch.bool\n"
|
|
||||||
" torch.prim.If.yield %13 : !torch.bool\n"
|
|
||||||
" } else {\n"
|
|
||||||
" torch.prim.If.yield %false : !torch.bool\n"
|
|
||||||
" }\n"
|
|
||||||
" torch.prim.If %9 -> () {\n"
|
|
||||||
" torch.prim.If.yield\n"
|
" torch.prim.If.yield\n"
|
||||||
" } else {\n"
|
" } else {\n"
|
||||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
" torch.prim.If.yield\n"
|
" torch.prim.If.yield\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" %10 = torch.aten.lt.int %6, %2 : !torch.int, !torch.int -> !torch.bool\n"
|
" %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %8 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %9 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %10 = torch.aten.lt.int %6, %3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
|
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
|
||||||
" %13 = torch.aten.lt.int %7, %2 : !torch.int, !torch.int -> !torch.bool\n"
|
" %15 = torch.aten.lt.int %7, %3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
" torch.prim.If.yield %13 : !torch.bool\n"
|
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||||
" } else {\n"
|
" } else {\n"
|
||||||
" torch.prim.If.yield %false : !torch.bool\n"
|
" torch.prim.If.yield %false : !torch.bool\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" torch.prim.If %11 -> () {\n"
|
" torch.prim.If %11 -> () {\n"
|
||||||
" torch.prim.If.yield\n"
|
" torch.prim.If.yield\n"
|
||||||
" } else {\n"
|
" } else {\n"
|
||||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||||
" torch.prim.If.yield\n"
|
" torch.prim.If.yield\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" %12 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %12 = torch.aten.lt.int %8, %2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
" return %12 : !torch.list<int>\n"
|
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
|
||||||
|
" %15 = torch.aten.lt.int %9, %2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %false : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If %13 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %14 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %14 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>\n"
|
||||||
|
@ -9056,18 +9066,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||||
" %none = torch.constant.none\n"
|
|
||||||
" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n"
|
|
||||||
" %int4 = torch.constant.int 4\n"
|
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
|
||||||
" %2 = torch.aten.eq.int %1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
|
|
||||||
" torch.prim.If %2 -> () {\n"
|
|
||||||
" torch.prim.If.yield\n"
|
|
||||||
" } else {\n"
|
|
||||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
|
||||||
" torch.prim.If.yield\n"
|
|
||||||
" }\n"
|
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
|
|
@ -1290,6 +1290,8 @@ def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List
|
||||||
# Padding size must be smaller than corresponding dimension
|
# Padding size must be smaller than corresponding dimension
|
||||||
@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,2,1,1]),
|
@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,2,1,1]),
|
||||||
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,1]),
|
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,1]),
|
||||||
|
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,3]),
|
||||||
|
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1]),
|
||||||
Invocation(TensorOfShape(2, 2, 2), padding=[1,1,1,1]),
|
Invocation(TensorOfShape(2, 2, 2), padding=[1,1,1,1]),
|
||||||
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,1]),
|
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,1]),
|
||||||
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])])
|
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])])
|
||||||
|
@ -1853,13 +1855,7 @@ def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: L
|
||||||
assert len(padding) == 2, 'padding size expected to be 2'
|
assert len(padding) == 2, 'padding size expected to be 2'
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1),
|
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4, 2, 2)], padding=[1,1,1,1]))
|
||||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]),
|
|
||||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]),
|
|
||||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2,1]),
|
|
||||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[3,2,1]),
|
|
||||||
Invocation(TensorOfShape(5, 5, 4), padding=[1,2,3,4]),
|
|
||||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[5,4,3,2,1])])
|
|
||||||
def aten〇reflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
|
def aten〇reflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
Loading…
Reference in New Issue