[onnx] Add `onnx-to-torch` lowering for random ops (#3193)

This commit adds the OnnxToTorch lowering for Onnx's RandomNormal, RandomNormalLike, RandomUniform, and RandomUniformLike op.
pull/3206/head
Vivek Khandelwal 2024-04-22 22:28:07 +05:30 committed by GitHub
parent 6abc7371c8
commit 3c252cdd44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 339 additions and 83 deletions

View File

@ -87,6 +87,8 @@ m_OnnxListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {
return detail::onnx_list_of_constant_ints_op_binder(bind_values);
}
std::optional<int64_t> onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx);
} // namespace mlir::torch::onnx_c
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

View File

@ -9,6 +9,7 @@
#include "mlir/IR/DialectResourceBlobManager.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/FormatVariadic.h"
@ -17,56 +18,6 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;
class Endian {
private:
static constexpr uint32_t uint32_ = 0x01020304;
static constexpr uint8_t magic_ = (const uint8_t &)uint32_;
public:
static constexpr bool little = magic_ == 0x04;
static constexpr bool big = magic_ == 0x01;
static_assert(little || big, "Cannot determine endianness!");
private:
Endian() = delete;
};
static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
// TODO: Add complete mapping.
// Where are the ONNX and PyTorch dtype enums defined?
// ONNX:
// https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto
// PyTorch:
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88
int64_t dtypeIntTorch = [dtypeIntOnnx]() {
switch (dtypeIntOnnx) {
case 1:
return 6; // float
case 2:
return 0; // uint8
case 3:
return 1; // int8
case 6:
return 3; // int32
case 7:
return 4; // int64
case 9:
return 11; // bool
case 10:
return 5; // half
case 11:
return 7; // double
case 16:
return 15; // bfloat16
default:
return -1; // No dtype
}
}();
return dtypeIntTorch;
}
static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
Location loc, Value input,
int64_t dimA, int64_t dimB,
@ -428,7 +379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
int64_t dtypeIntOnnx, dtypeIntTorch;
int64_t dtypeIntOnnx;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) ||
binder.tensorResultType(resultType))
@ -452,16 +403,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
rewriter.replaceOp(binder.op, bernoulli);
return success();
}
dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (dtypeIntTorch == -1) {
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch));
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
@ -539,25 +489,21 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
"Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
int64_t dtypeIntOnnx, dtypeIntTorch;
int64_t dtypeIntOnnx;
if (binder.tensorOperand(operand) ||
binder.s64IntegerAttr(dtypeIntOnnx, "to") ||
binder.tensorResultType(resultType))
return failure();
dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (dtypeIntTorch == -1) {
auto message = llvm::formatv("unimplemented support for the given "
"dtype conversion (onnx 'type' = {0})",
dtypeIntOnnx);
auto y = rewriter.notifyMatchFailure(binder.op, message);
return y;
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch));
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
@ -1768,9 +1714,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value mVal = rewriter.create<Torch::AtenSizeIntOp>(binder.getLoc(),
operand, cst1);
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
int64_t dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value dtypeVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch));
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
// diagonalIndex = 0 populates the main diagonal
// diagonalIndex > 0 populates an upper diagonal

View File

