Implement lowering of torch.aten.kthvalue (#3360)

Closes
[nod-ai/SHARK-Turbine#620](https://github.com/nod-ai/SHARK-Turbine/issues/620)
pull/3461/merge
ptrifunovic98 2024-06-15 07:48:39 +02:00 committed by GitHub
parent 51902ec2dc
commit 4555629246
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1022 additions and 22 deletions

View File

@ -326,6 +326,82 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
}];
}
def TMTensor_TopkOp : TMTensor_Op<"topk",
[DeclareOpInterfaceMethods<TMTensorInterface,
["payloadUsesValueFromOperand"]>,
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
["generateScalarImplementation"]>]> {
let summary = "Top-K operator";
let description = [{
A Top-K operation for N-D tensors. Reduces the target dimension from the input
size N down to K elements based on the supplied binary region.
Accepts an N-D tensor input consisting of values and an optioanl N-D tensor
for indices of those values (i32 type). If input indices aren't provided, the
index mapping is inferred based on the k dim. Both input values/indices
tensors and output values/indicies tensors must have the same shape. Top-K is
computed along the target dimension (from dimension()). Returns two output
tensors of values and the indicies of Top-K results. The output dimensions
must match the input save for the dimension that is reduced to K results.
Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an
i1. If true, the two values are swapped:
- For Top-K compoarision: >
- For Min-K comparision: <
Note: when the two values are equal, the first occurence is always selected.
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
I64Attr:$dimension
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
attr-dict
`dimension` `(` $dimension `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
Value values() {
return getInputOperand(0)->get();
}
std::optional<Value> indices() {
if (getNumInputs() < 2) {
return {};
} else {
return getInputOperand(1)->get();
}
}
Value outputValues() {
return getOutputOperand(0)->get();
}
Value outputIndices() {
return getOutputOperand(1)->get();
}
ShapedType getInputType() {
return cast<ShapedType>(values().getType());
}
int64_t getInputRank() {
return getInputType().getRank();
}
// Method to implement for specifying output range for
// DestinationStyleOpInterface
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
std::pair<unsigned, unsigned> outputsIndexAndLength =
getODSOperandIndexAndLength(1);
return std::make_pair<int64_t, int64_t>(
outputsIndexAndLength.first,
outputsIndexAndLength.first + outputsIndexAndLength.second);
}
}];
}
//===----------------------------------------------------------------------===//
// Pure ops
//===----------------------------------------------------------------------===//

View File

@ -12426,6 +12426,34 @@ def Torch_AtenCol2imOp : Torch_Op<"aten.col2im", [
}];
}
def Torch_AtenKthvalueOp : Torch_Op<"aten.kthvalue", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$k,
Torch_IntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchOptionalTensorType:$values,
AnyTorchOptionalTensorType:$indices
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenKthvalueOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 2);
}
void AtenKthvalueOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 2);
}
}];
let hasVerifier = 1;
}
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -254,6 +254,44 @@ static Value createTMTensorScanOp(
return scanOp->getResult(0);
}
static FailureOr<Value> createIntOrFloatCompareOp(PatternRewriter &rewriter,
Location loc,
Type elementType, Value lhs,
Value rhs, bool isDescending,
bool isEqual) {
Value compareOp;
if (auto intType = dyn_cast<mlir::IntegerType>(elementType)) {
// Case for using arith::CmpIOp.
arith::CmpIPredicate g =
isEqual ? arith::CmpIPredicate::sge : arith::CmpIPredicate::sgt;
arith::CmpIPredicate l =
isEqual ? arith::CmpIPredicate::sle : arith::CmpIPredicate::slt;
if (intType.isUnsignedInteger()) {
g = isEqual ? arith::CmpIPredicate::uge : arith::CmpIPredicate::ugt;
l = isEqual ? arith::CmpIPredicate::ule : arith::CmpIPredicate::ult;
}
arith::CmpIPredicate predicate = isDescending ? g : l;
compareOp = rewriter.create<arith::CmpIOp>(loc, predicate, lhs, rhs);
return compareOp;
}
if (isa<mlir::FloatType>(elementType)) {
// Case for using arith::CmpFOp.
arith::CmpFPredicate g =
isEqual ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OGT;
arith::CmpFPredicate l =
isEqual ? arith::CmpFPredicate::OLE : arith::CmpFPredicate::OLT;
arith::CmpFPredicate predicate = isDescending ? g : l;
compareOp = rewriter.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
return compareOp;
}
return rewriter.notifyMatchFailure(
loc, "Only Integer and Floating element type expected.");
}
// Utility function to create a TMTensor::SortOp.
static FailureOr<SmallVector<Value>>
createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
@ -280,34 +318,60 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
}
// Step 3. Create comparison op which will be used as the sorting predicate.
Value compareOp;
if (auto intType = dyn_cast<mlir::IntegerType>(elementTypes[0])) {
// Case for using arith::CmpIOp.
arith::CmpIPredicate ge = arith::CmpIPredicate::sge;
arith::CmpIPredicate le = arith::CmpIPredicate::sle;
if (intType.isUnsignedInteger()) {
ge = arith::CmpIPredicate::uge;
le = arith::CmpIPredicate::ule;
}
arith::CmpIPredicate predicate = isDescending ? ge : le;
compareOp = rewriter.create<arith::CmpIOp>(
loc, predicate, block->getArgument(0), block->getArgument(1));
} else if (isa<mlir::FloatType>(elementTypes[0])) {
// Case for using arith::CmpFOp.
arith::CmpFPredicate predicate =
isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE;
compareOp = rewriter.create<arith::CmpFOp>(
loc, predicate, block->getArgument(0), block->getArgument(1));
} else {
auto compareOpRetVal = createIntOrFloatCompareOp(
rewriter, loc, elementTypes[0], block->getArgument(0),
block->getArgument(1), isDescending, true);
if (failed(compareOpRetVal))
return rewriter.notifyMatchFailure(
sortOpLoc, "Only Integer and Floating element type expected.");
}
loc, "Only Integer and Floating element type expected.");
// Step 4. Create yield op for yielding the sorting predicate.
rewriter.create<TMTensor::YieldOp>(loc, compareOp);
rewriter.create<TMTensor::YieldOp>(loc, compareOpRetVal.value());
return SmallVector<Value>(sortOp.getResults());
}
static FailureOr<SmallVector<Value>> createTMTensorTopkOp(
PatternRewriter &rewriter, Location topkOpLoc, llvm::ArrayRef<Value> inputs,
llvm::ArrayRef<Value> outputs, llvm::ArrayRef<Type> elementTypes,
int64_t dimension, bool isMinK) {
// Generate output types.
SmallVector<Type> topkResultTypes;
for (Value val : outputs) {
topkResultTypes.push_back(val.getType());
}
// Create empty TopkOp, add body later.
auto topkOp = rewriter.create<TMTensor::TopkOp>(
topkOpLoc, topkResultTypes, inputs, outputs,
rewriter.getI64IntegerAttr(dimension));
Region *body = &topkOp.getRegion();
Block *block = rewriter.createBlock(body);
Location loc = body->getLoc();
// Add arguments for each passed body region element type.
for (Type elementType : elementTypes) {
block->addArgument({elementType}, {loc});
}
// Generate compare operator. If minK is chosen, isDescending should be false.
// Is equal should be false, because we do not want equality to cause element
// swap.
auto compareOpRetVal = createIntOrFloatCompareOp(
rewriter, loc, elementTypes[0], block->getArgument(0),
block->getArgument(1), /*isDescending=*/!isMinK, /*isEqual=*/false);
// Check if correct element types are passed.
if (failed(compareOpRetVal))
return rewriter.notifyMatchFailure(
loc, "Only Integer and Floating element type expected.");
// Yield the comparison result.
rewriter.create<TMTensor::YieldOp>(loc, compareOpRetVal.value());
return SmallVector<Value>(topkOp.getResults());
}
namespace {
class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
public:
@ -1570,6 +1634,456 @@ public:
};
} // namespace
namespace {
class ConvertAtenKthvalueOp : public OpConversionPattern<AtenKthvalueOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenKthvalueOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const llvm::StringRef opName = op->getName().getStringRef();
Location loc = op.getLoc();
auto typec = this->getTypeConverter();
Value input = adaptor.getSelf();
auto inputType = cast<RankedTensorType>(input.getType());
unsigned inputRank = inputType.getRank();
Type inputElementType = inputType.getElementType();
auto valResultType =
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
auto valResultElementType =
getElementTypeOrSelf(typec->convertType(valResultType));
auto idxResultType =
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
auto idxResultElementType =
getElementTypeOrSelf(typec->convertType(idxResultType));
// get keepdim and check it is bool
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(
op, opName + " requires boolean value for keepdim");
// get dim, check it is constant int
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant dim value is supported");
// turn dim into positive if negative, and check it is in the valid range
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
}
// get k, check it is a constant int
int64_t k;
if (!matchPattern(op.getK(), m_TorchConstantInt(&k)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant k value is supported");
// check if element type is float, int, or unsigned
bool isUnsigned = false;
if (!isa<mlir::FloatType>(inputElementType)) {
if (!isa<mlir::IntegerType>(inputElementType)) {
return rewriter.notifyMatchFailure(
op, opName + " to linalg.* requires Float or Integer "
"input element type");
}
auto integerTy = dyn_cast<mlir::IntegerType>(
cast<BaseTensorType>(op.getSelf().getType()).getDtype());
isUnsigned = integerTy.isUnsigned();
}
// Create the values to fill initial output tensors for
// topk op and linalg generic op for finding max value.
Value fillValLinalgFindMax;
Value fillValTopK;
if (isa<mlir::FloatType>(inputElementType)) {
// max float for topk tensor
fillValTopK = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getFloatAttr(
inputElementType,
APFloat::getInf(
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
/*Negative=*/false)));
// min float for linalg generic op tensor
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getFloatAttr(
inputElementType,
APFloat::getInf(
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
/*Negative=*/true)));
} else if (!isUnsigned) {
auto width = cast<mlir::IntegerType>(inputElementType).getWidth();
// max signed int for topk op tensor
auto init = APSInt::getSignedMaxValue(width);
fillValTopK = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inputElementType, init));
// min signed int for linalg generic op tensor
init = APSInt::getSignedMinValue(width);
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inputElementType, init));
} else if (isUnsigned) {
auto width = cast<mlir::IntegerType>(inputElementType).getWidth();
// max unsigned int for topk op tensor
auto init = APInt::getMaxValue(width);
fillValTopK = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inputElementType, init));
// min unsigned int for linalg generic op tensor
init = APInt::getMinValue(width);
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inputElementType, init));
}
auto i32Type = rewriter.getI32Type();
// ======== BEGIN: Topk op section ========
// Based on iree docs:
// https://iree.dev/reference/mlir-dialects/LinalgExt/#iree_linalg_extsort-linalgextsortop
// Create the output shape of topk op.
// For every dimension, topkShape[dimension] = inputShape[dimension],
// except topkShape[dim] = k.
SmallVector<Value> topkShape;
for (unsigned i = 0; i < inputRank; i++) {
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
topkShape.push_back(currentDimSize);
}
auto dimSize = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), k));
topkShape[dim] = dimSize;
// Fill the initial topk op output tensor.
Value topkOutputVal = createInitTensor(rewriter, loc, topkShape,
valResultElementType, fillValTopK);
// Create the initial value to fill the topk output indices tensor.
// It is equal to the max 32-bit signless integer.
auto signlessType = mlir::IntegerType::get(op.getContext(), 32,
mlir::IntegerType::Signless);
auto initIdx = getNumericLimit(rewriter, signlessType, /*getMin=*/false);
auto fillValTopkIdx = rewriter.create<arith::ConstantOp>(loc, initIdx);
// Fill the initial topk op output indices tensor.
Value topkOutputIdx =
createInitTensor(rewriter, loc, topkShape, i32Type, fillValTopkIdx);
// Input arguments for topk op contain only the input tensor.
// Input indices will be inferred based on input shape.
// (See docs link above).
SmallVector<Value> topkInputs;
topkInputs.push_back(input);
// Outputs contain both the values and the indices tensors.
SmallVector<Value> topkOutputs;
topkOutputs.push_back(topkOutputVal);
topkOutputs.push_back(topkOutputIdx);
// Element types of the arguments passed to the topk op region.
// The region accepts the next value N, and the current output
// candidate K (see docs link above).
// Both N and K are values from the input tensors, thus the
// element types are the same and are taken from inputType.
SmallVector<Type> topkElementTypes;
topkElementTypes.push_back(inputType.getElementType());
topkElementTypes.push_back(inputType.getElementType());
// Create the TMTensor TopkOp.
FailureOr<SmallVector<Value>> topkOp;
{
OpBuilder::InsertionGuard guard(rewriter);
topkOp = createTMTensorTopkOp(rewriter, loc, topkInputs, topkOutputs,
topkElementTypes, dim, /*isMinK=*/true);
}
// Topk op creation fails with invalid element types.
if (failed(topkOp))
return rewriter.notifyMatchFailure(
loc, "Only Integer and Floating element type expected.");
auto topkOpVal = topkOp.value();
// ======== END: Topk op section ========
// ======== BEGIN: Linalg generic to find max in topk result ========
// Create result shape as both a vector of Value and of int64_t types.
// We assume that keepdim is false, and fix the result later if true.
// Result shape is equal to inputShape, with dim dimension removed.
SmallVector<Value> resultShape;
SmallVector<int64_t> resultShapeInt;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dim != i) {
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
resultShape.push_back(currentDimSize);
resultShapeInt.push_back(inputType.getShape()[i]);
}
}
// Fill the initial output tensor for linalg op for finding max value.
Value findMaxOutputVal = createInitTensor(
rewriter, loc, resultShape, inputElementType, fillValLinalgFindMax);
// Fill the initial output indices tensor for linalg op for finding max
// value with zeros.
Value findMaxOutputIdx =
createZeroInitTensor(rewriter, loc, resultShape, idxResultElementType);
// Reduce along dim.
SmallVector<utils::IteratorType> findMaxIteratorTypes(
inputType.getRank(), utils::IteratorType::parallel);
findMaxIteratorTypes[dim] = utils::IteratorType::reduction;
SmallVector<AffineExpr> findMaxMapExprs;
SmallVector<AffineExpr> findMaxMapResultExprs;
for (auto size :
llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
findMaxMapExprs.push_back(rewriter.getAffineDimExpr(size.index()));
if (unsigned(dim) != size.index())
findMaxMapResultExprs.push_back(
rewriter.getAffineDimExpr(size.index()));
}
auto findMaxMaps = AffineMap::inferFromExprList(
{findMaxMapExprs, findMaxMapResultExprs, findMaxMapResultExprs},
rewriter.getContext());
// Create linalg op for finding the max value in the extracted topk values.
auto findMaxLinalg = rewriter.create<linalg::GenericOp>(
loc,
ArrayRef<Type>(
{findMaxOutputVal.getType(), findMaxOutputIdx.getType()}),
topkOpVal.front(), ValueRange({findMaxOutputVal, findMaxOutputIdx}),
findMaxMaps, findMaxIteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
// Linalg generic body is the same as the decomposition for
// AtenMinDim: lib/Conversion/TorchToLinalg/Reduction.cpp
Value newValue = blockArgs[0];
Value oldValue = blockArgs[1];
Value oldIndex = blockArgs[2];
Value newIndex = rewriter.create<arith::IndexCastOp>(
nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(nestedLoc, dim));
Value resultVal, predicate;
if (isa<mlir::FloatType>(inputElementType)) {
resultVal = rewriter.create<arith::MaximumFOp>(nestedLoc, newValue,
oldValue);
predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
} else {
arith::CmpIPredicate predType;
predType = isUnsigned ? arith::CmpIPredicate::ugt
: arith::CmpIPredicate::sgt;
if (isUnsigned) {
resultVal = rewriter.create<arith::MaxUIOp>(nestedLoc, newValue,
oldValue);
} else {
resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
oldValue);
}
predicate = rewriter.create<arith::CmpIOp>(nestedLoc, predType,
newValue, oldValue);
}
auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange{resultVal, resultIndex});
});
auto findMaxVal = findMaxLinalg.getResult(0);
auto findMaxIdx = findMaxLinalg.getResult(1);
auto findMaxIdxType = cast<RankedTensorType>(findMaxIdx.getType());
// ======== END: Linalg generic to find max in topk result ========
// ======== BEGIN: Linalg generic for index extraction ========
// The linalg op for finding max returned idx of max elements in the
// tensor returned by the topk op. We need the idx of those elements
// in the original input. The topk op returned the idx of the top k
// extracted elements in the original input. Using the linalg idx
// results to index the topk idx results returns the idx of kth
// max value in the original input. Example:
// input = [1, 7, 3, 6, 2, 8, 9, 5], k = 4
// topk_val = [1, 3, 2, 5], topk_idx = [0, 2, 4, 7]
// linalg_max_val = [5], linalg_max_idx = [3] (5 is at idx 3 in topk_val)
// index the topk_idx using linalg_max_idx -> topk_idx[3] = 7
// kth_val = [5], kth_idx = [7]
// Create a tensor for the resulting idx.
Value filledTensorExtractedIdx = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, findMaxIdx), i32Type);
// We iterate through the idx tensor returned by the linalg generic op for
// finding max.
SmallVector<utils::IteratorType> extractedIdxIteratorTypes(
findMaxIdxType.getRank(), utils::IteratorType::parallel);
SmallVector<AffineExpr> extractedIdxMapExprs;
for (auto size :
llvm::enumerate(makeShapeTorchCompatible(findMaxIdxType.getShape()))) {
extractedIdxMapExprs.push_back(rewriter.getAffineDimExpr(size.index()));
}
auto extractedIdxMaps = AffineMap::inferFromExprList(
{extractedIdxMapExprs, extractedIdxMapExprs}, rewriter.getContext());
// Linalg generic op for indexing the topk output idx tensor using
// the idx tensor returned by the linalg generic op for finding max.
// Only the idx tensor from the linalg generic op is sent as input.
auto extractedIdxLinalg = rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>({filledTensorExtractedIdx.getType()}), findMaxIdx,
filledTensorExtractedIdx, extractedIdxMaps, extractedIdxIteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
// Get the current input idx.
Value index = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), blockArgs[0]);
// Create idx to index the topk idx tensor.
// Index the dim dimension using the current input idx.
SmallVector<Value> indexTarget;
for (unsigned i = 0; i < dim; i++)
indexTarget.push_back(rewriter.create<linalg::IndexOp>(loc, i));
indexTarget.push_back(index);
for (unsigned i = dim; i < findMaxIdxType.getRank(); i++)
indexTarget.push_back(rewriter.create<linalg::IndexOp>(loc, i));
// Extract the element from the topk idx tensor.
Value extractedElement = rewriter.create<tensor::ExtractOp>(
loc, topkOpVal.back(), indexTarget);
rewriter.create<linalg::YieldOp>(loc, extractedElement);
});
auto extractedIdx = extractedIdxLinalg.getResult(0);
auto extractedIdxType = cast<RankedTensorType>(extractedIdx.getType());
// ======== END: Linalg generic for index extraction ========
// ======== BEGIN: Linalg generic for topk idx cast ========
// Casts from i32 to idx result type of the Kthvalue op.
// Create the initial tensor for the cast result.
Value filledTensorCastedIdx = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, extractedIdx),
idxResultElementType);
SmallVector<utils::IteratorType> castedIdxIteratorTypes(
extractedIdxType.getRank(), utils::IteratorType::parallel);
SmallVector<AffineExpr> castedIdxMapExprs;
for (auto size : llvm::enumerate(
makeShapeTorchCompatible(extractedIdxType.getShape()))) {
castedIdxMapExprs.push_back(rewriter.getAffineDimExpr(size.index()));
}
auto castedIdxMaps = AffineMap::inferFromExprList(
{castedIdxMapExprs, castedIdxMapExprs}, rewriter.getContext());
// Linalg generic op for casting topk idx output tensor elements from i32 to
// result idx tensor element type.
auto castedIdxLinalg = rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>({filledTensorCastedIdx.getType()}), extractedIdx,
filledTensorCastedIdx, castedIdxMaps, castedIdxIteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value oldIdx = blockArgs[0];
// Cast from i32 to index.
Value oldIdxToIndexType = rewriter.create<arith::IndexCastOp>(
nestedLoc, rewriter.getIndexType(), oldIdx);
// Cast from index to result idx element type.
Value resultIdx = rewriter.create<arith::IndexCastOp>(
nestedLoc, idxResultElementType, oldIdxToIndexType);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, resultIdx);
});
auto castedIdx = castedIdxLinalg.getResult(0);
// ======== END: Linalg generic for topk idx cast ========
// Create output value type ("squeezed" since we assume keepdim=False).
auto topkValResultType =
cast<RankedTensorType>(topkOpVal.front().getType());
auto squeezedValType = topkValResultType.cloneWith(
resultShapeInt,
cast<RankedTensorType>(findMaxVal.getType()).getElementType());
// Create output idx type ("squeezed" since we assume keepdim=False).
auto castedIdxType = cast<RankedTensorType>(castedIdx.getType());
auto squeezedIdxType = castedIdxType.cloneWith(
resultShapeInt, findMaxIdxType.getElementType());
if (!keepDim) {
// If keepdim=false, cast the the outputs to appropriate type and return.
Value retVal =
rewriter.create<tensor::CastOp>(loc, squeezedValType, findMaxVal);
Value retIdx =
rewriter.create<tensor::CastOp>(loc, squeezedIdxType, castedIdx);
llvm::SmallVector<Value> res{retVal, retIdx};
rewriter.replaceOp(op, res);
return success();
}
// If keepdim is false, unsqueeze.
// Unsqueezing implementation taken from AteMinMaxDimOp lowering:
// lib/Conversion/TorchToLinalg/Reduction.cpp
llvm::SmallVector<int64_t> valShape(valResultType.getShape());
llvm::SmallVector<int64_t> idxShape(idxResultType.getShape());
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
valShape[i] = valShape[i + 1];
idxShape[i] = idxShape[i + 1];
}
valShape.resize(valShape.size() - 1);
idxShape.resize(idxShape.size() - 1);
Value retVal = rewriter.create<tensor::CastOp>(
loc, squeezedValType.clone(valShape), findMaxLinalg.getResult(0));
Value retIdx = rewriter.create<tensor::CastOp>(
loc, squeezedIdxType.clone(idxShape), castedIdx);
SmallVector<ReassociationIndices> reassociation(valShape.size());
if (reassociation.size() > 0) {
for (int i = 0; i < dim; ++i)
reassociation[i].push_back(i);
reassociation[std::max<int64_t>(0, dim - 1)].push_back(dim);
for (int i = dim, s = reassociation.size(); i < s; ++i)
reassociation[i].push_back(i + 1);
}
valShape.push_back(0);
idxShape.push_back(0);
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
valShape[i + 1] = valShape[i];
idxShape[i + 1] = idxShape[i];
}
valShape[dim] = 1;
idxShape[dim] = 1;
Value unsqueezeVal = rewriter.create<tensor::ExpandShapeOp>(
loc, valResultType, retVal, reassociation);
Value unsqueezeIdx = rewriter.create<tensor::ExpandShapeOp>(
loc, idxResultType, retIdx, reassociation);
// Return unsqueezed.
llvm::SmallVector<Value> unsqueezes = {unsqueezeVal, unsqueezeIdx};
rewriter.replaceOp(op, unsqueezes);
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -1619,6 +2133,8 @@ public:
target.addIllegalOp<AtenScatterSrcOp>();
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);
target.addIllegalOp<AtenKthvalueOp>();
patterns.add<ConvertAtenKthvalueOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -910,6 +910,213 @@ bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
return true;
}
//===----------------------------------------------------------------------===//
// TopkOp
//===----------------------------------------------------------------------===//
LogicalResult TopkOp::verify() {
Operation *op = getOperation();
if (getNumInputs() != 1 && getNumInputs() != 2) {
return op->emitOpError("expected one or two input operands");
}
if (getNumOutputs() != 2) {
return op->emitOpError("expected two output operands");
}
// First check added to eliminate comparison of different int types
if (getInputRank() < 0 ||
(getDimension() >= static_cast<uint64_t>(getInputRank()))) {
return op->emitOpError("dimension exceeds rank");
}
// Ensure input/output element types match
auto inputValuesType = cast<ShapedType>(values().getType());
auto outputValuesType = cast<ShapedType>(outputValues().getType());
if (inputValuesType.getElementType() != outputValuesType.getElementType()) {
return op->emitOpError("expected input/output value types to be identical");
}
// Indices must be int if provided
auto outputIndicesType = cast<ShapedType>(outputIndices().getType());
if (auto inputIndices = indices()) {
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
if (!inputIndicesType.getElementType().isInteger(32) ||
!outputIndicesType.getElementType().isInteger(32)) {
return op->emitOpError("expected input/output indices types to be int32");
}
}
// Ranks must match
if (inputValuesType.getRank() != outputValuesType.getRank()) {
return op->emitOpError("expected input/output to have the same rank");
}
if (auto inputIndices = indices()) {
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
if (inputIndicesType.getRank() != outputIndicesType.getRank()) {
return op->emitOpError("expected input/output to have the same rank");
}
}
// Input indicies and values must have the same shape.
if (auto inputIndices = indices()) {
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType)))
return op->emitOpError("input indices/values shape must match");
}
// Output indicies and values must have the same shape.
if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType)))
return op->emitOpError("output indices/values shape must match");
// Input shape must match the output shape except for the dimension()
uint64_t dim = getDimension();
if (!llvm::all_of(llvm::enumerate(llvm::zip(inputValuesType.getShape(),
outputValuesType.getShape())),
[dim](auto e) {
if (e.index() == dim) {
return true;
}
std::tuple<int64_t, int64_t> s = e.value();
return succeeded(verifyCompatibleShape(std::get<0>(s),
std::get<1>(s)));
})) {
return op->emitOpError("incompatible input/output shapes");
}
// Check region compatibility
Block &block = getRegion().front();
if (block.getNumArguments() != 2) {
return op->emitOpError("region block should have 2 arguments");
}
if (block.getArgument(0).getType() != inputValuesType.getElementType() ||
block.getArgument(1).getType() != inputValuesType.getElementType()) {
return op->emitOpError("region block types must match input");
}
auto terminatorOp = llvm::dyn_cast<YieldOp>(block.getTerminator());
if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) {
return op->emitOpError("region block must end with a linalg_ext.yield i1!");
}
return success();
}
SmallVector<utils::IteratorType> TopkOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getInputRank(),
utils::IteratorType::parallel);
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
return iteratorTypes;
}
SmallVector<Range> TopkOp::getIterationDomain(OpBuilder &builder) {
int64_t operandRank = getInputRank();
SmallVector<Range> loopBounds(operandRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = values();
for (auto dim : llvm::enumerate(getInputType().getShape())) {
loopBounds[dim.index()].offset = zero;
loopBounds[dim.index()].size =
getDimValue(builder, loc, source, dim.index());
loopBounds[dim.index()].stride = one;
}
return loopBounds;
}
LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc,
ValueRange ivs) {
uint64_t kDim = getDimension();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value initialValue = b.create<memref::LoadOp>(loc, values(), ivs);
// If the indices tensor is not provided, the value index is derived from the
// loop induction variables.
Value initialIndex;
if (indices()) {
initialIndex = b.create<memref::LoadOp>(loc, *indices(), ivs);
} else {
Value rawInitialIndex = ivs[kDim];
initialIndex =
b.create<arith::IndexCastOp>(loc, b.getI32Type(), rawInitialIndex);
}
// Compute K (ub) from the selected dim of the output
Value ub = b.create<memref::DimOp>(loc, outputValues(), getDimension());
// Inner K loop functions:
// Load current K value and index
// Compare N/K using inserted block compare
// Check if N == K using strict weak ordering, select which index came first
// Select new K value from N/K comparison
// Select new K index from N/K comparison or which index came first
// Store new k value and index
// Yield loop carry values after K selection
Value kValue, kIndex;
auto scfFor = b.create<scf::ForOp>(
loc, zero, ub, one, ValueRange{initialValue, initialIndex},
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) {
SmallVector<Value> indices(ivs);
indices[kDim] = iv;
kValue = b.create<memref::LoadOp>(loc, outputValues(), indices);
kIndex = b.create<memref::LoadOp>(loc, outputIndices(), indices);
});
SmallVector<Value> indices(ivs);
indices[kDim] = scfFor.getInductionVar();
auto loopCarryValues = scfFor.getRegionIterArgs();
// Retrieve region as black box comparision function f(x,y). Plug into op.
auto &srcBlock = getRegion().front();
IRMapping bvmF; // f(x,y)
IRMapping bvmR; // f(y,x)
{
// Save previous insertion point. Continue within loop body.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToEnd(&scfFor.getRegion().front());
SmallVector<Value> forwardValues{loopCarryValues[0], kValue};
SmallVector<Value> reverseValues{kValue, loopCarryValues[0]};
for (auto it : llvm::zip(srcBlock.getArguments(), forwardValues)) {
bvmF.map(std::get<0>(it), std::get<1>(it));
}
for (auto it : llvm::zip(srcBlock.getArguments(), reverseValues)) {
bvmR.map(std::get<0>(it), std::get<1>(it));
}
for (auto &blockOp : srcBlock.without_terminator()) {
b.clone(blockOp, bvmF);
b.clone(blockOp, bvmR);
}
Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0));
Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0));
// Check value equality using strictly weak ordering from the region:
// f(x,y) --> forwardCmpRes
// f(y,x) --> reverseCmpRes
// if forwardCmpRes == reverseCmpRes then select which came first
Value cmpValuesEqual = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes);
Value cmpFirstIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex);
Value combinedCmpEqRes =
b.create<arith::AndIOp>(loc, cmpValuesEqual, cmpFirstIndex);
// True if N > K or N came before K
Value indexCmpRes =
b.create<arith::OrIOp>(loc, forwardCmpRes, combinedCmpEqRes);
// Select results for K based on comparisons
Value resultKValue = b.create<arith::SelectOp>(loc, forwardCmpRes,
loopCarryValues[0], kValue);
Value resultKIndex =
b.create<arith::SelectOp>(loc, indexCmpRes, loopCarryValues[1], kIndex);
b.create<memref::StoreOp>(loc, resultKValue, outputValues(), indices);
b.create<memref::StoreOp>(loc, resultKIndex, outputIndices(), indices);
// Select loop carry, opposite of K results
Value resultCarryValue = b.create<arith::SelectOp>(
loc, forwardCmpRes, kValue, loopCarryValues[0]);
Value resultCarryIndex =
b.create<arith::SelectOp>(loc, indexCmpRes, kIndex, loopCarryValues[1]);
b.create<scf::YieldOp>(loc, ValueRange{resultCarryValue, resultCarryIndex});
}
return success();
}
bool TopkOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
// Set to true so that output operands are always initialized.
return true;
}
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
@ -924,6 +1131,7 @@ DEFINE_OP_GET_EFFECTS(AttentionOp)
DEFINE_OP_GET_EFFECTS(ScanOp)
DEFINE_OP_GET_EFFECTS(ScatterOp)
DEFINE_OP_GET_EFFECTS(SortOp)
DEFINE_OP_GET_EFFECTS(TopkOp)
namespace {
/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any

View File

@ -4877,6 +4877,42 @@ LogicalResult AtenLinalgCrossOp::verify() {
return success();
}
LogicalResult AtenKthvalueOp::verify() {
auto selfType = cast<BaseTensorType>(getSelf().getType());
if (!selfType.hasDtype() || !selfType.hasSizes())
return success();
Type selfDtype = selfType.getDtype();
if (selfDtype.isSignlessInteger(1))
return emitOpError("input tensors must not have bool dtype");
int64_t dim;
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
return success();
ArrayRef<int64_t> selfShape = selfType.getSizes();
int64_t selfRank = selfShape.size();
dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
return emitOpError("dim expected to be in range of [")
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
// convert k to an integer type
int64_t k;
if (!matchPattern(getK(), m_TorchConstantInt(&k)))
return success();
// check if k is in the correct range
if (selfShape[dim] != kUnknownSize && (k < 1 || k > selfShape[dim]))
return emitOpError("k expected to be in range of [")
<< 1 << ", " << selfShape[dim] << "], but got " << k;
return success();
}
//===----------------------------------------------------------------------===//
// DtypeCalculateYieldDtypesOp
//===----------------------------------------------------------------------===//

View File

@ -6962,6 +6962,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.kthvalue\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = torch.derefine %arg2 : !torch.int to !torch.optional<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg3) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -10897,6 +10903,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.kthvalue\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" return %arg3 : !torch.int\n"
" }\n"

View File

@ -1618,6 +1618,7 @@ public:
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};
reduction = rewriter
.create<Torch::AtenMinDimOp>(loc, types, reduction,
dimValue, op.getKeepdim())

View File

@ -2274,6 +2274,11 @@ ONNX_XFAIL_SET = {
"AtenIntTensorCharDtypeModule_basic",
"AtenItemFpOpModule_basic",
"AtenItemIntOpModule_basic",
"AtenKthvalueModule_basic",
"AtenKthvalueKeepDimModule_basic",
"AtenKthvalueDynamicDimsModule_basic",
"AtenKthvalueFloat64Module_basic",
"AtenKthvalueFloat64DynamicDimsModule_basic",
"AtenLinalgCrossDynamic_basic",
"AtenMatmulQMixedSigni8Transpose_basic",
"AtenMatmulQMixedSigni8_basic",

View File

@ -468,6 +468,14 @@ def atenlinalg_cross〡shape(self: List[int], other: List[int], dim: int = -1
assert (self[i] == other[i]) or self[i] == 1 or other[i] == 1, f"the size of first tensor ({self[i]}) must match the size of second tensor ({other[i]}) at dimension {i}"
return upstream_shape_functions.broadcast(self, other)
@check_shape_function([
Invocation(TensorOfShape(2, 4, 3, device="cpu"), k=2, dim=1, keepdim=True), # keep dim,
Invocation(TensorOfShape(2, 4, 3, device="cpu"), k=2, dim=1, keepdim=False), # don't keep dim
])
def atenkthvalue〡shape(self: List[int], k: int, dim: int = -1, keepdim: bool = False) -> Tuple[List[int], List[int]]:
new_shape = upstream_shape_functions.argmax(self, dim, keepdim)
return (new_shape, new_shape)
def aten_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]:
return upstream_shape_functions.unary(grad_output)
@ -2705,6 +2713,13 @@ def atenlinalg_cross〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dty
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function([
Invocation(TensorOfShape(2, 4, 3, dtype=torch.int32, device="cpu"), k=2, dim=-1, keepdim=False)
])
def atenkthvalue〡dtype(self_rank_dtype: Tuple[int, int], k: int, dim: int = -1, keepdim: bool = False) -> Tuple[int, int]:
_, self_dtype = self_rank_dtype
return (self_dtype, torch.int64)
@check_dtype_function(
_check_two_tensor_op(dim=0, input_dtype=torch.float32) +
_check_two_tensor_op(dim=0, input_dtype=torch.float64))

View File

@ -912,6 +912,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
)
emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True)
emit("aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)")
emit(
"aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)",
has_verifier=True,
)
# Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)")

View File

@ -5547,3 +5547,102 @@ class CloneModule(torch.nn.Module):
@register_test_case(module_factory=lambda: CloneModule())
def CloneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5))
# ==============================================================================
class AtenKthvalueModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([2, 6, 3], torch.int32, True)])
def forward(self, x):
return torch.ops.aten.kthvalue(x, k=4, dim=1, keepdim=False)
@register_test_case(module_factory=lambda: AtenKthvalueModule())
def AtenKthvalueModule_basic(module, tu: TestUtils):
module.forward(torch.randperm(2 * 6 * 3, dtype=torch.int32).reshape(2, 6, 3))
# ==============================================================================
class AtenKthvalueKeepDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([2, 6, 3], torch.int32, True)])
def forward(self, x):
return torch.ops.aten.kthvalue(x, k=4, dim=1, keepdim=True)
@register_test_case(module_factory=lambda: AtenKthvalueKeepDimModule())
def AtenKthvalueKeepDimModule_basic(module, tu: TestUtils):
module.forward(torch.randperm(2 * 6 * 3, dtype=torch.int32).reshape(2, 6, 3))
# ==============================================================================
class AtenKthvalueDynamicDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1, -1, -1, -1], torch.int32, True)])
def forward(self, x):
return torch.ops.aten.kthvalue(x, k=6, dim=2, keepdim=True)
@register_test_case(module_factory=lambda: AtenKthvalueDynamicDimsModule())
def AtenKthvalueDynamicDimsModule_basic(module, tu: TestUtils):
module.forward(torch.randperm(4 * 2 * 8 * 3, dtype=torch.int32).reshape(4, 2, 8, 3))
# ==============================================================================
class AtenKthvalueFloat64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([4, 2, 8, 3], torch.float64, True)])
def forward(self, x):
return torch.ops.aten.kthvalue(x, k=3, dim=0, keepdim=True)
@register_test_case(module_factory=lambda: AtenKthvalueFloat64Module())
def AtenKthvalueFloat64Module_basic(module, tu: TestUtils):
module.forward(
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
)
# ==============================================================================
class AtenKthvalueFloat64DynamicDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1, -1, -1, -1], torch.float64, True)])
def forward(self, x):
return torch.ops.aten.kthvalue(x, k=3, dim=3, keepdim=True)
@register_test_case(module_factory=lambda: AtenKthvalueFloat64DynamicDimsModule())
def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils):
module.forward(
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
)