[MLIR][ONNX] Add OnnxToTorch support for MaxRoiPool Op (#3395)

This PR adds OnnxToTorch support for MaxRoiPool op
pull/3458/head
Surya Jasper 2024-06-12 22:16:14 -07:00 committed by GitHub
parent 9b76a2e3eb
commit de7f058a0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 296 additions and 0 deletions

View File

@ -604,6 +604,237 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}
return rewriter.notifyMatchFailure(binder.op, "No rank is matched.");
});
patterns.onOp(
"MaxRoiPool", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
SmallVector<int64_t> pooledShape;
float spatialScale;
if (binder.s64IntegerArrayAttr(pooledShape, "pooled_shape", {}) ||
binder.f32FloatAttr(spatialScale, "spatial_scale", 1.0f)) {
return rewriter.notifyMatchFailure(binder.op,
"Attribute bind failure");
}
Torch::ValueTensorType resultTy;
Value input, rois;
if (binder.tensorOperands(input, rois) ||
binder.tensorResultType(resultTy)) {
return rewriter.notifyMatchFailure(binder.op,
"Operand or result type mismatch");
}
Value outputShapeList =
createConstantIntList(binder, rewriter, pooledShape);
Location loc = binder.getLoc();
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
auto roisTy = cast<Torch::ValueTensorType>(rois.getType());
if (!inputTy || !inputTy.hasSizes())
return failure();
if (!roisTy || !roisTy.hasSizes())
return failure();
auto intTy = rewriter.getIntegerType(64, true);
auto floatTy = roisTy.getDtype();
auto torchIntTy = rewriter.getType<Torch::IntType>();
Value spatialScaleValue = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(spatialScale));
Value boolTrue = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(true));
ArrayRef<int64_t> inputShape = inputTy.getSizes();
int64_t inputRank = inputShape.size();
if (inputRank < 4) {
return rewriter.notifyMatchFailure(
binder.op, "Rank of input tensor must be >= 4");
}
ArrayRef<int64_t> roisShape = roisTy.getSizes();
if (!roisTy.areAllSizesKnown() || roisShape.size() != 2 ||
roisShape[1] != 5) {
return rewriter.notifyMatchFailure(
binder.op, "Expected ROIs to be statically sized tensor of shape "
"(num_rois, 5)");
}
int64_t numRois = roisShape[0];
/* The implementation is based on the following algorithm:
MaxRoiPool <pooled_shape, spatial_scale>(
input : tensor<float>, rois : tensor<?x5xfloat>) => (output)
{
* Step 1: Extract ROI specification
- Each ROI is represented as [batch_id, x1, y1, x2, y2], where
range is inclusive of x1, y1, x2, and y2
- The range values are scaled by spatial_scale
BatchIdxsFloat = Select(rois, dim=1, index=0)
BatchIdxs = CastLong(BatchIdxsFloat)
RoiBBsFloat = Slice(rois, dim=1, start=1, end=5, stride=1)
RoiBBsScaledFloat = MulScalar(RoiBBsFloat, spatial_scale)
RoiBBsScaled = CastLong(RoiBBsScaledFloat)
* Step 2: Iteratively pool ROIs
pooledROIs = []
for (roiIdx = 0; roiIdx < len(rois); roiIdx++) {
* Step 2a: For each ROI, we extract batch_id, x1, y1, x2, & y2
RoiSpec = Select(RoiBBsScaled, 0, roiIdx) : tensor<4xint>
roiValues = []
for (specIdx = 0; specIdx < 5; specIdx++) {
if (specIdx == 0)
SpecTensor = Select(BatchIdxs, 1, roiIdx) : tensor<int>
else
SpecTensor = Select(RoiSpec, 0, specIdx-1) : tensor<int>
SpecValue = Item(specTensor) : torch.int
roiValues.push(SpecValue)
}
BatchIdx, X1, Y1, X2, Y2 = roiValues
* Step 2b: extract image from input and extract region
- X2 and Y2 are incremented by 1 to make range inclusive
- width and height dimension are calculated once outside of loop
but intuition is expressed more clearly below
image = Select(input, 0, BatchIdx)
widthDim = rank(image) - 1
heightDim = rank(image) - 2
imageExtractedY = Slice(image, heightDim, Y1, Y2 + 1, 1)
region = Slice(image, widthDim, X1, X2 + 1, 1)
* Step 2c: apply adaptive max pooling to pool region of interest
into final pooled size
pooledROI = AdaptiveMaxPool2d(region, pooled_shape)
pooledROIs.push(pooledROI)
}
* Step 3: Stack pooled regions and return final output
return output = Stack(pooledRois, dim=0)
}
*/
SmallVector<Value> constInts(6);
for (int i = 0; i <= 5; i++) {
constInts[i] = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
}
int64_t widthDim = inputRank - 2;
Value widthDimValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(widthDim));
int64_t heightDim = inputRank - 3;
Value heightDimValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(heightDim));
// extract indices of images within batch
auto batchIdxsShape = SmallVector<int64_t>{Torch::kUnknownSize};
auto batchIdxsFloatTy =
rewriter.getType<Torch::ValueTensorType>(batchIdxsShape, floatTy);
Value batchIdxsFloat = rewriter.create<Torch::AtenSelectIntOp>(
loc, batchIdxsFloatTy, rois, constInts[1], constInts[0]);
auto batchIdxsIntTy =
rewriter.getType<Torch::ValueTensorType>(batchIdxsShape, intTy);
Value batchIdxs = rewriter.create<Torch::Aten_CastLongOp>(
loc, batchIdxsIntTy, batchIdxsFloat, boolTrue);
// extract scaled ranges for regions of interest
auto roiBBsShape = SmallVector<int64_t>{Torch::kUnknownSize, 4};
auto roiBBsFloatTy =
rewriter.getType<Torch::ValueTensorType>(roiBBsShape, floatTy);
Value roiBBs = rewriter.create<Torch::AtenSliceTensorOp>(
loc, roiBBsFloatTy, rois, constInts[1], constInts[1], constInts[5],
constInts[1]);
Value roiBBsScaledFloat = rewriter.create<Torch::AtenMulScalarOp>(
loc, roiBBsFloatTy, roiBBs, spatialScaleValue);
auto roiBBsTy =
rewriter.getType<Torch::ValueTensorType>(roiBBsShape, intTy);
Value roiBBsScaled = rewriter.create<Torch::Aten_CastLongOp>(
loc, roiBBsTy, roiBBsScaledFloat, boolTrue);
SmallVector<Value> pooledRois;
for (int64_t i = 0; i < numRois; i++) {
Value roiIdx = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
auto roiSpecTy = rewriter.getType<Torch::ValueTensorType>(
roiBBsTy.getSizes().slice(1), intTy);
Value roiSpec = rewriter.create<Torch::AtenSelectIntOp>(
loc, roiSpecTy, roiBBsScaled, constInts[0], roiIdx);
// Load individual ROI specification values
SmallVector<Value> roiValues(5);
for (int specIdx = 0; specIdx < 5; specIdx++) {
auto intEmptyTensorTy = rewriter.getType<Torch::ValueTensorType>(
SmallVector<int64_t>{}, intTy);
Value specTensor;
if (specIdx == 0) { // batch index
specTensor = rewriter.create<Torch::AtenSelectIntOp>(
loc, intEmptyTensorTy, batchIdxs, constInts[0], roiIdx);
} else { // roi dimension
specTensor = rewriter.create<Torch::AtenSelectIntOp>(
loc, intEmptyTensorTy, roiSpec, constInts[0],
constInts[specIdx - 1]);
}
Value specValue =
rewriter.create<Torch::AtenItemOp>(loc, torchIntTy, specTensor);
roiValues[specIdx] = specValue;
}
Value batchIdx = roiValues[0], roiX1 = roiValues[1],
roiY1 = roiValues[2], roiX2 = roiValues[3],
roiY2 = roiValues[4];
// add 1 to make range ends inclusive as per ONNX implementation
roiX2 = rewriter.create<Torch::AtenAddOp>(loc, torchIntTy, roiX2,
constInts[1]);
roiY2 = rewriter.create<Torch::AtenAddOp>(loc, torchIntTy, roiY2,
constInts[1]);
auto imageTy = rewriter.getType<Torch::ValueTensorType>(
inputShape.slice(1), inputTy.getDtype());
Value image = rewriter.create<Torch::AtenSelectIntOp>(
loc, imageTy, input, constInts[0], batchIdx); // (NC x H x W)
SmallVector<int64_t> imageUnknownShape(imageTy.getSizes());
imageUnknownShape[heightDim] = Torch::kUnknownSize;
imageUnknownShape[widthDim] = Torch::kUnknownSize;
auto imageUnknownTy = rewriter.getType<Torch::ValueTensorType>(
imageUnknownShape, imageTy.getDtype());
// extract ROI from image
Value imageExtractedY = rewriter.create<Torch::AtenSliceTensorOp>(
loc, imageUnknownTy, image, heightDimValue, roiY1, roiY2,
constInts[1]);
Value region = rewriter.create<Torch::AtenSliceTensorOp>(
loc, imageUnknownTy, imageExtractedY, widthDimValue, roiX1, roiX2,
constInts[1]);
SmallVector<int64_t> pooledRegionShape(imageTy.getSizes());
pooledRegionShape[heightDim] = pooledShape[0];
pooledRegionShape[widthDim] = pooledShape[1];
auto pooledRegionTy = rewriter.getType<Torch::ValueTensorType>(
pooledRegionShape, imageTy.getDtype());
auto pooledRegionIndicesTy = rewriter.getType<Torch::ValueTensorType>(
pooledRegionShape, intTy);
// apply pooling on ROI
Value pooledRegion =
rewriter
.create<Torch::AtenAdaptiveMaxPool2dOp>(
loc, pooledRegionTy, pooledRegionIndicesTy, region,
outputShapeList)
.getResult0();
pooledRois.push_back(pooledRegion);
}
Value pooledRoisList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(pooledRois[0].getType()), pooledRois);
rewriter.replaceOpWithNewOp<Torch::AtenStackOp>(
binder.op, resultTy, pooledRoisList, constInts[0]);
return success();
});
patterns.onOp("Greater", 16,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;

View File

@ -499,6 +499,71 @@ func.func @test_maxpool_symmetric_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>)
// -----
// CHECK-LABEL: func.func @test_maxroipool
func.func @test_maxroipool(%arg0: !torch.vtensor<[8,3,32,32],f32>, %arg1: !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,3,2,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT2_0:.*]] = torch.constant.int 2
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[INT5:.*]] = torch.constant.int 5
// CHECK: %[[INT2_2:.*]] = torch.constant.int 2
// CHECK: %[[INT1_3:.*]] = torch.constant.int 1
// CHECK: %[[SELECT1:.*]] = torch.aten.select.int %arg1, %[[INT1]], %[[INT0]] : !torch.vtensor<[2,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[CAST1:.*]] = torch.aten._cast_Long %[[SELECT1]], %[[TRUE]] : !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],si64>
// CHECK: %[[SLICE1:.*]] = torch.aten.slice.Tensor %arg1, %[[INT1]], %[[INT1]], %[[INT5]], %[[INT1]] : !torch.vtensor<[2,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,4],f32>
// CHECK: %[[MUL1:.*]] = torch.aten.mul.Scalar %[[SLICE1]], %[[FLOAT1]] : !torch.vtensor<[?,4],f32>, !torch.float -> !torch.vtensor<[?,4],f32>
// CHECK: %[[CAST2:.*]] = torch.aten._cast_Long %[[MUL1]], %[[TRUE]] : !torch.vtensor<[?,4],f32>, !torch.bool -> !torch.vtensor<[?,4],si64>
// CHECK: %[[INT0_4:.*]] = torch.constant.int 0
// CHECK: %[[SELECT2:.*]] = torch.aten.select.int %[[CAST2]], %[[INT0]], %[[INT0_4]] : !torch.vtensor<[?,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64>
// CHECK: %[[SELECT3:.*]] = torch.aten.select.int %[[CAST1]], %[[INT0]], %[[INT0_4]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[SELECT4:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM2:.*]] = torch.aten.item %[[SELECT4]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[SELECT5:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM3:.*]] = torch.aten.item %[[SELECT5]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[SELECT6:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT2_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM4:.*]] = torch.aten.item %[[SELECT6]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[SELECT7:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM5:.*]] = torch.aten.item %[[SELECT7]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[ADD1:.*]] = torch.aten.add %[[ITEM4]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD2:.*]] = torch.aten.add %[[ITEM5]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SELECT8:.*]] = torch.aten.select.int %arg0, %[[INT0]], %[[ITEM1]] : !torch.vtensor<[8,3,32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,32,32],f32>
// CHECK: %[[SLICE2:.*]] = torch.aten.slice.Tensor %[[SELECT8]], %[[INT1_3]], %[[ITEM3]], %[[ADD2]], %[[INT1]] : !torch.vtensor<[3,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
// CHECK: %[[SLICE3:.*]] = torch.aten.slice.Tensor %[[SLICE2]], %[[INT2_2]], %[[ITEM2]], %[[ADD1]], %[[INT1]] : !torch.vtensor<[3,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
// CHECK: %[[RESULT0:.*]], %[[RESULT1:.*]] = torch.aten.adaptive_max_pool2d %[[SLICE3]], %[[LIST0]] : !torch.vtensor<[3,?,?],f32>, !torch.list<int> -> !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],si64>
// CHECK: %[[INT1_5:.*]] = torch.constant.int 1
// CHECK: %[[SELECT9:.*]] = torch.aten.select.int %[[CAST2]], %[[INT0]], %[[INT1_5]] : !torch.vtensor<[?,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64>
// CHECK: %[[SELECT10:.*]] = torch.aten.select.int %[[CAST1]], %[[INT0]], %[[INT1_5]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM6:.*]] = torch.aten.item %[[SELECT10]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[SELECT11:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM7:.*]] = torch.aten.item %[[SELECT11]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[SELECT12:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM8:.*]] = torch.aten.item %[[SELECT12]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[SELECT13:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT2_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM9:.*]] = torch.aten.item %[[SELECT13]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[SELECT14:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM10:.*]] = torch.aten.item %[[SELECT14]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[ADD3:.*]] = torch.aten.add %[[ITEM9]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD4:.*]] = torch.aten.add %[[ITEM10]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SELECT15:.*]] = torch.aten.select.int %arg0, %[[INT0]], %[[ITEM6]] : !torch.vtensor<[8,3,32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,32,32],f32>
// CHECK: %[[SLICE4:.*]] = torch.aten.slice.Tensor %[[SELECT15]], %[[INT1_3]], %[[ITEM8]], %[[ADD4]], %[[INT1]] : !torch.vtensor<[3,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
// CHECK: %[[SLICE5:.*]] = torch.aten.slice.Tensor %[[SLICE4]], %[[INT2_2]], %[[ITEM7]], %[[ADD3]], %[[INT1]] : !torch.vtensor<[3,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
// CHECK: %[[RESULT0_6:.*]], %[[RESULT1_7:.*]] = torch.aten.adaptive_max_pool2d %[[SLICE5]], %[[LIST0]] : !torch.vtensor<[3,?,?],f32>, !torch.list<int> -> !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],si64>
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[RESULT0]], %[[RESULT0_6]] : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32>) -> !torch.list<vtensor<[3,2,2],f32>>
// CHECK: %[[STACK:.*]] = torch.aten.stack %[[LIST1]], %[[INT0]] : !torch.list<vtensor<[3,2,2],f32>>, !torch.int -> !torch.vtensor<[2,3,2,2],f32>
// CHECK: return %[[STACK]] : !torch.vtensor<[2,3,2,2],f32>
%0 = torch.operator "onnx.MaxRoiPool"(%arg0, %arg1) {torch.onnx.pooled_shape = [2 : si64, 2 : si64], torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[8,3,32,32],f32>, !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,3,2,2],f32>
return %0 : !torch.vtensor<[2,3,2,2],f32>
}
// -----
// CHECK-LABEL: @test_gelu_default_1
func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[STR1:.*]] = torch.constant.str "none"