mirror of https://github.com/llvm/torch-mlir
Add OnnxToTorch support for Compress op (#3025)
parent
90e3d69c25
commit
9cf6c45a39
|
@ -706,6 +706,72 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.op, resultType, source, min, max);
|
binder.op, resultType, source, min, max);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"Compress", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value operand, conditionTensor;
|
||||||
|
int64_t axis;
|
||||||
|
if (binder.tensorOperands(operand, conditionTensor) ||
|
||||||
|
binder.s64IntegerAttr(axis, "axis", INT64_MAX) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// get indexs from the condition tensor
|
||||||
|
auto dtype = dyn_cast<Torch::ValueTensorType>(conditionTensor.getType())
|
||||||
|
.getDtype();
|
||||||
|
auto constOp = dyn_cast<Torch::ValueTensorLiteralOp>(
|
||||||
|
conditionTensor.getDefiningOp());
|
||||||
|
auto elementsAttr =
|
||||||
|
dyn_cast<DenseIntElementsAttr>(constOp.getValueAttr());
|
||||||
|
SmallVector<APInt> apValues;
|
||||||
|
int64_t index = 0;
|
||||||
|
for (auto intAttr : elementsAttr.getValues<Attribute>()) {
|
||||||
|
int64_t i = dyn_cast<IntegerAttr>(intAttr).getSInt();
|
||||||
|
if (i)
|
||||||
|
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), index));
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> indexShape = {static_cast<long>(apValues.size())};
|
||||||
|
auto indexType = Torch::ValueTensorType::get(binder.op->getContext(),
|
||||||
|
indexShape, dtype);
|
||||||
|
auto attr = DenseElementsAttr::get(
|
||||||
|
cast<ShapedType>(RankedTensorType::get(indexShape, dtype)),
|
||||||
|
apValues);
|
||||||
|
Value indexVal =
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||||
|
constOp, indexType, attr);
|
||||||
|
|
||||||
|
auto shapeSizes =
|
||||||
|
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
|
||||||
|
if (axis == INT64_MAX) {
|
||||||
|
// flatten input tensor if using default axis
|
||||||
|
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||||
|
Value cstNegOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(-1));
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (auto i : shapeSizes) {
|
||||||
|
numElements *= i;
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> flattenShape = {numElements};
|
||||||
|
auto flattenType = Torch::ValueTensorType::get(
|
||||||
|
binder.op->getContext(), flattenShape, resultType.getDtype());
|
||||||
|
Value flattenTensor = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||||
|
binder.getLoc(), flattenType, operand, cstZero, cstNegOne);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenIndexSelectOp>(
|
||||||
|
binder.op, resultType, flattenTensor, cstZero, indexVal);
|
||||||
|
return success();
|
||||||
|
} else {
|
||||||
|
if (axis < 0)
|
||||||
|
// Negative axis value means counting dimensions from the back
|
||||||
|
axis += shapeSizes.size();
|
||||||
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenIndexSelectOp>(
|
||||||
|
binder.op, resultType, operand, dim, indexVal);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
|
|
@ -1646,22 +1646,62 @@ func.func @test_constant_of_shape_dense_int_cst() -> !torch.vtensor<[2,3,4], si6
|
||||||
return %0 : !torch.vtensor<[2,3,4], si64>
|
return %0 : !torch.vtensor<[2,3,4], si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_celu
|
// CHECK-LABEL: func.func @test_celu
|
||||||
func.func @test_celu(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_celu(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.000000e+00
|
// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.000000e+00
|
||||||
// CHECK: %0 = torch.aten.div.Scalar %arg0, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32>
|
// CHECK: %0 = torch.aten.div.Scalar %arg0, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32>
|
||||||
// CHECK: %1 = torch.aten.exp %0 : !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32>
|
// CHECK: %1 = torch.aten.exp %0 : !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32>
|
||||||
// CHECK: %int1 = torch.constant.int 1
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
// CHECK: %2 = torch.aten.sub.Scalar %1, %int1, %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,3,1],f32>
|
// CHECK: %2 = torch.aten.sub.Scalar %1, %int1, %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,3,1],f32>
|
||||||
// CHECK: %3 = torch.aten.mul.Scalar %2, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32>
|
// CHECK: %3 = torch.aten.mul.Scalar %2, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32>
|
||||||
// CHECK: %int0 = torch.constant.int 0
|
// CHECK: %int0 = torch.constant.int 0
|
||||||
// CHECK: %4 = torch.prim.ListConstruct : () -> !torch.list<int>
|
// CHECK: %4 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
// CHECK: %none = torch.constant.none
|
// CHECK: %none = torch.constant.none
|
||||||
// CHECK: %int6 = torch.constant.int 6
|
// CHECK: %int6 = torch.constant.int 6
|
||||||
// CHECK: %[[ZERO:.*]] = torch.aten.full %4, %int0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
// CHECK: %[[ZERO:.*]] = torch.aten.full %4, %int0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||||
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ZERO]], %3 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32>
|
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ZERO]], %3 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32>
|
||||||
// CHECK: %[[MAX:.*]] = torch.aten.maximum %[[ZERO]], %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32>
|
// CHECK: %[[MAX:.*]] = torch.aten.maximum %[[ZERO]], %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32>
|
||||||
// CHECK: %8 = torch.aten.add.Tensor %[[MAX]], %[[MIN]], %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[3,3,3,1],f32>, !torch.int -> !torch.vtensor<[3,3,3,1],f32>
|
// CHECK: %8 = torch.aten.add.Tensor %[[MAX]], %[[MIN]], %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[3,3,3,1],f32>, !torch.int -> !torch.vtensor<[3,3,3,1],f32>
|
||||||
%0 = torch.operator "onnx.Celu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32>
|
%0 = torch.operator "onnx.Celu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32>
|
||||||
return %0 : !torch.vtensor<[3,3,3,1],f32>
|
return %0 : !torch.vtensor<[3,3,3,1],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_compress
|
||||||
|
func.func @test_compress(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,2], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 2]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,3,2],f32>
|
||||||
|
%cst = torch.vtensor.literal(dense<[0,1,1]> : tensor<3xsi64>) : !torch.vtensor<[3], si64>
|
||||||
|
%0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,3,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_compress_default_axis
|
||||||
|
func.func @test_compress_default_axis(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 3, 5]> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[END_DIM:.*]] = torch.constant.int -1
|
||||||
|
// CHECK: %[[ATEN_FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0]], %[[END_DIM]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32>
|
||||||
|
// CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %[[ATEN_FLATTEN]], %[[INT0]], %[[INDEX]] : !torch.vtensor<[6],f32>, !torch.int, !torch.vtensor<[3],si64> -> !torch.vtensor<[3],f32>
|
||||||
|
%cst = torch.vtensor.literal(dense<[0,1,0,1,0,1]> : tensor<6xsi64>) : !torch.vtensor<[6], si64>
|
||||||
|
%0 = torch.operator "onnx.Compress"(%arg0, %cst) : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[3],f32>
|
||||||
|
return %0 : !torch.vtensor<[3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_compress_neg_axis
|
||||||
|
func.func @test_compress_neg_axis(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,2,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 2]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,2,4],f32>
|
||||||
|
%cst = torch.vtensor.literal(dense<[0,1,1]> : tensor<3xsi64>) : !torch.vtensor<[3], si64>
|
||||||
|
%0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,2,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,2,4],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue