mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for BatchNormalization and Concat op.
This commit adds the OnnxToTorch support for BatchNormalization and Concat op. Signed-Off By: vivekkhandelwal1424@gmail.compull/2691/head snapshot-20231222.1060
parent
85b86b36a2
commit
9a72c6584e
|
@ -165,6 +165,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.op, resultType, operand);
|
binder.op, resultType, operand);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp("BatchNormalization", 15,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value input, weight, bias, runningMean, runningVar;
|
||||||
|
bool training;
|
||||||
|
float momentum, eps;
|
||||||
|
if (binder.s64BoolAttr(training, "training_mode", 0))
|
||||||
|
return failure();
|
||||||
|
if (training) {
|
||||||
|
// TODO: Add support for training = true
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unsupported conversion: training = true");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (binder.tensorOperandAtIndex(input, 0) ||
|
||||||
|
binder.tensorOperandAtIndex(weight, 1) ||
|
||||||
|
binder.tensorOperandAtIndex(bias, 2) ||
|
||||||
|
binder.tensorOperandAtIndex(runningMean, 3) ||
|
||||||
|
binder.tensorOperandAtIndex(runningVar, 4) ||
|
||||||
|
binder.f32FloatAttr(momentum, "momentum", 0.9) ||
|
||||||
|
binder.f32FloatAttr(eps, "epsilon", 1e-05) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
|
binder.getLoc(), false);
|
||||||
|
Value cstMomentum = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getF64FloatAttr(momentum));
|
||||||
|
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getF64FloatAttr(eps));
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenBatchNormOp>(
|
||||||
|
binder.op, resultType, input, weight, bias, runningMean,
|
||||||
|
runningVar, /*training=*/cstFalse, cstMomentum, cstEps,
|
||||||
|
/*cudnn_enabled=*/cstFalse);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"AveragePool", 19,
|
"AveragePool", 19,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
@ -426,6 +463,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
}
|
}
|
||||||
return failure();
|
return failure();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
SmallVector<Value> tensors;
|
||||||
|
int64_t dim;
|
||||||
|
if (binder.tensorOperands(tensors, binder.op->getNumOperands()) ||
|
||||||
|
binder.s64IntegerAttr(dim, "axis", 0) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
Type listElemType =
|
||||||
|
tensors[0]
|
||||||
|
.getType()
|
||||||
|
.cast<Torch::BaseTensorType>()
|
||||||
|
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
||||||
|
/*optionalDtype=*/nullptr);
|
||||||
|
Type listType = Torch::ListType::get(listElemType);
|
||||||
|
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.op->getLoc(), listType, tensors);
|
||||||
|
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(dim));
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
|
||||||
|
tensorList, cstDim);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
std::string autoPad;
|
std::string autoPad;
|
||||||
|
|
|
@ -592,3 +592,131 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc
|
||||||
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.pads = [1 : si64, 2 : si64, 1 : si64, 2 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32>
|
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.pads = [1 : si64, 2 : si64, 1 : si64, 2 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32>
|
||||||
return %0 : !torch.vtensor<[1,2,7,3],f32>
|
return %0 : !torch.vtensor<[1,2,7,3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_batchnorm_epsilon
|
||||||
|
func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[MOMENTUM:.*]] = torch.constant.float 0.89999997615814208
|
||||||
|
// CHECK: %[[EPS:.*]] = torch.constant.float 0.0099999997764825821
|
||||||
|
// CHECK: torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %[[FALSE]], %[[MOMENTUM]], %[[EPS]], %[[FALSE]] : !torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,4,5],f32>
|
||||||
|
%0 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.epsilon = 0.00999999977 : f32} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_batchnorm_example
|
||||||
|
func.func @test_batchnorm_example(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[MOMENTUM:.*]] = torch.constant.float 0.89999997615814208
|
||||||
|
// CHECK: %[[EPS:.*]] = torch.constant.float 9.9999997473787516E-6
|
||||||
|
// CHECK: torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %[[FALSE]], %[[MOMENTUM]], %[[EPS]], %[[FALSE]] : !torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,4,5],f32>
|
||||||
|
%0 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_1d_axis_0
|
||||||
|
func.func @test_concat_1d_axis_0(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32>
|
||||||
|
return %0 : !torch.vtensor<[4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_1d_axis_negative_1
|
||||||
|
func.func @test_concat_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int -1
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32>
|
||||||
|
return %0 : !torch.vtensor<[4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_2d_axis_0
|
||||||
|
func.func @test_concat_2d_axis_0(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_2d_axis_1
|
||||||
|
func.func @test_concat_2d_axis_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_2d_axis_negative_1
|
||||||
|
func.func @test_concat_2d_axis_negative_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int -1
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_2d_axis_negative_2
|
||||||
|
func.func @test_concat_2d_axis_negative_2(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int -2
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_3d_axis_0
|
||||||
|
func.func @test_concat_3d_axis_0(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2,2],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,2,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_3d_axis_1
|
||||||
|
func.func @test_concat_3d_axis_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4,2],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,4,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_3d_axis_2
|
||||||
|
func.func @test_concat_3d_axis_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,2,4],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,2,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_3d_axis_negative_1
|
||||||
|
func.func @test_concat_3d_axis_negative_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int -1
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,2,4],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,2,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_3d_axis_negative_2
|
||||||
|
func.func @test_concat_3d_axis_negative_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int -2
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,4,2],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,4,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_concat_3d_axis_negative_3
|
||||||
|
func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list<vtensor>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.constant.int -3
|
||||||
|
// CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4,2,2],f32>
|
||||||
|
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -3 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,2,2],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue