[Stablehlo] lowering aten.randn & aten.normal_functional to mhlo.rng … (#3328)

…NORMAL

* split lowering of uniform, randn, normal from Basic.cpp into Rng.cpp
pull/3329/head
Yuanqiang Liu 2024-05-11 15:33:37 +08:00 committed by GitHub
parent 00efec0b73
commit 5f7cb9e253
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 224 additions and 58 deletions

View File

@ -1819,36 +1819,6 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
return success(); return success();
} }
template <>
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
AtenUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
Value generator = adaptor.getGenerator();
Location loc = op.getLoc();
if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
auto elements = cast<RankedTensorType>(self.getType()).getShape();
if (llvm::any_of(elements,
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getI64TensorAttr(elements));
auto outTy = getTypeConverter()->convertType(op.getType());
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
Value from =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
Value to =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy);
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM);
return success();
}
// Converts `aten.empty.memory_format` to `tensor.empty` op. // Converts `aten.empty.memory_format` to `tensor.empty` op.
template <> template <>
LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
@ -2240,7 +2210,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp);
INSERT_ATENOP_PATTERN(AtenUniformOp);
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp);
INSERT_ATENOP_PATTERN(AtenFlipOp); INSERT_ATENOP_PATTERN(AtenFlipOp);

View File

@ -6,6 +6,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
Linear.cpp Linear.cpp
ViewLike.cpp ViewLike.cpp
Reduction.cpp Reduction.cpp
Rng.cpp
Pooling.cpp Pooling.cpp
Utils.cpp Utils.cpp

View File

@ -62,6 +62,11 @@ void populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options); ConversionTarget &target, const TorchToStablehloOptions &options);
void populateRngOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToStablehloOptions &options);
} // namespace torch_to_stablehlo } // namespace torch_to_stablehlo
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -0,0 +1,137 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./PopulatePatterns.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo;
template <>
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
AtenUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
Value generator = adaptor.getGenerator();
Location loc = op.getLoc();
if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
auto elements = cast<RankedTensorType>(self.getType()).getShape();
if (llvm::any_of(elements,
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getI64TensorAttr(elements));
auto outTy = getTypeConverter()->convertType(op.getType());
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
Value from =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
Value to =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy);
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM);
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenRandnGeneratorOp>::matchAndRewrite(
AtenRandnGeneratorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value generator = adaptor.getGenerator();
Location loc = op.getLoc();
if (!isa<Torch::NoneType>(generator.getType())) {
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
}
llvm::SmallVector<int64_t> shape;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) {
return rewriter.notifyMatchFailure(op, "size must be constant");
}
auto outTy = getTypeConverter()->convertType(op.getType());
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
auto scalarTy = RankedTensorType::get({}, outElemTy);
if (!isa<mlir::FloatType>(outElemTy)) {
return rewriter.notifyMatchFailure(op,
"only support output with float type");
}
Value shapeTensor = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getI64TensorAttr(shape));
Value mean = rewriter.create<stablehlo::ConstantOp>(
loc, DenseFPElementsAttr::get(scalarTy, 0.0));
Value var = rewriter.create<stablehlo::ConstantOp>(
loc, DenseFPElementsAttr::get(scalarTy, 1.0));
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
op, outTy, mean, var, shapeTensor, stablehlo::RngDistribution::NORMAL);
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenNormalFunctionalOp>::matchAndRewrite(
AtenNormalFunctionalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
Value generator = adaptor.getGenerator();
Location loc = op.getLoc();
if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
auto elements = cast<RankedTensorType>(self.getType()).getShape();
if (llvm::any_of(elements,
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
auto shapeTensor = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getI64TensorAttr(elements));
auto outTy = getTypeConverter()->convertType(op.getType());
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
Value mean =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getMean(), outElemTy);
Value std =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStd(), outElemTy);
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
op, outTy, mean, std, shapeTensor, stablehlo::RngDistribution::NORMAL);
return success();
}
void mlir::torch::torch_to_stablehlo::populateRngOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATENOP_PATTERN(AtenUniformOp);
INSERT_ATENOP_PATTERN(AtenRandnGeneratorOp);
INSERT_ATENOP_PATTERN(AtenNormalFunctionalOp);
#undef INSERT_ATENOP_PATTERN
}

View File

@ -75,6 +75,8 @@ public:
typeConverter, patterns, target, options); typeConverter, patterns, target, options);
torch_to_stablehlo::populatePoolingOpPatternsAndLegality( torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
typeConverter, patterns, target, options); typeConverter, patterns, target, options);
torch_to_stablehlo::populateRngOpPatternsAndLegality(
typeConverter, patterns, target, options);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) { std::move(patterns)))) {

View File

@ -291,33 +291,6 @@ func.func @torch.runtime.assert(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
// ----- // -----
// CHECK-LABEL: func.func @torch.aten.uniform(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> {
%none = torch.constant.none
%float0 = torch.constant.float 0.0
%float1 = torch.constant.float 1.0
%0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64>
return %0 : !torch.vtensor<[32, 64],f64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor( // CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>, // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> { // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> {

View File

@ -0,0 +1,78 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// -----
// CHECK-LABEL: func.func @torch.aten.uniform(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> {
%none = torch.constant.none
%float0 = torch.constant.float 0.0
%float1 = torch.constant.float 1.0
%0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64>
return %0 : !torch.vtensor<[32, 64],f64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.randn.generator
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT32:.*]] = torch.constant.int 32
// CHECK: %[[INT64:.*]] = torch.constant.int 64
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK: %[[SHAPE:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f64>
// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f64>
// CHECK: %[[RNG:.*]] = stablehlo.rng %[[VAL_0]], %[[VAL_1]], %[[SHAPE]], distribution = NORMAL : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[RNG]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
// CHECK: return %[[RET]] : !torch.vtensor<[32,64],f64>
func.func @torch.aten.randn.generator() -> !torch.vtensor<[32, 64],f64> {
%none = torch.constant.none
%int32 = torch.constant.int 32
%int64 = torch.constant.int 64
%size = torch.prim.ListConstruct %int32, %int64 : (!torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.randn.generator %size, %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[32, 64], f64>
return %0 : !torch.vtensor<[32, 64],f64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.normal_functional(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 2.000000e+00
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = NORMAL : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
func.func @torch.aten.normal_functional(%arg0: !torch.vtensor<[32, 64], f64>) -> !torch.vtensor<[32, 64], f64> {
%none = torch.constant.none
%mean = torch.constant.float 2.0
%std = torch.constant.float 1.0
%0 = torch.aten.normal_functional %arg0, %mean, %std, %none : !torch.vtensor<[32, 64], f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64], f64>
return %0 : !torch.vtensor<[32, 64],f64>
}