mirror of https://github.com/llvm/torch-mlir
[Torch] emit upsample_nearest1d/2d/vec, and add shape/dtype functions (#3629)
parent
a4ba02eef5
commit
c5b3cf299a
|
@ -13582,6 +13582,56 @@ def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUpsampleNearest1dOp : Torch_Op<"aten.upsample_nearest1d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::upsample_nearest1d : (Tensor, int[], float?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$output_size,
|
||||
AnyTorchOptionalFloatType:$scales
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenUpsampleNearest1dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenUpsampleNearest1dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUpsampleNearest1dVecOp : Torch_Op<"aten.upsample_nearest1d.vec", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchOptionalListOfTorchIntType:$output_size,
|
||||
AnyTorchOptionalListOfTorchFloatType:$scale_factors
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenUpsampleNearest1dVecOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenUpsampleNearest1dVecOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUpsampleNearest2dOp : Torch_Op<"aten.upsample_nearest2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -13608,6 +13658,31 @@ def Torch_AtenUpsampleNearest2dOp : Torch_Op<"aten.upsample_nearest2d", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchOptionalListOfTorchIntType:$output_size,
|
||||
AnyTorchOptionalListOfTorchFloatType:$scale_factors
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenUpsampleNearest2dVecOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenUpsampleNearest2dVecOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -10727,6 +10727,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg3, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>) -> !torch.list<int> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0, %1, %2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %3 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest1d.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>) -> !torch.list<int> {\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %0 = torch.prim.Uninitialized : !torch.list<float>\n"
|
||||
" %1 = torch.prim.Uninitialized : !torch.optional<list<int>>\n"
|
||||
" %2 = torch.prim.Uninitialized : !torch.optional<list<float>>\n"
|
||||
" %3 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" %12 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %5:2 = torch.prim.If %4 -> (!torch.optional<list<int>>, !torch.optional<list<float>>) {\n"
|
||||
" torch.prim.If.yield %arg1, %arg2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield %1, %2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
|
||||
" }\n"
|
||||
" %6 = torch.aten.__is__ %5#0, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||
" %7 = torch.prim.If %6 -> (!torch.bool) {\n"
|
||||
" %11 = torch.aten.__is__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %8 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %9 = torch.aten.__isnot__ %5#0, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||
" %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
|
||||
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" %12 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.aten.__getitem__.t %11, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.prim.ListConstruct %12, %13, %14 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %15 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.aten.__isnot__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" %12 = torch.prim.If %11 -> (!torch.list<float>) {\n"
|
||||
" %20 = torch.prim.unchecked_cast %5#1 : !torch.optional<list<float>> -> !torch.list<float>\n"
|
||||
" torch.prim.If.yield %20 : !torch.list<float>\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield %0 : !torch.list<float>\n"
|
||||
" }\n"
|
||||
" %13 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
|
||||
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
|
||||
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
|
||||
" %19 = torch.prim.ListConstruct %13, %14, %18 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %19 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %10 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
|
@ -10737,6 +10814,80 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %4 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>) -> !torch.list<int> {\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %0 = torch.prim.Uninitialized : !torch.list<float>\n"
|
||||
" %1 = torch.prim.Uninitialized : !torch.optional<list<int>>\n"
|
||||
" %2 = torch.prim.Uninitialized : !torch.optional<list<float>>\n"
|
||||
" %3 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" %12 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %5:2 = torch.prim.If %4 -> (!torch.optional<list<int>>, !torch.optional<list<float>>) {\n"
|
||||
" torch.prim.If.yield %arg1, %arg2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield %1, %2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
|
||||
" }\n"
|
||||
" %6 = torch.aten.__is__ %5#0, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||
" %7 = torch.prim.If %6 -> (!torch.bool) {\n"
|
||||
" %11 = torch.aten.__is__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %8 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %9 = torch.aten.__isnot__ %5#0, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||
" %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
|
||||
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" %12 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.aten.__getitem__.t %11, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.aten.__getitem__.t %11, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.prim.ListConstruct %12, %13, %14, %15 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %16 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.aten.__isnot__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" %12 = torch.prim.If %11 -> (!torch.list<float>) {\n"
|
||||
" %24 = torch.prim.unchecked_cast %5#1 : !torch.optional<list<float>> -> !torch.list<float>\n"
|
||||
" torch.prim.If.yield %24 : !torch.list<float>\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield %0 : !torch.list<float>\n"
|
||||
" }\n"
|
||||
" %13 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
|
||||
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
|
||||
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
|
||||
" %19 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %20 = torch.aten.__getitem__.t %12, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
|
||||
" %21 = torch.operator \"aten.mul.int_float\"(%19, %20) : (!torch.int, !torch.float) -> !torch.float \n"
|
||||
" %22 = torch.aten.Int.float %21 : !torch.float -> !torch.int\n"
|
||||
" %23 = torch.prim.ListConstruct %13, %14, %18, %22 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %23 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %10 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
@ -12117,10 +12268,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest1d.vec\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d.vec\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -2225,9 +2225,46 @@ def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]:
|
|||
def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(1, 3, 10), [11])
|
||||
])
|
||||
def aten〇upsample_nearest1d〡shape(self: List[int], output_size: List[int], scales: Optional[float] = None) -> List[int]:
|
||||
return [self[0], self[1], output_size[0]]
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(1, 3, 10), [11], None),
|
||||
Invocation(TensorOfShape(1, 3, 10), None, [2.0]),
|
||||
Invocation(TensorOfShape(1, 3, 5), None, [2.5])
|
||||
])
|
||||
def aten〇upsample_nearest1d〇vec〡shape(input: List[int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> List[int]:
|
||||
assert output_size is None or scale_factors is None
|
||||
assert not (output_size is None and scale_factors is None)
|
||||
if output_size is not None:
|
||||
return [input[0], input[1], output_size[0]]
|
||||
else:
|
||||
assert scale_factors is not None
|
||||
return [input[0], input[1], int(input[2] * scale_factors[0])]
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(1, 3, 10, 10), [11, 12])
|
||||
])
|
||||
def aten〇upsample_nearest2d〡shape(self: List[int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
|
||||
return [self[0], self[1], output_size[0], output_size[1]]
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], None),
|
||||
Invocation(TensorOfShape(1, 3, 10, 9), None, [2.0, 2.3]),
|
||||
Invocation(TensorOfShape(1, 3, 5, 6), None, [2.5, 1.0])
|
||||
])
|
||||
def aten〇upsample_nearest2d〇vec〡shape(input: List[int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> List[int]:
|
||||
assert output_size is None or scale_factors is None
|
||||
assert not (output_size is None and scale_factors is None)
|
||||
if output_size is not None:
|
||||
return [input[0], input[1], output_size[0], output_size[1]]
|
||||
else:
|
||||
assert scale_factors is not None
|
||||
return [input[0], input[1], int(input[2] * scale_factors[0]), int(input[3] * scale_factors[1])]
|
||||
|
||||
# ==============================================================================
|
||||
# Dtype Functions
|
||||
# ==============================================================================
|
||||
|
@ -3380,11 +3417,26 @@ def aten〇upsample_nearest2d_backward〡dtype(grad_output_rank_dtype: Tuple[int
|
|||
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
|
||||
return grad_output_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], output_size=[11]))
|
||||
def aten〇upsample_nearest1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], scales: Optional[float] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], output_size=[11], scale_factors=None))
|
||||
def aten〇upsample_nearest1d〇vec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> int:
|
||||
self_rank, self_dtype = input_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13]))
|
||||
def aten〇upsample_nearest2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], scale_factors=None))
|
||||
def aten〇upsample_nearest2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> int:
|
||||
self_rank, self_dtype = input_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]))
|
||||
def aten〇view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -981,7 +981,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
|
||||
emit("aten::upsample_nearest1d : (Tensor, int[], float?) -> (Tensor)")
|
||||
emit("aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
|
||||
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
|
||||
emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
|
||||
emit(
|
||||
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue