torch-mlir change for dense resource implementation (#2513)

Co-authored-by: Avinash Sharma <avinash@nod-labs.com>
pull/2544/head snapshot-20231104.1012
saienduri 2023-11-03 11:44:07 -07:00 committed by GitHub
parent 1b9fb1b51d
commit 88adf384cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 9 deletions

View File

@ -171,14 +171,18 @@ public:
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
Type elemTy = op.getValueAttr().getElementType();
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
return APInt(bitWidth, v.getSExtValue());
}));
return success();
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
Type elemTy = op.getValueAttr().getElementType();
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
auto shapedType =
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
auto rawData = elements.getRawData();
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
shapedType, rawData);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
return success();
}
}
if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
@ -190,7 +194,8 @@ public:
AsmResourceBlob *blob = elements.getRawHandle().getBlob();
assert(blob && "Expecting dense resource with a valid blob");
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(shapedType, blob->getData()));
op, DenseResourceElementsAttr::get(shapedType,
elements.getRawHandle()));
return success();
}
}