mirror of https://github.com/llvm/torch-mlir
[ONNX] Add basic support for RoiAlign (#3493)
This adds an onnx->torch conversion for onnx.RoiAlign into torchvision.roi_align or torchvision.roi_pool, and adds those two torchvision ops to torch-mlir.pull/3495/head
parent
02340408b7
commit
e346c911f7
|
@ -16660,3 +16660,60 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$input,
|
||||||
|
AnyTorchTensorType:$rois,
|
||||||
|
Torch_FloatType:$spatial_scale,
|
||||||
|
Torch_IntType:$pooled_height,
|
||||||
|
Torch_IntType:$pooled_width,
|
||||||
|
Torch_IntType:$sampling_ratio,
|
||||||
|
Torch_BoolType:$aligned
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult TorchvisionRoiAlignOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||||
|
}
|
||||||
|
void TorchvisionRoiAlignOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 7, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_TorchvisionRoiPoolOp : Torch_Op<"torchvision.roi_pool", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$input,
|
||||||
|
AnyTorchTensorType:$rois,
|
||||||
|
Torch_FloatType:$spatial_scale,
|
||||||
|
Torch_IntType:$pooled_height,
|
||||||
|
Torch_IntType:$pooled_width
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result0,
|
||||||
|
AnyTorchOptionalTensorType:$result1
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult TorchvisionRoiPoolOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 5, 2);
|
||||||
|
}
|
||||||
|
void TorchvisionRoiPoolOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 5, 2);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -2953,6 +2953,104 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
/*Torch_BoolType:$antialias*/ cstFalse);
|
/*Torch_BoolType:$antialias*/ cstFalse);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"RoiAlign", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
// operands = input, rois, batch_indices
|
||||||
|
SmallVector<Value> operands;
|
||||||
|
std::string coordTfMode, mode;
|
||||||
|
int64_t outHInt, outWInt, samplingRatioInt;
|
||||||
|
float spatialScaleFloat;
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
if (binder.tensorOperands(operands, 3) ||
|
||||||
|
binder.customOpNameStringAttr(
|
||||||
|
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
|
||||||
|
binder.customOpNameStringAttr(mode, "mode", "avg") ||
|
||||||
|
binder.s64IntegerAttr(outHInt, "output_height", 1) ||
|
||||||
|
binder.s64IntegerAttr(outWInt, "output_width", 1) ||
|
||||||
|
binder.s64IntegerAttr(samplingRatioInt, "sampling_ratio", 0) ||
|
||||||
|
binder.f32FloatAttr(spatialScaleFloat, "spatial_scale", 1.0f) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
Value input = operands[0];
|
||||||
|
Value rois = operands[1];
|
||||||
|
Value batchIndices = operands[2];
|
||||||
|
|
||||||
|
// the torchvision roi_pool op does not support these features:
|
||||||
|
if (mode == "max" &&
|
||||||
|
(coordTfMode != "half_pixel" || samplingRatioInt != 0))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unsupported: roi max pooling without default "
|
||||||
|
"coordTfMode and sampling_ratio");
|
||||||
|
|
||||||
|
Location loc = binder.getLoc();
|
||||||
|
// concatenate the batchIndices to the rois to get rois as a num_roisx5
|
||||||
|
// tensor. The batchIndices tensor is an int64 tensor, and needs to be
|
||||||
|
// converted to float before concatenation.
|
||||||
|
auto roisType = dyn_cast<Torch::ValueTensorType>(rois.getType());
|
||||||
|
if (!roisType || !roisType.hasSizes())
|
||||||
|
return failure();
|
||||||
|
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||||
|
FailureOr<Value> unsqueezeIndices =
|
||||||
|
Torch::unsqueezeTensor(rewriter, binder.op, batchIndices, cstDim);
|
||||||
|
if (failed(unsqueezeIndices))
|
||||||
|
return failure();
|
||||||
|
batchIndices = unsqueezeIndices.value();
|
||||||
|
auto batchIndicesType =
|
||||||
|
cast<Torch::ValueTensorType>(batchIndices.getType());
|
||||||
|
Value dTypeInt =
|
||||||
|
Torch::getDtypeIntValueForType(rewriter, loc, roisType.getDtype());
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
Value cstFalse =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||||
|
Value newBatchIndices = rewriter.create<Torch::AtenToDtypeOp>(
|
||||||
|
loc,
|
||||||
|
batchIndicesType.getWithSizesAndDtype(
|
||||||
|
batchIndicesType.getOptionalSizes(),
|
||||||
|
roisType.getOptionalDtype()),
|
||||||
|
batchIndices, dTypeInt, cstFalse, cstFalse, none);
|
||||||
|
SmallVector<int64_t> roiSizes(roisType.getSizes());
|
||||||
|
roiSizes.back() = 5;
|
||||||
|
auto catType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
roiSizes, roisType.getDtype());
|
||||||
|
Type listElemType =
|
||||||
|
roisType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
||||||
|
/*optionalDtype=*/nullptr);
|
||||||
|
Type listType = Torch::ListType::get(listElemType);
|
||||||
|
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.op->getLoc(), listType, ValueRange{newBatchIndices, rois});
|
||||||
|
Value newRois =
|
||||||
|
rewriter.create<Torch::AtenCatOp>(loc, catType, tensorList, cstDim);
|
||||||
|
|
||||||
|
// make constants from attributes
|
||||||
|
Value cstSpatialScale = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
loc, rewriter.getF64FloatAttr(spatialScaleFloat));
|
||||||
|
Value pooledHeight = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(outHInt));
|
||||||
|
Value pooledWidth = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(outWInt));
|
||||||
|
// this is for consistency with the default pytorch sampling ratio value
|
||||||
|
samplingRatioInt = (samplingRatioInt == 0) ? -1 : samplingRatioInt;
|
||||||
|
Value samplingRatio = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(samplingRatioInt));
|
||||||
|
bool aligned = coordTfMode == "half_pixel";
|
||||||
|
Value cstAligned = rewriter.create<Torch::ConstantBoolOp>(loc, aligned);
|
||||||
|
|
||||||
|
if (mode == "avg") {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::TorchvisionRoiAlignOp>(
|
||||||
|
binder.op, resultType, input, newRois, cstSpatialScale,
|
||||||
|
pooledHeight, pooledWidth, samplingRatio, cstAligned);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
// mode == "max"
|
||||||
|
auto indicesType = resultType.getWithSizesAndDtype(
|
||||||
|
resultType.getOptionalSizes(), batchIndicesType.getDtype());
|
||||||
|
auto roiPool = rewriter.create<Torch::TorchvisionRoiPoolOp>(
|
||||||
|
loc, TypeRange{resultType, indicesType}, input, newRois,
|
||||||
|
cstSpatialScale, pooledHeight, pooledWidth);
|
||||||
|
rewriter.replaceOp(binder.op, roiPool.getResult(0));
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"SpaceToDepth", 1,
|
"SpaceToDepth", 1,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
|
|
@ -6256,6 +6256,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.torchvision.roi_align\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" return %2 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_align\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.int {\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.torchvision.roi_pool\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %3 = torch.prim.TupleConstruct %2, %2 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" return %3 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_pool\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<int, int> {\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
|
||||||
|
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
|
" return %1 : !torch.tuple<int, int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||||
" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n"
|
" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n"
|
||||||
" %true = torch.constant.bool true\n"
|
" %true = torch.constant.bool true\n"
|
||||||
|
|
|
@ -8,6 +8,7 @@ import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
from torch import device
|
from torch import device
|
||||||
import torch.jit._shape_functions as upstream_shape_functions
|
import torch.jit._shape_functions as upstream_shape_functions
|
||||||
|
|
||||||
|
@ -85,6 +86,20 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
|
||||||
def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]:
|
def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
|
||||||
|
def torchvision〇roi_align〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> List[int]:
|
||||||
|
return [rois[0], input[1], pooled_height, pooled_width]
|
||||||
|
|
||||||
|
def torchvision〇roi_align〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> int:
|
||||||
|
return input_rank_dtype[1]
|
||||||
|
|
||||||
|
def torchvision〇roi_pool〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[List[int], List[int]]:
|
||||||
|
output = [rois[0], input[1], pooled_height, pooled_width]
|
||||||
|
return (output, output)
|
||||||
|
|
||||||
|
def torchvision〇roi_pool〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[int, int]:
|
||||||
|
return (input_rank_dtype[1], torch.int64)
|
||||||
|
|
||||||
@check_shape_function([
|
@check_shape_function([
|
||||||
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
|
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
|
||||||
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`.
|
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`.
|
||||||
|
|
|
@ -1155,6 +1155,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
traits=["HasValueSemantics"],
|
traits=["HasValueSemantics"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
emit(
|
||||||
|
"torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def dump_registered_ops(outfile: TextIO, registry: Registry):
|
def dump_registered_ops(outfile: TextIO, registry: Registry):
|
||||||
for _, v in sorted(registry.by_unique_key.items()):
|
for _, v in sorted(registry.by_unique_key.items()):
|
||||||
|
@ -1173,6 +1180,8 @@ def _maybe_import_op_extensions(args: argparse.Namespace):
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
_maybe_import_op_extensions(args)
|
_maybe_import_op_extensions(args)
|
||||||
|
import torchvision
|
||||||
|
|
||||||
registry = Registry.load()
|
registry = Registry.load()
|
||||||
if args.debug_registry_dump:
|
if args.debug_registry_dump:
|
||||||
with open(args.debug_registry_dump, "w") as debug_registry_dump:
|
with open(args.debug_registry_dump, "w") as debug_registry_dump:
|
||||||
|
|
|
@ -2207,6 +2207,37 @@ f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_ve
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_roialign_avg
|
||||||
|
func.func @test_roialign_avg(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[Dim:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]]
|
||||||
|
// CHECK: %[[cst6:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]]
|
||||||
|
// CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1
|
||||||
|
// CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]]
|
||||||
|
// CHECK: %[[Align:.*]] = torch.torchvision.roi_align %arg0, %[[Cat]]
|
||||||
|
%0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "output_half_pixel", torch.onnx.mode = "avg", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[30,2,5,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_roialign_max
|
||||||
|
func.func @test_roialign_max(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[Dim:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]]
|
||||||
|
// CHECK: %[[cst6:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]]
|
||||||
|
// CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1
|
||||||
|
// CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]]
|
||||||
|
// CHECK: %[[Pool:.*]], %[[Indices:.*]] = torch.torchvision.roi_pool %arg0, %[[Cat]]
|
||||||
|
// CHECK: return %[[Pool]]
|
||||||
|
%0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "max", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[30,2,5,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_spacetodepth_example
|
// CHECK-LABEL: @test_spacetodepth_example
|
||||||
func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
|
Loading…
Reference in New Issue