[tosa] Fix torch.vtensor.literal lowering

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1431/head
Vivek Khandelwal 2022-09-29 13:17:22 +05:30
parent 53e76b8ab6
commit bce00c8ed1
2 changed files with 4 additions and 4 deletions

View File

@ -2224,9 +2224,9 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
// existing elements attribute.
// TODO: what about unsigned integer?
if (auto elements = op.valueAttr().dyn_cast<DenseIntElementsAttr>()) {
Type builtinTensorElemTy = outputTy.getElementType();
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
if (builtinTensorElemTy.isSignedInteger()) {
if (elements.getElementType().isSignedInteger()) {
Type builtinTensorElemTy = outputTy.getElementType();
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
DenseElementsAttr valueAttr =
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
return APInt(bitWidth, v.getSExtValue());

View File

@ -821,7 +821,7 @@ func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf3
// -----
// CHECK-LABEL: @torch.vtensor.literal_si64$basic(
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<-1> : tensor<1x512xsi64>} : () -> tensor<1x512xi64>
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<-1> : tensor<1x512xi64>} : () -> tensor<1x512xi64>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<1x512xi64> -> !torch.vtensor<[1,512],si64>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[1,512],si64>
func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> {