[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)

For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
  func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
    %0 = arith.index_cast %dim : index to i64
    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
    %1 = arith.index_cast %dim_0 : index to i64
    %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
    %2 = arith.index_cast %dim_1 : index to i64
    %from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
    %3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
    %4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    return %4 : tensor<?x?x?xf32>
  }
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
  func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
    %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
    %from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
    %0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
    %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    return %1 : tensor<?x?x?xf32>
  }
}
```

The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
pull/3525/head
Yuanqiang Liu 2024-07-07 18:03:03 +08:00 committed by GitHub
parent d466d5b809
commit 3225f20ab1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 176 additions and 238 deletions

View File

@ -68,21 +68,29 @@ FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
Operation *op, Value value,
size_t dimSizeIndexBits);
// Get the dimension sizes of the input tensor, given the dimension axes
FailureOr<SmallVector<Value, 4>> getDimIndexOfTensor(PatternRewriter &rewriter,
Operation *op, Value value,
ArrayRef<int64_t> inpDims);
// Get the dimension sizes of the input tensor
FailureOr<SmallVector<Value, 4>>
getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value);
// Get a tensor that unsqueezed the specified dimensions of the input tensor
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
size_t dimSizeIndexBits);
Value tensor,
ArrayRef<int64_t> inputUnsqzDims);
// Get a tensor that collapse the specified dimensions of the input tensor
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t collapseStartDim,
int64_t collapseEndDim,
size_t dimSizeIndexBits);
int64_t collapseEndDim);
// Get a tensor that splits the specified dimensions of the input tensor
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t splitDim,
int64_t outerLength, size_t dimSizeIndexBits);
int64_t outerLength);
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,

View File

@ -35,8 +35,7 @@ using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo;
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other,
size_t dimSizeIndexBits) {
mlir::Value &self, mlir::Value &other) {
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
auto otherTy = dyn_cast<RankedTensorType>(other.getType());
auto selfRank = selfTy.getRank();
@ -46,16 +45,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
if (selfRank > otherRank) {
auto unsqueezeDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other,
unsqueezeDims, dimSizeIndexBits);
auto unsqueezeInfo =
hlo::unsqueezeTensor(rewriter, op, other, unsqueezeDims);
if (failed(unsqueezeInfo))
return failure();
other = *unsqueezeInfo;
} else if (otherRank > selfRank) {
auto unsqueezeDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims,
dimSizeIndexBits);
auto unsqueezeInfo =
hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims);
if (failed(unsqueezeInfo))
return failure();
self = *unsqueezeInfo;
@ -740,12 +739,10 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
if (failed(
broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits)))
if (failed(broadcastRanks(rewriter, op, self, cond)))
return op.emitError("failed broadcast self and condition ranks");
if (failed(
broadcastRanks(rewriter, op, other, cond, options.dimSizeIndexBits)))
if (failed(broadcastRanks(rewriter, op, other, cond)))
return op.emitError("failed broadcast other and condition ranks");
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(

View File

@ -438,16 +438,14 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
}
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, weight, options.dimSizeIndexBits);
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, weight);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
op->getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
outShapeVec[0] = one;
auto outShapeTensor =
rewriter.create<mlir::tensor::FromElementsOp>(op->getLoc(), outShapeVec);
@ -537,16 +535,13 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
op, "only constant boolean `sparse_grad` param supported");
}
auto options = getOptions();
auto indexShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
auto indexShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, index);
if (failed(indexShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dim sizes of `index` param");
}
auto intType = rewriter.getIntegerType(options.dimSizeIndexBits);
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
auto toConcatIndexShapeValueVec = *indexShapeInfo;
toConcatIndexShapeValueVec.push_back(one);
auto toConcatIndexShape =
@ -672,24 +667,20 @@ public:
return rewriter.notifyMatchFailure(op, "invalid `dim` param detected");
}
auto options = this->getOptions();
auto indexShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
auto indexShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, index);
if (failed(indexShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dim sizes of `index` param");
}
auto intType = rewriter.getIntegerType(options.dimSizeIndexBits);
// slice src tensor to have the same shape bound of index tensor in the
// leading dimensions. PyTorch has guaranteed that src tensor size will not
// be smaller than that of index tensor. REF:
// https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 0));
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
SmallVector<Value> sliceIndicies(srcType.getRank(), zero);
SmallVector<Value> sliceStrides(srcType.getRank(), one);

View File

@ -148,10 +148,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(rhsShape.begin(),
rhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims,
dimSizeIndexBits);
auto lhsDimSizes =
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
auto newDimSizes =
*hlo::getDimIndexOfTensor(rewriter, op, rhs, leadingDims);
auto lhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, lhs);
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
lhsDimSizes.end());
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
@ -160,10 +159,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(lhsShape.begin(),
lhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims,
dimSizeIndexBits);
auto rhsDimSizes =
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
auto newDimSizes =
*hlo::getDimIndexOfTensor(rewriter, op, lhs, leadingDims);
auto rhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, rhs);
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
rhsDimSizes.end());
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
@ -207,10 +205,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
return;
}
auto lhsDimSizes =
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
auto rhsDimSizes =
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
auto lhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, lhs);
auto rhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, rhs);
if (!lhsBroadcastDims.empty()) {
SmallVector<int64_t> lhsNewShape(newBatchShape);
@ -526,16 +522,15 @@ public:
auto weightTy = cast<RankedTensorType>(weight.getType());
auto weightElemTy = weightTy.getElementType();
auto rank = weightTy.getRank();
const auto &options = getOptions();
SmallVector<Value> weightShapeVec = *hlo::getDimSizesOfTensor(
rewriter, op, weight, options.dimSizeIndexBits);
SmallVector<Value> weightShapeVec =
*hlo::getDimIndexOfTensor(rewriter, op, weight);
auto weightShape = weightTy.getShape();
SmallVector<int64_t> weightShapeInt(rank);
std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin());
// 1. [H, W, ..., OC, IC] => [H, W, ..., OC, G, IC//G]
Value GValue = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getI64IntegerAttr(groups));
op->getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), groups));
Value ICDivGValue = rewriter.create<mlir::arith::DivSIOp>(
op->getLoc(), weightShapeVec[rank - 1], GValue);
Value OCMulGValue = rewriter.create<mlir::arith::MulIOp>(
@ -839,9 +834,7 @@ public:
auto inputUnsqzDims =
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
const auto &options = getOptions();
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits);
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);
DenseI64ArrayAttr bcastDimensions;

View File

@ -146,9 +146,7 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
rewriter.getI64Type()),
stablehloPadding);
const auto &options = getOptions();
auto inputShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -536,9 +534,7 @@ public:
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst =
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input,
options.dimSizeIndexBits);
auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);

View File

@ -310,12 +310,10 @@ static Value reshapeReduceResultWhenKeepDim(ConversionPatternRewriter &rewriter,
Location loc, Value reduceResult,
ArrayRef<Value> inputShapeVec,
Type outType,
ArrayRef<int64_t> dims,
size_t dimSizeIndexBits) {
ArrayRef<int64_t> dims) {
SmallVector<Value> outShapeVec(inputShapeVec);
Value one = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
for (auto dim : dims) {
outShapeVec[dim] = one;
}
@ -432,16 +430,13 @@ public:
}
if (keepDim) {
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input,
options.dimSizeIndexBits);
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
reduceResult = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim},
options.dimSizeIndexBits);
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim});
}
rewriter.replaceOp(op, reduceResult);
return success();
@ -512,16 +507,13 @@ public:
}
if (keepDim) {
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input,
options.dimSizeIndexBits);
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
reduceResult = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims,
options.dimSizeIndexBits);
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims);
}
rewriter.replaceOp(op, reduceResult);
return success();
@ -573,8 +565,7 @@ public:
}
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
auto inputShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -592,9 +583,9 @@ public:
}
if (keepDim) {
reduceResult = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), reduceResult, inputShapeVec, valResultType,
{dim}, options.dimSizeIndexBits);
reduceResult =
reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), reduceResult,
inputShapeVec, valResultType, {dim});
}
rewriter.replaceOp(op, {reduceResult, Value()});
return success();
@ -603,16 +594,16 @@ public:
createReduceOpReturnIndices(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();
SmallVector<Value> reduceResults(stablehloReduceResults);
if (keepDim) {
stablehloReduceResults[0] = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), stablehloReduceResults[0], inputShapeVec,
valResultType, {dim}, options.dimSizeIndexBits);
stablehloReduceResults[1] = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), stablehloReduceResults[1], inputShapeVec,
idxResultType, {dim}, options.dimSizeIndexBits);
reduceResults[0] = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), reduceResults[0], inputShapeVec,
valResultType, {dim});
reduceResults[1] = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), reduceResults[1], inputShapeVec,
idxResultType, {dim});
}
rewriter.replaceOp(
op, {stablehloReduceResults[0], stablehloReduceResults[1]});
rewriter.replaceOp(op, reduceResults);
return success();
}
};
@ -685,16 +676,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
}
if (keepDim) {
const auto &options = getOptions();
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
reduceResult = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims,
options.dimSizeIndexBits);
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims);
}
rewriter.replaceOp(op, reduceResult);
return success();
@ -709,7 +697,6 @@ template <>
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
const TorchToStablehloOptions &options = getOptions();
Value input = adaptor.getSelf();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
@ -761,16 +748,14 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
Value output = rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceResult);
if (keepDim) {
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
output = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), output, *outShapeInfo,
getTypeConverter()->convertType(op.getType()), dims,
options.dimSizeIndexBits);
getTypeConverter()->convertType(op.getType()), dims);
}
rewriter.replaceOp(op, output);
return success();
@ -783,7 +768,6 @@ template <>
LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
AtenLinalgVectorNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
const TorchToStablehloOptions &options = getOptions();
Value input = adaptor.getSelf();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
@ -861,15 +845,13 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
op->getLoc(), reduceResult, reciprocalOrd, nullptr);
if (keepDim) {
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
output = reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), output,
*outShapeInfo, outType, dims,
options.dimSizeIndexBits);
*outShapeInfo, outType, dims);
}
rewriter.replaceOp(op, output);
return success();

View File

@ -279,9 +279,47 @@ FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits);
}
// Get the dimension sizes of the input tensor, given the dimension axes
FailureOr<SmallVector<Value, 4>>
getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
ArrayRef<int64_t> inpDims) {
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
if (!valueTy) {
return rewriter.notifyMatchFailure(
op, "getDimIndexOfTensor(): the input is not a ranked tensor");
}
auto rank = valueTy.getRank();
auto dims = toPositiveDims(inpDims, rank);
SmallVector<Value, 4> dimSizes;
dimSizes.reserve(dims.size());
auto loc = op->getLoc();
for (auto d : dims) {
dimSizes.emplace_back(rewriter.create<tensor::DimOp>(loc, value, d));
}
return dimSizes;
}
// Get the dimension sizes of the input tensor
FailureOr<SmallVector<Value, 4>>
getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
if (!valueTy) {
return rewriter.notifyMatchFailure(
op, "getDimIndexOfTensor(): the input is not a ranked tensor");
}
auto rank = valueTy.getRank();
// Get int vector [0, 1, ..., rank-1]
std::vector<int64_t> dims(rank);
std::iota(dims.begin(), dims.end(), 0);
return getDimIndexOfTensor(rewriter, op, value, dims);
}
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
size_t dimSizeIndexBits) {
Value tensor,
ArrayRef<int64_t> inputUnsqzDims) {
// Returns a new tensor with dims of size 1 inserted at the specified
// position.
//
@ -289,8 +327,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
// tensor) are specified with unsqzDims. Indices must be in-order, and in
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -307,9 +344,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
std::vector<Value> newDimSizes;
std::vector<int64_t> newShape;
@ -335,12 +371,9 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t collapseStartDim,
int64_t collapseEndDim,
size_t dimSizeIndexBits) {
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
int64_t collapseEndDim) {
auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -356,7 +389,6 @@ FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
std::vector<Value> newDimSizes;
std::vector<int64_t> newShape;
@ -364,7 +396,7 @@ FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
newShape.reserve(newRank);
Value collapseDimSize = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
int64_t collapseShape = 1;
for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
@ -402,10 +434,8 @@ FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
// TODO: support splitDim & outerLength to be Value
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t splitDim,
int64_t outerLength, size_t dimSizeIndexBits) {
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
int64_t outerLength) {
auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -417,7 +447,6 @@ FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
if (splitDim < 0 || splitDim >= rank) {
return rewriter.notifyMatchFailure(
@ -426,7 +455,7 @@ FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
int64_t newRank = rank + 1;
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, outerLength));
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), outerLength));
auto innerLengthValue = rewriter.create<arith::DivSIOp>(
loc, dimSizes[splitDim], outerLengthValue);

View File

@ -323,8 +323,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
return success();
}
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
auto newDimSizesInfo = hlo::getDimIndexOfTensor(rewriter, op, self, dims);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -375,8 +374,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
auto newDimSizesInfo = hlo::getDimIndexOfTensor(rewriter, op, self, dims);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -406,8 +404,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (!isValidDim(dim, inputRank + 1))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
{dim}, options.dimSizeIndexBits);
auto unsqzTensorInfo =
hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), {dim});
if (failed(unsqzTensorInfo))
return rewriter.notifyMatchFailure(op,
"failed to create unsqueezed tensor");
@ -438,8 +436,8 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "only constant end is currently supported");
auto collapseTensorInfo = hlo::collapseTensor(
rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits);
auto collapseTensorInfo =
hlo::collapseTensor(rewriter, op, adaptor.getA(), start, end);
if (failed(collapseTensorInfo))
return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor");
@ -469,8 +467,8 @@ LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "only constant outerLength is currently supported");
auto splitTensorInfo = hlo::splitTensor(
rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits);
auto splitTensorInfo =
hlo::splitTensor(rewriter, op, adaptor.getA(), dim, outerLength);
if (failed(splitTensorInfo))
return rewriter.notifyMatchFailure(op, "failed to create split tensor");

View File

@ -36,15 +36,12 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1:
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10,4,5],f32> -> tensor<10x4x5xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<10x4x5xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<10x4x5xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xindex>) -> tensor<10x4x5xf32>
// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32>
@ -62,15 +59,12 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,4,?],f32> -> tensor<?x4x?xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<?x4x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<?x4x?xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<?x4x?xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<?x4x?xf32>, tensor<3xindex>) -> tensor<?x4x?xf32>
// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
@ -88,15 +82,12 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,120,256],f32> -> tensor<4x120x256xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<4x120x256xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256x120xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xindex>) -> tensor<4x256x120xf32>
// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T9]], %[[T1]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32>
@ -114,15 +105,12 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>,
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x?x256xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x?xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xindex>) -> tensor<4x256x?xf32>
// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32>
@ -140,12 +128,10 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>,
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<1x?x256xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]], %[[T4]] : tensor<2xindex>
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xindex>) -> tensor<1x256xf32>
// CHECK: %[[T8:.*]] = stablehlo.dot_general %[[T0]], %[[T7]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32>
@ -163,12 +149,10 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1:
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,256,?],f32> -> tensor<?x256x?xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<?x256x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]], %[[T4]] : tensor<2xindex>
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xindex>) -> tensor<?x256xf32>
// CHECK: %[[T8:.*]] = stablehlo.dot_general %[[T7]], %[[T1]], batching_dims = [0] x [0], contracting_dims = [1] x [1] : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -231,15 +215,12 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x256xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x256xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xindex>) -> tensor<?x256x256xf32>
// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x256xf32> to tensor<?x?x256xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32>
@ -324,10 +305,9 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor<?xf32>
// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64>
// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32>
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_9]], %[[VAL_0]], %[[VAL_0]] : tensor<3xindex>
// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xindex>) -> tensor<?x1x1xf32>
// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor<?x?x?x?xf32>, tensor<?x1x1xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32>
@ -466,24 +446,20 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32>
// CHECK: %c0 = arith.constant 0 : index
// CHECK: %dim = tensor.dim %[[T_7]], %c0 : tensor<3x3x2x2xf32>
// CHECK: %[[T_8:.*]] = arith.index_cast %dim : index to i64
// CHECK: %c1 = arith.constant 1 : index
// CHECK: %dim_0 = tensor.dim %[[T_7]], %c1 : tensor<3x3x2x2xf32>
// CHECK: %[[T_9:.*]] = arith.index_cast %dim_0 : index to i64
// CHECK: %c2 = arith.constant 2 : index
// CHECK: %dim_1 = tensor.dim %[[T_7]], %c2 : tensor<3x3x2x2xf32>
// CHECK: %[[T_10:.*]] = arith.index_cast %dim_1 : index to i64
// CHECK: %c3 = arith.constant 3 : index
// CHECK: %dim_2 = tensor.dim %[[T_7]], %c3 : tensor<3x3x2x2xf32>
// CHECK: %[[T_11:.*]] = arith.index_cast %dim_2 : index to i64
// CHECK: %[[C2:.*]] = arith.constant 2 : i64
// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %[[C2]] : i64
// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %[[C2]] : i64
// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %[[C2]], %[[T_12]] : tensor<5xi64>
// CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xi64>) -> tensor<3x3x2x2x1xf32>
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[T_12:.*]] = arith.divsi %dim_2, %[[C2]] : index
// CHECK: %[[T_13:.*]] = arith.muli %dim_1, %[[C2]] : index
// CHECK: %from_elements = tensor.from_elements %dim, %dim_0, %dim_1, %[[C2]], %[[T_12]] : tensor<5xindex>
// CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xindex>) -> tensor<3x3x2x2x1xf32>
// CHECK: %[[T_15:.*]] = stablehlo.transpose %[[T_14]], dims = [0, 1, 3, 2, 4] : (tensor<3x3x2x2x1xf32>) -> tensor<3x3x2x2x1xf32>
// CHECK: %[[from_elements_3:.*]] = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64>
// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %[[from_elements_3]] : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32>
// CHECK: %[[from_elements_3:.*]] = tensor.from_elements %dim, %dim_0, %[[T_13]], %[[T_12]] : tensor<4xindex>
// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %[[from_elements_3]] : (tensor<3x3x2x2x1xf32>, tensor<4xindex>) -> tensor<3x3x4x1xf32>
// CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x1xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_18:.*]] = torch_c.from_builtin_tensor %[[T_17]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>

View File

@ -83,18 +83,15 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[T5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[DIM_0]] : index to i64
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?xf32>
// CHECK: %[[T8:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64
// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64>
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]] : tensor<3xindex>
// CHECK: %[[T9:.*]] = arith.muli %[[DIM_1]], %[[DIM_0]] : index
// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[DIM]], %[[T9]] : tensor<2xindex>
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xindex>) -> tensor<?x?xi64>
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xindex>) -> tensor<?x?x?xi64>
// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) <{padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>}> ({
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
@ -146,18 +143,14 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : index to i64
// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_12:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i64
// CHECK: %[[IDX_3:.*]] = arith.constant 3 : index
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_10]], %[[VAL_12]], %[[VAL_14]] : tensor<4xindex>
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]])
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({

View File

@ -8,19 +8,17 @@
// CHECK: %int0 = torch.constant.int 0
// CHECK: %[[INDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[VAR_1]], %[[INDEX_0]] : tensor<?x?xi64>
// CHECK: %[[VAR_3:.*]] = arith.index_cast %[[DIM_0]] : index to i64
// CHECK: %[[INDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %1, %[[INDEX_1]] : tensor<?x?xi64>
// CHECK: %[[VAR_4:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i64
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : i64
// CHECK: %[[FE_:.*]] = tensor.from_elements %[[CONSTANT_0]], %[[CONSTANT_0]] : tensor<2xi64>
// CHECK: %[[FE_1:.*]] = tensor.from_elements %[[CONSTANT_1]], %[[CONSTANT_1]] : tensor<2xi64>
// CHECK: %[[FE_2:.*]] = tensor.from_elements %[[VAR_3]], %[[VAR_4]] : tensor<2xi64>
// CHECK: %[[VAR_5:.*]] = stablehlo.real_dynamic_slice %[[VAR_2]], %[[FE_]], %[[FE_2]], %[[FE_1]] : (tensor<?x?xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[FE_3:.*]] = tensor.from_elements %[[VAR_3]], %[[VAR_4]], %[[CONSTANT_1]] : tensor<3xi64>
// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x1xi64>
// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor<?x?x1xi64>
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : index
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : index
// CHECK: %[[FE_:.*]] = tensor.from_elements %[[CONSTANT_0]], %[[CONSTANT_0]] : tensor<2xindex>
// CHECK: %[[FE_1:.*]] = tensor.from_elements %[[CONSTANT_1]], %[[CONSTANT_1]] : tensor<2xindex>
// CHECK: %[[FE_2:.*]] = tensor.from_elements %[[DIM_0]], %[[DIM_1]] : tensor<2xindex>
// CHECK: %[[VAR_5:.*]] = stablehlo.real_dynamic_slice %[[VAR_2]], %[[FE_]], %[[FE_2]], %[[FE_1]] : (tensor<?x?xi64>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xi64>
// CHECK: %[[FE_3:.*]] = tensor.from_elements %[[DIM_0]], %[[DIM_1]], %[[CONSTANT_1]] : tensor<3xindex>
// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xindex>) -> tensor<?x?x1xi64>
// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xindex>) -> tensor<?x?x1xi64>
// CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor<?x?x1xi64>, tensor<?x?x1xi64>) -> tensor<?x?x2xi64>
// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
// CHECK: ^bb0(%arg3: tensor<i64>, %[[ARG_4:.*]]: tensor<i64>):

View File

@ -398,18 +398,14 @@ func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32>
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xindex>) -> tensor<?x?x1x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x1x?xf32> -> !torch.vtensor<[?,?,1,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32>
func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> {
@ -426,18 +422,14 @@ func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !
// CHECK: %[[INT:.*]]-2 = torch.constant.int -2
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32>
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xindex>) -> tensor<?x1x?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?xf32> -> !torch.vtensor<[?,1,?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32>
func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> {
@ -453,15 +445,12 @@ func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64>
// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]] : tensor<3xindex>
// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xindex>) -> tensor<2x2x2xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32>
func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> {
@ -477,19 +466,15 @@ func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) ->
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32>
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<5xindex>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xindex>) -> tensor<1x?x?x?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32>
func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> {
@ -506,19 +491,15 @@ func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32>
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[C1_I64]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<5xindex>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xindex>) -> tensor<?x1x?x?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?x?xf32> -> !torch.vtensor<[?,1,?,?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32>
func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> {
@ -535,19 +516,15 @@ func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
// CHECK: %[[INT:.*]]-2 = torch.constant.int -2
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32>
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[C1_I64]], %[[DIM_2]] : tensor<5xindex>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xindex>) -> tensor<?x?x?x1x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x?x1x?xf32> -> !torch.vtensor<[?,?,?,1,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32>
func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> {