mirror of https://github.com/llvm/torch-mlir
[TORCH]MLIR] Fix C++17 extension warning
The existing implementation of `ConvertConstantTensorAllocOp<>` requires a C++17 feature `if constexpr ()`. This commit removes the use of that feature to support the implementation even for lower C++ versions. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/478/head
parent
ab81f871e4
commit
d13bb0e5c1
|
@ -3507,7 +3507,7 @@ struct ConvertAtenScalarToTensorLike : ConversionPattern {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Converts constant tensor allocation like ops.
|
// Converts constant tensor allocation like ops.
|
||||||
template <typename OpTy>
|
template <typename OpTy, int fillVal>
|
||||||
class ConvertConstantTensorAllocOp : public OpConversionPattern<OpTy> {
|
class ConvertConstantTensorAllocOp : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
@ -3517,10 +3517,13 @@ public:
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Currently memory pinning and layout features are not supported.
|
// TODO: Add support for layout, pin_memory features.
|
||||||
|
// Only `none` layout is supported.
|
||||||
if (!op.layout().getType().template isa<Torch::NoneType>())
|
if (!op.layout().getType().template isa<Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: only default layout is supported");
|
op, "unimplemented: only default layout is supported");
|
||||||
|
|
||||||
|
// The pin_memory should be either `False` or `none`.
|
||||||
bool pinMemory;
|
bool pinMemory;
|
||||||
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
|
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
|
||||||
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
|
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
|
||||||
|
@ -3529,13 +3532,6 @@ public:
|
||||||
op, "unimplemented: pin_memory must be either None or false");
|
op, "unimplemented: pin_memory must be either None or false");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Memory formats are not supported in the case of `AtenEmptyMemoryFormat`.
|
|
||||||
if constexpr (std::is_same<OpTy, AtenEmptyMemoryFormatOp>::value) {
|
|
||||||
if (!op.memory_format().getType().template isa<Torch::NoneType>())
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: only default memory format is supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
TypeConverter *typeConverter = this->getTypeConverter();
|
TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
|
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
|
||||||
|
@ -3552,27 +3548,71 @@ public:
|
||||||
typeConverter->convertType(op.getType()).template cast<RankedTensorType>();
|
typeConverter->convertType(op.getType()).template cast<RankedTensorType>();
|
||||||
Type outElemType = resultType.getElementType();
|
Type outElemType = resultType.getElementType();
|
||||||
|
|
||||||
// Create an uninitialized tensor of `resultSize` shape. It will be returned
|
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
||||||
// without initialization/filling in the case of `AtenEmptyMemoryFormatOp`.
|
// value `fillVal`.
|
||||||
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
|
Value constVal = getConstant(rewriter, loc, fillVal, outElemType);
|
||||||
loc, resultSizeIndex, outElemType);
|
Value outputTensor =
|
||||||
|
createInitTensor(rewriter, loc, resultSizeIndex, outElemType, constVal);
|
||||||
// `AtenZeros` and `AtenOnes` ops will be filled with corresponding values.
|
|
||||||
if (std::is_same<OpTy, AtenZerosOp>::value) {
|
|
||||||
Value zero = getConstant(rewriter, loc, 0, outElemType);
|
|
||||||
outputTensor =
|
|
||||||
rewriter.create<linalg::FillOp>(loc, zero, outputTensor).getResult(0);
|
|
||||||
} else if (std::is_same<OpTy, AtenOnesOp>::value) {
|
|
||||||
Value one = getConstant(rewriter, loc, 1, outElemType);
|
|
||||||
outputTensor =
|
|
||||||
rewriter.create<linalg::FillOp>(loc, one, outputTensor).getResult(0);
|
|
||||||
}
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Converts `aten.empty` to `linalg.init_tensor` op.
|
||||||
|
class ConvertAtenEmptyMemoryFormatOp
|
||||||
|
: public OpConversionPattern<AtenEmptyMemoryFormatOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenEmptyMemoryFormatOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// TODO: Add support for layout, pin_memory and memory_format features.
|
||||||
|
// Only `none` layout is supported.
|
||||||
|
if (!op.layout().getType().template isa<Torch::NoneType>())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only default layout is supported");
|
||||||
|
|
||||||
|
// The pin_memory should be either `False` or `none`.
|
||||||
|
bool pinMemory;
|
||||||
|
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
|
||||||
|
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
|
||||||
|
pinMemory))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: pin_memory must be either None or false");
|
||||||
|
|
||||||
|
// Only `none` memory_format is supported.
|
||||||
|
if (!op.memory_format().getType().template isa<Torch::NoneType>())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only default memory format is supported");
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
|
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
|
||||||
|
if (!getListConstructElements(op.size(), resultSizeTorchInt)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: size must be constructed using ListConstruct");
|
||||||
|
}
|
||||||
|
resultSize = getTypeConvertedValues(rewriter, loc, typeConverter,
|
||||||
|
resultSizeTorchInt);
|
||||||
|
for (auto size : resultSize)
|
||||||
|
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||||
|
|
||||||
|
auto resultType = typeConverter->convertType(op.getType())
|
||||||
|
.template cast<RankedTensorType>();
|
||||||
|
// Create an uninitialized tensor of `resultSize` shape.
|
||||||
|
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, resultSizeIndex, resultType.getElementType());
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertPrimNumToTensorScalarOp
|
class ConvertPrimNumToTensorScalarOp
|
||||||
: public OpConversionPattern<PrimNumToTensorScalarOp> {
|
: public OpConversionPattern<PrimNumToTensorScalarOp> {
|
||||||
|
@ -3779,13 +3819,11 @@ public:
|
||||||
target.addIllegalOp<AtenEmbeddingOp>();
|
target.addIllegalOp<AtenEmbeddingOp>();
|
||||||
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
|
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenEmptyMemoryFormatOp>();
|
target.addIllegalOp<AtenEmptyMemoryFormatOp>();
|
||||||
patterns.add<ConvertConstantTensorAllocOp<AtenEmptyMemoryFormatOp>>(
|
patterns.add<ConvertAtenEmptyMemoryFormatOp>(typeConverter, context);
|
||||||
typeConverter, context);
|
target.addIllegalOp<AtenZerosOp, AtenOnesOp>();
|
||||||
target.addIllegalOp<AtenZerosOp>();
|
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp, 0>>(typeConverter,
|
||||||
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp>>(typeConverter,
|
|
||||||
context);
|
context);
|
||||||
target.addIllegalOp<AtenOnesOp>();
|
patterns.add<ConvertConstantTensorAllocOp<AtenOnesOp, 1>>(typeConverter,
|
||||||
patterns.add<ConvertConstantTensorAllocOp<AtenOnesOp>>(typeConverter,
|
|
||||||
context);
|
context);
|
||||||
target.addIllegalOp<AtenContiguousOp>();
|
target.addIllegalOp<AtenContiguousOp>();
|
||||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||||
|
|
Loading…
Reference in New Issue