emit aten.upsample_bilinear2d

byteir
Yuanqiang Liu 2024-09-14 16:09:21 +08:00
parent 7ecad699a3
commit 208e5fac64
4 changed files with 101 additions and 0 deletions

View File

@ -13754,6 +13754,59 @@ def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [
}];
}
def Torch_AtenUpsampleBilinear2dOp : Torch_Op<"aten.upsample_bilinear2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size,
Torch_BoolType:$align_corners,
AnyTorchOptionalFloatType:$scales_h,
AnyTorchOptionalFloatType:$scales_w
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleBilinear2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenUpsampleBilinear2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenUpsampleBilinear2dVecOp : Torch_Op<"aten.upsample_bilinear2d.vec", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalListOfTorchIntType:$output_size,
Torch_BoolType:$align_corners,
AnyTorchOptionalListOfTorchFloatType:$scale_factors
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleBilinear2dVecOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenUpsampleBilinear2dVecOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -10917,6 +10917,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %10 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<list<float>>) -> !torch.list<int> {\n"
" %0 = call @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0, %arg1, %arg3) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<float>>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
@ -12338,6 +12352,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<list<float>>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -2280,6 +2280,20 @@ def atenupsample_nearest2dvec〡shape(input: List[int], output_size: Optio
assert scale_factors is not None
return [input[0], input[1], int(input[2] * scale_factors[0]), int(input[3] * scale_factors[1])]
@check_shape_function([
Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True)
])
def atenupsample_bilinear2d〡shape(self: List[int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
return [self[0], self[1], output_size[0], output_size[1]]
@check_shape_function([
Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True, None),
Invocation(TensorOfShape(1, 3, 10, 9), None, True, [2.0, 2.3]),
Invocation(TensorOfShape(1, 3, 5, 6), None, True, [2.5, 1.0])
])
def atenupsample_bilinear2dvec〡shape(input: List[int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> List[int]:
return atenupsample_nearest2dvec〡shape(input, output_size, scale_factors)
# ==============================================================================
# Dtype Functions
# ==============================================================================
@ -3462,6 +3476,16 @@ def atenupsample_nearest2dvec〡dtype(input_rank_dtype: Tuple[int, int], o
self_rank, self_dtype = input_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True))
def atenupsample_bilinear2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True, scale_factors=None))
def atenupsample_bilinear2dvec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> int:
self_rank, self_dtype = input_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]))
def atenview〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -988,6 +988,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
emit("aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)")
emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)")
emit(
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)"
)