@ -2274,4 +2274,218 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, input, cstAlpha, value);
return success();
});
patterns.onOp(
"RandomNormal", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
SmallString<64> name("torch.onnx.seed");
auto seedAttr = binder.op->getAttr(name);
if (seedAttr)
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for seed attribute");
Torch::ValueTensorType resultType;
int64_t dtypeIntOnnx;
float mean, scale;
SmallVector<int64_t> shape;
if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
binder.f32FloatAttr(mean, "mean", 0.0) ||
binder.f32FloatAttr(scale, "scale", 1.0) ||
binder.s64IntegerArrayAttr(shape, "shape", {}) ||
binder.tensorResultType(resultType)) {
return failure();
}
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
Value shapeList = createConstantIntList(binder, rewriter, shape);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
binder.op->getLoc(), resultType, shapeList,
/*dtype=*/constDtype,
/*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone,
/*memoryFormat=*/cstNone);
Value cstMean = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), mean));
Value cstStd = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), scale));
rewriter.replaceOpWithNewOp<Torch::AtenNormalFunctionalOp>(
binder.op, resultType, self, cstMean, cstStd,
/*generator=*/cstNone);
return success();
});
patterns.onOp(
"RandomNormalLike", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
SmallString<64> name("torch.onnx.seed");
auto seedAttr = binder.op->getAttr(name);
if (seedAttr)
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for seed attribute");
Torch::ValueTensorType resultType;
int64_t dtypeIntOnnx;
float mean, scale;
SmallVector<int64_t> shape;
Value input;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
binder.f32FloatAttr(mean, "mean", 0.0) ||
binder.f32FloatAttr(scale, "scale", 1.0) ||
binder.tensorResultType(resultType)) {
return failure();
}
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
input = rewriter.create<Torch::AtenToDtypeOp>(
binder.op->getLoc(), resultType, input, constDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/cstNone);
Value cstMean = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), mean));
Value cstStd = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), scale));
rewriter.replaceOpWithNewOp<Torch::AtenNormalFunctionalOp>(
binder.op, resultType, input, cstMean, cstStd,
/*generator=*/cstNone);
return success();
});
patterns.onOp(
"RandomUniform", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
SmallString<64> name("torch.onnx.seed");
auto seedAttr = binder.op->getAttr(name);
if (seedAttr)
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for seed attribute");
Torch::ValueTensorType resultType;
int64_t dtypeIntOnnx;
float high, low;
SmallVector<int64_t> shape;
if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
binder.f32FloatAttr(high, "high", 1.0) ||
binder.f32FloatAttr(low, "low", 0.0) ||
binder.s64IntegerArrayAttr(shape, "shape", {}) ||
binder.tensorResultType(resultType)) {
return failure();
}
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
Value shapeList = createConstantIntList(binder, rewriter, shape);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
binder.op->getLoc(), resultType, shapeList,
/*dtype=*/constDtype,
/*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone,
/*memoryFormat=*/cstNone);
Value cstHigh = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), high));
Value cstLow = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), low));
rewriter.replaceOpWithNewOp<Torch::AtenUniformOp>(
binder.op, resultType, self, cstLow, cstHigh,
/*generator=*/cstNone);
return success();
});
patterns.onOp(
"RandomUniformLike", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
SmallString<64> name("torch.onnx.seed");
auto seedAttr = binder.op->getAttr(name);
if (seedAttr)
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for seed attribute");
Torch::ValueTensorType resultType;
int64_t dtypeIntOnnx;
float high, low;
SmallVector<int64_t> shape;
Value input;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
binder.f32FloatAttr(high, "high", 1.0) ||
binder.f32FloatAttr(low, "low", 0.0) ||
binder.tensorResultType(resultType)) {
return failure();
}
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
input = rewriter.create<Torch::AtenToDtypeOp>(
binder.op->getLoc(), resultType, input, constDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/cstNone);
Value cstHigh = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), high));
Value cstLow = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), low));
rewriter.replaceOpWithNewOp<Torch::AtenUniformOp>(
binder.op, resultType, input, cstLow, cstHigh,
/*generator=*/cstNone);
return success();
});
}

View File

@ -59,3 +59,41 @@ bool mlir::torch::onnx_c::areAllElementsDistinct(SmallVector<int64_t> array) {
// as array's size.
return (set.size() == array.size());
}
std::optional<int64_t>
mlir::torch::onnx_c::onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
// TODO: Add complete mapping.
// Where are the ONNX and PyTorch dtype enums defined?
// ONNX:
// https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto
// PyTorch:
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88
std::optional<int64_t> dtypeIntTorch =
[dtypeIntOnnx]() -> std::optional<int64_t> {
switch (dtypeIntOnnx) {
case 1:
return 6; // float
case 2:
return 0; // uint8
case 3:
return 1; // int8
case 6:
return 3; // int32
case 7:
return 4; // int64
case 9:
return 11; // bool
case 10:
return 5; // half
case 11:
return 7; // double
case 16:
return 15; // bfloat16
default:
return std::nullopt; // No dtype
}
}();
return dtypeIntTorch;
}

