mirror of https://github.com/llvm/torch-mlir
torch-mlir change for dense resource implementation (#2513)
Co-authored-by: Avinash Sharma <avinash@nod-labs.com>pull/2544/head snapshot-20231104.1012
parent
1b9fb1b51d
commit
88adf384cc
|
@ -171,15 +171,19 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
MLIRContext *context = op->getContext();
|
MLIRContext *context = op->getContext();
|
||||||
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
||||||
|
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
||||||
Type elemTy = op.getValueAttr().getElementType();
|
Type elemTy = op.getValueAttr().getElementType();
|
||||||
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
|
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
|
||||||
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
|
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
|
||||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
auto shapedType =
|
||||||
op, elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
|
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
|
||||||
return APInt(bitWidth, v.getSExtValue());
|
auto rawData = elements.getRawData();
|
||||||
}));
|
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
|
||||||
|
shapedType, rawData);
|
||||||
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
|
if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
|
||||||
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
||||||
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
||||||
|
@ -190,7 +194,8 @@ public:
|
||||||
AsmResourceBlob *blob = elements.getRawHandle().getBlob();
|
AsmResourceBlob *blob = elements.getRawHandle().getBlob();
|
||||||
assert(blob && "Expecting dense resource with a valid blob");
|
assert(blob && "Expecting dense resource with a valid blob");
|
||||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||||
op, DenseElementsAttr::get(shapedType, blob->getData()));
|
op, DenseResourceElementsAttr::get(shapedType,
|
||||||
|
elements.getRawHandle()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue