From 3225f20ab19db12e532e51fe60a9ff78b48be880 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sun, 7 Jul 2024 18:03:03 +0800 Subject: [PATCH] [Stablehlo] use index type as dim size, avoid to generate index_cast (#3526) For example, the original IR is: ``` module attributes {torch.debug_module_name = "Matmul3D"} { func.func @forward(%arg0: tensor, %arg1: tensor) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %dim = tensor.dim %arg1, %c0 : tensor %0 = arith.index_cast %dim : index to i64 %dim_0 = tensor.dim %arg1, %c1 : tensor %1 = arith.index_cast %dim_0 : index to i64 %dim_1 = tensor.dim %arg1, %c2 : tensor %2 = arith.index_cast %dim_1 : index to i64 %from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64> %3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor, tensor<3xi64>) -> tensor %4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor, tensor) -> tensor return %4 : tensor } } ``` After using IndexType, the IR is: ``` module attributes {torch.debug_module_name = "Matmul3D"} { func.func @forward(%arg0: tensor, %arg1: tensor) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %dim = tensor.dim %arg1, %c0 : tensor %dim_0 = tensor.dim %arg1, %c1 : tensor %dim_1 = tensor.dim %arg1, %c2 : tensor %from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex> %0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor, tensor) -> tensor return %1 : tensor } } ``` The benefits of using IndexType on shape tensor: * simplify the IR, avoid to generate `arith.index_cast` * let backend compiler have a chance to decide the index width of shape tensor * let stablehlo backend have a chance to serialize dynamic shape IR by [shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir) --- .../TorchToStablehlo/StablehloLegalizeUtils.h | 18 +++-- lib/Conversion/TorchToStablehlo/Basic.cpp | 17 ++--- .../TorchToStablehlo/GatherScatter.cpp | 23 ++---- lib/Conversion/TorchToStablehlo/Linear.cpp | 31 ++++---- lib/Conversion/TorchToStablehlo/Pooling.cpp | 8 +-- lib/Conversion/TorchToStablehlo/Reduction.cpp | 66 +++++++---------- .../StablehloLegalizeUtils.cpp | 67 ++++++++++++----- lib/Conversion/TorchToStablehlo/ViewLike.cpp | 18 +++-- test/Conversion/TorchToStablehlo/linear.mlir | 72 +++++++------------ test/Conversion/TorchToStablehlo/pooling.mlir | 21 ++---- test/Conversion/TorchToStablehlo/scatter.mlir | 20 +++--- .../TorchToStablehlo/view_like.mlir | 53 ++++---------- 12 files changed, 176 insertions(+), 238 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index 3abe16fbf..6efa11f8b 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -68,21 +68,29 @@ FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, size_t dimSizeIndexBits); +// Get the dimension sizes of the input tensor, given the dimension axes +FailureOr> getDimIndexOfTensor(PatternRewriter &rewriter, + Operation *op, Value value, + ArrayRef inpDims); + +// Get the dimension sizes of the input tensor +FailureOr> +getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value); + // Get a tensor that unsqueezed the specified dimensions of the input tensor FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, - Value tensor, ArrayRef inputUnsqzDims, - size_t dimSizeIndexBits); + Value tensor, + ArrayRef inputUnsqzDims); // Get a tensor that collapse the specified dimensions of the input tensor FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, Value tensor, int64_t collapseStartDim, - int64_t collapseEndDim, - size_t dimSizeIndexBits); + int64_t collapseEndDim); // Get a tensor that splits the specified dimensions of the input tensor FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, Value tensor, int64_t splitDim, - int64_t outerLength, size_t dimSizeIndexBits); + int64_t outerLength); Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 644d28cc0..db7c26565 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -35,8 +35,7 @@ using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, - mlir::Value &self, mlir::Value &other, - size_t dimSizeIndexBits) { + mlir::Value &self, mlir::Value &other) { auto selfTy = dyn_cast(self.getType()); auto otherTy = dyn_cast(other.getType()); auto selfRank = selfTy.getRank(); @@ -46,16 +45,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, if (selfRank > otherRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, selfRank - otherRank)); - auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other, - unsqueezeDims, dimSizeIndexBits); + auto unsqueezeInfo = + hlo::unsqueezeTensor(rewriter, op, other, unsqueezeDims); if (failed(unsqueezeInfo)) return failure(); other = *unsqueezeInfo; } else if (otherRank > selfRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, otherRank - selfRank)); - auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims, - dimSizeIndexBits); + auto unsqueezeInfo = + hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims); if (failed(unsqueezeInfo)) return failure(); self = *unsqueezeInfo; @@ -740,12 +739,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( self = hlo::promoteType(rewriter, op.getLoc(), self, outType); other = hlo::promoteType(rewriter, op.getLoc(), other, outType); - if (failed( - broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits))) + if (failed(broadcastRanks(rewriter, op, self, cond))) return op.emitError("failed broadcast self and condition ranks"); - if (failed( - broadcastRanks(rewriter, op, other, cond, options.dimSizeIndexBits))) + if (failed(broadcastRanks(rewriter, op, other, cond))) return op.emitError("failed broadcast other and condition ranks"); rewriter.replaceOpWithNewOp( diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 05c52483c..bba8b7438 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -438,16 +438,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op->getLoc(), addResult); } - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, weight, options.dimSizeIndexBits); + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, weight); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; auto one = rewriter.create( - op->getLoc(), rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + op->getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); outShapeVec[0] = one; auto outShapeTensor = rewriter.create(op->getLoc(), outShapeVec); @@ -537,16 +535,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "only constant boolean `sparse_grad` param supported"); } - auto options = getOptions(); - auto indexShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); + auto indexShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, index); if (failed(indexShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dim sizes of `index` param"); } - auto intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto one = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); auto toConcatIndexShapeValueVec = *indexShapeInfo; toConcatIndexShapeValueVec.push_back(one); auto toConcatIndexShape = @@ -672,24 +667,20 @@ public: return rewriter.notifyMatchFailure(op, "invalid `dim` param detected"); } - auto options = this->getOptions(); - - auto indexShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); + auto indexShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, index); if (failed(indexShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dim sizes of `index` param"); } - auto intType = rewriter.getIntegerType(options.dimSizeIndexBits); // slice src tensor to have the same shape bound of index tensor in the // leading dimensions. PyTorch has guaranteed that src tensor size will not // be smaller than that of index tensor. REF: // https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_ auto zero = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 0)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); auto one = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); SmallVector sliceIndicies(srcType.getRank(), zero); SmallVector sliceStrides(srcType.getRank(), one); diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index e2c2f9a66..6237db281 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -148,10 +148,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, std::vector newShape(rhsShape.begin(), rhsShape.begin() + leadingRank); newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); - auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims, - dimSizeIndexBits); - auto lhsDimSizes = - *hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); + auto newDimSizes = + *hlo::getDimIndexOfTensor(rewriter, op, rhs, leadingDims); + auto lhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, lhs); newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), lhsDimSizes.end()); lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, @@ -160,10 +159,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, std::vector newShape(lhsShape.begin(), lhsShape.begin() + leadingRank); newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); - auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims, - dimSizeIndexBits); - auto rhsDimSizes = - *hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); + auto newDimSizes = + *hlo::getDimIndexOfTensor(rewriter, op, lhs, leadingDims); + auto rhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, rhs); newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), rhsDimSizes.end()); rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, @@ -207,10 +205,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, return; } - auto lhsDimSizes = - *hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); - auto rhsDimSizes = - *hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); + auto lhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, lhs); + auto rhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, rhs); if (!lhsBroadcastDims.empty()) { SmallVector lhsNewShape(newBatchShape); @@ -526,16 +522,15 @@ public: auto weightTy = cast(weight.getType()); auto weightElemTy = weightTy.getElementType(); auto rank = weightTy.getRank(); - const auto &options = getOptions(); - SmallVector weightShapeVec = *hlo::getDimSizesOfTensor( - rewriter, op, weight, options.dimSizeIndexBits); + SmallVector weightShapeVec = + *hlo::getDimIndexOfTensor(rewriter, op, weight); auto weightShape = weightTy.getShape(); SmallVector weightShapeInt(rank); std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin()); // 1. [H, W, ..., OC, IC] => [H, W, ..., OC, G, IC//G] Value GValue = rewriter.create( - op->getLoc(), rewriter.getI64IntegerAttr(groups)); + op->getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), groups)); Value ICDivGValue = rewriter.create( op->getLoc(), weightShapeVec[rank - 1], GValue); Value OCMulGValue = rewriter.create( @@ -839,9 +834,7 @@ public: auto inputUnsqzDims = llvm::to_vector<4>(llvm::seq(-nSpatialDims, 0)); - const auto &options = getOptions(); - bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, - options.dimSizeIndexBits); + bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims); bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy); DenseI64ArrayAttr bcastDimensions; diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index a52d4e719..4b6d677a5 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -146,9 +146,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getI64Type()), stablehloPadding); - const auto &options = getOptions(); - auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -536,9 +534,7 @@ public: hlo::getConstTensor(rewriter, op, {1.0}, {}).value(); windowSizeConst = hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy); - const auto &options = ConvertAtenOp::getOptions(); - auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input, - options.dimSizeIndexBits); + auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index bc77a860a..f2e8086de 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -310,12 +310,10 @@ static Value reshapeReduceResultWhenKeepDim(ConversionPatternRewriter &rewriter, Location loc, Value reduceResult, ArrayRef inputShapeVec, Type outType, - ArrayRef dims, - size_t dimSizeIndexBits) { + ArrayRef dims) { SmallVector outShapeVec(inputShapeVec); Value one = rewriter.create( - loc, - rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); for (auto dim : dims) { outShapeVec[dim] = one; } @@ -432,16 +430,13 @@ public: } if (keepDim) { - const auto &options = ConvertAtenReductionOp::getOptions(); - auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, - options.dimSizeIndexBits); + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } reduceResult = reshapeReduceResultWhenKeepDim( - rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim}, - options.dimSizeIndexBits); + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim}); } rewriter.replaceOp(op, reduceResult); return success(); @@ -512,16 +507,13 @@ public: } if (keepDim) { - const auto &options = ConvertAtenReductionOp::getOptions(); - auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, - options.dimSizeIndexBits); + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } reduceResult = reshapeReduceResultWhenKeepDim( - rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims, - options.dimSizeIndexBits); + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims); } rewriter.replaceOp(op, reduceResult); return success(); @@ -573,8 +565,7 @@ public: } const auto &options = ConvertAtenReductionOp::getOptions(); - auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -592,9 +583,9 @@ public: } if (keepDim) { - reduceResult = reshapeReduceResultWhenKeepDim( - rewriter, op->getLoc(), reduceResult, inputShapeVec, valResultType, - {dim}, options.dimSizeIndexBits); + reduceResult = + reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), reduceResult, + inputShapeVec, valResultType, {dim}); } rewriter.replaceOp(op, {reduceResult, Value()}); return success(); @@ -603,16 +594,16 @@ public: createReduceOpReturnIndices(rewriter, op, input, inputShapeVec, dim, options.dimSizeIndexBits) .value(); + SmallVector reduceResults(stablehloReduceResults); if (keepDim) { - stablehloReduceResults[0] = reshapeReduceResultWhenKeepDim( - rewriter, op->getLoc(), stablehloReduceResults[0], inputShapeVec, - valResultType, {dim}, options.dimSizeIndexBits); - stablehloReduceResults[1] = reshapeReduceResultWhenKeepDim( - rewriter, op->getLoc(), stablehloReduceResults[1], inputShapeVec, - idxResultType, {dim}, options.dimSizeIndexBits); + reduceResults[0] = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResults[0], inputShapeVec, + valResultType, {dim}); + reduceResults[1] = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResults[1], inputShapeVec, + idxResultType, {dim}); } - rewriter.replaceOp( - op, {stablehloReduceResults[0], stablehloReduceResults[1]}); + rewriter.replaceOp(op, reduceResults); return success(); } }; @@ -685,16 +676,13 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } if (keepDim) { - const auto &options = getOptions(); - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } reduceResult = reshapeReduceResultWhenKeepDim( - rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims, - options.dimSizeIndexBits); + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims); } rewriter.replaceOp(op, reduceResult); return success(); @@ -709,7 +697,6 @@ template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenFrobeniusNormDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - const TorchToStablehloOptions &options = getOptions(); Value input = adaptor.getSelf(); auto inputType = dyn_cast(input.getType()); @@ -761,16 +748,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value output = rewriter.create(op->getLoc(), reduceResult); if (keepDim) { - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } output = reshapeReduceResultWhenKeepDim( rewriter, op->getLoc(), output, *outShapeInfo, - getTypeConverter()->convertType(op.getType()), dims, - options.dimSizeIndexBits); + getTypeConverter()->convertType(op.getType()), dims); } rewriter.replaceOp(op, output); return success(); @@ -783,7 +768,6 @@ template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenLinalgVectorNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - const TorchToStablehloOptions &options = getOptions(); Value input = adaptor.getSelf(); auto inputType = dyn_cast(input.getType()); @@ -861,15 +845,13 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op->getLoc(), reduceResult, reciprocalOrd, nullptr); if (keepDim) { - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } output = reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), output, - *outShapeInfo, outType, dims, - options.dimSizeIndexBits); + *outShapeInfo, outType, dims); } rewriter.replaceOp(op, output); return success(); diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index 179d55194..113e94be5 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -279,9 +279,47 @@ FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits); } +// Get the dimension sizes of the input tensor, given the dimension axes +FailureOr> +getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value, + ArrayRef inpDims) { + auto valueTy = dyn_cast(value.getType()); + if (!valueTy) { + return rewriter.notifyMatchFailure( + op, "getDimIndexOfTensor(): the input is not a ranked tensor"); + } + + auto rank = valueTy.getRank(); + auto dims = toPositiveDims(inpDims, rank); + SmallVector dimSizes; + dimSizes.reserve(dims.size()); + + auto loc = op->getLoc(); + for (auto d : dims) { + dimSizes.emplace_back(rewriter.create(loc, value, d)); + } + return dimSizes; +} + +// Get the dimension sizes of the input tensor +FailureOr> +getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { + auto valueTy = dyn_cast(value.getType()); + if (!valueTy) { + return rewriter.notifyMatchFailure( + op, "getDimIndexOfTensor(): the input is not a ranked tensor"); + } + + auto rank = valueTy.getRank(); + // Get int vector [0, 1, ..., rank-1] + std::vector dims(rank); + std::iota(dims.begin(), dims.end(), 0); + return getDimIndexOfTensor(rewriter, op, value, dims); +} + FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, - Value tensor, ArrayRef inputUnsqzDims, - size_t dimSizeIndexBits) { + Value tensor, + ArrayRef inputUnsqzDims) { // Returns a new tensor with dims of size 1 inserted at the specified // position. // @@ -289,8 +327,7 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, // tensor) are specified with unsqzDims. Indices must be in-order, and in // range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1, // 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not. - auto dimSizesInfo = - getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -307,9 +344,8 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, auto loc = op->getLoc(); auto rankTy = dyn_cast(tensor.getType()); auto oldShape = rankTy.getShape(); - Type intType = rewriter.getIntegerType(dimSizeIndexBits); auto one = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); std::vector newDimSizes; std::vector newShape; @@ -335,12 +371,9 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, Value tensor, int64_t collapseStartDim, - int64_t collapseEndDim, - size_t dimSizeIndexBits) { - - auto dimSizesInfo = - getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + int64_t collapseEndDim) { + auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -356,7 +389,6 @@ FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, auto loc = op->getLoc(); auto rankTy = dyn_cast(tensor.getType()); auto oldShape = rankTy.getShape(); - Type intType = rewriter.getIntegerType(dimSizeIndexBits); std::vector newDimSizes; std::vector newShape; @@ -364,7 +396,7 @@ FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, newShape.reserve(newRank); Value collapseDimSize = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); int64_t collapseShape = 1; for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) { @@ -402,10 +434,8 @@ FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, // TODO: support splitDim & outerLength to be Value FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, Value tensor, int64_t splitDim, - int64_t outerLength, size_t dimSizeIndexBits) { - auto dimSizesInfo = - getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); - + int64_t outerLength) { + auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -417,7 +447,6 @@ FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, auto loc = op->getLoc(); auto rankTy = dyn_cast(tensor.getType()); auto oldShape = rankTy.getShape(); - Type intType = rewriter.getIntegerType(dimSizeIndexBits); if (splitDim < 0 || splitDim >= rank) { return rewriter.notifyMatchFailure( @@ -426,7 +455,7 @@ FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, int64_t newRank = rank + 1; auto outerLengthValue = rewriter.create( - loc, rewriter.getIntegerAttr(intType, outerLength)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), outerLength)); auto innerLengthValue = rewriter.create( loc, dimSizes[splitDim], outerLengthValue); diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index 46d58b8b5..541c02a07 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -323,8 +323,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, - options.dimSizeIndexBits); + auto newDimSizesInfo = hlo::getDimIndexOfTensor(rewriter, op, self, dims); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -375,8 +374,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, getTypeConverter()->convertType(op.getType()), self); return success(); } - auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, - options.dimSizeIndexBits); + auto newDimSizesInfo = hlo::getDimIndexOfTensor(rewriter, op, self, dims); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -406,8 +404,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), - {dim}, options.dimSizeIndexBits); + auto unsqzTensorInfo = + hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), {dim}); if (failed(unsqzTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create unsqueezed tensor"); @@ -438,8 +436,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only constant end is currently supported"); - auto collapseTensorInfo = hlo::collapseTensor( - rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits); + auto collapseTensorInfo = + hlo::collapseTensor(rewriter, op, adaptor.getA(), start, end); if (failed(collapseTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor"); @@ -469,8 +467,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only constant outerLength is currently supported"); - auto splitTensorInfo = hlo::splitTensor( - rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits); + auto splitTensorInfo = + hlo::splitTensor(rewriter, op, adaptor.getA(), dim, outerLength); if (failed(splitTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create split tensor"); diff --git a/test/Conversion/TorchToStablehlo/linear.mlir b/test/Conversion/TorchToStablehlo/linear.mlir index db61dc262..ec6bfee22 100644 --- a/test/Conversion/TorchToStablehlo/linear.mlir +++ b/test/Conversion/TorchToStablehlo/linear.mlir @@ -36,15 +36,12 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: // CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10,4,5],f32> -> tensor<10x4x5xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<10x4x5xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<10x4x5xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xindex>) -> tensor<10x4x5xf32> // CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> @@ -62,15 +59,12 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg // CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,4,?],f32> -> tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor // CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor, tensor) -> tensor // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,?],f32> @@ -88,15 +82,12 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg // CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,120,256],f32> -> tensor<4x120x256xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<4x120x256xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256x120xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xindex>) -> tensor<4x256x120xf32> // CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T9]], %[[T1]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32> @@ -114,15 +105,12 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, // CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x?x256xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x?xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xindex>) -> tensor<4x256x?xf32> // CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> @@ -140,12 +128,10 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, // CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<1x?x256xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]], %[[T4]] : tensor<2xindex> +// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xindex>) -> tensor<1x256xf32> // CHECK: %[[T8:.*]] = stablehlo.dot_general %[[T0]], %[[T7]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32> // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> @@ -163,12 +149,10 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: // CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,256,?],f32> -> tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]], %[[T4]] : tensor<2xindex> +// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xindex>) -> tensor // CHECK: %[[T8:.*]] = stablehlo.dot_general %[[T7]], %[[T1]], batching_dims = [0] x [0], contracting_dims = [1] x [1] : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor to tensor // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> @@ -231,15 +215,12 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x256xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xindex>) -> tensor // CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor, tensor) -> tensor // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,256],f32> @@ -324,10 +305,9 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor -// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 -// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64 -// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64> -// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index +// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_9]], %[[VAL_0]], %[[VAL_0]] : tensor<3xindex> +// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor, tensor<3xindex>) -> tensor // CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor, tensor) -> tensor // CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32> @@ -466,24 +446,20 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32> // CHECK: %c0 = arith.constant 0 : index // CHECK: %dim = tensor.dim %[[T_7]], %c0 : tensor<3x3x2x2xf32> -// CHECK: %[[T_8:.*]] = arith.index_cast %dim : index to i64 // CHECK: %c1 = arith.constant 1 : index // CHECK: %dim_0 = tensor.dim %[[T_7]], %c1 : tensor<3x3x2x2xf32> -// CHECK: %[[T_9:.*]] = arith.index_cast %dim_0 : index to i64 // CHECK: %c2 = arith.constant 2 : index // CHECK: %dim_1 = tensor.dim %[[T_7]], %c2 : tensor<3x3x2x2xf32> -// CHECK: %[[T_10:.*]] = arith.index_cast %dim_1 : index to i64 // CHECK: %c3 = arith.constant 3 : index // CHECK: %dim_2 = tensor.dim %[[T_7]], %c3 : tensor<3x3x2x2xf32> -// CHECK: %[[T_11:.*]] = arith.index_cast %dim_2 : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : i64 -// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %[[C2]] : i64 -// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %[[C2]] : i64 -// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %[[C2]], %[[T_12]] : tensor<5xi64> -// CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xi64>) -> tensor<3x3x2x2x1xf32> +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T_12:.*]] = arith.divsi %dim_2, %[[C2]] : index +// CHECK: %[[T_13:.*]] = arith.muli %dim_1, %[[C2]] : index +// CHECK: %from_elements = tensor.from_elements %dim, %dim_0, %dim_1, %[[C2]], %[[T_12]] : tensor<5xindex> +// CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xindex>) -> tensor<3x3x2x2x1xf32> // CHECK: %[[T_15:.*]] = stablehlo.transpose %[[T_14]], dims = [0, 1, 3, 2, 4] : (tensor<3x3x2x2x1xf32>) -> tensor<3x3x2x2x1xf32> -// CHECK: %[[from_elements_3:.*]] = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64> -// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %[[from_elements_3]] : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32> +// CHECK: %[[from_elements_3:.*]] = tensor.from_elements %dim, %dim_0, %[[T_13]], %[[T_12]] : tensor<4xindex> +// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %[[from_elements_3]] : (tensor<3x3x2x2x1xf32>, tensor<4xindex>) -> tensor<3x3x4x1xf32> // CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x1xf32>) -> tensor<1x4x15x15xf32> // CHECK: %[[T_18:.*]] = torch_c.from_builtin_tensor %[[T_17]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> diff --git a/test/Conversion/TorchToStablehlo/pooling.mlir b/test/Conversion/TorchToStablehlo/pooling.mlir index 156c3ff51..537ed9ca5 100644 --- a/test/Conversion/TorchToStablehlo/pooling.mlir +++ b/test/Conversion/TorchToStablehlo/pooling.mlir @@ -83,18 +83,15 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T5:.*]] = stablehlo.constant dense<0xFF800000> : tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T7:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T8:.*]] = arith.index_cast %[[DIM_1]] : index to i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64 -// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64> -// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor -// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = arith.muli %[[DIM_1]], %[[DIM_0]] : index +// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[DIM]], %[[T9]] : tensor<2xindex> +// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xindex>) -> tensor +// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xindex>) -> tensor // CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor // CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) <{padding = dense<0> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): @@ -146,18 +143,14 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64 // CHECK: %[[IDX_1:.*]] = arith.constant 1 : index // CHECK: %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor -// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : index to i64 // CHECK: %[[IDX_2:.*]] = arith.constant 2 : index // CHECK: %[[VAL_12:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i64 // CHECK: %[[IDX_3:.*]] = arith.constant 3 : index // CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor -// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64 -// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64> -// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_10]], %[[VAL_12]], %[[VAL_14]] : tensor<4xindex> +// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor, tensor<4xindex>) -> tensor // CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor // CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) // CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ diff --git a/test/Conversion/TorchToStablehlo/scatter.mlir b/test/Conversion/TorchToStablehlo/scatter.mlir index 20188ca85..937c14a69 100644 --- a/test/Conversion/TorchToStablehlo/scatter.mlir +++ b/test/Conversion/TorchToStablehlo/scatter.mlir @@ -8,19 +8,17 @@ // CHECK: %int0 = torch.constant.int 0 // CHECK: %[[INDEX_0:.*]] = arith.constant 0 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[VAR_1]], %[[INDEX_0]] : tensor -// CHECK: %[[VAR_3:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[INDEX_1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %1, %[[INDEX_1]] : tensor -// CHECK: %[[VAR_4:.*]] = arith.index_cast %[[DIM_1]] : index to i64 -// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i64 -// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : i64 -// CHECK: %[[FE_:.*]] = tensor.from_elements %[[CONSTANT_0]], %[[CONSTANT_0]] : tensor<2xi64> -// CHECK: %[[FE_1:.*]] = tensor.from_elements %[[CONSTANT_1]], %[[CONSTANT_1]] : tensor<2xi64> -// CHECK: %[[FE_2:.*]] = tensor.from_elements %[[VAR_3]], %[[VAR_4]] : tensor<2xi64> -// CHECK: %[[VAR_5:.*]] = stablehlo.real_dynamic_slice %[[VAR_2]], %[[FE_]], %[[FE_2]], %[[FE_1]] : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor -// CHECK: %[[FE_3:.*]] = tensor.from_elements %[[VAR_3]], %[[VAR_4]], %[[CONSTANT_1]] : tensor<3xi64> -// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor, tensor<3xi64>) -> tensor -// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : index +// CHECK: %[[FE_:.*]] = tensor.from_elements %[[CONSTANT_0]], %[[CONSTANT_0]] : tensor<2xindex> +// CHECK: %[[FE_1:.*]] = tensor.from_elements %[[CONSTANT_1]], %[[CONSTANT_1]] : tensor<2xindex> +// CHECK: %[[FE_2:.*]] = tensor.from_elements %[[DIM_0]], %[[DIM_1]] : tensor<2xindex> +// CHECK: %[[VAR_5:.*]] = stablehlo.real_dynamic_slice %[[VAR_2]], %[[FE_]], %[[FE_2]], %[[FE_1]] : (tensor, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor +// CHECK: %[[FE_3:.*]] = tensor.from_elements %[[DIM_0]], %[[DIM_1]], %[[CONSTANT_1]] : tensor<3xindex> +// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor, tensor<3xindex>) -> tensor +// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xindex>) -> tensor // CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor, tensor) -> tensor // CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ // CHECK: ^bb0(%arg3: tensor, %[[ARG_4:.*]]: tensor): diff --git a/test/Conversion/TorchToStablehlo/view_like.mlir b/test/Conversion/TorchToStablehlo/view_like.mlir index f956c13cf..2de800804 100644 --- a/test/Conversion/TorchToStablehlo/view_like.mlir +++ b/test/Conversion/TorchToStablehlo/view_like.mlir @@ -398,18 +398,14 @@ func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xindex>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?,1,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32> func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { @@ -426,18 +422,14 @@ func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> ! // CHECK: %[[INT:.*]]-2 = torch.constant.int -2 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xindex>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,1,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32> func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { @@ -453,15 +445,12 @@ func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32 // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32> -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32> -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64> -// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]] : tensor<3xindex> +// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xindex>) -> tensor<2x2x2xf32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32> // CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32> func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { @@ -477,19 +466,15 @@ func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<5xindex> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xindex>) -> tensor<1x?x?x?x?xf32> // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32> func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { @@ -506,19 +491,15 @@ func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[C1_I64]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<5xindex> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xindex>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,1,?,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32> func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { @@ -535,19 +516,15 @@ func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[INT:.*]]-2 = torch.constant.int -2 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[C1_I64]], %[[DIM_2]] : tensor<5xindex> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xindex>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?,?,1,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32> func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> {