mirror of https://github.com/llvm/torch-mlir
[tosa] Fix torch.vtensor.literal lowering
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1431/head
parent
53e76b8ab6
commit
bce00c8ed1
|
@ -2224,9 +2224,9 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||||
// existing elements attribute.
|
// existing elements attribute.
|
||||||
// TODO: what about unsigned integer?
|
// TODO: what about unsigned integer?
|
||||||
if (auto elements = op.valueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
if (auto elements = op.valueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
||||||
Type builtinTensorElemTy = outputTy.getElementType();
|
if (elements.getElementType().isSignedInteger()) {
|
||||||
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
Type builtinTensorElemTy = outputTy.getElementType();
|
||||||
if (builtinTensorElemTy.isSignedInteger()) {
|
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
||||||
DenseElementsAttr valueAttr =
|
DenseElementsAttr valueAttr =
|
||||||
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
|
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
|
||||||
return APInt(bitWidth, v.getSExtValue());
|
return APInt(bitWidth, v.getSExtValue());
|
||||||
|
|
|
@ -821,7 +821,7 @@ func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf3
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @torch.vtensor.literal_si64$basic(
|
// CHECK-LABEL: @torch.vtensor.literal_si64$basic(
|
||||||
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<-1> : tensor<1x512xsi64>} : () -> tensor<1x512xi64>
|
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<-1> : tensor<1x512xi64>} : () -> tensor<1x512xi64>
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<1x512xi64> -> !torch.vtensor<[1,512],si64>
|
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<1x512xi64> -> !torch.vtensor<[1,512],si64>
|
||||||
// CHECK: return %[[VAL_1]] : !torch.vtensor<[1,512],si64>
|
// CHECK: return %[[VAL_1]] : !torch.vtensor<[1,512],si64>
|
||||||
func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> {
|
func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> {
|
||||||
|
|
Loading…
Reference in New Issue