diff --git a/externals/llvm-project b/externals/llvm-project index 70e227a40..1e5f29af8 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 70e227a404e51f9248c7ad5d79953805b2afacb4 +Subproject commit 1e5f29af81a5f6fda308074f6345b9fba4faa71c diff --git a/externals/stablehlo b/externals/stablehlo index ab92adeda..c44d9af8d 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit ab92adeda9119a6c3914cd42367b0a2b70765e91 +Subproject commit c44d9af8d4879adccf1054cb61a53377ae5898cb diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 5854b1b7d..feae36f4f 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -101,6 +101,8 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, rewriter.getContext(), /*offsetDims=*/offsetDims, /*collapsedSliceDims=*/collapsedSliceDims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); @@ -584,6 +586,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getContext(), /*offsetDims=*/{}, /*collapsedSliceDims=*/collapsedDims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); @@ -744,6 +748,8 @@ public: rewriter.getContext(), /*updateWindowDims=*/{}, /*insertedWindowDims=*/insertedWindowDims, + /*inputBatchingDims=*/{}, + /*scatterIndicesBatchingDims=*/{}, /*scatterDimsToOperandDim=*/scatterDimOperandDimMap, /*indexVectorDim=*/indexVecDim); @@ -826,6 +832,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getContext(), /*offsetDims=*/offsetDims, /*collapsedSliceDims=*/collapsedDims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); @@ -900,6 +908,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getContext(), /*updateWindowDims=*/updateWindowDims, /*insertedWindowDims=*/insertedWindowDims, + /*inputBatchingDims=*/{}, + /*scatterIndicesBatchingDims=*/{}, /*scatterDimsToOperandDim=*/scatterDimOperandDimMap, /*indexVectorDim=*/indexVecDim); diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index 04952d843..5bb83d098 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -172,6 +172,7 @@ public: if (!rankType) return op.emitError("Only ranked tensor types are currently supported"); + // collect Value of dims SmallVector dimSizes; if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) { return op.emitError("Dims size must be a list of Scalar"); @@ -187,6 +188,20 @@ public: return success(); } + // collect constant dim size which == -1 + SmallVector negOneIndex; + for (size_t i = 0; i < dimSizes.size(); i++) { + int64_t dim; + if (matchPattern(dimSizes[i], m_TorchConstantInt(&dim))) { + if (dim == -1) { + negOneIndex.push_back(i); + } + } + } + if (negOneIndex.size() > 1) { + return op.emitError("Only support at most one -1 in view target dims"); + } + std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { dSize = rewriter.create(loc, dSize).getResult(); return dSize; @@ -194,16 +209,29 @@ public: Value numel = rewriter.create( loc, rewriter.create(loc, adaptor.getSelf())); + numel = + rewriter.create(loc, rewriter.getI64Type(), numel); + + // note: assuming that -1 doesn't arise from dynamic value + if (negOneIndex.size() == 1) { + size_t index = negOneIndex[0]; + Value realDim = numel; + for (size_t i = 0; i < dimSizes.size(); i++) { + if (i != index) { + realDim = rewriter.create(loc, realDim, dimSizes[i]); + } + } + // update -1 to realDim + dimSizes[index] = realDim; + } Value stablehloShape = rewriter.create(loc, dimSizes); - Value computedShape = rewriter.create( - loc, stablehloShape.getType(), numel, stablehloShape); rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), - adaptor.getSelf(), computedShape); + adaptor.getSelf(), stablehloShape); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ab162ab94..c1a9bc26a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1449,6 +1449,13 @@ STABLEHLO_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "ElementwiseDivTensorFloatModule_basic", + "ElementwiseMulTensorFloatModule_basic", + "ElementwiseWhereScalarSelfStaticModule_basic", + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", + "NativeGroupNormModule_basic", "AtenDotModule_basic", "ElementwiseFloatTensorGtIntScalarModule_basic", "ElementwiseLogSigmoidModule_basic", @@ -1946,7 +1953,6 @@ MAKE_FX_TOSA_PASS_SET = ( "Conv2dNoPaddingModule_basic", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", - "AtenInstanceNormModule_basic", # failed to legalize operation 'torch.operator' "ElementwisePreluModule_basic", "ElementwisePreluStaticModule_basic", diff --git a/test/Conversion/TorchToStablehlo/view_like.mlir b/test/Conversion/TorchToStablehlo/view_like.mlir index ab54d2764..3b0169036 100644 --- a/test/Conversion/TorchToStablehlo/view_like.mlir +++ b/test/Conversion/TorchToStablehlo/view_like.mlir @@ -310,11 +310,12 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]] // CHECK: %[[T4:.*]] = shape.shape_of %[[T0]] : tensor -> tensor<4xindex> // CHECK: %[[T5:.*]] = shape.num_elements %[[T4]] : tensor<4xindex> -> index -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T6:.*]] = stablehlo.compute_reshape_shape %[[T5]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64> -// CHECK: %[[T7:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T6]] : (tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,224],f32> -// CHECK: return %[[T8]] : !torch.vtensor<[?,224],f32> +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[T7:.*]] = arith.divui %[[T6]], %[[T3]] : i64 +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T7]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,224],f32> +// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32> func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { %int-1 = torch.constant.int -1 %int224 = torch.constant.int 224 @@ -339,11 +340,14 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]] // CHECK: %[[T6:.*]] = shape.shape_of %[[T0]] : tensor -> tensor<5xindex> // CHECK: %[[T7:.*]] = shape.num_elements %[[T6]] : tensor<5xindex> -> index -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> -// CHECK: %[[T8:.*]] = stablehlo.compute_reshape_shape %[[T7]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T8]] : (tensor, tensor<4xi64>) -> tensor -// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,120,4,64],f32> -// CHECK: return %[[T10]] : !torch.vtensor<[?,120,4,64],f32> +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[T9:.*]] = arith.divui %[[T8]], %[[T3]] : i64 +// CHECK: %[[T10:.*]] = arith.divui %[[T9]], %[[T4]] : i64 +// CHECK: %[[T11:.*]] = arith.divui %[[T10]], %[[T5]] : i64 +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T11]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> +// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor -> !torch.vtensor<[?,120,4,64],f32> +// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32> func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { %int-1 = torch.constant.int -1 %int120 = torch.constant.int 120