diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 785c631c1..c976b4984 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1495,23 +1495,34 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( patterns.onOp( "Expand", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // uses ideas and code from onnx.Reshape + auto loc = binder.getLoc(); Torch::ValueTensorType resultType; Value data, shape; if (binder.tensorOperands(data, shape) || binder.tensorResultType(resultType)) return failure(); - Torch::BaseTensorType shapeType = - shape.getType().cast(); + + auto dataType = cast(data.getType()); + auto shapeType = cast(shape.getType()); + if (!dataType.hasSizes() || !shapeType.hasSizes()) + return failure(); + + auto shapeSizes = shapeType.getSizes(); + int64_t dataRank = dataType.getSizes().size(); + int64_t shapeRank = shapeSizes.size(); + if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize) + return failure(); + + auto rankDifference = dataRank - shapeSizes[0]; + SmallVector selectSizes; Type selectResultType = shapeType.getWithSizesAndDtype( llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); // Variable to store 1-D onnx shape tensor, shapeSizes[0] has the // dimension size - auto shapeSizes = - dyn_cast(shape.getType()).getSizes(); // A constant zero value Value zero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + loc, rewriter.getI64IntegerAttr(0)); // Variable to store pytorch int list of shape (dimension) SmallVector dimList; @@ -1520,12 +1531,21 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // takes list of int for (int i = 0; i < shapeSizes[0]; i++) { Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), + loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); Value extract = rewriter.create( - binder.getLoc(), selectResultType, shape, zero, selectIndex); + loc, selectResultType, shape, zero, selectIndex); Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); + loc, rewriter.getType(), extract); + + if (i + rankDifference >= 0) { + Value iv = + rewriter.create(loc, i + rankDifference); + auto sz = rewriter.create( + loc, rewriter.getType(), data, iv); + dim = rewriter.create(loc, dim, sz); + } + dimList.push_back(dim); } Value dimValueList = rewriter.create( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 695b51c18..dd4976018 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1507,8 +1507,6 @@ ONNX_XFAIL_SET = { "ArangeStartOutDtypeModule_basic", "ArangeStartOutViewModule_basic", "BroadcastDynamicDimModule_basic", - "BroadcastToModule_basic", - "ExpandModule_basic", "MoveDimIntNegativeIndexModule_basic", "ViewSizeFromOtherTensor_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 2c013553b..1e816a38e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1164,15 +1164,21 @@ func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f3 // CHECK-LABEL: @test_expand_dim2_shape2 func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> - // CHECK: torch.aten.item %0 : !torch.vtensor<[],si32> -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> - // CHECK: torch.aten.item %2 : !torch.vtensor<[],si32> -> !torch.int - // CHECK: torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.broadcast_to %arg0, %4 : !torch.vtensor<[1,4],f32>, !torch.list -> !torch.vtensor<[3,4],f32> + // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> + // CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int + // CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> + // CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int + // CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list -> !torch.vtensor<[3,4],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> return %0 : !torch.vtensor<[3,4],f32> } @@ -1181,47 +1187,31 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor // CHECK-LABEL: @test_expand_dim2_shape3 func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %2 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %4 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: torch.prim.ListConstruct %1, %3, %5 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.broadcast_to %arg0, %6 : !torch.vtensor<[3,1],f32>, !torch.list -> !torch.vtensor<[2,3,6],f32> + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK-NEXT: %[[I0_0:.+]] = torch.constant.int 0 + // CHECK-NEXT: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I0_0]] + // CHECK-NEXT: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] + // CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1 + // CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]] + // CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] + // CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0 + // CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]] + // CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int + // CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2 + // CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]] + // CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]] + // CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1 + // CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]] + // CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]] + // CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]] + // CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]] + // CHECK: return %[[EXPAND]] %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> return %0 : !torch.vtensor<[2,3,6],f32> } // ----- -// CHECK-LABEL: @test_expand_dim3_shape4 -func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %2 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %4 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: torch.prim.ListConstruct %1, %3, %5, %7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %9 = torch.aten.broadcast_to %arg0, %8 : !torch.vtensor<[1,3,1],f32>, !torch.list -> !torch.vtensor<[3,3,3,3],f32> - %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,3,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> - return %0 : !torch.vtensor<[3,3,3,3],f32> -} - -// ----- - // CHECK-LABEL: @test_dropout func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32