mirror of https://github.com/llvm/torch-mlir
[TORCH]MLIR] Fix C++17 extension warning
The existing implementation of `ConvertConstantTensorAllocOp<>` requires a C++17 feature `if constexpr ()`. This commit removes the use of that feature to support the implementation even for lower C++ versions. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/478/head
parent
ab81f871e4
commit
d13bb0e5c1
|
@ -3507,7 +3507,7 @@ struct ConvertAtenScalarToTensorLike : ConversionPattern {
|
|||
|
||||
namespace {
|
||||
// Converts constant tensor allocation like ops.
|
||||
template <typename OpTy>
|
||||
template <typename OpTy, int fillVal>
|
||||
class ConvertConstantTensorAllocOp : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
|
@ -3517,10 +3517,13 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
// Currently memory pinning and layout features are not supported.
|
||||
// TODO: Add support for layout, pin_memory features.
|
||||
// Only `none` layout is supported.
|
||||
if (!op.layout().getType().template isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default layout is supported");
|
||||
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
|
||||
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
|
@ -3529,13 +3532,6 @@ public:
|
|||
op, "unimplemented: pin_memory must be either None or false");
|
||||
}
|
||||
|
||||
// Memory formats are not supported in the case of `AtenEmptyMemoryFormat`.
|
||||
if constexpr (std::is_same<OpTy, AtenEmptyMemoryFormatOp>::value) {
|
||||
if (!op.memory_format().getType().template isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default memory format is supported");
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = this->getTypeConverter();
|
||||
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
|
||||
|
@ -3552,27 +3548,71 @@ public:
|
|||
typeConverter->convertType(op.getType()).template cast<RankedTensorType>();
|
||||
Type outElemType = resultType.getElementType();
|
||||
|
||||
// Create an uninitialized tensor of `resultSize` shape. It will be returned
|
||||
// without initialization/filling in the case of `AtenEmptyMemoryFormatOp`.
|
||||
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, resultSizeIndex, outElemType);
|
||||
|
||||
// `AtenZeros` and `AtenOnes` ops will be filled with corresponding values.
|
||||
if (std::is_same<OpTy, AtenZerosOp>::value) {
|
||||
Value zero = getConstant(rewriter, loc, 0, outElemType);
|
||||
outputTensor =
|
||||
rewriter.create<linalg::FillOp>(loc, zero, outputTensor).getResult(0);
|
||||
} else if (std::is_same<OpTy, AtenOnesOp>::value) {
|
||||
Value one = getConstant(rewriter, loc, 1, outElemType);
|
||||
outputTensor =
|
||||
rewriter.create<linalg::FillOp>(loc, one, outputTensor).getResult(0);
|
||||
}
|
||||
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
||||
// value `fillVal`.
|
||||
Value constVal = getConstant(rewriter, loc, fillVal, outElemType);
|
||||
Value outputTensor =
|
||||
createInitTensor(rewriter, loc, resultSizeIndex, outElemType, constVal);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Converts `aten.empty` to `linalg.init_tensor` op.
|
||||
class ConvertAtenEmptyMemoryFormatOp
|
||||
: public OpConversionPattern<AtenEmptyMemoryFormatOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenEmptyMemoryFormatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
// TODO: Add support for layout, pin_memory and memory_format features.
|
||||
// Only `none` layout is supported.
|
||||
if (!op.layout().getType().template isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default layout is supported");
|
||||
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
|
||||
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: pin_memory must be either None or false");
|
||||
|
||||
// Only `none` memory_format is supported.
|
||||
if (!op.memory_format().getType().template isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default memory format is supported");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = this->getTypeConverter();
|
||||
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
|
||||
if (!getListConstructElements(op.size(), resultSizeTorchInt)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: size must be constructed using ListConstruct");
|
||||
}
|
||||
resultSize = getTypeConvertedValues(rewriter, loc, typeConverter,
|
||||
resultSizeTorchInt);
|
||||
for (auto size : resultSize)
|
||||
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||
|
||||
auto resultType = typeConverter->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
// Create an uninitialized tensor of `resultSize` shape.
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, resultSizeIndex, resultType.getElementType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertPrimNumToTensorScalarOp
|
||||
: public OpConversionPattern<PrimNumToTensorScalarOp> {
|
||||
|
@ -3779,14 +3819,12 @@ public:
|
|||
target.addIllegalOp<AtenEmbeddingOp>();
|
||||
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenEmptyMemoryFormatOp>();
|
||||
patterns.add<ConvertConstantTensorAllocOp<AtenEmptyMemoryFormatOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenZerosOp>();
|
||||
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<AtenOnesOp>();
|
||||
patterns.add<ConvertConstantTensorAllocOp<AtenOnesOp>>(typeConverter,
|
||||
context);
|
||||
patterns.add<ConvertAtenEmptyMemoryFormatOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenZerosOp, AtenOnesOp>();
|
||||
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp, 0>>(typeConverter,
|
||||
context);
|
||||
patterns.add<ConvertConstantTensorAllocOp<AtenOnesOp, 1>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<AtenContiguousOp>();
|
||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIntTensorOp>();
|
||||
|
|
Loading…
Reference in New Issue