diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 9e3cc2f75..32ec30b18 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -171,14 +171,18 @@ public: ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); if (auto elements = op.getValueAttr().dyn_cast()) { - Type elemTy = op.getValueAttr().getElementType(); - unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); - Type builtinTensorElemTy = IntegerType::get(context, bitWidth); - rewriter.replaceOpWithNewOp( - op, elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { - return APInt(bitWidth, v.getSExtValue()); - })); - return success(); + if (auto type = elements.getType().dyn_cast()) { + Type elemTy = op.getValueAttr().getElementType(); + unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); + Type builtinTensorElemTy = IntegerType::get(context, bitWidth); + auto shapedType = + RankedTensorType::get(type.getShape(), builtinTensorElemTy); + auto rawData = elements.getRawData(); + DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( + shapedType, rawData); + rewriter.replaceOpWithNewOp(op, newAttr); + return success(); + } } if (auto elements = op.getValueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { @@ -190,7 +194,8 @@ public: AsmResourceBlob *blob = elements.getRawHandle().getBlob(); assert(blob && "Expecting dense resource with a valid blob"); rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(shapedType, blob->getData())); + op, DenseResourceElementsAttr::get(shapedType, + elements.getRawHandle())); return success(); } }