mirror of https://github.com/llvm/torch-mlir
101 lines
6.6 KiB
MLIR
101 lines
6.6 KiB
MLIR
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
|
|
|
|
|
// -----
|
|
|
|
// CHECK-LABEL: func.func @torch.aten.uniform(
|
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
|
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
|
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00
|
|
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
|
|
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
|
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
|
|
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
|
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
|
|
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
|
|
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
|
|
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
|
|
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
|
|
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
|
|
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
|
|
func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> {
|
|
%none = torch.constant.none
|
|
%float0 = torch.constant.float 0.0
|
|
%float1 = torch.constant.float 1.0
|
|
%0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64>
|
|
return %0 : !torch.vtensor<[32, 64],f64>
|
|
}
|
|
|
|
// -----
|
|
|
|
// CHECK-LABEL: func.func @torch.aten.randn.generator
|
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
|
// CHECK: %[[INT32:.*]] = torch.constant.int 32
|
|
// CHECK: %[[INT64:.*]] = torch.constant.int 64
|
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
|
|
// CHECK: %[[SHAPE:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
|
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f64>
|
|
// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f64>
|
|
// CHECK: %[[RNG:.*]] = stablehlo.rng %[[VAL_0]], %[[VAL_1]], %[[SHAPE]], distribution = NORMAL : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
|
// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[RNG]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
|
// CHECK: return %[[RET]] : !torch.vtensor<[32,64],f64>
|
|
func.func @torch.aten.randn.generator() -> !torch.vtensor<[32, 64],f64> {
|
|
%none = torch.constant.none
|
|
%int32 = torch.constant.int 32
|
|
%int64 = torch.constant.int 64
|
|
%size = torch.prim.ListConstruct %int32, %int64 : (!torch.int, !torch.int) -> !torch.list<int>
|
|
%0 = torch.aten.randn.generator %size, %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[32, 64], f64>
|
|
return %0 : !torch.vtensor<[32, 64],f64>
|
|
}
|
|
|
|
// -----
|
|
|
|
// CHECK-LABEL: func.func @torch.aten.randn.generator$f32
|
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
|
// CHECK: %[[INT32:.*]] = torch.constant.int 32
|
|
// CHECK: %[[INT64:.*]] = torch.constant.int 64
|
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
|
|
// CHECK: %[[SHAPE:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
|
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
|
// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
|
|
// CHECK: %[[RNG:.*]] = stablehlo.rng %[[VAL_0]], %[[VAL_1]], %[[SHAPE]], distribution = NORMAL : (tensor<f32>, tensor<f32>, tensor<2xi64>) -> tensor<32x64xf32>
|
|
// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[RNG]] : tensor<32x64xf32> -> !torch.vtensor<[32,64],f32>
|
|
// CHECK: return %[[RET]] : !torch.vtensor<[32,64],f32>
|
|
func.func @torch.aten.randn.generator$f32() -> !torch.vtensor<[32, 64],f32> {
|
|
%none = torch.constant.none
|
|
%int32 = torch.constant.int 32
|
|
%int64 = torch.constant.int 64
|
|
%size = torch.prim.ListConstruct %int32, %int64 : (!torch.int, !torch.int) -> !torch.list<int>
|
|
%0 = torch.aten.randn.generator %size, %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[32, 64], f32>
|
|
return %0 : !torch.vtensor<[32, 64],f32>
|
|
}
|
|
|
|
// -----
|
|
|
|
// CHECK-LABEL: func.func @torch.aten.normal_functional(
|
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
|
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
|
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 2.000000e+00
|
|
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
|
|
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
|
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
|
|
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
|
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
|
|
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
|
|
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
|
|
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
|
|
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
|
|
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
|
|
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = NORMAL : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
|
|
func.func @torch.aten.normal_functional(%arg0: !torch.vtensor<[32, 64], f64>) -> !torch.vtensor<[32, 64], f64> {
|
|
%none = torch.constant.none
|
|
%mean = torch.constant.float 2.0
|
|
%std = torch.constant.float 1.0
|
|
%0 = torch.aten.normal_functional %arg0, %mean, %std, %none : !torch.vtensor<[32, 64], f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64], f64>
|
|
return %0 : !torch.vtensor<[32, 64],f64>
|
|
}
|