diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 7fca15acf..4fe360077 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -3507,7 +3507,7 @@ struct ConvertAtenScalarToTensorLike : ConversionPattern { namespace { // Converts constant tensor allocation like ops. -template +template class ConvertConstantTensorAllocOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -3517,10 +3517,13 @@ public: if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - // Currently memory pinning and layout features are not supported. + // TODO: Add support for layout, pin_memory features. + // Only `none` layout is supported. if (!op.layout().getType().template isa()) return rewriter.notifyMatchFailure( op, "unimplemented: only default layout is supported"); + + // The pin_memory should be either `False` or `none`. bool pinMemory; if (!op.pin_memory().getType().template isa() && (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || @@ -3529,13 +3532,6 @@ public: op, "unimplemented: pin_memory must be either None or false"); } - // Memory formats are not supported in the case of `AtenEmptyMemoryFormat`. - if constexpr (std::is_same::value) { - if (!op.memory_format().getType().template isa()) - return rewriter.notifyMatchFailure( - op, "unimplemented: only default memory format is supported"); - } - Location loc = op.getLoc(); TypeConverter *typeConverter = this->getTypeConverter(); SmallVector resultSizeTorchInt, resultSize, resultSizeIndex; @@ -3552,27 +3548,71 @@ public: typeConverter->convertType(op.getType()).template cast(); Type outElemType = resultType.getElementType(); - // Create an uninitialized tensor of `resultSize` shape. It will be returned - // without initialization/filling in the case of `AtenEmptyMemoryFormatOp`. - Value outputTensor = rewriter.create( - loc, resultSizeIndex, outElemType); - - // `AtenZeros` and `AtenOnes` ops will be filled with corresponding values. - if (std::is_same::value) { - Value zero = getConstant(rewriter, loc, 0, outElemType); - outputTensor = - rewriter.create(loc, zero, outputTensor).getResult(0); - } else if (std::is_same::value) { - Value one = getConstant(rewriter, loc, 1, outElemType); - outputTensor = - rewriter.create(loc, one, outputTensor).getResult(0); - } + // Create an uninitialized tensor of `resultSize` shape and fill it with + // value `fillVal`. + Value constVal = getConstant(rewriter, loc, fillVal, outElemType); + Value outputTensor = + createInitTensor(rewriter, loc, resultSizeIndex, outElemType, constVal); rewriter.replaceOpWithNewOp(op, resultType, outputTensor); return success(); } }; } // namespace +namespace { +// Converts `aten.empty` to `linalg.init_tensor` op. +class ConvertAtenEmptyMemoryFormatOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenEmptyMemoryFormatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + // TODO: Add support for layout, pin_memory and memory_format features. + // Only `none` layout is supported. + if (!op.layout().getType().template isa()) + return rewriter.notifyMatchFailure( + op, "unimplemented: only default layout is supported"); + + // The pin_memory should be either `False` or `none`. + bool pinMemory; + if (!op.pin_memory().getType().template isa() && + (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || + pinMemory)) + return rewriter.notifyMatchFailure( + op, "unimplemented: pin_memory must be either None or false"); + + // Only `none` memory_format is supported. + if (!op.memory_format().getType().template isa()) + return rewriter.notifyMatchFailure( + op, "unimplemented: only default memory format is supported"); + + Location loc = op.getLoc(); + TypeConverter *typeConverter = this->getTypeConverter(); + SmallVector resultSizeTorchInt, resultSize, resultSizeIndex; + if (!getListConstructElements(op.size(), resultSizeTorchInt)) { + return rewriter.notifyMatchFailure( + op, "unimplemented: size must be constructed using ListConstruct"); + } + resultSize = getTypeConvertedValues(rewriter, loc, typeConverter, + resultSizeTorchInt); + for (auto size : resultSize) + resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); + + auto resultType = typeConverter->convertType(op.getType()) + .template cast(); + // Create an uninitialized tensor of `resultSize` shape. + Value initTensor = rewriter.create( + loc, resultSizeIndex, resultType.getElementType()); + rewriter.replaceOpWithNewOp(op, resultType, initTensor); + return success(); + } +}; +} // namespace + namespace { class ConvertPrimNumToTensorScalarOp : public OpConversionPattern { @@ -3779,14 +3819,12 @@ public: target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); - patterns.add>( - typeConverter, context); - target.addIllegalOp(); - patterns.add>(typeConverter, - context); - target.addIllegalOp(); - patterns.add>(typeConverter, - context); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp();