[Torch] emit upsample_nearest1d/2d/vec, and add shape/dtype functions (#3629)

pull/3660/head
Yuanqiang Liu 2024-08-13 19:14:24 +08:00 committed by GitHub
parent a4ba02eef5
commit c5b3cf299a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 293 additions and 0 deletions

View File

@ -13582,6 +13582,56 @@ def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [
}];
}
def Torch_AtenUpsampleNearest1dOp : Torch_Op<"aten.upsample_nearest1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_nearest1d : (Tensor, int[], float?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size,
AnyTorchOptionalFloatType:$scales
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleNearest1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenUpsampleNearest1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenUpsampleNearest1dVecOp : Torch_Op<"aten.upsample_nearest1d.vec", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalListOfTorchIntType:$output_size,
AnyTorchOptionalListOfTorchFloatType:$scale_factors
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleNearest1dVecOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenUpsampleNearest1dVecOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenUpsampleNearest2dOp : Torch_Op<"aten.upsample_nearest2d", [
AllowsTypeRefinement,
HasValueSemantics,
@ -13608,6 +13658,31 @@ def Torch_AtenUpsampleNearest2dOp : Torch_Op<"aten.upsample_nearest2d", [
}];
}
def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalListOfTorchIntType:$output_size,
AnyTorchOptionalListOfTorchFloatType:$scale_factors
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleNearest2dVecOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenUpsampleNearest2dVecOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -10727,6 +10727,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg3, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0, %1, %2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %3 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest1d.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %0 = torch.prim.Uninitialized : !torch.list<float>\n"
" %1 = torch.prim.Uninitialized : !torch.optional<list<int>>\n"
" %2 = torch.prim.Uninitialized : !torch.optional<list<float>>\n"
" %3 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %12 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" }\n"
" %5:2 = torch.prim.If %4 -> (!torch.optional<list<int>>, !torch.optional<list<float>>) {\n"
" torch.prim.If.yield %arg1, %arg2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %1, %2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
" }\n"
" %6 = torch.aten.__is__ %5#0, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.bool) {\n"
" %11 = torch.aten.__is__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" } else {\n"
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %9 = torch.aten.__isnot__ %5#0, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %12 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.aten.__getitem__.t %11, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.prim.ListConstruct %12, %13, %14 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %15 : !torch.list<int>\n"
" } else {\n"
" %11 = torch.aten.__isnot__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" %12 = torch.prim.If %11 -> (!torch.list<float>) {\n"
" %20 = torch.prim.unchecked_cast %5#1 : !torch.optional<list<float>> -> !torch.list<float>\n"
" torch.prim.If.yield %20 : !torch.list<float>\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %0 : !torch.list<float>\n"
" }\n"
" %13 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
" %19 = torch.prim.ListConstruct %13, %14, %18 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %19 : !torch.list<int>\n"
" }\n"
" return %10 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
@ -10737,6 +10814,80 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %int3 = torch.constant.int 3\n"
" %0 = torch.prim.Uninitialized : !torch.list<float>\n"
" %1 = torch.prim.Uninitialized : !torch.optional<list<int>>\n"
" %2 = torch.prim.Uninitialized : !torch.optional<list<float>>\n"
" %3 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %12 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" }\n"
" %5:2 = torch.prim.If %4 -> (!torch.optional<list<int>>, !torch.optional<list<float>>) {\n"
" torch.prim.If.yield %arg1, %arg2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %1, %2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
" }\n"
" %6 = torch.aten.__is__ %5#0, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.bool) {\n"
" %11 = torch.aten.__is__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" } else {\n"
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %9 = torch.aten.__isnot__ %5#0, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %12 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.aten.__getitem__.t %11, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.__getitem__.t %11, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.prim.ListConstruct %12, %13, %14, %15 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %16 : !torch.list<int>\n"
" } else {\n"
" %11 = torch.aten.__isnot__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" %12 = torch.prim.If %11 -> (!torch.list<float>) {\n"
" %24 = torch.prim.unchecked_cast %5#1 : !torch.optional<list<float>> -> !torch.list<float>\n"
" torch.prim.If.yield %24 : !torch.list<float>\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %0 : !torch.list<float>\n"
" }\n"
" %13 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
" %19 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %20 = torch.aten.__getitem__.t %12, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
" %21 = torch.operator \"aten.mul.int_float\"(%19, %20) : (!torch.int, !torch.float) -> !torch.float \n"
" %22 = torch.aten.Int.float %21 : !torch.float -> !torch.int\n"
" %23 = torch.prim.ListConstruct %13, %14, %18, %22 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %23 : !torch.list<int>\n"
" }\n"
" return %10 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
@ -12117,10 +12268,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest1d.vec\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d.vec\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -2225,9 +2225,46 @@ def atennormScalar〡shape(self: List[int], p: float = 2) -> List[int]:
def atennormScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
@check_shape_function([
Invocation(TensorOfShape(1, 3, 10), [11])
])
def atenupsample_nearest1d〡shape(self: List[int], output_size: List[int], scales: Optional[float] = None) -> List[int]:
return [self[0], self[1], output_size[0]]
@check_shape_function([
Invocation(TensorOfShape(1, 3, 10), [11], None),
Invocation(TensorOfShape(1, 3, 10), None, [2.0]),
Invocation(TensorOfShape(1, 3, 5), None, [2.5])
])
def atenupsample_nearest1dvec〡shape(input: List[int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> List[int]:
assert output_size is None or scale_factors is None
assert not (output_size is None and scale_factors is None)
if output_size is not None:
return [input[0], input[1], output_size[0]]
else:
assert scale_factors is not None
return [input[0], input[1], int(input[2] * scale_factors[0])]
@check_shape_function([
Invocation(TensorOfShape(1, 3, 10, 10), [11, 12])
])
def atenupsample_nearest2d〡shape(self: List[int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
return [self[0], self[1], output_size[0], output_size[1]]
@check_shape_function([
Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], None),
Invocation(TensorOfShape(1, 3, 10, 9), None, [2.0, 2.3]),
Invocation(TensorOfShape(1, 3, 5, 6), None, [2.5, 1.0])
])
def atenupsample_nearest2dvec〡shape(input: List[int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> List[int]:
assert output_size is None or scale_factors is None
assert not (output_size is None and scale_factors is None)
if output_size is not None:
return [input[0], input[1], output_size[0], output_size[1]]
else:
assert scale_factors is not None
return [input[0], input[1], int(input[2] * scale_factors[0]), int(input[3] * scale_factors[1])]
# ==============================================================================
# Dtype Functions
# ==============================================================================
@ -3380,11 +3417,26 @@ def atenupsample_nearest2d_backward〡dtype(grad_output_rank_dtype: Tuple[int
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
return grad_output_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], output_size=[11]))
def atenupsample_nearest1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], scales: Optional[float] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], output_size=[11], scale_factors=None))
def atenupsample_nearest1dvec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> int:
self_rank, self_dtype = input_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13]))
def atenupsample_nearest2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], scale_factors=None))
def atenupsample_nearest2dvec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> int:
self_rank, self_dtype = input_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]))
def atenview〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -981,7 +981,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::upsample_nearest1d : (Tensor, int[], float?) -> (Tensor)")
emit("aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
emit(
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)"
)