From 96f90efd1606cc6695d5fb514ee23817bb733174 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Thu, 12 May 2022 16:00:59 -0700 Subject: [PATCH] Add shape info to `rand_like` + support for `dtype` flag (#851) The op `aten.rand_like` was missing a shape function, unit tests, and the `dtype` argument was being ignored in its decomposition. This commit fixes all three things. --- .../Torch/Transforms/DecomposeComplexOps.cpp | 36 +++------- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 3 + .../jit_ir/build_tools/shape_lib_gen.py | 3 + python/torch_mlir_e2e_test/test_suite/rng.py | 42 ++++++++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 68 +++++++++++++++++-- 5 files changed, 121 insertions(+), 31 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index a38bcc73c..6418c4e9a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1116,42 +1116,26 @@ public: PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.self(); + Type resultType = op.getType(); auto inputType = input.getType().cast(); - if (!inputType.hasDtype() || !inputType.getDtype().isa()) + if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure(op, "only support floating-point type"); - - // TODO: Add support for layout, pin_memory and memory_format features. - // Only `none` layout is supported. - if (!op.layout().getType().isa()) - return rewriter.notifyMatchFailure( - op, "unimplemented: only default layout is supported"); - - // The pin_memory should be either `none` or constant `False`. - if (!op.pin_memory().getType().isa()) { - bool pinMemory; - if (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory))) - return rewriter.notifyMatchFailure( - op, "unimplemented: pin_memory must be a constant"); - else if (pinMemory) - return rewriter.notifyMatchFailure( - op, "unimplemented: pin_memory is expected to be false"); } - // Only `none` memory_format is supported. - if (!op.memory_format().getType().isa()) - return rewriter.notifyMatchFailure( - op, "unimplemented: only default memory format is supported"); - - // Create a uniform random op with low and high set to 0.0 and 1.0 + // Create a uniform random op with low and high set to 0.0 and 1.0, // respectively. Value none = rewriter.create(loc); - Value lb = + Value zero = rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); - Value ub = + Value one = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value emptyTensor = rewriter.create( + loc, resultType, input, op.dtype(), op.layout(), op.device(), + op.pin_memory(), op.memory_format()); rewriter.replaceOpWithNewOp( - op, op.getType(), input, lb, ub, /*generator=*/none); + op, resultType, emptyTensor, /*from=*/zero, /*to=*/one, + /*generator=*/none); return success(); } }; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index f75d7b2b8..f1ecc727c 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -2018,6 +2018,9 @@ module { func @"__torch_mlir_shape_fn.aten.bernoulli"(%arg0: !torch.list, %arg1: !torch.any) -> !torch.list { return %arg0 : !torch.list } + func @"__torch_mlir_shape_fn.aten.rand_like"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { + return %arg0 : !torch.list + } func @"__torch_mlir_shape_fn.aten.arange.start_step"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list { %0 = torch.derefine %arg3 : !torch.optional to !torch.any %1 = torch.derefine %arg4 : !torch.optional to !torch.any diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index ec1a3a853..479535fe2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -665,6 +665,9 @@ def aten〇index_put_impl(self: List[int], indices: List[Optional[List[int]]], v def aten〇bernoulli(self: List[int], generator: Any = None) -> List[int]: return self +def aten〇rand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: + return self + def aten〇arange〇start_step(start: float, end: float, step: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return upstream_shape_helpers.arange_start_step(start, end, step, dtype, layout, device, pin_memory) diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index edaf43ce1..ea13eadaf 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -234,3 +234,45 @@ def BernoulliTensorModule_basic(module, tu: TestUtils): tu.rand(1024, 2048, 8).double(), tu.rand(1024, 512, 8).double(), tu.rand(1024, 512, 8).double()) + +# ============================================================================== + +class RandLikeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float64, True), + ]) + def forward(self, x): + a = torch.ops.aten.rand_like(x) + mean = torch.mean(a) + return mean + + +@register_test_case(module_factory=lambda: RandLikeModule()) +def RandLikeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1024).double()) + +# ============================================================================== + +class RandLikeDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float64, True), + ]) + def forward(self, x): + a = torch.ops.aten.rand_like(x, dtype=torch.float32) + mean = torch.mean(a) + return mean + + +@register_test_case(module_factory=lambda: RandLikeDtypeModule()) +def RandLikeDtypeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1024).double()) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 2aba19401..c4a494ec8 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -406,7 +406,16 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) - // CHECK: %[[NONE_2:.*]] = torch.constant.none // CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> + // CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1> // CHECK: %[[INT7_2:.*]] = torch.constant.int 7 // CHECK: %[[FALSE_2:.*]] = torch.constant.bool false @@ -437,7 +446,15 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor // CHECK: %[[NONE_2:.*]] = torch.constant.none // CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> // CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1> // CHECK: %[[INT7_2:.*]] = torch.constant.int 7 // CHECK: %[[FALSE_2:.*]] = torch.constant.bool false @@ -468,7 +485,15 @@ func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> ! // CHECK: %[[NONE_2:.*]] = torch.constant.none // CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> // CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1> // CHECK: %[[INT7_2:.*]] = torch.constant.int 7 // CHECK: %[[FALSE_2:.*]] = torch.constant.bool false @@ -484,6 +509,33 @@ func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %ar return %1 : !torch.vtensor } +// ----- +// CHECK-LABEL: func @torch.aten.rand_like( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[INPUT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[INT6]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[UNIFORM:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_1]] : !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[UNIFORM]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor +func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { + %int6 = torch.constant.int 6 + %none = torch.constant.none + %0 = torch.aten.rand_like %arg0, %int6, %none, %none, %none, %none : !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32> + %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor + return %1 : !torch.vtensor +} + // ----- // CHECK-LABEL: func @torch.aten.select.int( // CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> { @@ -737,7 +789,7 @@ func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtenso // ----- // CHECK-LABEL: func @torch.aten.dropout$train( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[PROB:.*]] = torch.constant.float 3.000000e-01 // CHECK: %[[TRAIN:.*]] = torch.constant.bool true // CHECK: %[[NONE:.*]] = torch.constant.none @@ -753,7 +805,13 @@ func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtenso // CHECK: %[[NONE_2:.*]] = torch.constant.none // CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?],f64> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int +// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f64> +// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?],i1> // CHECK: %[[INT6:.*]] = torch.constant.int 6 // CHECK: %[[FALSE_2:.*]] = torch.constant.bool false