mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for MaxRoiPool Op (#3395)
This PR adds OnnxToTorch support for MaxRoiPool oppull/3458/head
parent
9b76a2e3eb
commit
de7f058a0e
|
@ -604,6 +604,237 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
}
|
}
|
||||||
return rewriter.notifyMatchFailure(binder.op, "No rank is matched.");
|
return rewriter.notifyMatchFailure(binder.op, "No rank is matched.");
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"MaxRoiPool", 1,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
SmallVector<int64_t> pooledShape;
|
||||||
|
float spatialScale;
|
||||||
|
if (binder.s64IntegerArrayAttr(pooledShape, "pooled_shape", {}) ||
|
||||||
|
binder.f32FloatAttr(spatialScale, "spatial_scale", 1.0f)) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"Attribute bind failure");
|
||||||
|
}
|
||||||
|
Torch::ValueTensorType resultTy;
|
||||||
|
Value input, rois;
|
||||||
|
if (binder.tensorOperands(input, rois) ||
|
||||||
|
binder.tensorResultType(resultTy)) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"Operand or result type mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value outputShapeList =
|
||||||
|
createConstantIntList(binder, rewriter, pooledShape);
|
||||||
|
Location loc = binder.getLoc();
|
||||||
|
|
||||||
|
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
|
||||||
|
auto roisTy = cast<Torch::ValueTensorType>(rois.getType());
|
||||||
|
if (!inputTy || !inputTy.hasSizes())
|
||||||
|
return failure();
|
||||||
|
if (!roisTy || !roisTy.hasSizes())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto intTy = rewriter.getIntegerType(64, true);
|
||||||
|
auto floatTy = roisTy.getDtype();
|
||||||
|
auto torchIntTy = rewriter.getType<Torch::IntType>();
|
||||||
|
|
||||||
|
Value spatialScaleValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
loc, rewriter.getF64FloatAttr(spatialScale));
|
||||||
|
|
||||||
|
Value boolTrue = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
|
loc, rewriter.getBoolAttr(true));
|
||||||
|
|
||||||
|
ArrayRef<int64_t> inputShape = inputTy.getSizes();
|
||||||
|
int64_t inputRank = inputShape.size();
|
||||||
|
if (inputRank < 4) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Rank of input tensor must be >= 4");
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<int64_t> roisShape = roisTy.getSizes();
|
||||||
|
if (!roisTy.areAllSizesKnown() || roisShape.size() != 2 ||
|
||||||
|
roisShape[1] != 5) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Expected ROIs to be statically sized tensor of shape "
|
||||||
|
"(num_rois, 5)");
|
||||||
|
}
|
||||||
|
int64_t numRois = roisShape[0];
|
||||||
|
|
||||||
|
/* The implementation is based on the following algorithm:
|
||||||
|
MaxRoiPool <pooled_shape, spatial_scale>(
|
||||||
|
input : tensor<float>, rois : tensor<?x5xfloat>) => (output)
|
||||||
|
{
|
||||||
|
* Step 1: Extract ROI specification
|
||||||
|
- Each ROI is represented as [batch_id, x1, y1, x2, y2], where
|
||||||
|
range is inclusive of x1, y1, x2, and y2
|
||||||
|
- The range values are scaled by spatial_scale
|
||||||
|
|
||||||
|
BatchIdxsFloat = Select(rois, dim=1, index=0)
|
||||||
|
BatchIdxs = CastLong(BatchIdxsFloat)
|
||||||
|
RoiBBsFloat = Slice(rois, dim=1, start=1, end=5, stride=1)
|
||||||
|
RoiBBsScaledFloat = MulScalar(RoiBBsFloat, spatial_scale)
|
||||||
|
RoiBBsScaled = CastLong(RoiBBsScaledFloat)
|
||||||
|
|
||||||
|
* Step 2: Iteratively pool ROIs
|
||||||
|
pooledROIs = []
|
||||||
|
for (roiIdx = 0; roiIdx < len(rois); roiIdx++) {
|
||||||
|
* Step 2a: For each ROI, we extract batch_id, x1, y1, x2, & y2
|
||||||
|
RoiSpec = Select(RoiBBsScaled, 0, roiIdx) : tensor<4xint>
|
||||||
|
roiValues = []
|
||||||
|
for (specIdx = 0; specIdx < 5; specIdx++) {
|
||||||
|
if (specIdx == 0)
|
||||||
|
SpecTensor = Select(BatchIdxs, 1, roiIdx) : tensor<int>
|
||||||
|
else
|
||||||
|
SpecTensor = Select(RoiSpec, 0, specIdx-1) : tensor<int>
|
||||||
|
SpecValue = Item(specTensor) : torch.int
|
||||||
|
roiValues.push(SpecValue)
|
||||||
|
}
|
||||||
|
BatchIdx, X1, Y1, X2, Y2 = roiValues
|
||||||
|
|
||||||
|
* Step 2b: extract image from input and extract region
|
||||||
|
- X2 and Y2 are incremented by 1 to make range inclusive
|
||||||
|
- width and height dimension are calculated once outside of loop
|
||||||
|
but intuition is expressed more clearly below
|
||||||
|
|
||||||
|
image = Select(input, 0, BatchIdx)
|
||||||
|
widthDim = rank(image) - 1
|
||||||
|
heightDim = rank(image) - 2
|
||||||
|
|
||||||
|
imageExtractedY = Slice(image, heightDim, Y1, Y2 + 1, 1)
|
||||||
|
region = Slice(image, widthDim, X1, X2 + 1, 1)
|
||||||
|
|
||||||
|
* Step 2c: apply adaptive max pooling to pool region of interest
|
||||||
|
into final pooled size
|
||||||
|
pooledROI = AdaptiveMaxPool2d(region, pooled_shape)
|
||||||
|
pooledROIs.push(pooledROI)
|
||||||
|
}
|
||||||
|
|
||||||
|
* Step 3: Stack pooled regions and return final output
|
||||||
|
return output = Stack(pooledRois, dim=0)
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
SmallVector<Value> constInts(6);
|
||||||
|
for (int i = 0; i <= 5; i++) {
|
||||||
|
constInts[i] = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t widthDim = inputRank - 2;
|
||||||
|
Value widthDimValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(widthDim));
|
||||||
|
|
||||||
|
int64_t heightDim = inputRank - 3;
|
||||||
|
Value heightDimValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(heightDim));
|
||||||
|
|
||||||
|
// extract indices of images within batch
|
||||||
|
auto batchIdxsShape = SmallVector<int64_t>{Torch::kUnknownSize};
|
||||||
|
auto batchIdxsFloatTy =
|
||||||
|
rewriter.getType<Torch::ValueTensorType>(batchIdxsShape, floatTy);
|
||||||
|
Value batchIdxsFloat = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
loc, batchIdxsFloatTy, rois, constInts[1], constInts[0]);
|
||||||
|
auto batchIdxsIntTy =
|
||||||
|
rewriter.getType<Torch::ValueTensorType>(batchIdxsShape, intTy);
|
||||||
|
Value batchIdxs = rewriter.create<Torch::Aten_CastLongOp>(
|
||||||
|
loc, batchIdxsIntTy, batchIdxsFloat, boolTrue);
|
||||||
|
|
||||||
|
// extract scaled ranges for regions of interest
|
||||||
|
auto roiBBsShape = SmallVector<int64_t>{Torch::kUnknownSize, 4};
|
||||||
|
auto roiBBsFloatTy =
|
||||||
|
rewriter.getType<Torch::ValueTensorType>(roiBBsShape, floatTy);
|
||||||
|
Value roiBBs = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||||
|
loc, roiBBsFloatTy, rois, constInts[1], constInts[1], constInts[5],
|
||||||
|
constInts[1]);
|
||||||
|
Value roiBBsScaledFloat = rewriter.create<Torch::AtenMulScalarOp>(
|
||||||
|
loc, roiBBsFloatTy, roiBBs, spatialScaleValue);
|
||||||
|
auto roiBBsTy =
|
||||||
|
rewriter.getType<Torch::ValueTensorType>(roiBBsShape, intTy);
|
||||||
|
Value roiBBsScaled = rewriter.create<Torch::Aten_CastLongOp>(
|
||||||
|
loc, roiBBsTy, roiBBsScaledFloat, boolTrue);
|
||||||
|
|
||||||
|
SmallVector<Value> pooledRois;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < numRois; i++) {
|
||||||
|
Value roiIdx = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(i));
|
||||||
|
|
||||||
|
auto roiSpecTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
roiBBsTy.getSizes().slice(1), intTy);
|
||||||
|
Value roiSpec = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
loc, roiSpecTy, roiBBsScaled, constInts[0], roiIdx);
|
||||||
|
|
||||||
|
// Load individual ROI specification values
|
||||||
|
SmallVector<Value> roiValues(5);
|
||||||
|
for (int specIdx = 0; specIdx < 5; specIdx++) {
|
||||||
|
auto intEmptyTensorTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
SmallVector<int64_t>{}, intTy);
|
||||||
|
Value specTensor;
|
||||||
|
if (specIdx == 0) { // batch index
|
||||||
|
specTensor = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
loc, intEmptyTensorTy, batchIdxs, constInts[0], roiIdx);
|
||||||
|
} else { // roi dimension
|
||||||
|
specTensor = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
loc, intEmptyTensorTy, roiSpec, constInts[0],
|
||||||
|
constInts[specIdx - 1]);
|
||||||
|
}
|
||||||
|
Value specValue =
|
||||||
|
rewriter.create<Torch::AtenItemOp>(loc, torchIntTy, specTensor);
|
||||||
|
roiValues[specIdx] = specValue;
|
||||||
|
}
|
||||||
|
Value batchIdx = roiValues[0], roiX1 = roiValues[1],
|
||||||
|
roiY1 = roiValues[2], roiX2 = roiValues[3],
|
||||||
|
roiY2 = roiValues[4];
|
||||||
|
|
||||||
|
// add 1 to make range ends inclusive as per ONNX implementation
|
||||||
|
roiX2 = rewriter.create<Torch::AtenAddOp>(loc, torchIntTy, roiX2,
|
||||||
|
constInts[1]);
|
||||||
|
roiY2 = rewriter.create<Torch::AtenAddOp>(loc, torchIntTy, roiY2,
|
||||||
|
constInts[1]);
|
||||||
|
|
||||||
|
auto imageTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
inputShape.slice(1), inputTy.getDtype());
|
||||||
|
Value image = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
loc, imageTy, input, constInts[0], batchIdx); // (NC x H x W)
|
||||||
|
|
||||||
|
SmallVector<int64_t> imageUnknownShape(imageTy.getSizes());
|
||||||
|
imageUnknownShape[heightDim] = Torch::kUnknownSize;
|
||||||
|
imageUnknownShape[widthDim] = Torch::kUnknownSize;
|
||||||
|
auto imageUnknownTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
imageUnknownShape, imageTy.getDtype());
|
||||||
|
|
||||||
|
// extract ROI from image
|
||||||
|
Value imageExtractedY = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||||
|
loc, imageUnknownTy, image, heightDimValue, roiY1, roiY2,
|
||||||
|
constInts[1]);
|
||||||
|
Value region = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||||
|
loc, imageUnknownTy, imageExtractedY, widthDimValue, roiX1, roiX2,
|
||||||
|
constInts[1]);
|
||||||
|
|
||||||
|
SmallVector<int64_t> pooledRegionShape(imageTy.getSizes());
|
||||||
|
pooledRegionShape[heightDim] = pooledShape[0];
|
||||||
|
pooledRegionShape[widthDim] = pooledShape[1];
|
||||||
|
auto pooledRegionTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
pooledRegionShape, imageTy.getDtype());
|
||||||
|
auto pooledRegionIndicesTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
pooledRegionShape, intTy);
|
||||||
|
|
||||||
|
// apply pooling on ROI
|
||||||
|
Value pooledRegion =
|
||||||
|
rewriter
|
||||||
|
.create<Torch::AtenAdaptiveMaxPool2dOp>(
|
||||||
|
loc, pooledRegionTy, pooledRegionIndicesTy, region,
|
||||||
|
outputShapeList)
|
||||||
|
.getResult0();
|
||||||
|
pooledRois.push_back(pooledRegion);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value pooledRoisList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(pooledRois[0].getType()), pooledRois);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenStackOp>(
|
||||||
|
binder.op, resultTy, pooledRoisList, constInts[0]);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp("Greater", 16,
|
patterns.onOp("Greater", 16,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
|
|
@ -499,6 +499,71 @@ func.func @test_maxpool_symmetric_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_maxroipool
|
||||||
|
func.func @test_maxroipool(%arg0: !torch.vtensor<[8,3,32,32],f32>, %arg1: !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,3,2,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT2_0:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[INT5:.*]] = torch.constant.int 5
|
||||||
|
// CHECK: %[[INT2_2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT1_3:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[SELECT1:.*]] = torch.aten.select.int %arg1, %[[INT1]], %[[INT0]] : !torch.vtensor<[2,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
|
||||||
|
// CHECK: %[[CAST1:.*]] = torch.aten._cast_Long %[[SELECT1]], %[[TRUE]] : !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],si64>
|
||||||
|
// CHECK: %[[SLICE1:.*]] = torch.aten.slice.Tensor %arg1, %[[INT1]], %[[INT1]], %[[INT5]], %[[INT1]] : !torch.vtensor<[2,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,4],f32>
|
||||||
|
// CHECK: %[[MUL1:.*]] = torch.aten.mul.Scalar %[[SLICE1]], %[[FLOAT1]] : !torch.vtensor<[?,4],f32>, !torch.float -> !torch.vtensor<[?,4],f32>
|
||||||
|
// CHECK: %[[CAST2:.*]] = torch.aten._cast_Long %[[MUL1]], %[[TRUE]] : !torch.vtensor<[?,4],f32>, !torch.bool -> !torch.vtensor<[?,4],si64>
|
||||||
|
// CHECK: %[[INT0_4:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SELECT2:.*]] = torch.aten.select.int %[[CAST2]], %[[INT0]], %[[INT0_4]] : !torch.vtensor<[?,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64>
|
||||||
|
// CHECK: %[[SELECT3:.*]] = torch.aten.select.int %[[CAST1]], %[[INT0]], %[[INT0_4]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT4:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM2:.*]] = torch.aten.item %[[SELECT4]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT5:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM3:.*]] = torch.aten.item %[[SELECT5]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT6:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT2_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM4:.*]] = torch.aten.item %[[SELECT6]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT7:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM5:.*]] = torch.aten.item %[[SELECT7]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[ADD1:.*]] = torch.aten.add %[[ITEM4]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[ADD2:.*]] = torch.aten.add %[[ITEM5]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SELECT8:.*]] = torch.aten.select.int %arg0, %[[INT0]], %[[ITEM1]] : !torch.vtensor<[8,3,32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,32,32],f32>
|
||||||
|
// CHECK: %[[SLICE2:.*]] = torch.aten.slice.Tensor %[[SELECT8]], %[[INT1_3]], %[[ITEM3]], %[[ADD2]], %[[INT1]] : !torch.vtensor<[3,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
|
||||||
|
// CHECK: %[[SLICE3:.*]] = torch.aten.slice.Tensor %[[SLICE2]], %[[INT2_2]], %[[ITEM2]], %[[ADD1]], %[[INT1]] : !torch.vtensor<[3,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
|
||||||
|
// CHECK: %[[RESULT0:.*]], %[[RESULT1:.*]] = torch.aten.adaptive_max_pool2d %[[SLICE3]], %[[LIST0]] : !torch.vtensor<[3,?,?],f32>, !torch.list<int> -> !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],si64>
|
||||||
|
// CHECK: %[[INT1_5:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[SELECT9:.*]] = torch.aten.select.int %[[CAST2]], %[[INT0]], %[[INT1_5]] : !torch.vtensor<[?,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64>
|
||||||
|
// CHECK: %[[SELECT10:.*]] = torch.aten.select.int %[[CAST1]], %[[INT0]], %[[INT1_5]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM6:.*]] = torch.aten.item %[[SELECT10]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT11:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM7:.*]] = torch.aten.item %[[SELECT11]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT12:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM8:.*]] = torch.aten.item %[[SELECT12]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT13:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT2_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM9:.*]] = torch.aten.item %[[SELECT13]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT14:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM10:.*]] = torch.aten.item %[[SELECT14]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[ADD3:.*]] = torch.aten.add %[[ITEM9]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[ADD4:.*]] = torch.aten.add %[[ITEM10]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SELECT15:.*]] = torch.aten.select.int %arg0, %[[INT0]], %[[ITEM6]] : !torch.vtensor<[8,3,32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,32,32],f32>
|
||||||
|
// CHECK: %[[SLICE4:.*]] = torch.aten.slice.Tensor %[[SELECT15]], %[[INT1_3]], %[[ITEM8]], %[[ADD4]], %[[INT1]] : !torch.vtensor<[3,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
|
||||||
|
// CHECK: %[[SLICE5:.*]] = torch.aten.slice.Tensor %[[SLICE4]], %[[INT2_2]], %[[ITEM7]], %[[ADD3]], %[[INT1]] : !torch.vtensor<[3,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
|
||||||
|
// CHECK: %[[RESULT0_6:.*]], %[[RESULT1_7:.*]] = torch.aten.adaptive_max_pool2d %[[SLICE5]], %[[LIST0]] : !torch.vtensor<[3,?,?],f32>, !torch.list<int> -> !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],si64>
|
||||||
|
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[RESULT0]], %[[RESULT0_6]] : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32>) -> !torch.list<vtensor<[3,2,2],f32>>
|
||||||
|
// CHECK: %[[STACK:.*]] = torch.aten.stack %[[LIST1]], %[[INT0]] : !torch.list<vtensor<[3,2,2],f32>>, !torch.int -> !torch.vtensor<[2,3,2,2],f32>
|
||||||
|
// CHECK: return %[[STACK]] : !torch.vtensor<[2,3,2,2],f32>
|
||||||
|
%0 = torch.operator "onnx.MaxRoiPool"(%arg0, %arg1) {torch.onnx.pooled_shape = [2 : si64, 2 : si64], torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[8,3,32,32],f32>, !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,3,2,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,3,2,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_gelu_default_1
|
// CHECK-LABEL: @test_gelu_default_1
|
||||||
func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[STR1:.*]] = torch.constant.str "none"
|
// CHECK: %[[STR1:.*]] = torch.constant.str "none"
|
||||||
|
|
Loading…
Reference in New Issue