mlir: fix replacement of `OpaqueElementsAttr` (#1274)

An earlier patch (bb47c166) incorrectly replaced the now-dropped
`OpaqueElementsAttr` with `SparseElementsAttr` in one place and with
`DenseElementsAttr` in another.  This patch fixes the problem by making
both replacements use the dense-equivalent type.
pull/1280/head
Ashay Rane 2022-08-24 17:10:40 -05:00 committed by GitHub
parent e2f862cb85
commit 1d9d925f6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 2 deletions

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
@ -178,15 +179,17 @@ public:
})); }));
return success(); return success();
} }
if (auto elements = op.valueAttr().dyn_cast<SparseElementsAttr>()) { if (auto elements = op.valueAttr().dyn_cast<DenseResourceElementsAttr>()) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) { if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) { if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
Type builtinTensorElemTy = Type builtinTensorElemTy =
IntegerType::get(context, intType.getIntOrFloatBitWidth()); IntegerType::get(context, intType.getIntOrFloatBitWidth());
auto shapedType = auto shapedType =
RankedTensorType::get(type.getShape(), builtinTensorElemTy); RankedTensorType::get(type.getShape(), builtinTensorElemTy);
AsmResourceBlob *blob = elements.getRawHandle().getBlob();
assert(blob && "Expecting dense resource with a valid blob");
rewriter.replaceOpWithNewOp<arith::ConstantOp>( rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(shapedType, elements.getValues())); op, DenseElementsAttr::get(shapedType, blob->getData()));
return success(); return success();
} }
} }