From 5d7a6c2976278c9ec2fe3f2ddd3e02c5a65de897 Mon Sep 17 00:00:00 2001 From: Anup Gangwar Date: Fri, 25 Mar 2022 16:15:07 -0500 Subject: [PATCH] [tosa] Support for Aten[Unsqueeze|Contiguous|Dropout|Reshape|View] ops (#700) --- e2e_testing/torchscript/xfail_sets.py | 8 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 114 +++++++++++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 56 +++++++++- 3 files changed, 175 insertions(+), 3 deletions(-) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 72fe66470..a6895e069 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -144,4 +144,12 @@ TOSA_PASS_SET = { "SiluModule_basic", "DropoutEvalIntModule_basic", "DropoutEvalFloatModule_basic", + "ContiguousModule_basic", + "DropoutModule_basic", + "ViewExpandModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewNoChangeStaticModule_basic", + "UnsafeViewExpandModule_basic", + "ReshapeCollapseModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index fe7d155b7..197b3ae45 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2329,6 +2329,116 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnsqueezeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) { + return op.emitError("Only tensor types are currently supported"); + } + + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + if (!selfElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } + + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return op->emitError("dim must be a Scalar constant"); + + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return op.emitError("dim is statically invalid"); + + SmallVector outShape; + for (auto en : llvm::enumerate(selfType.getShape())) { + if (static_cast(en.index()) == dim) + outShape.push_back(1); + + outShape.push_back(en.value()); + } + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.self(), + rewriter.getI64ArrayAttr(outShape)); + + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenContiguousOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return op.emitError("Only tensor types are currently supported"); + + // FIXME: memory_format is not handled. + + rewriter.replaceOp(op, adaptor.self()); + + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenDropoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.input().getType().dyn_cast(); + if (!selfType) + return op.emitError("Only tensor types are currently supported"); + + // FIXME: train and p are not handled. + + bool train; + if (!matchPattern(op.train(), m_TorchConstantBool(&train))) + op.emitError("train must be a Scalar constant"); + + if (train) + op.emitError("train must be false"); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.input()); + + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return op.emitError("Only tensor types are currently supported"); + + auto selfElemTy = selfType.getElementType(); + if (!selfElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } + + SmallVector outShape; + if (!matchPattern(op.size(), m_TorchConstantIntList(outShape))) + return op.emitError("size must consist of Scalar constants"); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.self(), + rewriter.getI64ArrayAttr(outShape)); + + return success(); +} + template class ConvertAtenPoolingBaseOp : public OpConversionPattern { public: @@ -2879,6 +2989,10 @@ public: INSERT_ATENOP_PATTERN(AtenPermuteOp); INSERT_ATENOP_PATTERN(AtenLog2Op); INSERT_ATENOP_PATTERN(AtenThresholdOp); + INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); + INSERT_ATENOP_PATTERN(AtenContiguousOp); + INSERT_ATENOP_PATTERN(AtenDropoutOp); + INSERT_ATENOP_PATTERN(AtenViewOp); #undef INSERT_ATENOP_PATTERN if (failed(applyPartialConversion(getOperation(), target, diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index cc80ae7ad..5ffc5cc91 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -492,8 +492,8 @@ func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.v // ----- -// CHECK-LABEL: func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> { +// CHECK-LABEL: func @torch.aten.native_batch_norm$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,4,3],f32> -> tensor<10x4x3xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>} : () -> tensor<4xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>} : () -> tensor<4xf32> @@ -515,7 +515,7 @@ func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.v // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32> // CHECK: } -func @forward(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> { +func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> { %0 = torch.vtensor.literal(dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>) : !torch.vtensor<[4],f32> %1 = torch.vtensor.literal(dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>) : !torch.vtensor<[4],f32> %float1.000000e-01 = torch.constant.float 1.000000e-01 @@ -689,6 +689,38 @@ func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> { // ----- +// CHECK-LABEL: func @torch.aten.unsqueeze$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,1,3],si32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 1, 3]} : (tensor<4x3xi32>) -> tensor<4x1x3xi32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x1x3xi32> -> !torch.vtensor<[4,1,3],si32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,1,3],si32> +// CHECK: } + +func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,1,3],si32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,1,3],si32> + return %0 : !torch.vtensor<[4,1,3],si32> +} + +// ----- + +// CHECK-LABEL: func @torch.aten.contiguous$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + // CHECK-LABEL: func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> { // CHECK: %[[VAL_0:.*]] = torch.constant.int 4 // CHECK: %[[VAL_1:.*]] = torch.constant.int 3 @@ -707,3 +739,21 @@ func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> { %1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32> return %1 : !torch.vtensor<[3,4],f32> } + +// ----- + +// CHECK-LABEL: func @torch.aten.dropout$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.float 0.000000e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> { + %float0.000000e00 = torch.constant.float 0.000000e+00 + %false = torch.constant.bool false + %0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +}