From 8a17c98b74b53ae71e19a9c7f7451af62dc339d9 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Wed, 31 Jan 2024 14:21:17 -0800 Subject: [PATCH] Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821) With the recent LLVM integrate and changes from https://github.com/llvm/llvm-project/pull/78260, we hit this build error in Stablehlo (which is quite old). ``` external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter' rewriter.startRootUpdate(op); ~~~~~~~~ ^ external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter' rewriter.finalizeRootUpdate(op); ~~~~~~~~ ^ external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter' rewriter.cancelRootUpdate(op); ~~~~~~~~ ^ external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter' rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; }); ~~~~~~~~ ^ 4 errors generated. Target @torch-mlir//:torch-mlir-opt failed to build ``` I'm still puzzled as to how this didn't fail with the CMake merge gating CI (do we not test Stablehlo builds/tests?). In any case, bumping our submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it. It exposes a new failing lit test in TorchToStablehlo though, that I have looped stablehlo developers into ([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)). ``` bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test ...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir within split at :1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference %0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64> ^ LLVM ERROR: Failed to infer result type(s). ``` Bazel CI: https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228 --- externals/stablehlo | 2 +- lib/Conversion/TorchToStablehlo/Basic.cpp | 18 ++--- .../TorchToStablehlo/GatherScatter.cpp | 10 ++- lib/Conversion/TorchToStablehlo/Linear.cpp | 34 +++------ lib/Conversion/TorchToStablehlo/Pooling.cpp | 73 +++++-------------- lib/Conversion/TorchToStablehlo/Reduction.cpp | 14 ++-- .../StablehloLegalizeUtils.cpp | 7 +- test/Conversion/TorchToStablehlo/pooling.mlir | 12 +-- 8 files changed, 61 insertions(+), 109 deletions(-) diff --git a/externals/stablehlo b/externals/stablehlo index ab709fe48..fd52182f7 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit ab709fe48de88c67717abfbd7ef17425eb95ddaf +Subproject commit fd52182f76cadb82f2064fe5fc49a4fb4347a826 diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 33db9ac9e..bee6c529b 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -377,12 +377,12 @@ public: if (!skipMultiplyAlpha(op.getAlpha())) { Value alpha = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getAlpha(), outElemTy); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rhs = rewriter.create(op->getLoc(), rhs, alpha, bcastDimensions); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); @@ -424,7 +424,7 @@ public: rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), outElemTy); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); @@ -542,7 +542,7 @@ public: } else { return op.emitError("operator haven't been supported"); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr, compareTypeAttr); @@ -570,7 +570,7 @@ public: Value rhs = hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); @@ -757,7 +757,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); rewriter.replaceOpWithNewOp( op, outType, self, bcastShapeTensor, - rewriter.getI64TensorAttr(dimensionNumbers)); + rewriter.getDenseI64ArrayAttr(dimensionNumbers)); } return success(); } @@ -887,7 +887,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!rhsType) { rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); @@ -1478,7 +1478,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value window = rewriter.create(loc, outType, resultLength, 0); - DenseIntElementsAttr broadcastDimensions; + DenseI64ArrayAttr broadcastDimensions; Value mulOut = rewriter.create(loc, window, step, broadcastDimensions); rewriter.replaceOpWithNewOp(op, mulOut, start, @@ -1721,7 +1721,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op->getLoc(), adaptor.getSelf()); Value bcastScalar = rewriter.create( op->getLoc(), outType, scalarTensor, shapeTensor, - rewriter.getI64TensorAttr({})); + rewriter.getDenseI64ArrayAttr({})); rewriter.replaceOp(op, bcastScalar); return success(); } diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index d2b0450cd..53c418da4 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -334,7 +334,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return failure(); auto stablehloReduceOp = rewriter.create( - op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0})); + op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}), + elementTy); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -510,7 +511,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, input, gatherIndicies, dimsAttr, - rewriter.getI64TensorAttr(sliceSizes)); + rewriter.getDenseI64ArrayAttr(sliceSizes)); return success(); } @@ -666,7 +667,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( /*indexVectorDim=*/indexVecDim); auto stablehloScatterOp = rewriter.create( - loc, input, scatterIndicies, src, scatterDimensionNumbers, false, false); + loc, inputType, input, scatterIndicies, src, scatterDimensionNumbers, + false, false); // config update computation function: just return the element from src. Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock(); @@ -833,7 +835,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, resultType, input, finalIndexTensor, dimsAttr, - rewriter.getI64TensorAttr(sliceSizes)); + rewriter.getDenseI64ArrayAttr(sliceSizes)); return success(); } diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index df9231782..b1749ee1c 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -39,10 +39,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, RankedTensorType outTy = RankedTensorType::get(shape, tensorTy.getElementType()); - RankedTensorType attrTy = - RankedTensorType::get({static_cast(broadcastDims.size())}, - rewriter.getIntegerType(64)); - auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims); + auto broadcastAttr = rewriter.getDenseI64ArrayAttr(broadcastDims); auto broadcast = rewriter.create( loc, outTy, tensor, stablehloShape, broadcastAttr); @@ -549,8 +546,7 @@ public: // Prepare for transposed convolution SmallVector stablehloStrideVec(nSpatialDims, 1); - DenseIntElementsAttr stablehloStride = - rewriter.getI64TensorAttr(stablehloStrideVec); + auto stablehloStride = rewriter.getDenseI64ArrayAttr(stablehloStrideVec); SmallVector stablehloPaddingVec(nSpatialDims * 2, 0); for (int i = 0; i < nSpatialDims; ++i) { int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i]; @@ -563,15 +559,15 @@ public: stablehloPaddingVec); SmallVector stablehloLhsDilationVec(nSpatialDims); std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin()); - DenseIntElementsAttr stablehloLhsDilation = - rewriter.getI64TensorAttr(stablehloLhsDilationVec); + auto stablehloLhsDilation = + rewriter.getDenseI64ArrayAttr(stablehloLhsDilationVec); SmallVector stablehloRhsDilationVec(nSpatialDims); std::copy(dilation.begin(), dilation.end(), stablehloRhsDilationVec.begin()); - DenseIntElementsAttr stablehloRhsDilation = - rewriter.getI64TensorAttr(stablehloRhsDilationVec); + auto stablehloRhsDilation = + rewriter.getDenseI64ArrayAttr(stablehloRhsDilationVec); - DenseElementsAttr windowReversal; + DenseBoolArrayAttr windowReversal; ArrayAttr precisionConfig; SmallVector spatialDims; @@ -614,10 +610,7 @@ public: int64_t nDims = outType.getRank(); // Get stablehlo::ConvolutionOp attributes - DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stride.size())}, - rewriter.getI64Type()), - stride); + auto stablehloWindowStride = rewriter.getDenseI64ArrayAttr(stride); std::vector stablehloPaddingVec; for (size_t i = 0; i < padding.size(); i++) { stablehloPaddingVec.emplace_back(padding[i]); @@ -628,10 +621,7 @@ public: {static_cast(padding.size()), static_cast(2)}, rewriter.getI64Type()), stablehloPaddingVec); - DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(dilation.size())}, - rewriter.getI64Type()), - dilation); + auto stablehloRhsDilation = rewriter.getDenseI64ArrayAttr(dilation); SmallVector spatialDimensions; for (int64_t i = 2; i < nDims; i++) { spatialDimensions.emplace_back(i); @@ -648,8 +638,8 @@ public: /*outputSpatialDimensions=*/spatialDimensions); // stablehlo::ConvolutionOp's optional attributes, leave them as default - DenseIntElementsAttr stablehloLhsDilation; - DenseElementsAttr windowReversal; + DenseI64ArrayAttr stablehloLhsDilation; + DenseBoolArrayAttr windowReversal; ArrayAttr precisionConfig; auto stablehloConvOp = rewriter.create( @@ -781,7 +771,7 @@ public: options.dimSizeIndexBits); bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, stablehloConvResult, bias, bcastDimensions); return success(); diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 7ef69ae67..40b0dd691 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -136,19 +136,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, @@ -242,19 +233,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, @@ -453,20 +435,10 @@ public: Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get( - {static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, @@ -508,7 +480,7 @@ public: .value(); } divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); return success(); @@ -528,7 +500,7 @@ public: windowSizeConst = rewriter.create( op->getLoc(), RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), - windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); + windowSizeConst, inputShapeTensor, rewriter.getDenseI64ArrayAttr({})); Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); auto reduceWindowSize = rewriter.create( @@ -599,19 +571,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector stablehloPadding(inputRank * 2, 0); stablehloPadding[dim * 2] = inputShape[dim] - 1; - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index f495aa395..e413fe532 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -130,7 +130,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initValue, initIndex, }, - rewriter.getI64TensorAttr(dim)); + rewriter.getDenseI64ArrayAttr(dim)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); @@ -412,7 +412,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -473,7 +473,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -535,7 +535,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -625,7 +625,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -729,7 +729,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto reduceOp = rewriter.create( op->getLoc(), squareOp.getResult(), initValue, - rewriter.getI64TensorAttr(dims)); + rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -848,7 +848,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ord, nullptr); auto reduceOp = rewriter.create( - op->getLoc(), powValue, initValue, rewriter.getI64TensorAttr(dims)); + op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index ed203cb0f..c3f8eff22 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -241,10 +241,7 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, if (!do_bcast) { return input; } - DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(bcastDims.size())}, - rewriter.getI64Type()), - bcastDims); + auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims); auto bcast_op = rewriter.create( op->getLoc(), outType, input, bcast_attr); return bcast_op.getResult(); @@ -360,7 +357,7 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc, auto constTensor = rewriter.create(loc, constAttr); return rewriter .create( - loc, outType, constTensor, shape, rewriter.getI64TensorAttr({})) + loc, outType, constTensor, shape, rewriter.getDenseI64ArrayAttr({})) .getResult(); } } // namespace hlo diff --git a/test/Conversion/TorchToStablehlo/pooling.mlir b/test/Conversion/TorchToStablehlo/pooling.mlir index fd531006d..b8fc6cbd8 100644 --- a/test/Conversion/TorchToStablehlo/pooling.mlir +++ b/test/Conversion/TorchToStablehlo/pooling.mlir @@ -18,7 +18,7 @@ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor -// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -51,7 +51,7 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor // CHECK: }) -// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -105,7 +105,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor, tensor // CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor, tensor // CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor, tensor -// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> // CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> @@ -141,7 +141,7 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): // CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor // CHECK: stablehlo.return %[[IVAL_2]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor @@ -162,7 +162,7 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: ^bb0(%[[IVAL_3:.*]]: tensor, %[[IVAL_4:.*]]: tensor): // CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor // CHECK: stablehlo.return %[[IVAL_5]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> @@ -198,7 +198,7 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: stablehlo.return %[[T10]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor // CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor, tensor) -> tensor