Generalize getting index for onnx compress op (#3150)

pull/3158/head
jinchen 2024-04-12 15:18:22 -07:00 committed by GitHub
parent ed163f49e8
commit 859f5d280f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 44 deletions

View File

@ -721,35 +721,20 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();
// get indexs from the condition tensor
auto dtype = dyn_cast<Torch::ValueTensorType>(conditionTensor.getType())
.getDtype();
auto constOp = dyn_cast<Torch::ValueTensorLiteralOp>(
conditionTensor.getDefiningOp());
auto elementsAttr =
dyn_cast<DenseIntElementsAttr>(constOp.getValueAttr());
SmallVector<APInt> apValues;
int64_t index = 0;
for (auto intAttr : elementsAttr.getValues<Attribute>()) {
int64_t i = dyn_cast<IntegerAttr>(intAttr).getSInt();
if (i)
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), index));
index++;
}
SmallVector<int64_t> indexShape = {static_cast<long>(apValues.size())};
auto indexType = Torch::ValueTensorType::get(binder.op->getContext(),
indexShape, dtype);
auto attr = DenseElementsAttr::get(
cast<ShapedType>(RankedTensorType::get(indexShape, dtype)),
apValues);
Value indexVal =
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
constOp, indexType, attr);
auto shapeSizes =
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
auto resultSizes = resultType.getSizes();
// flatten input tensor if using default axis
if (axis == INT64_MAX) {
// flatten input tensor if using default axis
SmallVector<int64_t> nonzeroShape = {resultSizes[0]};
auto dtype =
dyn_cast<Torch::ValueTensorType>(conditionTensor.getType())
.getDtype();
auto nonzeroType =
rewriter.getType<Torch::ValueTensorType>(nonzeroShape, dtype);
Value indexVal = rewriter.create<Torch::AtenNonzeroOp>(
binder.getLoc(), nonzeroType, conditionTensor);
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value cstNegOne = rewriter.create<Torch::ConstantIntOp>(
@ -759,22 +744,29 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
numElements *= i;
}
SmallVector<int64_t> flattenShape = {numElements};
auto flattenType = Torch::ValueTensorType::get(
binder.op->getContext(), flattenShape, resultType.getDtype());
auto flattenType = rewriter.getType<Torch::ValueTensorType>(
flattenShape, resultType.getDtype());
Value flattenTensor = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
binder.getLoc(), flattenType, operand, cstZero, cstNegOne);
rewriter.replaceOpWithNewOp<Torch::AtenIndexSelectOp>(
binder.op, resultType, flattenTensor, cstZero, indexVal);
return success();
} else {
if (axis < 0)
// Negative axis value means counting dimensions from the back
axis += shapeSizes.size();
Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
rewriter.replaceOpWithNewOp<Torch::AtenIndexSelectOp>(
binder.op, resultType, operand, dim, indexVal);
}
// Negative axis value means counting dimensions from the back
if (axis < 0)
axis += shapeSizes.size();
SmallVector<int64_t> nonzeroShape = {resultSizes[axis]};
auto dtype = dyn_cast<Torch::ValueTensorType>(conditionTensor.getType())
.getDtype();
auto nonzeroType =
rewriter.getType<Torch::ValueTensorType>(nonzeroShape, dtype);
Value indexVal = rewriter.create<Torch::AtenNonzeroOp>(
binder.getLoc(), nonzeroType, conditionTensor);
Value dimVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
rewriter.replaceOpWithNewOp<Torch::AtenIndexSelectOp>(
binder.op, resultType, operand, dimVal, indexVal);
return success();
});
patterns.onOp(

View File

@ -1709,12 +1709,11 @@ func.func @test_celu(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,
// -----
// CHECK-LABEL: func.func @test_compress
func.func @test_compress(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,2], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 2]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
func.func @test_compress(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,2], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INDEX:.*]] = torch.aten.nonzero %arg1 : !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64>
// CHECK: %[[DIM:.*]] = torch.constant.int 2
// CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,3,2],f32>
%cst = torch.vtensor.literal(dense<[0,1,1]> : tensor<3xsi64>) : !torch.vtensor<[3], si64>
%0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,2],f32>
// CHECK: torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,3,2],f32>
%0 = torch.operator "onnx.Compress"(%arg0, %arg1) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,2],f32>
return %0 : !torch.vtensor<[2,3,2],f32>
}
@ -1722,11 +1721,12 @@ func.func @test_compress(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[
// CHECK-LABEL: func.func @test_compress_default_axis
func.func @test_compress_default_axis(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 3, 5]> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<[0, 1, 0, 1, 0, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
// CHECK: %[[INDEX:.*]] = torch.aten.nonzero %[[CST]] : !torch.vtensor<[6],si64> -> !torch.vtensor<[3],si64>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[END_DIM:.*]] = torch.constant.int -1
// CHECK: %[[ATEN_FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0]], %[[END_DIM]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32>
// CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %[[ATEN_FLATTEN]], %[[INT0]], %[[INDEX]] : !torch.vtensor<[6],f32>, !torch.int, !torch.vtensor<[3],si64> -> !torch.vtensor<[3],f32>
// CHECK: torch.aten.index_select %[[ATEN_FLATTEN]], %[[INT0]], %[[INDEX]] : !torch.vtensor<[6],f32>, !torch.int, !torch.vtensor<[3],si64> -> !torch.vtensor<[3],f32>
%cst = torch.vtensor.literal(dense<[0,1,0,1,0,1]> : tensor<6xsi64>) : !torch.vtensor<[6], si64>
%0 = torch.operator "onnx.Compress"(%arg0, %cst) : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
@ -1736,7 +1736,8 @@ func.func @test_compress_default_axis(%arg0: !torch.vtensor<[2,3],f32>) -> !torc
// CHECK-LABEL: func.func @test_compress_neg_axis
func.func @test_compress_neg_axis(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,2,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 2]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<[0, 1, 1]> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
// CHECK: %[[INDEX:.*]] = torch.aten.nonzero %[[CST]] : !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64>
// CHECK: %[[DIM:.*]] = torch.constant.int 1
// CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,2,4],f32>
%cst = torch.vtensor.literal(dense<[0,1,1]> : tensor<3xsi64>) : !torch.vtensor<[3], si64>