diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 740eaadb9..29cadf7d1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -706,6 +706,72 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, source, min, max); return success(); }); + patterns.onOp( + "Compress", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand, conditionTensor; + int64_t axis; + if (binder.tensorOperands(operand, conditionTensor) || + binder.s64IntegerAttr(axis, "axis", INT64_MAX) || + binder.tensorResultType(resultType)) + return failure(); + + // get indexs from the condition tensor + auto dtype = dyn_cast(conditionTensor.getType()) + .getDtype(); + auto constOp = dyn_cast( + conditionTensor.getDefiningOp()); + auto elementsAttr = + dyn_cast(constOp.getValueAttr()); + SmallVector apValues; + int64_t index = 0; + for (auto intAttr : elementsAttr.getValues()) { + int64_t i = dyn_cast(intAttr).getSInt(); + if (i) + apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), index)); + index++; + } + SmallVector indexShape = {static_cast(apValues.size())}; + auto indexType = Torch::ValueTensorType::get(binder.op->getContext(), + indexShape, dtype); + auto attr = DenseElementsAttr::get( + cast(RankedTensorType::get(indexShape, dtype)), + apValues); + Value indexVal = + rewriter.replaceOpWithNewOp( + constOp, indexType, attr); + + auto shapeSizes = + dyn_cast(operand.getType()).getSizes(); + if (axis == INT64_MAX) { + // flatten input tensor if using default axis + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstNegOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(-1)); + int64_t numElements = 1; + for (auto i : shapeSizes) { + numElements *= i; + } + SmallVector flattenShape = {numElements}; + auto flattenType = Torch::ValueTensorType::get( + binder.op->getContext(), flattenShape, resultType.getDtype()); + Value flattenTensor = rewriter.create( + binder.getLoc(), flattenType, operand, cstZero, cstNegOne); + rewriter.replaceOpWithNewOp( + binder.op, resultType, flattenTensor, cstZero, indexVal); + return success(); + } else { + if (axis < 0) + // Negative axis value means counting dimensions from the back + axis += shapeSizes.size(); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, dim, indexVal); + } + return success(); + }); patterns.onOp( "Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 8cd8bab00..d11c8cf6e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1646,22 +1646,62 @@ func.func @test_constant_of_shape_dense_int_cst() -> !torch.vtensor<[2,3,4], si6 return %0 : !torch.vtensor<[2,3,4], si64> } +// ----- + // CHECK-LABEL: func.func @test_celu func.func @test_celu(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { -// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.000000e+00 -// CHECK: %0 = torch.aten.div.Scalar %arg0, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32> -// CHECK: %1 = torch.aten.exp %0 : !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %2 = torch.aten.sub.Scalar %1, %int1, %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,3,1],f32> -// CHECK: %3 = torch.aten.mul.Scalar %2, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32> -// CHECK: %int0 = torch.constant.int 0 -// CHECK: %4 = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %none = torch.constant.none -// CHECK: %int6 = torch.constant.int 6 -// CHECK: %[[ZERO:.*]] = torch.aten.full %4, %int0, %int6, %none, %none, %none : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> -// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ZERO]], %3 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> -// CHECK: %[[MAX:.*]] = torch.aten.maximum %[[ZERO]], %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> -// CHECK: %8 = torch.aten.add.Tensor %[[MAX]], %[[MIN]], %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[3,3,3,1],f32>, !torch.int -> !torch.vtensor<[3,3,3,1],f32> + // CHECK: %[[ALPHA:.*]] = torch.constant.float 2.000000e+00 + // CHECK: %0 = torch.aten.div.Scalar %arg0, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32> + // CHECK: %1 = torch.aten.exp %0 : !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> + // CHECK: %int1 = torch.constant.int 1 + // CHECK: %2 = torch.aten.sub.Scalar %1, %int1, %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,3,1],f32> + // CHECK: %3 = torch.aten.mul.Scalar %2, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32> + // CHECK: %int0 = torch.constant.int 0 + // CHECK: %4 = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %none = torch.constant.none + // CHECK: %int6 = torch.constant.int 6 + // CHECK: %[[ZERO:.*]] = torch.aten.full %4, %int0, %int6, %none, %none, %none : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ZERO]], %3 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> + // CHECK: %[[MAX:.*]] = torch.aten.maximum %[[ZERO]], %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> + // CHECK: %8 = torch.aten.add.Tensor %[[MAX]], %[[MIN]], %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[3,3,3,1],f32>, !torch.int -> !torch.vtensor<[3,3,3,1],f32> %0 = torch.operator "onnx.Celu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> return %0 : !torch.vtensor<[3,3,3,1],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_compress +func.func @test_compress(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,2], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 2]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + // CHECK: %[[DIM:.*]] = torch.constant.int 2 + // CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,3,2],f32> + %cst = torch.vtensor.literal(dense<[0,1,1]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,2],f32> + return %0 : !torch.vtensor<[2,3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_compress_default_axis +func.func @test_compress_default_axis(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 3, 5]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[END_DIM:.*]] = torch.constant.int -1 + // CHECK: %[[ATEN_FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0]], %[[END_DIM]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32> + // CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %[[ATEN_FLATTEN]], %[[INT0]], %[[INDEX]] : !torch.vtensor<[6],f32>, !torch.int, !torch.vtensor<[3],si64> -> !torch.vtensor<[3],f32> + %cst = torch.vtensor.literal(dense<[0,1,0,1,0,1]> : tensor<6xsi64>) : !torch.vtensor<[6], si64> + %0 = torch.operator "onnx.Compress"(%arg0, %cst) : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_compress_neg_axis +func.func @test_compress_neg_axis(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,2,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 2]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + // CHECK: %[[DIM:.*]] = torch.constant.int 1 + // CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,2,4],f32> + %cst = torch.vtensor.literal(dense<[0,1,1]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,2,4],f32> + return %0 : !torch.vtensor<[2,2,4],f32> +}