[TORCH]MLIR] Fix C++17 extension warning

The existing implementation of `ConvertConstantTensorAllocOp<>` requires
a C++17 feature `if constexpr ()`. This commit removes the use of that
feature to support the implementation even for lower C++ versions.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/478/head
Gaurav Shukla 2021-12-15 20:54:50 +05:30 committed by Gaurav Shukla
parent ab81f871e4
commit d13bb0e5c1
1 changed files with 70 additions and 32 deletions

View File

@ -3507,7 +3507,7 @@ struct ConvertAtenScalarToTensorLike : ConversionPattern {
namespace {
// Converts constant tensor allocation like ops.
template <typename OpTy>
template <typename OpTy, int fillVal>
class ConvertConstantTensorAllocOp : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
@ -3517,10 +3517,13 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// Currently memory pinning and layout features are not supported.
// TODO: Add support for layout, pin_memory features.
// Only `none` layout is supported.
if (!op.layout().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default layout is supported");
// The pin_memory should be either `False` or `none`.
bool pinMemory;
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
@ -3529,13 +3532,6 @@ public:
op, "unimplemented: pin_memory must be either None or false");
}
// Memory formats are not supported in the case of `AtenEmptyMemoryFormat`.
if constexpr (std::is_same<OpTy, AtenEmptyMemoryFormatOp>::value) {
if (!op.memory_format().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default memory format is supported");
}
Location loc = op.getLoc();
TypeConverter *typeConverter = this->getTypeConverter();
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
@ -3552,27 +3548,71 @@ public:
typeConverter->convertType(op.getType()).template cast<RankedTensorType>();
Type outElemType = resultType.getElementType();
// Create an uninitialized tensor of `resultSize` shape. It will be returned
// without initialization/filling in the case of `AtenEmptyMemoryFormatOp`.
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultSizeIndex, outElemType);
// `AtenZeros` and `AtenOnes` ops will be filled with corresponding values.
if (std::is_same<OpTy, AtenZerosOp>::value) {
Value zero = getConstant(rewriter, loc, 0, outElemType);
outputTensor =
rewriter.create<linalg::FillOp>(loc, zero, outputTensor).getResult(0);
} else if (std::is_same<OpTy, AtenOnesOp>::value) {
Value one = getConstant(rewriter, loc, 1, outElemType);
outputTensor =
rewriter.create<linalg::FillOp>(loc, one, outputTensor).getResult(0);
}
// Create an uninitialized tensor of `resultSize` shape and fill it with
// value `fillVal`.
Value constVal = getConstant(rewriter, loc, fillVal, outElemType);
Value outputTensor =
createInitTensor(rewriter, loc, resultSizeIndex, outElemType, constVal);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
return success();
}
};
} // namespace
namespace {
// Converts `aten.empty` to `linalg.init_tensor` op.
class ConvertAtenEmptyMemoryFormatOp
: public OpConversionPattern<AtenEmptyMemoryFormatOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenEmptyMemoryFormatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Add support for layout, pin_memory and memory_format features.
// Only `none` layout is supported.
if (!op.layout().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default layout is supported");
// The pin_memory should be either `False` or `none`.
bool pinMemory;
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory))
return rewriter.notifyMatchFailure(
op, "unimplemented: pin_memory must be either None or false");
// Only `none` memory_format is supported.
if (!op.memory_format().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default memory format is supported");
Location loc = op.getLoc();
TypeConverter *typeConverter = this->getTypeConverter();
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
if (!getListConstructElements(op.size(), resultSizeTorchInt)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: size must be constructed using ListConstruct");
}
resultSize = getTypeConvertedValues(rewriter, loc, typeConverter,
resultSizeTorchInt);
for (auto size : resultSize)
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
auto resultType = typeConverter->convertType(op.getType())
.template cast<RankedTensorType>();
// Create an uninitialized tensor of `resultSize` shape.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultSizeIndex, resultType.getElementType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
return success();
}
};
} // namespace
namespace {
class ConvertPrimNumToTensorScalarOp
: public OpConversionPattern<PrimNumToTensorScalarOp> {
@ -3779,14 +3819,12 @@ public:
target.addIllegalOp<AtenEmbeddingOp>();
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
target.addIllegalOp<AtenEmptyMemoryFormatOp>();
patterns.add<ConvertConstantTensorAllocOp<AtenEmptyMemoryFormatOp>>(
typeConverter, context);
target.addIllegalOp<AtenZerosOp>();
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp>>(typeConverter,
context);
target.addIllegalOp<AtenOnesOp>();
patterns.add<ConvertConstantTensorAllocOp<AtenOnesOp>>(typeConverter,
context);
patterns.add<ConvertAtenEmptyMemoryFormatOp>(typeConverter, context);
target.addIllegalOp<AtenZerosOp, AtenOnesOp>();
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp, 0>>(typeConverter,
context);
patterns.add<ConvertConstantTensorAllocOp<AtenOnesOp, 1>>(typeConverter,
context);
target.addIllegalOp<AtenContiguousOp>();
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<AtenIntTensorOp>();