//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v3.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-1.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { class ConvertAtenItemOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenItemOp::Adaptor; LogicalResult matchAndRewrite(AtenItemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto operand = adaptor.getOperands()[0]; auto operandTy = cast(operand.getType()); auto torchDTy = cast(op.getOperand().getType()).getDtype(); if (operandTy.getNumElements() != 1) return rewriter.notifyMatchFailure(op, "expected only one item"); auto zeroIdx = rewriter.create(op.getLoc(), 0); auto rank = operandTy.getRank(); llvm::SmallVector indices(rank, zeroIdx); Value extract = rewriter.create( op.getLoc(), operandTy.getElementType(), operand, indices); auto extractTy = extract.getType(); if (isa(extractTy) && !extractTy.isInteger(64)) { if (torchDTy.isUnsignedInteger()) { extract = rewriter.create( op.getLoc(), rewriter.getIntegerType(64), extract); } else { extract = rewriter.create( op.getLoc(), rewriter.getIntegerType(64), extract); } } if (isa(extractTy) && !extractTy.isF64()) { extract = rewriter.create(op.getLoc(), rewriter.getF64Type(), extract); } rewriter.replaceOp(op, extract); return success(); } }; class ConvertAtenShapeToTensorPatternOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename Aten_ShapeAsTensorOp::Adaptor; LogicalResult matchAndRewrite(Aten_ShapeAsTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto operand = adaptor.getOperands()[0]; auto operandTy = cast(operand.getType()); auto resultTy = cast(getTypeConverter()->convertType(op.getType())); int64_t rank = operandTy.getRank(); if (rank == 0) { rewriter.replaceOpWithNewOp(op, resultTy.getShape(), resultTy.getElementType()); return success(); } SmallVector dims; for (int i = 0; i < rank; ++i) { Value dim = rewriter.createOrFold(loc, operand, i); dim = rewriter.createOrFold( loc, resultTy.getElementType(), dim); dims.push_back(dim); } Value tensor = rewriter.createOrFold(op.getLoc(), dims); rewriter.replaceOp(op, tensor); return success(); } }; class ConvertAtenTensorOpPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenTensorOp::Adaptor; LogicalResult matchAndRewrite(AtenTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto list = op.getData().getDefiningOp(); if (!list) return failure(); auto typeConverter = getTypeConverter(); auto resultTy = cast(typeConverter->convertType(op.getType())); auto resultETy = resultTy.getElementType(); SmallVector values; for (Value operand : list.getOperands()) { Value value = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(operand.getType()), operand); if (isa(resultETy) && value.getType() != resultETy) value = rewriter.create(loc, resultETy, value); if (isa(resultETy) && value.getType() != resultETy) value = rewriter.create(loc, resultETy, value); values.push_back(value); } rewriter.replaceOpWithNewOp(op, resultTy, values); return success(); } }; class ConvertTorchToTensor : public ConvertTorchToTensorBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::torch::createConvertTorchToTensorPass() { return std::make_unique(); }