Convert torch si literals into signless for TOSA (#1421)

pull/1422/head snapshot-20220927.609
Eric Kunze 2022-09-26 16:54:27 -07:00 committed by GitHub
parent c03aa63325
commit cb1b8796a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 5 deletions

View File

@ -2214,11 +2214,28 @@ template <>
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
ValueTensorLiteralOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto outputTy = getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputTy, adaptor.value());
// Tensors with integer types need to be converted to signless integer
// element type. All tensors with element types other than integer can reuse
// existing elements attribute.
// TODO: what about unsigned integer?
if (auto elements = op.valueAttr().dyn_cast<DenseIntElementsAttr>()) {
Type builtinTensorElemTy = outputTy.getElementType();
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
if (builtinTensorElemTy.isSignedInteger()) {
DenseElementsAttr valueAttr =
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
return APInt(bitWidth, v.getSExtValue());
});
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputTy, valueAttr);
return success();
}
}
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputTy, adaptor.value());
return success();
}

View File

@ -31,7 +31,7 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.clamp"(%[[ARG_BUILTIN]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
@ -586,7 +586,7 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor
// CHECK: }
func.func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> {
%float5.000000e-01 = torch.constant.float 5.000000e-01
%int3 = torch.constant.int 3
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int2, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2, %float5.000000e-01 : !torch.vtensor<[5,2,2,3],f32>, !torch.list<int>, !torch.vtensor<[2,2,3],f32>, !torch.vtensor<[2,2,3],f32>, !torch.float -> !torch.vtensor<[5,2,2,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>
@ -752,7 +752,7 @@ func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
%float0.000000e00 = torch.constant.float 0.000000e+00
%float0.000000e00 = torch.constant.float 0.000000e+00
%false = torch.constant.bool false
%0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
@ -808,7 +808,7 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [3, 2, 1]} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
// CEHCK: return %[[VAL_6]] : tensor<3x2x1xf32>
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
%true = torch.constant.bool true
@ -817,3 +817,14 @@ func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf3
%1 = torch_c.to_builtin_tensor %values : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
return %1 : tensor<3x2x1xf32>
}
// -----
// CHECK-LABEL: @torch.vtensor.literal_si64$basic(
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<-1> : tensor<1x512xsi64>} : () -> tensor<1x512xi64>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<1x512xi64> -> !torch.vtensor<[1,512],si64>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[1,512],si64>
func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> {
%0 = torch.vtensor.literal(dense<-1> : tensor<1x512xsi64>) : !torch.vtensor<[1,512],si64>
return %0 : !torch.vtensor<[1,512],si64>
}