mirror of https://github.com/llvm/torch-mlir
[Stablehlo] lowering aten.randn & aten.normal_functional to mhlo.rng … (#3328)
…NORMAL * split lowering of uniform, randn, normal from Basic.cpp into Rng.cpppull/3329/head
parent
00efec0b73
commit
5f7cb9e253
|
@ -1819,36 +1819,6 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
|
||||
AtenUniformOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
Value generator = adaptor.getGenerator();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
if (!isa<Torch::NoneType>(generator.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
||||
auto elements = cast<RankedTensorType>(self.getType()).getShape();
|
||||
if (llvm::any_of(elements,
|
||||
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
|
||||
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
|
||||
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
|
||||
loc, rewriter.getI64TensorAttr(elements));
|
||||
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
|
||||
Value from =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
|
||||
Value to =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
|
||||
op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Converts `aten.empty.memory_format` to `tensor.empty` op.
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
||||
|
@ -2240,7 +2210,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||
|
|
|
@ -6,6 +6,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
|||
Linear.cpp
|
||||
ViewLike.cpp
|
||||
Reduction.cpp
|
||||
Rng.cpp
|
||||
Pooling.cpp
|
||||
Utils.cpp
|
||||
|
||||
|
|
|
@ -62,6 +62,11 @@ void populatePoolingOpPatternsAndLegality(
|
|||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
|
||||
void populateRngOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToStablehloOptions &options);
|
||||
|
||||
} // namespace torch_to_stablehlo
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
|
||||
AtenUniformOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
Value generator = adaptor.getGenerator();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
if (!isa<Torch::NoneType>(generator.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
||||
auto elements = cast<RankedTensorType>(self.getType()).getShape();
|
||||
if (llvm::any_of(elements,
|
||||
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
|
||||
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
|
||||
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
|
||||
loc, rewriter.getI64TensorAttr(elements));
|
||||
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
|
||||
Value from =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
|
||||
Value to =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
|
||||
op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenRandnGeneratorOp>::matchAndRewrite(
|
||||
AtenRandnGeneratorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value generator = adaptor.getGenerator();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
if (!isa<Torch::NoneType>(generator.getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
}
|
||||
llvm::SmallVector<int64_t> shape;
|
||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) {
|
||||
return rewriter.notifyMatchFailure(op, "size must be constant");
|
||||
}
|
||||
|
||||
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
|
||||
auto scalarTy = RankedTensorType::get({}, outElemTy);
|
||||
if (!isa<mlir::FloatType>(outElemTy)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support output with float type");
|
||||
}
|
||||
|
||||
Value shapeTensor = rewriter.create<stablehlo::ConstantOp>(
|
||||
loc, rewriter.getI64TensorAttr(shape));
|
||||
Value mean = rewriter.create<stablehlo::ConstantOp>(
|
||||
loc, DenseFPElementsAttr::get(scalarTy, 0.0));
|
||||
Value var = rewriter.create<stablehlo::ConstantOp>(
|
||||
loc, DenseFPElementsAttr::get(scalarTy, 1.0));
|
||||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
|
||||
op, outTy, mean, var, shapeTensor, stablehlo::RngDistribution::NORMAL);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenNormalFunctionalOp>::matchAndRewrite(
|
||||
AtenNormalFunctionalOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
Value generator = adaptor.getGenerator();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
if (!isa<Torch::NoneType>(generator.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
||||
auto elements = cast<RankedTensorType>(self.getType()).getShape();
|
||||
if (llvm::any_of(elements,
|
||||
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
|
||||
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
|
||||
auto shapeTensor = rewriter.create<stablehlo::ConstantOp>(
|
||||
loc, rewriter.getI64TensorAttr(elements));
|
||||
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
|
||||
Value mean =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getMean(), outElemTy);
|
||||
Value std =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStd(), outElemTy);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
|
||||
op, outTy, mean, std, shapeTensor, stablehlo::RngDistribution::NORMAL);
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::populateRngOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRandnGeneratorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenNormalFunctionalOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
}
|
|
@ -75,6 +75,8 @@ public:
|
|||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populateRngOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -291,33 +291,6 @@ func.func @torch.runtime.assert(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.uniform(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
|
||||
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
|
||||
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
||||
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
|
||||
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
|
||||
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
|
||||
func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> {
|
||||
%none = torch.constant.none
|
||||
%float0 = torch.constant.float 0.0
|
||||
%float1 = torch.constant.float 1.0
|
||||
%0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64>
|
||||
return %0 : !torch.vtensor<[32, 64],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> {
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.uniform(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
|
||||
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
|
||||
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
||||
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
|
||||
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
|
||||
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
|
||||
func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> {
|
||||
%none = torch.constant.none
|
||||
%float0 = torch.constant.float 0.0
|
||||
%float1 = torch.constant.float 1.0
|
||||
%0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64>
|
||||
return %0 : !torch.vtensor<[32, 64],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.randn.generator
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT32:.*]] = torch.constant.int 32
|
||||
// CHECK: %[[INT64:.*]] = torch.constant.int 64
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
|
||||
// CHECK: %[[SHAPE:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
||||
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f64>
|
||||
// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f64>
|
||||
// CHECK: %[[RNG:.*]] = stablehlo.rng %[[VAL_0]], %[[VAL_1]], %[[SHAPE]], distribution = NORMAL : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
||||
// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[RNG]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[32,64],f64>
|
||||
func.func @torch.aten.randn.generator() -> !torch.vtensor<[32, 64],f64> {
|
||||
%none = torch.constant.none
|
||||
%int32 = torch.constant.int 32
|
||||
%int64 = torch.constant.int 64
|
||||
%size = torch.prim.ListConstruct %int32, %int64 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%0 = torch.aten.randn.generator %size, %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[32, 64], f64>
|
||||
return %0 : !torch.vtensor<[32, 64],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.normal_functional(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 2.000000e+00
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]]
|
||||
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]]
|
||||
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64>
|
||||
// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor<f64>
|
||||
// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64>
|
||||
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor<f64>
|
||||
// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = NORMAL : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<32x64xf64>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64>
|
||||
func.func @torch.aten.normal_functional(%arg0: !torch.vtensor<[32, 64], f64>) -> !torch.vtensor<[32, 64], f64> {
|
||||
%none = torch.constant.none
|
||||
%mean = torch.constant.float 2.0
|
||||
%std = torch.constant.float 1.0
|
||||
%0 = torch.aten.normal_functional %arg0, %mean, %std, %none : !torch.vtensor<[32, 64], f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64], f64>
|
||||
return %0 : !torch.vtensor<[32, 64],f64>
|
||||
}
|
Loading…
Reference in New Issue