View File

@ -2605,27 +2605,15 @@ ONNX_XFAIL_SET = {
# Failure - onnx_lowering: onnx.OneHot
"OneHotModule_basic",
# Failure - onnx_lowering: onnx.RandomNormal
# ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64)
"RandnDtypeDeviceModule_basic",
"RandnGeneratorF64Module_basic",
"RandnGeneratorModule_basic",
"RandnModule_basic",
# Failure - onnx_lowering: onnx.RandomNormalLike
"RandnLikeDtypeModule_basic",
"RandnLikeModule_basic",
# Failure - onnx_lowering: onnx.RandomUniform
"RandIntLowDtypeModule_basic",
"RandIntLowModule_basic",
# Failure - onnx_lowering: onnx.RandomUniformLike
"BernoulliFloatModule_basic",
"BernoulliPModule_basic",
"BernoulliTensorModule_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",
"RandModule_basic",
# Failure - onnx_lowering: onnx.ReduceL2
"LinalgNormKeepDimModule_basic",

View File

@ -1679,3 +1679,65 @@ func.func @test_triu_zero(%arg0: !torch.vtensor<[0,5],si64>, %arg1: !torch.vtens
%0 = torch.operator "onnx.Trilu"(%arg0, %arg1) : (!torch.vtensor<[0,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[0,5],si64>
return %0 : !torch.vtensor<[0,5],si64>
}
// -----
// CHECK-LABEL: func.func @test_random_normal
func.func @test_random_normal() -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[I6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[I10:.+]] = torch.constant.int 10
// CHECK: %[[SHAPE:.+]] = torch.prim.ListConstruct %[[I10]] : (!torch.int) -> !torch.list<int>
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[EMPTY_TENSOR:.+]] = torch.aten.empty.memory_format %[[SHAPE]], %[[I6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00
// CHECK: torch.aten.normal_functional %[[EMPTY_TENSOR]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.RandomNormal"() {torch.onnx.dtype = 1 : si64, torch.onnx.mean = 0.000000e+00 : f32, torch.onnx.scale = 1.000000e+00 : f32, torch.onnx.shape = [10 : si64]} : () -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}
// -----
// CHECK-LABEL: func.func @test_random_normal_like
func.func @test_random_normal_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[I6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[I6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00
// CHECK: torch.aten.normal_functional %[[CAST]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.RandomNormalLike"(%arg0) {torch.onnx.dtype = 1 : si64, torch.onnx.mean = 0.000000e+00 : f32, torch.onnx.scale = 1.000000e+00 : f32} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}
// -----
// CHECK-LABEL: func.func @test_random_uniform
func.func @test_random_uniform() -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[I6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[I10:.+]] = torch.constant.int 10
// CHECK: %[[SHAPE:.+]] = torch.prim.ListConstruct %[[I10]] : (!torch.int) -> !torch.list<int>
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[EMPTY_TENSOR:.+]] = torch.aten.empty.memory_format %[[SHAPE]], %[[I6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
// CHECK: torch.aten.uniform %[[EMPTY_TENSOR]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.RandomUniform"() {torch.onnx.dtype = 1 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32, torch.onnx.shape = [10 : si64]} : () -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}
// -----
// CHECK-LABEL: func.func @test_random_uniform_like
func.func @test_random_uniform_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[I6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[I6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00
// CHECK: torch.aten.uniform %[[CAST]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.RandomUniformLike"(%arg0) {torch.onnx.dtype = 1 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}