mirror of https://github.com/llvm/torch-mlir
[onnx] Add `onnx-to-torch` lowering for random ops (#3193)
This commit adds the OnnxToTorch lowering for Onnx's RandomNormal, RandomNormalLike, RandomUniform, and RandomUniformLike op.pull/3206/head
parent
6abc7371c8
commit
3c252cdd44
|
@ -87,6 +87,8 @@ m_OnnxListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {
|
|||
return detail::onnx_list_of_constant_ints_op_binder(bind_values);
|
||||
}
|
||||
|
||||
std::optional<int64_t> onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx);
|
||||
|
||||
} // namespace mlir::torch::onnx_c
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include "mlir/IR/DialectResourceBlobManager.h"
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
@ -17,56 +18,6 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::onnx_c;
|
||||
|
||||
class Endian {
|
||||
private:
|
||||
static constexpr uint32_t uint32_ = 0x01020304;
|
||||
static constexpr uint8_t magic_ = (const uint8_t &)uint32_;
|
||||
|
||||
public:
|
||||
static constexpr bool little = magic_ == 0x04;
|
||||
static constexpr bool big = magic_ == 0x01;
|
||||
static_assert(little || big, "Cannot determine endianness!");
|
||||
|
||||
private:
|
||||
Endian() = delete;
|
||||
};
|
||||
|
||||
static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
|
||||
// TODO: Add complete mapping.
|
||||
// Where are the ONNX and PyTorch dtype enums defined?
|
||||
// ONNX:
|
||||
// https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto
|
||||
// PyTorch:
|
||||
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88
|
||||
|
||||
int64_t dtypeIntTorch = [dtypeIntOnnx]() {
|
||||
switch (dtypeIntOnnx) {
|
||||
case 1:
|
||||
return 6; // float
|
||||
case 2:
|
||||
return 0; // uint8
|
||||
case 3:
|
||||
return 1; // int8
|
||||
case 6:
|
||||
return 3; // int32
|
||||
case 7:
|
||||
return 4; // int64
|
||||
case 9:
|
||||
return 11; // bool
|
||||
case 10:
|
||||
return 5; // half
|
||||
case 11:
|
||||
return 7; // double
|
||||
case 16:
|
||||
return 15; // bfloat16
|
||||
default:
|
||||
return -1; // No dtype
|
||||
}
|
||||
}();
|
||||
|
||||
return dtypeIntTorch;
|
||||
}
|
||||
|
||||
static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value input,
|
||||
int64_t dimA, int64_t dimB,
|
||||
|
@ -428,7 +379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input;
|
||||
int64_t dtypeIntOnnx, dtypeIntTorch;
|
||||
int64_t dtypeIntOnnx;
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) ||
|
||||
binder.tensorResultType(resultType))
|
||||
|
@ -452,16 +403,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
rewriter.replaceOp(binder.op, bernoulli);
|
||||
return success();
|
||||
}
|
||||
dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
if (dtypeIntTorch == -1) {
|
||||
std::optional<int64_t> dtypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
if (!dtypeIntTorch.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented support for the given dtype conversion");
|
||||
}
|
||||
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
dtypeIntTorch));
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
||||
Value cstFalse =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
||||
|
@ -539,25 +489,21 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
"Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
int64_t dtypeIntOnnx, dtypeIntTorch;
|
||||
int64_t dtypeIntOnnx;
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.s64IntegerAttr(dtypeIntOnnx, "to") ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
if (dtypeIntTorch == -1) {
|
||||
auto message = llvm::formatv("unimplemented support for the given "
|
||||
"dtype conversion (onnx 'type' = {0})",
|
||||
dtypeIntOnnx);
|
||||
auto y = rewriter.notifyMatchFailure(binder.op, message);
|
||||
|
||||
return y;
|
||||
std::optional<int64_t> dtypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
if (!dtypeIntTorch.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented support for the given dtype conversion");
|
||||
}
|
||||
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
dtypeIntTorch));
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value cstFalse =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
|
@ -1768,9 +1714,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
Value mVal = rewriter.create<Torch::AtenSizeIntOp>(binder.getLoc(),
|
||||
operand, cst1);
|
||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
int64_t dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
std::optional<int64_t> dtypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
if (!dtypeIntTorch.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented support for the given dtype conversion");
|
||||
}
|
||||
Value dtypeVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch));
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
||||
|
||||
// diagonalIndex = 0 populates the main diagonal
|
||||
// diagonalIndex > 0 populates an upper diagonal
|
||||
|
|
|
@ -2274,4 +2274,218 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, resultType, input, cstAlpha, value);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"RandomNormal", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
SmallString<64> name("torch.onnx.seed");
|
||||
auto seedAttr = binder.op->getAttr(name);
|
||||
if (seedAttr)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented: support not present for seed attribute");
|
||||
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t dtypeIntOnnx;
|
||||
float mean, scale;
|
||||
SmallVector<int64_t> shape;
|
||||
if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
|
||||
binder.f32FloatAttr(mean, "mean", 0.0) ||
|
||||
binder.f32FloatAttr(scale, "scale", 1.0) ||
|
||||
binder.s64IntegerArrayAttr(shape, "shape", {}) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::optional<int64_t> dtypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
if (!dtypeIntTorch.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented support for the given dtype conversion");
|
||||
}
|
||||
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
||||
|
||||
Value shapeList = createConstantIntList(binder, rewriter, shape);
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
|
||||
Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
|
||||
binder.op->getLoc(), resultType, shapeList,
|
||||
/*dtype=*/constDtype,
|
||||
/*layout=*/cstNone,
|
||||
/*device=*/cstNone, /*pinMemory=*/cstNone,
|
||||
/*memoryFormat=*/cstNone);
|
||||
|
||||
Value cstMean = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), mean));
|
||||
Value cstStd = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), scale));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenNormalFunctionalOp>(
|
||||
binder.op, resultType, self, cstMean, cstStd,
|
||||
/*generator=*/cstNone);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"RandomNormalLike", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
SmallString<64> name("torch.onnx.seed");
|
||||
auto seedAttr = binder.op->getAttr(name);
|
||||
if (seedAttr)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented: support not present for seed attribute");
|
||||
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t dtypeIntOnnx;
|
||||
float mean, scale;
|
||||
SmallVector<int64_t> shape;
|
||||
Value input;
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
|
||||
binder.f32FloatAttr(mean, "mean", 0.0) ||
|
||||
binder.f32FloatAttr(scale, "scale", 1.0) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::optional<int64_t> dtypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
if (!dtypeIntTorch.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented support for the given dtype conversion");
|
||||
}
|
||||
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
||||
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value cstFalse =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
input = rewriter.create<Torch::AtenToDtypeOp>(
|
||||
binder.op->getLoc(), resultType, input, constDtype,
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/cstNone);
|
||||
|
||||
Value cstMean = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), mean));
|
||||
Value cstStd = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), scale));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenNormalFunctionalOp>(
|
||||
binder.op, resultType, input, cstMean, cstStd,
|
||||
/*generator=*/cstNone);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"RandomUniform", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
SmallString<64> name("torch.onnx.seed");
|
||||
auto seedAttr = binder.op->getAttr(name);
|
||||
if (seedAttr)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented: support not present for seed attribute");
|
||||
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t dtypeIntOnnx;
|
||||
float high, low;
|
||||
SmallVector<int64_t> shape;
|
||||
if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
|
||||
binder.f32FloatAttr(high, "high", 1.0) ||
|
||||
binder.f32FloatAttr(low, "low", 0.0) ||
|
||||
binder.s64IntegerArrayAttr(shape, "shape", {}) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::optional<int64_t> dtypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
if (!dtypeIntTorch.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented support for the given dtype conversion");
|
||||
}
|
||||
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
||||
|
||||
Value shapeList = createConstantIntList(binder, rewriter, shape);
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
|
||||
Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
|
||||
binder.op->getLoc(), resultType, shapeList,
|
||||
/*dtype=*/constDtype,
|
||||
/*layout=*/cstNone,
|
||||
/*device=*/cstNone, /*pinMemory=*/cstNone,
|
||||
/*memoryFormat=*/cstNone);
|
||||
|
||||
Value cstHigh = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), high));
|
||||
Value cstLow = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), low));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenUniformOp>(
|
||||
binder.op, resultType, self, cstLow, cstHigh,
|
||||
/*generator=*/cstNone);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"RandomUniformLike", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
SmallString<64> name("torch.onnx.seed");
|
||||
auto seedAttr = binder.op->getAttr(name);
|
||||
if (seedAttr)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented: support not present for seed attribute");
|
||||
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t dtypeIntOnnx;
|
||||
float high, low;
|
||||
SmallVector<int64_t> shape;
|
||||
Value input;
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
|
||||
binder.f32FloatAttr(high, "high", 1.0) ||
|
||||
binder.f32FloatAttr(low, "low", 0.0) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::optional<int64_t> dtypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
if (!dtypeIntTorch.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented support for the given dtype conversion");
|
||||
}
|
||||
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
||||
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value cstFalse =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
input = rewriter.create<Torch::AtenToDtypeOp>(
|
||||
binder.op->getLoc(), resultType, input, constDtype,
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/cstNone);
|
||||
|
||||
Value cstHigh = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), high));
|
||||
Value cstLow = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), low));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenUniformOp>(
|
||||
binder.op, resultType, input, cstLow, cstHigh,
|
||||
/*generator=*/cstNone);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -59,3 +59,41 @@ bool mlir::torch::onnx_c::areAllElementsDistinct(SmallVector<int64_t> array) {
|
|||
// as array's size.
|
||||
return (set.size() == array.size());
|
||||
}
|
||||
|
||||
std::optional<int64_t>
|
||||
mlir::torch::onnx_c::onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
|
||||
// TODO: Add complete mapping.
|
||||
// Where are the ONNX and PyTorch dtype enums defined?
|
||||
// ONNX:
|
||||
// https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto
|
||||
// PyTorch:
|
||||
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88
|
||||
|
||||
std::optional<int64_t> dtypeIntTorch =
|
||||
[dtypeIntOnnx]() -> std::optional<int64_t> {
|
||||
switch (dtypeIntOnnx) {
|
||||
case 1:
|
||||
return 6; // float
|
||||
case 2:
|
||||
return 0; // uint8
|
||||
case 3:
|
||||
return 1; // int8
|
||||
case 6:
|
||||
return 3; // int32
|
||||
case 7:
|
||||
return 4; // int64
|
||||
case 9:
|
||||
return 11; // bool
|
||||
case 10:
|
||||
return 5; // half
|
||||
case 11:
|
||||
return 7; // double
|
||||
case 16:
|
||||
return 15; // bfloat16
|
||||
default:
|
||||
return std::nullopt; // No dtype
|
||||
}
|
||||
}();
|
||||
|
||||
return dtypeIntTorch;
|
||||
}
|
||||
|
|
|
@ -2605,27 +2605,15 @@ ONNX_XFAIL_SET = {
|
|||
# Failure - onnx_lowering: onnx.OneHot
|
||||
"OneHotModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.RandomNormal
|
||||
# ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64)
|
||||
"RandnDtypeDeviceModule_basic",
|
||||
"RandnGeneratorF64Module_basic",
|
||||
"RandnGeneratorModule_basic",
|
||||
"RandnModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.RandomNormalLike
|
||||
"RandnLikeDtypeModule_basic",
|
||||
"RandnLikeModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.RandomUniform
|
||||
"RandIntLowDtypeModule_basic",
|
||||
"RandIntLowModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.RandomUniformLike
|
||||
"BernoulliFloatModule_basic",
|
||||
"BernoulliPModule_basic",
|
||||
"BernoulliTensorModule_basic",
|
||||
"RandLikeDtypeModule_basic",
|
||||
"RandLikeModule_basic",
|
||||
"RandModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.ReduceL2
|
||||
"LinalgNormKeepDimModule_basic",
|
||||
|
|
|
@ -1679,3 +1679,65 @@ func.func @test_triu_zero(%arg0: !torch.vtensor<[0,5],si64>, %arg1: !torch.vtens
|
|||
%0 = torch.operator "onnx.Trilu"(%arg0, %arg1) : (!torch.vtensor<[0,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[0,5],si64>
|
||||
return %0 : !torch.vtensor<[0,5],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_random_normal
|
||||
func.func @test_random_normal() -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[I6:.+]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[I10:.+]] = torch.constant.int 10
|
||||
// CHECK: %[[SHAPE:.+]] = torch.prim.ListConstruct %[[I10]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[EMPTY_TENSOR:.+]] = torch.aten.empty.memory_format %[[SHAPE]], %[[I6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: torch.aten.normal_functional %[[EMPTY_TENSOR]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32>
|
||||
%0 = torch.operator "onnx.RandomNormal"() {torch.onnx.dtype = 1 : si64, torch.onnx.mean = 0.000000e+00 : f32, torch.onnx.scale = 1.000000e+00 : f32, torch.onnx.shape = [10 : si64]} : () -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_random_normal_like
|
||||
func.func @test_random_normal_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[I6:.+]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[I6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: torch.aten.normal_functional %[[CAST]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32>
|
||||
%0 = torch.operator "onnx.RandomNormalLike"(%arg0) {torch.onnx.dtype = 1 : si64, torch.onnx.mean = 0.000000e+00 : f32, torch.onnx.scale = 1.000000e+00 : f32} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_random_uniform
|
||||
func.func @test_random_uniform() -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[I6:.+]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[I10:.+]] = torch.constant.int 10
|
||||
// CHECK: %[[SHAPE:.+]] = torch.prim.ListConstruct %[[I10]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[EMPTY_TENSOR:.+]] = torch.aten.empty.memory_format %[[SHAPE]], %[[I6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: torch.aten.uniform %[[EMPTY_TENSOR]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32>
|
||||
%0 = torch.operator "onnx.RandomUniform"() {torch.onnx.dtype = 1 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32, torch.onnx.shape = [10 : si64]} : () -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_random_uniform_like
|
||||
func.func @test_random_uniform_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[I6:.+]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[I6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: torch.aten.uniform %[[CAST]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32>
|
||||
%0 = torch.operator "onnx.RandomUniformLike"(%arg0) {torch.onnx.dtype = 1 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue