From 4339c00f1b818f3cc6fd52fd5610ea973317a7b0 Mon Sep 17 00:00:00 2001 From: Jiawei Wu Date: Sun, 27 Aug 2023 21:56:36 +0800 Subject: [PATCH] [Torch Dialect][stablehlo] emit aten.rand op and add converter to stablehlo (#2413) * [Torch Dialect] emit aten.rand op and add converter to stablehlo * add failed tests for torchdynamo backend * add failed test for linalg backend --- e2e_testing/xfail_sets.py | 5 ++- .../Dialect/Torch/IR/GeneratedTorchOps.td | 27 ++++++++++++++++ lib/Conversion/TorchToStablehlo/Basic.cpp | 31 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 15 +++++++++ .../build_tools/abstract_interp_lib_gen.py | 9 ++++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + python/torch_mlir_e2e_test/test_suite/rng.py | 22 +++++++++++++ 7 files changed, 109 insertions(+), 1 deletion(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 8224a6632..7eca06c02 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -16,7 +16,8 @@ from torch_mlir._version import torch_version_for_comparison, version LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier" + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "RandModule_basic" } TORCHDYNAMO_XFAIL_SET = { @@ -62,6 +63,7 @@ TORCHDYNAMO_XFAIL_SET = { # error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal "ElementwiseFlattenBroadcastModule_basic", "FlattenRank0Module_basic", + "RandModule_basic", "UniformModule_basic", "UniformStaticShapeModule_basic", # error: unsupported by backend contract: tensor with unknown rank @@ -868,6 +870,7 @@ STABLEHLO_PASS_SET = { "RandIntLowModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", + "RandModule_basic", "UniformStaticShapeModule_basic", "UniformNoCorrelationModule_basic", "TupleModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b124cef10..9f122b085 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4154,6 +4154,33 @@ def Torch_AtenRandLikeOp : Torch_Op<"aten.rand_like", [ }]; } +def Torch_AtenRandOp : Torch_Op<"aten.rand", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rand : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenRandOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenBernoulliOp : Torch_Op<"aten.bernoulli", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 269bda449..be26ea478 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1553,6 +1553,36 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + SmallVector size; + if (!matchPattern(adaptor.getSize(), m_TorchListOfConstantInts(size))) { + return rewriter.notifyMatchFailure(op, + "only constant integer size supported"); + } + auto shapeTensor = rewriter.create( + loc, rewriter.getI64TensorAttr(size)); + auto outTy = getTypeConverter()->convertType(op.getType()); + auto outElemTy = outTy.cast().getElementType(); + + if (!outElemTy.isa()) { + return rewriter.notifyMatchFailure(op, "only float type supported"); + } + + Value from = rewriter.create( + loc, rewriter.getFloatAttr(outElemTy, 0.0)); + from = hlo::scalarToStablehloTensor(rewriter, op, from, outElemTy); + Value to = rewriter.create( + loc, rewriter.getFloatAttr(outElemTy, 1.0)); + to = hlo::scalarToStablehloTensor(rewriter, op, to, outElemTy); + rewriter.replaceOpWithNewOp( + op, outTy, from, to, shapeTensor, stablehlo::RngDistribution::UNIFORM); + return success(); +} + // Converts `aten.empty.memory_format` to `tensor.empty` op. template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1844,6 +1874,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenRandOp); INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFlipOp); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 416b74854..11302e86a 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7251,6 +7251,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rand\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bernoulli.float\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -8840,6 +8843,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rand\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._unsafe_view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 8674f1621..c13711789 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -671,6 +671,9 @@ def aten〇copy〡shape(self: List[int], src: List[int], non_blocking: bool = Fa def aten〇uniform〡shape(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]: return self +def aten〇rand〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + return size + @not_present_in_registry def aten〇bernoulli〇float〡shape(self: List[int], p: float = 0.5, generator: Any = None) -> List[int]: return self @@ -1965,6 +1968,12 @@ def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇rand〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) def aten〇_unsafe_view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 08671dd90..877057ece 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -347,6 +347,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # Random number generation emit_with_mutating_variants("aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)") emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") + emit("aten::rand : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)") emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)") emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index 2575090db..1baa46246 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -6,6 +6,28 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== +class RandModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1024, 512], torch.float, True) + ]) + def forward(self, x): + size = x.size() + a = torch.rand(size) + return torch.std(a), torch.mean(a) + + +@register_test_case(module_factory=lambda: RandModule()) +def RandModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 512)) + +# ============================================================================== + class UniformModule(torch.nn.Module): def __init__(self):