[Torch] emit upsample_bilinear2d(.vec) ops (#3834)

pull/3837/head
Yuanqiang Liu 2024-10-30 20:18:24 +08:00 committed by GitHub
parent 2b01f8b7f3
commit 9ab2a150f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 109 additions and 0 deletions

View File

@ -14093,6 +14093,59 @@ def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [
}];
}
def Torch_AtenUpsampleBilinear2dOp : Torch_Op<"aten.upsample_bilinear2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size,
Torch_BoolType:$align_corners,
AnyTorchOptionalFloatType:$scales_h,
AnyTorchOptionalFloatType:$scales_w
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleBilinear2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenUpsampleBilinear2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenUpsampleBilinear2dVecOp : Torch_Op<"aten.upsample_bilinear2d.vec", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalListOfTorchIntType:$output_size,
Torch_BoolType:$align_corners,
AnyTorchOptionalListOfTorchFloatType:$scale_factors
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleBilinear2dVecOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenUpsampleBilinear2dVecOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -11043,6 +11043,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %10 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<list<float>>) -> !torch.list<int> {\n"
" %0 = call @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0, %arg1, %arg3) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<float>>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
@ -12576,6 +12590,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<list<float>>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -531,6 +531,9 @@ FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
"_SoftmaxModule_basic",
"UpSampleNearest2dDynamicFactor_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
# torch export: RuntimeError: cannot mutate tensors with frozen storage
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
}
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
@ -979,6 +982,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
# materialization callback produced value of incorrect type failed
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
# torch export: RuntimeError: cannot mutate tensors with frozen storage
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
}
STABLEHLO_PASS_SET = {

View File

@ -2342,6 +2342,20 @@ def atenupsample_nearest2dvec〡shape(input: List[int], output_size: Optio
assert scale_factors is not None
return [input[0], input[1], int(input[2] * scale_factors[0]), int(input[3] * scale_factors[1])]
@check_shape_function([
Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True)
])
def atenupsample_bilinear2d〡shape(self: List[int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
return [self[0], self[1], output_size[0], output_size[1]]
@check_shape_function([
Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True, None),
Invocation(TensorOfShape(1, 3, 10, 9), None, True, [2.0, 2.3]),
Invocation(TensorOfShape(1, 3, 5, 6), None, True, [2.5, 1.0])
])
def atenupsample_bilinear2dvec〡shape(input: List[int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> List[int]:
return atenupsample_nearest2dvec〡shape(input, output_size, scale_factors)
# ==============================================================================
# Dtype Functions
# ==============================================================================
@ -3570,6 +3584,16 @@ def atenupsample_nearest2dvec〡dtype(input_rank_dtype: Tuple[int, int], o
self_rank, self_dtype = input_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True))
def atenupsample_bilinear2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True, scale_factors=None))
def atenupsample_bilinear2dvec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> int:
self_rank, self_dtype = input_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]))
def atenview〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -1013,6 +1013,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
emit(
"aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)"
)
emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)")
emit(
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)"
)