//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { class ConvertAtenSizeIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenSizeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value self = adaptor.self(); Value dim = adaptor.dim(); auto type = self.getType().cast(); Value inputRank = rewriter.create( loc, rewriter.getI64IntegerAttr(type.getRank())); Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank); assertIsValidDim(rewriter, loc, dimPositive, inputRank); Value size = rewriter.create( loc, adaptor.self(), castIntToIndex(rewriter, loc, dimPositive)); rewriter.replaceOp(op, castIndexToInt(rewriter, loc, size)); return success(); } }; } // namespace namespace { class ConvertAtenNumelOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNumelOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value tensorSize = getTensorSize(rewriter, loc, adaptor.self()); rewriter.replaceOp(op, tensorSize); return success(); } }; } // namespace namespace { // Casts a tensor of exactly one element to an elemental type. template class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpTy op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value input = adaptor.a(); SmallVector inputSizes = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputSizes.size(); // The `input` tensor must contain exactly one element, i.e., either the // `input` is a zero rank tensor or all the dimensions of the `input` tensor // are unit. Value constantOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); for (int64_t i = 0; i < inputRank; i++) checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); // Extract the only element from the `input` tensor. Value constantZero = rewriter.create(loc, rewriter.getIndexAttr(0)); SmallVector indices(inputRank, constantZero); rewriter.replaceOpWithNewOp(op, input, indices); return success(); } }; } // namespace namespace { class ConvertAtenScalarToTensorLike : public ConversionPattern { public: ConvertAtenScalarToTensorLike(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!isa(op)) return rewriter.notifyMatchFailure( op, "not a supported Scalar to Tensor like op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value elemVal, dtype, device, requires_grad; if (AtenTensorIntOp tensorIntOp = dyn_cast(op)) { AtenTensorIntOp::Adaptor adaptor(operands); elemVal = adaptor.t(); dtype = tensorIntOp.dtype(); device = tensorIntOp.device(); requires_grad = tensorIntOp.requires_grad(); } if (AtenTensorFloatOp tensorFloatOp = dyn_cast(op)) { AtenTensorFloatOp::Adaptor adaptor(operands); elemVal = adaptor.t(); dtype = tensorFloatOp.dtype(); device = tensorFloatOp.device(); requires_grad = tensorFloatOp.requires_grad(); } // TODO: Dtype conversion. if (!dtype.getType().isa()) return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype"); // TODO: Device information. if (!device.getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented non-None device information"); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type outElementType = resultType.getElementType(); Value elemValProm = convertScalarToDtype(rewriter, loc, elemVal, outElementType); Value zeroDTensor = createInitTensor(rewriter, loc, {}, outElementType, elemValProm); rewriter.replaceOpWithNewOp(op, resultType, zeroDTensor); return success(); } }; } // namespace namespace { class ConvertPrimNumToTensorScalarOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value a = adaptor.a(); Value outTensor = rewriter.create(loc, ValueRange{}, a.getType()) ->getResult(0); rewriter.replaceOpWithNewOp(op, a, outTensor); return success(); } }; } // namespace void mlir::torch::torch_to_linalg:: populateTensorScalarInteropPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); patterns.add>( typeConverter, context); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); }