From 9ab2a150f20abbddcb291b9437d5b2b3506c9ace Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 30 Oct 2024 20:18:24 +0800 Subject: [PATCH] [Torch] emit upsample_bilinear2d(.vec) ops (#3834) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 53 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 22 ++++++++ projects/pt1/e2e_testing/xfail_sets.py | 6 +++ .../build_tools/abstract_interp_lib_gen.py | 24 +++++++++ .../build_tools/torch_ods_gen.py | 4 ++ 5 files changed, 109 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 206d70ffb..5ec6a4d1d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14093,6 +14093,59 @@ def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [ }]; } +def Torch_AtenUpsampleBilinear2dOp : Torch_Op<"aten.upsample_bilinear2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size, + Torch_BoolType:$align_corners, + AnyTorchOptionalFloatType:$scales_h, + AnyTorchOptionalFloatType:$scales_w + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleBilinear2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenUpsampleBilinear2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenUpsampleBilinear2dVecOp : Torch_Op<"aten.upsample_bilinear2d.vec", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalListOfTorchIntType:$output_size, + Torch_BoolType:$align_corners, + AnyTorchOptionalListOfTorchFloatType:$scale_factors + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleBilinear2dVecOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenUpsampleBilinear2dVecOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 46cb3e6b7..1765786be 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -11043,6 +11043,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %10 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional>) -> !torch.list {\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0, %arg1, %arg3) : (!torch.list, !torch.optional>, !torch.optional>) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -12576,6 +12590,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 82ca24443..5686664d3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -531,6 +531,9 @@ FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { "_SoftmaxModule_basic", "UpSampleNearest2dDynamicFactor_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + # torch export: RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { @@ -979,6 +982,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = { # materialization callback produced value of incorrect type failed "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", + # torch export: RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } STABLEHLO_PASS_SET = { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1cb9678ec..d9e57d674 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2342,6 +2342,20 @@ def aten〇upsample_nearest2d〇vec〡shape(input: List[int], output_size: Optio assert scale_factors is not None return [input[0], input[1], int(input[2] * scale_factors[0]), int(input[3] * scale_factors[1])] +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True) +]) +def aten〇upsample_bilinear2d〡shape(self: List[int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: + return [self[0], self[1], output_size[0], output_size[1]] + +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True, None), + Invocation(TensorOfShape(1, 3, 10, 9), None, True, [2.0, 2.3]), + Invocation(TensorOfShape(1, 3, 5, 6), None, True, [2.5, 1.0]) +]) +def aten〇upsample_bilinear2d〇vec〡shape(input: List[int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> List[int]: + return aten〇upsample_nearest2d〇vec〡shape(input, output_size, scale_factors) + # ============================================================================== # Dtype Functions # ============================================================================== @@ -3570,6 +3584,16 @@ def aten〇upsample_nearest2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], o self_rank, self_dtype = input_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True)) +def aten〇upsample_bilinear2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True, scale_factors=None)) +def aten〇upsample_bilinear2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> int: + self_rank, self_dtype = input_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) def aten〇view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 17f7faa10..311636c82 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1013,6 +1013,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)") emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)") emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)") + emit( + "aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)" + ) + emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)") emit( "aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)" )