mirror of https://github.com/llvm/torch-mlir
parent
c03aa63325
commit
cb1b8796a2
|
@ -2214,11 +2214,28 @@ template <>
|
|||
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto outputTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputTy, adaptor.value());
|
||||
|
||||
// Tensors with integer types need to be converted to signless integer
|
||||
// element type. All tensors with element types other than integer can reuse
|
||||
// existing elements attribute.
|
||||
// TODO: what about unsigned integer?
|
||||
if (auto elements = op.valueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
||||
Type builtinTensorElemTy = outputTy.getElementType();
|
||||
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
||||
if (builtinTensorElemTy.isSignedInteger()) {
|
||||
DenseElementsAttr valueAttr =
|
||||
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
|
||||
return APInt(bitWidth, v.getSExtValue());
|
||||
});
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputTy, valueAttr);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputTy, adaptor.value());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.
|
|||
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.clamp"(%[[ARG_BUILTIN]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
|
@ -586,7 +586,7 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor
|
|||
// CHECK: }
|
||||
func.func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> {
|
||||
%float5.000000e-01 = torch.constant.float 5.000000e-01
|
||||
%int3 = torch.constant.int 3
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prim.ListConstruct %int2, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2, %float5.000000e-01 : !torch.vtensor<[5,2,2,3],f32>, !torch.list<int>, !torch.vtensor<[2,2,3],f32>, !torch.vtensor<[2,2,3],f32>, !torch.float -> !torch.vtensor<[5,2,2,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>
|
||||
|
@ -752,7 +752,7 @@ func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
|
|||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
|
||||
%float0.000000e00 = torch.constant.float 0.000000e+00
|
||||
%float0.000000e00 = torch.constant.float 0.000000e+00
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
|
@ -808,7 +808,7 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
|
|||
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [3, 2, 1]} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CEHCK: return %[[VAL_6]] : tensor<3x2x1xf32>
|
||||
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
|
||||
func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
%true = torch.constant.bool true
|
||||
|
@ -817,3 +817,14 @@ func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf3
|
|||
%1 = torch_c.to_builtin_tensor %values : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
return %1 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @torch.vtensor.literal_si64$basic(
|
||||
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<-1> : tensor<1x512xsi64>} : () -> tensor<1x512xi64>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<1x512xi64> -> !torch.vtensor<[1,512],si64>
|
||||
// CHECK: return %[[VAL_1]] : !torch.vtensor<[1,512],si64>
|
||||
func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> {
|
||||
%0 = torch.vtensor.literal(dense<-1> : tensor<1x512xsi64>) : !torch.vtensor<[1,512],si64>
|
||||
return %0 : !torch.vtensor<[1,512],si64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue