mirror of https://github.com/llvm/torch-mlir
[Stablehlo] lowering aten.randn & aten.normal_functional to mhlo.rng … (#3328)
…NORMAL * split lowering of uniform, randn, normal from Basic.cpp into Rng.cpppull/3329/head
parent
00efec0b73
commit
5f7cb9e253
|
@ -1819,36 +1819,6 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
|
||||||
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
|
|
||||||
AtenUniformOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
Value self = adaptor.getSelf();
|
|
||||||
Value generator = adaptor.getGenerator();
|
|
||||||
Location loc = op.getLoc();
|
|
||||||
|
|
||||||
if (!isa<Torch::NoneType>(generator.getType()))
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "The generator has to be None because only global default "
|
|
||||||
"generator is supported");
|
|
||||||
|
|
||||||
auto elements = cast<RankedTensorType>(self.getType()).getShape();
|
|
||||||
if (llvm::any_of(elements,
|
|
||||||
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
|
|
||||||
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
|
|
||||||
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
|
|
||||||
loc, rewriter.getI64TensorAttr(elements));
|
|
||||||
auto outTy = getTypeConverter()->convertType(op.getType());
|
|
||||||
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
|
|
||||||
Value from =
|
|
||||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
|
|
||||||
Value to =
|
|
||||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy);
|
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
|
|
||||||
op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Converts `aten.empty.memory_format` to `tensor.empty` op.
|
// Converts `aten.empty.memory_format` to `tensor.empty` op.
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
||||||
|
@ -2240,7 +2210,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp);
|
INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
|
||||||
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
|
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||||
|
|
|
@ -6,6 +6,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
||||||
Linear.cpp
|
Linear.cpp
|
||||||
ViewLike.cpp
|
ViewLike.cpp
|
||||||
Reduction.cpp
|
Reduction.cpp
|
||||||
|
Rng.cpp
|
||||||
Pooling.cpp
|
Pooling.cpp
|
||||||
Utils.cpp
|
Utils.cpp
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,11 @@ void populatePoolingOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||||
|
|
||||||
|
void populateRngOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
|
RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target,
|
||||||
|
const TorchToStablehloOptions &options);
|
||||||
|
|
||||||
} // namespace torch_to_stablehlo
|
} // namespace torch_to_stablehlo
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -0,0 +1,137 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "./PopulatePatterns.h"
|
||||||
|
|
||||||
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
using namespace mlir::torch::torch_to_stablehlo;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
|
||||||
|
AtenUniformOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value self = adaptor.getSelf();
|
||||||
|
Value generator = adaptor.getGenerator();
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
if (!isa<Torch::NoneType>(generator.getType()))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "The generator has to be None because only global default "
|
||||||
|
"generator is supported");
|
||||||
|
|
||||||
|
auto elements = cast<RankedTensorType>(self.getType()).getShape();
|
||||||
|
if (llvm::any_of(elements,
|
||||||
|
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
|
||||||
|
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
|
||||||
|
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
|
||||||
|
loc, rewriter.getI64TensorAttr(elements));
|
||||||
|
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||||
|
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
|
||||||
|
Value from =
|
||||||
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
|
||||||
|
Value to =
|
||||||
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy);
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
|
||||||
|
op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenRandnGeneratorOp>::matchAndRewrite(
|
||||||
|
AtenRandnGeneratorOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value generator = adaptor.getGenerator();
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
if (!isa<Torch::NoneType>(generator.getType())) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "The generator has to be None because only global default "
|
||||||
|
"generator is supported");
|
||||||
|
}
|
||||||
|
llvm::SmallVector<int64_t> shape;
|
||||||
|
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "size must be constant");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||||
|
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
|
||||||
|
auto scalarTy = RankedTensorType::get({}, outElemTy);
|
||||||
|
if (!isa<mlir::FloatType>(outElemTy)) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"only support output with float type");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value shapeTensor = rewriter.create<stablehlo::ConstantOp>(
|
||||||
|
loc, rewriter.getI64TensorAttr(shape));
|
||||||
|
Value mean = rewriter.create<stablehlo::ConstantOp>(
|
||||||
|
loc, DenseFPElementsAttr::get(scalarTy, 0.0));
|
||||||
|
Value var = rewriter.create<stablehlo::ConstantOp>(
|
||||||
|
loc, DenseFPElementsAttr::get(scalarTy, 1.0));
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
|
||||||
|
op, outTy, mean, var, shapeTensor, stablehlo::RngDistribution::NORMAL);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenNormalFunctionalOp>::matchAndRewrite(
|
||||||
|
AtenNormalFunctionalOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value self = adaptor.getSelf();
|
||||||
|
Value generator = adaptor.getGenerator();
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
if (!isa<Torch::NoneType>(generator.getType()))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "The generator has to be None because only global default "
|
||||||
|
"generator is supported");
|
||||||
|
|
||||||
|
auto elements = cast<RankedTensorType>(self.getType()).getShape();
|
||||||
|
if (llvm::any_of(elements,
|
||||||
|
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
|
||||||
|
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
|
||||||
|
auto shapeTensor = rewriter.create<stablehlo::ConstantOp>(
|
||||||
|
loc, rewriter.getI64TensorAttr(elements));
|
||||||
|
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||||
|
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
|
||||||
|
Value mean =
|
||||||
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getMean(), outElemTy);
|
||||||
|
Value std =
|
||||||
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStd(), outElemTy);
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
|
||||||
|
op, outTy, mean, std, shapeTensor, stablehlo::RngDistribution::NORMAL);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::torch::torch_to_stablehlo::populateRngOpPatternsAndLegality(
|
||||||
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||||
|
|
||||||
|
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenRandnGeneratorOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenNormalFunctionalOp);
|
||||||
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
}
|
|
@ -75,6 +75,8 @@ public:
|
||||||
typeConverter, patterns, target, options);
|
typeConverter, patterns, target, options);
|
||||||
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||||
typeConverter, patterns, target, options);
|
typeConverter, patterns, target, options);
|
||||||
|
torch_to_stablehlo::populateRngOpPatternsAndLegality(
|
||||||
|
typeConverter, patterns, target, options);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
|
|
@ -291,33 +291,6 @@ func.func @torch.runtime.assert(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.uniform(
|
|
||||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
|
|
||||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
|
||||||
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00
|
|
||||||
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
|
|
||||||
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
|
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
|
|
||||||
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
|
||||||
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
|
|
||||||
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
|
|
||||||
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
|
|
||||||
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
|
|
||||||
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
|
|
||||||
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
|
|
||||||
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
|
||||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
|
||||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
|
|
||||||
func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> {
|
|
||||||
%none = torch.constant.none
|
|
||||||
%float0 = torch.constant.float 0.0
|
|
||||||
%float1 = torch.constant.float 1.0
|
|
||||||
%0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64>
|
|
||||||
return %0 : !torch.vtensor<[32, 64],f64>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor(
|
// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor(
|
||||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>,
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>,
|
||||||
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> {
|
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> {
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.uniform(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
|
||||||
|
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
|
||||||
|
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
||||||
|
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
|
||||||
|
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
||||||
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
|
||||||
|
func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%float0 = torch.constant.float 0.0
|
||||||
|
%float1 = torch.constant.float 1.0
|
||||||
|
%0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64>
|
||||||
|
return %0 : !torch.vtensor<[32, 64],f64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.randn.generator
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT32:.*]] = torch.constant.int 32
|
||||||
|
// CHECK: %[[INT64:.*]] = torch.constant.int 64
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
|
||||||
|
// CHECK: %[[SHAPE:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
||||||
|
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f64>
|
||||||
|
// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f64>
|
||||||
|
// CHECK: %[[RNG:.*]] = stablehlo.rng %[[VAL_0]], %[[VAL_1]], %[[SHAPE]], distribution = NORMAL : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
||||||
|
// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[RNG]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
||||||
|
// CHECK: return %[[RET]] : !torch.vtensor<[32,64],f64>
|
||||||
|
func.func @torch.aten.randn.generator() -> !torch.vtensor<[32, 64],f64> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%int32 = torch.constant.int 32
|
||||||
|
%int64 = torch.constant.int 64
|
||||||
|
%size = torch.prim.ListConstruct %int32, %int64 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%0 = torch.aten.randn.generator %size, %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[32, 64], f64>
|
||||||
|
return %0 : !torch.vtensor<[32, 64],f64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.normal_functional(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 2.000000e+00
|
||||||
|
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
|
||||||
|
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
|
||||||
|
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
||||||
|
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
|
||||||
|
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = NORMAL : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
||||||
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
|
||||||
|
func.func @torch.aten.normal_functional(%arg0: !torch.vtensor<[32, 64], f64>) -> !torch.vtensor<[32, 64], f64> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%mean = torch.constant.float 2.0
|
||||||
|
%std = torch.constant.float 1.0
|
||||||
|
%0 = torch.aten.normal_functional %arg0, %mean, %std, %none : !torch.vtensor<[32, 64], f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64], f64>
|
||||||
|
return %0 : !torch.vtensor<[32, 64],f64>
|
||||||
|
}
|
Loading…
Reference in New Issue