//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" #include "./PopulatePatterns.h" #include "mlir/IR/BuiltinTypes.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenUniformOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); Value generator = adaptor.getGenerator(); Location loc = op.getLoc(); if (!isa(generator.getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); auto elements = cast(self.getType()).getShape(); if (llvm::any_of(elements, [](int64_t dim) { return dim == ShapedType::kDynamic; })) return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); auto shape_tensor = rewriter.create( loc, rewriter.getI64TensorAttr(elements)); auto outTy = getTypeConverter()->convertType(op.getType()); auto outElemTy = cast(outTy).getElementType(); Value from = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy); Value to = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy); rewriter.replaceOpWithNewOp( op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenRandnGeneratorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value generator = adaptor.getGenerator(); Location loc = op.getLoc(); if (!isa(generator.getType())) { return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); } llvm::SmallVector shape; if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) { return rewriter.notifyMatchFailure(op, "size must be constant"); } auto outTy = getTypeConverter()->convertType(op.getType()); auto outElemTy = cast(outTy).getElementType(); if (!isa(outElemTy)) { return rewriter.notifyMatchFailure(op, "only support output with float type"); } auto scalarTy = RankedTensorType::get({}, outElemTy); Value shapeTensor = rewriter.create( loc, rewriter.getI64TensorAttr(shape)); Value mean = rewriter.create( loc, DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 0.0))); Value var = rewriter.create( loc, DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 1.0))); rewriter.replaceOpWithNewOp( op, outTy, mean, var, shapeTensor, stablehlo::RngDistribution::NORMAL); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenNormalFunctionalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); Value generator = adaptor.getGenerator(); Location loc = op.getLoc(); if (!isa(generator.getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); auto elements = cast(self.getType()).getShape(); if (llvm::any_of(elements, [](int64_t dim) { return dim == ShapedType::kDynamic; })) return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); auto shapeTensor = rewriter.create( loc, rewriter.getI64TensorAttr(elements)); auto outTy = getTypeConverter()->convertType(op.getType()); auto outElemTy = cast(outTy).getElementType(); Value mean = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getMean(), outElemTy); Value std = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStd(), outElemTy); rewriter.replaceOpWithNewOp( op, outTy, mean, std, shapeTensor, stablehlo::RngDistribution::NORMAL); return success(); } void mlir::torch::torch_to_stablehlo::populateRngOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenUniformOp); INSERT_ATENOP_PATTERN(AtenRandnGeneratorOp); INSERT_ATENOP_PATTERN(AtenNormalFunctionalOp); #undef INSERT_ATENOP_PATTERN }