diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index b4bd102f1..8f6788620 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1379,7 +1379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value self; int64_t axis; - int64_t num_outputs; + int64_t numOutputs; if (binder.tensorOperand(self)) return rewriter.notifyMatchFailure( binder.op, "Not converting to AtenSplitTensorOp due to input " @@ -1387,49 +1387,65 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.s64IntegerAttr(axis, "axis", 0)) return rewriter.notifyMatchFailure(binder.op, "Failed to get axis attribute"); - if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0)) + if (binder.s64IntegerAttr(numOutputs, "num_outputs", 2)) return rewriter.notifyMatchFailure( binder.op, "Failed to get num_outputs attribute"); + auto loc = binder.getLoc(); auto result0Ty = binder.op->getResult(0).getType().cast(); + auto resultNTy = binder.op->getResults() + .back() + .getType() + .cast(); auto selfTy = self.getType().cast(); int64_t dim = axis; if (dim < 0) dim += selfTy.getSizes().size(); - // set intermediate shape to the shape of the first result - // if the results are of different shapes - // set the splitted axis to variable shape - llvm::SmallVector intermediateShape(result0Ty.getSizes()); - for (auto result : binder.op->getResultTypes()) { - int64_t d = cast(result).getSizes()[dim]; - intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; + Value dimValue = rewriter.create( + loc, rewriter.getType(), + rewriter.getI64IntegerAttr(dim)); + + Value vNumOutputs = rewriter.create( + loc, rewriter.getType(), + rewriter.getI64IntegerAttr(numOutputs)); + + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + Value vDimSize = rewriter.create( + loc, rewriter.getType(), self, dimValue); + + Value addNumOutputs = + rewriter.create(loc, vDimSize, vNumOutputs); + Value subOne = + rewriter.create(loc, addNumOutputs, one); + Value splitSize = + rewriter.create(loc, subOne, vNumOutputs); + + llvm::SmallVector outputs; + Value step = one; + Value start = zero; + + for (int i = 0; i < numOutputs - 1; ++i) { + Value end = + rewriter.create(loc, start, splitSize); + Value slice = rewriter.create( + loc, result0Ty, self, dimValue, start, end, step); + start = end; + outputs.push_back(slice); } - Value dimValue = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); + Value end = vDimSize; + Value lastSlice = rewriter.create( + loc, resultNTy, self, dimValue, start, end, step); + outputs.push_back(lastSlice); - Value splitSize = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), num_outputs)); - - // TODO: Attempting to use the shape expected by the ONNX mlir as ground - // truth. For now just use dynamic shapes. - auto resultOuterType = - Torch::ListType::get(rewriter.getType( - /*std::optional>=*/intermediateShape, - result0Ty.getOptionalDtype())); - Torch::AtenSplitTensorOp new_op = - rewriter.create( - binder.getLoc(), resultOuterType, self, splitSize, dimValue); - - // the onnx op is variadic with multiple results, but AtenSplitWithSizes - // outputs a list so we need to unpack the list - rewriter.replaceOpWithNewOp( - binder.op, binder.op->getResults().getType(), new_op.getResult()); + rewriter.replaceOp(binder.op, outputs); return success(); }); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 35327f367..ec4d3a8da 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2718,7 +2718,6 @@ ONNX_XFAIL_SET = { "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", - "GluStaticModule_basic", "GroupNormModule_basic", "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 0fdecd684..47497d5ea 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1288,12 +1288,20 @@ func.func @test_split_variable_parts_2d_opset18(%arg0: !torch.vtensor<[2,6],f32> // CHECK-LABEL: func.func @test_split_2d_uneven_split_opset18( // CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { -// CHECK: %[[AXIS:.*]] = torch.constant.int 1 -// CHECK: %[[SPLIT_SIZE:.*]] = torch.constant.int 3 -// CHECK: %[[SPLIT_RESULT:.*]] = torch.aten.split.Tensor %[[INPUT_TENSOR]], %[[SPLIT_SIZE]], %[[AXIS]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int -> !torch.list> -// CHECK: %[[UNPACKED_TENSORS:.*]]:3 = torch.prim.ListUnpack %[[SPLIT_RESULT]] : !torch.list> -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> -// CHECK: return %[[UNPACKED_TENSORS]]#0, %[[UNPACKED_TENSORS]]#1, %[[UNPACKED_TENSORS]]#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> -// CHECK: } +// CHECK-DAG: %[[DIM:.+]] = torch.constant.int 1 +// CHECK-DAG: %[[SPLITS:.+]] = torch.constant.int 3 +// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 +// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 +// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[DIM]] +// CHECK-DAG: %[[ADD:.+]] = torch.aten.add.int %[[SZ1]], %[[SPLITS]] +// CHECK-DAG: %[[SUB:.+]] = torch.aten.sub.int %[[ADD]], %[[ONE]] +// CHECK-DAG: %[[SLICESZ:.+]] = torch.aten.floordiv.int %[[SUB]], %[[SPLITS]] +// CHECK-DAG: %[[START1:.+]] = torch.aten.add.int %[[ZERO]], %[[SLICESZ]] : !torch.int, !torch.int -> !torch.int +// CHECK-DAG: %[[SLICE0:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[ZERO]], %[[START1]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +// CHECK-DAG: %[[START2:.+]] = torch.aten.add.int %[[START1]], %[[SLICESZ]] : !torch.int, !torch.int -> !torch.int +// CHECK-DAG: %[[SLICE1:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[START1]], %[[START2]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +// CHECK-DAG: %[[SLICE2:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[START2]], %[[SZ1]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> +// CHECK: return %[[SLICE0]], %[[SLICE1]], %[[SLICE2]] func.func @test_split_2d_uneven_split_opset18(%arg0: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0:3 = torch.operator "onnx.Split"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.num_outputs = 3 : si64} : (!torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) return %0#0, %0#1, %0#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>