torch-mlir change for dense resource implementation (#2513)

Co-authored-by: Avinash Sharma <avinash@nod-labs.com>
pull/2544/head snapshot-20231104.1012
saienduri 2023-11-03 11:44:07 -07:00 committed by GitHub
parent 1b9fb1b51d
commit 88adf384cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 9 deletions

View File

@ -171,15 +171,19 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext(); MLIRContext *context = op->getContext();
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) { if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
Type elemTy = op.getValueAttr().getElementType(); Type elemTy = op.getValueAttr().getElementType();
unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
Type builtinTensorElemTy = IntegerType::get(context, bitWidth); Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
rewriter.replaceOpWithNewOp<arith::ConstantOp>( auto shapedType =
op, elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { RankedTensorType::get(type.getShape(), builtinTensorElemTy);
return APInt(bitWidth, v.getSExtValue()); auto rawData = elements.getRawData();
})); DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
shapedType, rawData);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
return success(); return success();
} }
}
if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) { if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) { if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) { if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
@ -190,7 +194,8 @@ public:
AsmResourceBlob *blob = elements.getRawHandle().getBlob(); AsmResourceBlob *blob = elements.getRawHandle().getBlob();
assert(blob && "Expecting dense resource with a valid blob"); assert(blob && "Expecting dense resource with a valid blob");
rewriter.replaceOpWithNewOp<arith::ConstantOp>( rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(shapedType, blob->getData())); op, DenseResourceElementsAttr::get(shapedType,
elements.getRawHandle()));
return success(); return success();
} }
} }