//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { class ConvertAtenSizeIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenSizeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value self = adaptor.getSelf(); Value dim = adaptor.getDim(); auto type = cast(self.getType()); Value inputRank = rewriter.create( loc, rewriter.getI64IntegerAttr(type.getRank())); Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank); if (!isAssumingStrictSymbolicShapes(rewriter)) { assertIsValidDim(rewriter, loc, dimPositive, inputRank); } Value size = rewriter.create( loc, adaptor.getSelf(), castIntToIndex(rewriter, loc, dimPositive)); rewriter.replaceOp(op, castIndexToInt64(rewriter, loc, size)); return success(); } }; } // namespace namespace { class ConvertAtenNumelOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNumelOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value tensorSize = getTensorSize(rewriter, loc, adaptor.getSelf()); rewriter.replaceOp(op, tensorSize); return success(); } }; } // namespace namespace { // Casts a tensor of exactly one element to an elemental type. template class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpTy op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value input = adaptor.getA(); SmallVector inputSizes = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputSizes.size(); Type inputDtype = cast(op.getA().getType()).getDtype(); // The `input` tensor must contain exactly one element, i.e., either the // `input` is a zero rank tensor or all the dimensions of the `input` tensor // are unit. Value constantOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); for (int64_t i = 0; i < inputRank; i++) checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); // Extract the only element from the `input` tensor. Value constantZero = rewriter.create(loc, rewriter.getIndexAttr(0)); SmallVector indices(inputRank, constantZero); Value result = rewriter.create(loc, input, indices); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype)); return success(); } }; } // namespace namespace { class ConvertAtenScalarToTensorLike : public ConversionPattern { public: ConvertAtenScalarToTensorLike(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!isa(op)) return rewriter.notifyMatchFailure( op, "not a supported Scalar to Tensor like op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value elemVal, dtype, device, requires_grad; if (AtenTensorIntOp tensorIntOp = dyn_cast(op)) { AtenTensorIntOp::Adaptor adaptor(operands); elemVal = adaptor.getT(); dtype = tensorIntOp.getDtype(); device = tensorIntOp.getDevice(); requires_grad = tensorIntOp.getRequiresGrad(); } if (AtenTensorFloatOp tensorFloatOp = dyn_cast(op)) { AtenTensorFloatOp::Adaptor adaptor(operands); elemVal = adaptor.getT(); dtype = tensorFloatOp.getDtype(); device = tensorFloatOp.getDevice(); requires_grad = tensorFloatOp.getRequiresGrad(); } // TODO: Dtype conversion. if (!isa(dtype.getType())) return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype"); // TODO: Device information. if (!isa(device.getType())) return rewriter.notifyMatchFailure( op, "Unimplemented non-None device information"); RankedTensorType resultType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); Type outElementType = resultType.getElementType(); Value elemValProm = convertScalarToDtype(rewriter, loc, elemVal, outElementType); Value zeroDTensor = createInitTensor(rewriter, loc, {}, outElementType, elemValProm); rewriter.replaceOpWithNewOp(op, resultType, zeroDTensor); return success(); } }; } // namespace namespace { class ConvertPrimNumToTensorScalarOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); RankedTensorType resultType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); Type outElementType = resultType.getElementType(); Value elemVal = adaptor.getA(); Value elemValProm = convertScalarToDtype(rewriter, loc, elemVal, outElementType); Value zeroDTensor = createInitTensor(rewriter, loc, {}, outElementType, elemValProm); rewriter.replaceOp(op, zeroDTensor); return success(); } }; } // namespace namespace { class ConvertAtenFullOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenFullOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); SmallVector inShape; if (!getListConstructElements(adaptor.getSize(), inShape)) { return rewriter.notifyMatchFailure( op, "unimplemented: the size list is not from list construct"); } auto resultTy = cast( this->getTypeConverter()->convertType(op.getResult().getType())); if (resultTy.getRank() != static_cast(inShape.size())) return rewriter.notifyMatchFailure( op, "rank of shape and result shape do not match"); SmallVector filteredShape; for (int i = 0, s = resultTy.getRank(); i < s; ++i) { if (resultTy.isDynamicDim(i)) { filteredShape.push_back(inShape[i]); continue; } filteredShape.push_back(rewriter.getIndexAttr(resultTy.getDimSize(i))); } Value full = adaptor.getFillValue(); if (full.getType() != resultTy.getElementType()) { if (isa(full.getType())) { full = rewriter.create(loc, resultTy.getElementType(), full); } else if (isa(full.getType())) { full = rewriter.create(loc, resultTy.getElementType(), full); } } Value outTensor = rewriter.create( loc, filteredShape, resultTy.getElementType()); rewriter.replaceOpWithNewOp(op, full, outTensor); return success(); } }; } // namespace namespace { // Converts a tensor with one element to a scalar value. template class ConvertAtenImplicitLikeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpTy op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getA()); return success(); } }; } // namespace void mlir::torch::torch_to_linalg:: populateTensorScalarInteropPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); patterns.add>( typeConverter, context); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); target.addIllegalOp(); }