mirror of https://github.com/llvm/torch-mlir
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is: ``` module attributes {torch.debug_module_name = "Matmul3D"} { func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32> %0 = arith.index_cast %dim : index to i64 %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32> %1 = arith.index_cast %dim_0 : index to i64 %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32> %2 = arith.index_cast %dim_1 : index to i64 %from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64> %3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> %4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> return %4 : tensor<?x?x?xf32> } } ``` After using IndexType, the IR is: ``` module attributes {torch.debug_module_name = "Matmul3D"} { func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32> %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32> %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32> %from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex> %0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> return %1 : tensor<?x?x?xf32> } } ``` The benefits of using IndexType on shape tensor: * simplify the IR, avoid to generate `arith.index_cast` * let backend compiler have a chance to decide the index width of shape tensor * let stablehlo backend have a chance to serialize dynamic shape IR by [shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)pull/3525/head
parent
d466d5b809
commit
3225f20ab1
|
@ -68,21 +68,29 @@ FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
|||
Operation *op, Value value,
|
||||
size_t dimSizeIndexBits);
|
||||
|
||||
// Get the dimension sizes of the input tensor, given the dimension axes
|
||||
FailureOr<SmallVector<Value, 4>> getDimIndexOfTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Value value,
|
||||
ArrayRef<int64_t> inpDims);
|
||||
|
||||
// Get the dimension sizes of the input tensor
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value);
|
||||
|
||||
// Get a tensor that unsqueezed the specified dimensions of the input tensor
|
||||
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
||||
size_t dimSizeIndexBits);
|
||||
Value tensor,
|
||||
ArrayRef<int64_t> inputUnsqzDims);
|
||||
|
||||
// Get a tensor that collapse the specified dimensions of the input tensor
|
||||
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t collapseStartDim,
|
||||
int64_t collapseEndDim,
|
||||
size_t dimSizeIndexBits);
|
||||
int64_t collapseEndDim);
|
||||
|
||||
// Get a tensor that splits the specified dimensions of the input tensor
|
||||
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t splitDim,
|
||||
int64_t outerLength, size_t dimSizeIndexBits);
|
||||
int64_t outerLength);
|
||||
|
||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||
const APFloat &constant, Value shape,
|
||||
|
|
|
@ -35,8 +35,7 @@ using namespace mlir::torch::Torch;
|
|||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
||||
mlir::Value &self, mlir::Value &other,
|
||||
size_t dimSizeIndexBits) {
|
||||
mlir::Value &self, mlir::Value &other) {
|
||||
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||
auto otherTy = dyn_cast<RankedTensorType>(other.getType());
|
||||
auto selfRank = selfTy.getRank();
|
||||
|
@ -46,16 +45,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
|||
if (selfRank > otherRank) {
|
||||
auto unsqueezeDims =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
|
||||
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other,
|
||||
unsqueezeDims, dimSizeIndexBits);
|
||||
auto unsqueezeInfo =
|
||||
hlo::unsqueezeTensor(rewriter, op, other, unsqueezeDims);
|
||||
if (failed(unsqueezeInfo))
|
||||
return failure();
|
||||
other = *unsqueezeInfo;
|
||||
} else if (otherRank > selfRank) {
|
||||
auto unsqueezeDims =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
|
||||
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims,
|
||||
dimSizeIndexBits);
|
||||
auto unsqueezeInfo =
|
||||
hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims);
|
||||
if (failed(unsqueezeInfo))
|
||||
return failure();
|
||||
self = *unsqueezeInfo;
|
||||
|
@ -740,12 +739,10 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
|||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
||||
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
|
||||
|
||||
if (failed(
|
||||
broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits)))
|
||||
if (failed(broadcastRanks(rewriter, op, self, cond)))
|
||||
return op.emitError("failed broadcast self and condition ranks");
|
||||
|
||||
if (failed(
|
||||
broadcastRanks(rewriter, op, other, cond, options.dimSizeIndexBits)))
|
||||
if (failed(broadcastRanks(rewriter, op, other, cond)))
|
||||
return op.emitError("failed broadcast other and condition ranks");
|
||||
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(
|
||||
|
|
|
@ -438,16 +438,14 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
|
|||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
|
||||
}
|
||||
|
||||
auto outShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, weight, options.dimSizeIndexBits);
|
||||
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, weight);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto outShapeVec = *outShapeInfo;
|
||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
op->getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
outShapeVec[0] = one;
|
||||
auto outShapeTensor =
|
||||
rewriter.create<mlir::tensor::FromElementsOp>(op->getLoc(), outShapeVec);
|
||||
|
@ -537,16 +535,13 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
op, "only constant boolean `sparse_grad` param supported");
|
||||
}
|
||||
|
||||
auto options = getOptions();
|
||||
auto indexShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
||||
auto indexShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, index);
|
||||
if (failed(indexShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dim sizes of `index` param");
|
||||
}
|
||||
auto intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||
auto one = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
auto toConcatIndexShapeValueVec = *indexShapeInfo;
|
||||
toConcatIndexShapeValueVec.push_back(one);
|
||||
auto toConcatIndexShape =
|
||||
|
@ -672,24 +667,20 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "invalid `dim` param detected");
|
||||
}
|
||||
|
||||
auto options = this->getOptions();
|
||||
|
||||
auto indexShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
||||
auto indexShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, index);
|
||||
if (failed(indexShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dim sizes of `index` param");
|
||||
}
|
||||
auto intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||
|
||||
// slice src tensor to have the same shape bound of index tensor in the
|
||||
// leading dimensions. PyTorch has guaranteed that src tensor size will not
|
||||
// be smaller than that of index tensor. REF:
|
||||
// https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_
|
||||
auto zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 0));
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
||||
auto one = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
SmallVector<Value> sliceIndicies(srcType.getRank(), zero);
|
||||
SmallVector<Value> sliceStrides(srcType.getRank(), one);
|
||||
|
||||
|
|
|
@ -148,10 +148,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
std::vector<int64_t> newShape(rhsShape.begin(),
|
||||
rhsShape.begin() + leadingRank);
|
||||
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
|
||||
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims,
|
||||
dimSizeIndexBits);
|
||||
auto lhsDimSizes =
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
|
||||
auto newDimSizes =
|
||||
*hlo::getDimIndexOfTensor(rewriter, op, rhs, leadingDims);
|
||||
auto lhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, lhs);
|
||||
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
|
||||
lhsDimSizes.end());
|
||||
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
|
||||
|
@ -160,10 +159,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
std::vector<int64_t> newShape(lhsShape.begin(),
|
||||
lhsShape.begin() + leadingRank);
|
||||
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
|
||||
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims,
|
||||
dimSizeIndexBits);
|
||||
auto rhsDimSizes =
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
|
||||
auto newDimSizes =
|
||||
*hlo::getDimIndexOfTensor(rewriter, op, lhs, leadingDims);
|
||||
auto rhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, rhs);
|
||||
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
|
||||
rhsDimSizes.end());
|
||||
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
|
||||
|
@ -207,10 +205,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
return;
|
||||
}
|
||||
|
||||
auto lhsDimSizes =
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
|
||||
auto rhsDimSizes =
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
|
||||
auto lhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, lhs);
|
||||
auto rhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, rhs);
|
||||
|
||||
if (!lhsBroadcastDims.empty()) {
|
||||
SmallVector<int64_t> lhsNewShape(newBatchShape);
|
||||
|
@ -526,16 +522,15 @@ public:
|
|||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
auto weightElemTy = weightTy.getElementType();
|
||||
auto rank = weightTy.getRank();
|
||||
const auto &options = getOptions();
|
||||
SmallVector<Value> weightShapeVec = *hlo::getDimSizesOfTensor(
|
||||
rewriter, op, weight, options.dimSizeIndexBits);
|
||||
SmallVector<Value> weightShapeVec =
|
||||
*hlo::getDimIndexOfTensor(rewriter, op, weight);
|
||||
auto weightShape = weightTy.getShape();
|
||||
SmallVector<int64_t> weightShapeInt(rank);
|
||||
std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin());
|
||||
|
||||
// 1. [H, W, ..., OC, IC] => [H, W, ..., OC, G, IC//G]
|
||||
Value GValue = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getI64IntegerAttr(groups));
|
||||
op->getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), groups));
|
||||
Value ICDivGValue = rewriter.create<mlir::arith::DivSIOp>(
|
||||
op->getLoc(), weightShapeVec[rank - 1], GValue);
|
||||
Value OCMulGValue = rewriter.create<mlir::arith::MulIOp>(
|
||||
|
@ -839,9 +834,7 @@ public:
|
|||
auto inputUnsqzDims =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
||||
|
||||
const auto &options = getOptions();
|
||||
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
||||
options.dimSizeIndexBits);
|
||||
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
|
||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);
|
||||
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
|
|
|
@ -146,9 +146,7 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
rewriter.getI64Type()),
|
||||
stablehloPadding);
|
||||
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -536,9 +534,7 @@ public:
|
|||
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||
windowSizeConst =
|
||||
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
|
||||
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||
auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input,
|
||||
options.dimSizeIndexBits);
|
||||
auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
|
||||
|
|
|
@ -310,12 +310,10 @@ static Value reshapeReduceResultWhenKeepDim(ConversionPatternRewriter &rewriter,
|
|||
Location loc, Value reduceResult,
|
||||
ArrayRef<Value> inputShapeVec,
|
||||
Type outType,
|
||||
ArrayRef<int64_t> dims,
|
||||
size_t dimSizeIndexBits) {
|
||||
ArrayRef<int64_t> dims) {
|
||||
SmallVector<Value> outShapeVec(inputShapeVec);
|
||||
Value one = rewriter.create<arith::ConstantOp>(
|
||||
loc,
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
for (auto dim : dims) {
|
||||
outShapeVec[dim] = one;
|
||||
}
|
||||
|
@ -432,16 +430,13 @@ public:
|
|||
}
|
||||
|
||||
if (keepDim) {
|
||||
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
|
||||
auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input,
|
||||
options.dimSizeIndexBits);
|
||||
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
reduceResult = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim},
|
||||
options.dimSizeIndexBits);
|
||||
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim});
|
||||
}
|
||||
rewriter.replaceOp(op, reduceResult);
|
||||
return success();
|
||||
|
@ -512,16 +507,13 @@ public:
|
|||
}
|
||||
|
||||
if (keepDim) {
|
||||
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
|
||||
auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input,
|
||||
options.dimSizeIndexBits);
|
||||
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
reduceResult = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims,
|
||||
options.dimSizeIndexBits);
|
||||
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims);
|
||||
}
|
||||
rewriter.replaceOp(op, reduceResult);
|
||||
return success();
|
||||
|
@ -573,8 +565,7 @@ public:
|
|||
}
|
||||
|
||||
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
|
||||
auto inputShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -592,9 +583,9 @@ public:
|
|||
}
|
||||
|
||||
if (keepDim) {
|
||||
reduceResult = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), reduceResult, inputShapeVec, valResultType,
|
||||
{dim}, options.dimSizeIndexBits);
|
||||
reduceResult =
|
||||
reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), reduceResult,
|
||||
inputShapeVec, valResultType, {dim});
|
||||
}
|
||||
rewriter.replaceOp(op, {reduceResult, Value()});
|
||||
return success();
|
||||
|
@ -603,16 +594,16 @@ public:
|
|||
createReduceOpReturnIndices(rewriter, op, input, inputShapeVec, dim,
|
||||
options.dimSizeIndexBits)
|
||||
.value();
|
||||
SmallVector<Value> reduceResults(stablehloReduceResults);
|
||||
if (keepDim) {
|
||||
stablehloReduceResults[0] = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), stablehloReduceResults[0], inputShapeVec,
|
||||
valResultType, {dim}, options.dimSizeIndexBits);
|
||||
stablehloReduceResults[1] = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), stablehloReduceResults[1], inputShapeVec,
|
||||
idxResultType, {dim}, options.dimSizeIndexBits);
|
||||
reduceResults[0] = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), reduceResults[0], inputShapeVec,
|
||||
valResultType, {dim});
|
||||
reduceResults[1] = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), reduceResults[1], inputShapeVec,
|
||||
idxResultType, {dim});
|
||||
}
|
||||
rewriter.replaceOp(
|
||||
op, {stablehloReduceResults[0], stablehloReduceResults[1]});
|
||||
rewriter.replaceOp(op, reduceResults);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -685,16 +676,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
if (keepDim) {
|
||||
const auto &options = getOptions();
|
||||
auto outShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
reduceResult = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims,
|
||||
options.dimSizeIndexBits);
|
||||
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims);
|
||||
}
|
||||
rewriter.replaceOp(op, reduceResult);
|
||||
return success();
|
||||
|
@ -709,7 +697,6 @@ template <>
|
|||
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
||||
AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
const TorchToStablehloOptions &options = getOptions();
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||
|
@ -761,16 +748,14 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
Value output = rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceResult);
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
output = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), output, *outShapeInfo,
|
||||
getTypeConverter()->convertType(op.getType()), dims,
|
||||
options.dimSizeIndexBits);
|
||||
getTypeConverter()->convertType(op.getType()), dims);
|
||||
}
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
|
@ -783,7 +768,6 @@ template <>
|
|||
LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
||||
AtenLinalgVectorNormOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
const TorchToStablehloOptions &options = getOptions();
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||
|
@ -861,15 +845,13 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
|||
op->getLoc(), reduceResult, reciprocalOrd, nullptr);
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
output = reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), output,
|
||||
*outShapeInfo, outType, dims,
|
||||
options.dimSizeIndexBits);
|
||||
*outShapeInfo, outType, dims);
|
||||
}
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
|
|
|
@ -279,9 +279,47 @@ FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
|||
return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits);
|
||||
}
|
||||
|
||||
// Get the dimension sizes of the input tensor, given the dimension axes
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
|
||||
ArrayRef<int64_t> inpDims) {
|
||||
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!valueTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "getDimIndexOfTensor(): the input is not a ranked tensor");
|
||||
}
|
||||
|
||||
auto rank = valueTy.getRank();
|
||||
auto dims = toPositiveDims(inpDims, rank);
|
||||
SmallVector<Value, 4> dimSizes;
|
||||
dimSizes.reserve(dims.size());
|
||||
|
||||
auto loc = op->getLoc();
|
||||
for (auto d : dims) {
|
||||
dimSizes.emplace_back(rewriter.create<tensor::DimOp>(loc, value, d));
|
||||
}
|
||||
return dimSizes;
|
||||
}
|
||||
|
||||
// Get the dimension sizes of the input tensor
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
|
||||
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!valueTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "getDimIndexOfTensor(): the input is not a ranked tensor");
|
||||
}
|
||||
|
||||
auto rank = valueTy.getRank();
|
||||
// Get int vector [0, 1, ..., rank-1]
|
||||
std::vector<int64_t> dims(rank);
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
return getDimIndexOfTensor(rewriter, op, value, dims);
|
||||
}
|
||||
|
||||
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
||||
size_t dimSizeIndexBits) {
|
||||
Value tensor,
|
||||
ArrayRef<int64_t> inputUnsqzDims) {
|
||||
// Returns a new tensor with dims of size 1 inserted at the specified
|
||||
// position.
|
||||
//
|
||||
|
@ -289,8 +327,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
// tensor) are specified with unsqzDims. Indices must be in-order, and in
|
||||
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
|
||||
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
|
||||
auto dimSizesInfo =
|
||||
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||
auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor);
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -307,9 +344,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
auto loc = op->getLoc();
|
||||
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||
auto oldShape = rankTy.getShape();
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
auto one = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
|
||||
std::vector<Value> newDimSizes;
|
||||
std::vector<int64_t> newShape;
|
||||
|
@ -335,12 +371,9 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
|
||||
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t collapseStartDim,
|
||||
int64_t collapseEndDim,
|
||||
size_t dimSizeIndexBits) {
|
||||
|
||||
auto dimSizesInfo =
|
||||
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||
int64_t collapseEndDim) {
|
||||
|
||||
auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor);
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -356,7 +389,6 @@ FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
|||
auto loc = op->getLoc();
|
||||
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||
auto oldShape = rankTy.getShape();
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
|
||||
std::vector<Value> newDimSizes;
|
||||
std::vector<int64_t> newShape;
|
||||
|
@ -364,7 +396,7 @@ FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
|||
newShape.reserve(newRank);
|
||||
|
||||
Value collapseDimSize = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
int64_t collapseShape = 1;
|
||||
|
||||
for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
|
||||
|
@ -402,10 +434,8 @@ FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
|||
// TODO: support splitDim & outerLength to be Value
|
||||
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t splitDim,
|
||||
int64_t outerLength, size_t dimSizeIndexBits) {
|
||||
auto dimSizesInfo =
|
||||
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||
|
||||
int64_t outerLength) {
|
||||
auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor);
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -417,7 +447,6 @@ FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
|||
auto loc = op->getLoc();
|
||||
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||
auto oldShape = rankTy.getShape();
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
|
||||
if (splitDim < 0 || splitDim >= rank) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -426,7 +455,7 @@ FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
|||
|
||||
int64_t newRank = rank + 1;
|
||||
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, outerLength));
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), outerLength));
|
||||
|
||||
auto innerLengthValue = rewriter.create<arith::DivSIOp>(
|
||||
loc, dimSizes[splitDim], outerLengthValue);
|
||||
|
|
|
@ -323,8 +323,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
auto newDimSizesInfo = hlo::getDimIndexOfTensor(rewriter, op, self, dims);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -375,8 +374,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
|||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
}
|
||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
auto newDimSizesInfo = hlo::getDimIndexOfTensor(rewriter, op, self, dims);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -406,8 +404,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
if (!isValidDim(dim, inputRank + 1))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
|
||||
{dim}, options.dimSizeIndexBits);
|
||||
auto unsqzTensorInfo =
|
||||
hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), {dim});
|
||||
if (failed(unsqzTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"failed to create unsqueezed tensor");
|
||||
|
@ -438,8 +436,8 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant end is currently supported");
|
||||
|
||||
auto collapseTensorInfo = hlo::collapseTensor(
|
||||
rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits);
|
||||
auto collapseTensorInfo =
|
||||
hlo::collapseTensor(rewriter, op, adaptor.getA(), start, end);
|
||||
if (failed(collapseTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor");
|
||||
|
||||
|
@ -469,8 +467,8 @@ LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant outerLength is currently supported");
|
||||
|
||||
auto splitTensorInfo = hlo::splitTensor(
|
||||
rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits);
|
||||
auto splitTensorInfo =
|
||||
hlo::splitTensor(rewriter, op, adaptor.getA(), dim, outerLength);
|
||||
|
||||
if (failed(splitTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op, "failed to create split tensor");
|
||||
|
|
|
@ -36,15 +36,12 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1:
|
|||
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10,4,5],f32> -> tensor<10x4x5xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<10x4x5xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<10x4x5xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xindex>) -> tensor<10x4x5xf32>
|
||||
// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
|
||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32>
|
||||
|
@ -62,15 +59,12 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg
|
|||
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,4,?],f32> -> tensor<?x4x?xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<?x4x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<?x4x?xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<?x4x?xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<?x4x?xf32>, tensor<3xindex>) -> tensor<?x4x?xf32>
|
||||
// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
|
@ -88,15 +82,12 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg
|
|||
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,120,256],f32> -> tensor<4x120x256xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<4x120x256xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256x120xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xindex>) -> tensor<4x256x120xf32>
|
||||
// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T9]], %[[T1]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
|
||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32>
|
||||
|
@ -114,15 +105,12 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>,
|
|||
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x?x256xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x?xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xindex>) -> tensor<4x256x?xf32>
|
||||
// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32>
|
||||
|
@ -140,12 +128,10 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>,
|
|||
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<1x?x256xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
|
||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]], %[[T4]] : tensor<2xindex>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xindex>) -> tensor<1x256xf32>
|
||||
// CHECK: %[[T8:.*]] = stablehlo.dot_general %[[T0]], %[[T7]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
|
||||
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32>
|
||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32>
|
||||
|
@ -163,12 +149,10 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1:
|
|||
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,256,?],f32> -> tensor<?x256x?xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<?x256x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
|
||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]], %[[T4]] : tensor<2xindex>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xindex>) -> tensor<?x256xf32>
|
||||
// CHECK: %[[T8:.*]] = stablehlo.dot_general %[[T7]], %[[T1]], batching_dims = [0] x [0], contracting_dims = [1] x [1] : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -231,15 +215,12 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
|||
// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x256xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x256xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xindex>) -> tensor<?x256x256xf32>
|
||||
// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
|
||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x256xf32> to tensor<?x?x256xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32>
|
||||
|
@ -324,10 +305,9 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
|
|||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor<?xf32>
|
||||
// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64
|
||||
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64>
|
||||
// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32>
|
||||
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_9]], %[[VAL_0]], %[[VAL_0]] : tensor<3xindex>
|
||||
// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xindex>) -> tensor<?x1x1xf32>
|
||||
// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor<?x?x?x?xf32>, tensor<?x1x1xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
|
@ -466,24 +446,20 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
|
|||
// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32>
|
||||
// CHECK: %c0 = arith.constant 0 : index
|
||||
// CHECK: %dim = tensor.dim %[[T_7]], %c0 : tensor<3x3x2x2xf32>
|
||||
// CHECK: %[[T_8:.*]] = arith.index_cast %dim : index to i64
|
||||
// CHECK: %c1 = arith.constant 1 : index
|
||||
// CHECK: %dim_0 = tensor.dim %[[T_7]], %c1 : tensor<3x3x2x2xf32>
|
||||
// CHECK: %[[T_9:.*]] = arith.index_cast %dim_0 : index to i64
|
||||
// CHECK: %c2 = arith.constant 2 : index
|
||||
// CHECK: %dim_1 = tensor.dim %[[T_7]], %c2 : tensor<3x3x2x2xf32>
|
||||
// CHECK: %[[T_10:.*]] = arith.index_cast %dim_1 : index to i64
|
||||
// CHECK: %c3 = arith.constant 3 : index
|
||||
// CHECK: %dim_2 = tensor.dim %[[T_7]], %c3 : tensor<3x3x2x2xf32>
|
||||
// CHECK: %[[T_11:.*]] = arith.index_cast %dim_2 : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : i64
|
||||
// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %[[C2]] : i64
|
||||
// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %[[C2]] : i64
|
||||
// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %[[C2]], %[[T_12]] : tensor<5xi64>
|
||||
// CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xi64>) -> tensor<3x3x2x2x1xf32>
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T_12:.*]] = arith.divsi %dim_2, %[[C2]] : index
|
||||
// CHECK: %[[T_13:.*]] = arith.muli %dim_1, %[[C2]] : index
|
||||
// CHECK: %from_elements = tensor.from_elements %dim, %dim_0, %dim_1, %[[C2]], %[[T_12]] : tensor<5xindex>
|
||||
// CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xindex>) -> tensor<3x3x2x2x1xf32>
|
||||
// CHECK: %[[T_15:.*]] = stablehlo.transpose %[[T_14]], dims = [0, 1, 3, 2, 4] : (tensor<3x3x2x2x1xf32>) -> tensor<3x3x2x2x1xf32>
|
||||
// CHECK: %[[from_elements_3:.*]] = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64>
|
||||
// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %[[from_elements_3]] : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32>
|
||||
// CHECK: %[[from_elements_3:.*]] = tensor.from_elements %dim, %dim_0, %[[T_13]], %[[T_12]] : tensor<4xindex>
|
||||
// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %[[from_elements_3]] : (tensor<3x3x2x2x1xf32>, tensor<4xindex>) -> tensor<3x3x4x1xf32>
|
||||
// CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]])
|
||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x1xf32>) -> tensor<1x4x15x15xf32>
|
||||
// CHECK: %[[T_18:.*]] = torch_c.from_builtin_tensor %[[T_17]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
|
||||
|
|
|
@ -83,18 +83,15 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
|||
// CHECK: %[[T5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64
|
||||
// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64>
|
||||
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]] : tensor<3xindex>
|
||||
// CHECK: %[[T9:.*]] = arith.muli %[[DIM_1]], %[[DIM_0]] : index
|
||||
// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[DIM]], %[[T9]] : tensor<2xindex>
|
||||
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xindex>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xindex>) -> tensor<?x?x?xi64>
|
||||
// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) <{padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>}> ({
|
||||
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
|
||||
|
@ -146,18 +143,14 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
|||
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64
|
||||
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : index to i64
|
||||
// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[VAL_12:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i64
|
||||
// CHECK: %[[IDX_3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64
|
||||
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
|
||||
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_10]], %[[VAL_12]], %[[VAL_14]] : tensor<4xindex>
|
||||
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]])
|
||||
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
|
||||
|
|
|
@ -8,19 +8,17 @@
|
|||
// CHECK: %int0 = torch.constant.int 0
|
||||
// CHECK: %[[INDEX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[VAR_1]], %[[INDEX_0]] : tensor<?x?xi64>
|
||||
// CHECK: %[[VAR_3:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[INDEX_1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %1, %[[INDEX_1]] : tensor<?x?xi64>
|
||||
// CHECK: %[[VAR_4:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i64
|
||||
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[FE_:.*]] = tensor.from_elements %[[CONSTANT_0]], %[[CONSTANT_0]] : tensor<2xi64>
|
||||
// CHECK: %[[FE_1:.*]] = tensor.from_elements %[[CONSTANT_1]], %[[CONSTANT_1]] : tensor<2xi64>
|
||||
// CHECK: %[[FE_2:.*]] = tensor.from_elements %[[VAR_3]], %[[VAR_4]] : tensor<2xi64>
|
||||
// CHECK: %[[VAR_5:.*]] = stablehlo.real_dynamic_slice %[[VAR_2]], %[[FE_]], %[[FE_2]], %[[FE_1]] : (tensor<?x?xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[FE_3:.*]] = tensor.from_elements %[[VAR_3]], %[[VAR_4]], %[[CONSTANT_1]] : tensor<3xi64>
|
||||
// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x1xi64>
|
||||
// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor<?x?x1xi64>
|
||||
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[FE_:.*]] = tensor.from_elements %[[CONSTANT_0]], %[[CONSTANT_0]] : tensor<2xindex>
|
||||
// CHECK: %[[FE_1:.*]] = tensor.from_elements %[[CONSTANT_1]], %[[CONSTANT_1]] : tensor<2xindex>
|
||||
// CHECK: %[[FE_2:.*]] = tensor.from_elements %[[DIM_0]], %[[DIM_1]] : tensor<2xindex>
|
||||
// CHECK: %[[VAR_5:.*]] = stablehlo.real_dynamic_slice %[[VAR_2]], %[[FE_]], %[[FE_2]], %[[FE_1]] : (tensor<?x?xi64>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[FE_3:.*]] = tensor.from_elements %[[DIM_0]], %[[DIM_1]], %[[CONSTANT_1]] : tensor<3xindex>
|
||||
// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xindex>) -> tensor<?x?x1xi64>
|
||||
// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xindex>) -> tensor<?x?x1xi64>
|
||||
// CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor<?x?x1xi64>, tensor<?x?x1xi64>) -> tensor<?x?x2xi64>
|
||||
// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
|
||||
// CHECK: ^bb0(%arg3: tensor<i64>, %[[ARG_4:.*]]: tensor<i64>):
|
||||
|
|
|
@ -398,18 +398,14 @@ func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32
|
|||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32>
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xindex>) -> tensor<?x?x1x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x1x?xf32> -> !torch.vtensor<[?,?,1,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32>
|
||||
func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> {
|
||||
|
@ -426,18 +422,14 @@ func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !
|
|||
// CHECK: %[[INT:.*]]-2 = torch.constant.int -2
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32>
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xindex>) -> tensor<?x1x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?xf32> -> !torch.vtensor<[?,1,?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32>
|
||||
func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> {
|
||||
|
@ -453,15 +445,12 @@ func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32
|
|||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]] : tensor<3xindex>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xindex>) -> tensor<2x2x2xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32>
|
||||
func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> {
|
||||
|
@ -477,19 +466,15 @@ func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) ->
|
|||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||
// CHECK: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32>
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<5xindex>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xindex>) -> tensor<1x?x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32>
|
||||
func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> {
|
||||
|
@ -506,19 +491,15 @@ func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
|
|||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||
// CHECK: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32>
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[C1_I64]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<5xindex>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xindex>) -> tensor<?x1x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?x?xf32> -> !torch.vtensor<[?,1,?,?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32>
|
||||
func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> {
|
||||
|
@ -535,19 +516,15 @@ func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
|
|||
// CHECK: %[[INT:.*]]-2 = torch.constant.int -2
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||
// CHECK: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32>
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[C1_I64]], %[[DIM_2]] : tensor<5xindex>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xindex>) -> tensor<?x?x?x1x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x?x1x?xf32> -> !torch.vtensor<[?,?,?,1,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32>
|
||||
func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> {
|
||||
|
|
Loading…
Reference in New Issue