From 4555629246ab56a5588dbd1f0676830484c65a40 Mon Sep 17 00:00:00 2001 From: ptrifunovic98 <156185835+ptrifunovic98@users.noreply.github.com> Date: Sat, 15 Jun 2024 07:48:39 +0200 Subject: [PATCH] Implement lowering of torch.aten.kthvalue (#3360) Closes [nod-ai/SHARK-Turbine#620](https://github.com/nod-ai/SHARK-Turbine/issues/620) --- .../Dialect/TMTensor/IR/TMTensorOps.td | 76 +++ .../Dialect/Torch/IR/GeneratedTorchOps.td | 28 + .../TorchToTMTensor/TorchToTMTensor.cpp | 560 +++++++++++++++++- lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 208 +++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 36 ++ .../Transforms/AbstractInterpLibrary.cpp | 12 + .../Torch/Transforms/DecomposeComplexOps.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 5 + .../build_tools/abstract_interp_lib_gen.py | 15 + .../build_tools/torch_ods_gen.py | 4 + .../torch_mlir_e2e_test/test_suite/basic.py | 99 ++++ 11 files changed, 1022 insertions(+), 22 deletions(-) diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index dc745097c..e1a8bf452 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -326,6 +326,82 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", }]; } +def TMTensor_TopkOp : TMTensor_Op<"topk", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Top-K operator"; + let description = [{ + A Top-K operation for N-D tensors. Reduces the target dimension from the input + size N down to K elements based on the supplied binary region. + + Accepts an N-D tensor input consisting of values and an optioanl N-D tensor + for indices of those values (i32 type). If input indices aren't provided, the + index mapping is inferred based on the k dim. Both input values/indices + tensors and output values/indicies tensors must have the same shape. Top-K is + computed along the target dimension (from dimension()). Returns two output + tensors of values and the indicies of Top-K results. The output dimensions + must match the input save for the dimension that is reduced to K results. + + Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an + i1. If true, the two values are swapped: + - For Top-K compoarision: > + - For Min-K comparision: < + Note: when the two values are equal, the first occurence is always selected. + }]; + + let arguments = (ins Variadic:$inputs, + Variadic:$outputs, + I64Attr:$dimension + ); + + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); + let assemblyFormat = [{ + attr-dict + `dimension` `(` $dimension `)` + `ins` `(` $inputs `:` type($inputs) `)` + `outs` `(` $outputs `:` type($outputs) `)` + $region (`->` type($results)^)? + }]; + + let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{ + Value values() { + return getInputOperand(0)->get(); + } + std::optional indices() { + if (getNumInputs() < 2) { + return {}; + } else { + return getInputOperand(1)->get(); + } + } + Value outputValues() { + return getOutputOperand(0)->get(); + } + Value outputIndices() { + return getOutputOperand(1)->get(); + } + ShapedType getInputType() { + return cast(values().getType()); + } + int64_t getInputRank() { + return getInputType().getRank(); + } + + // Method to implement for specifying output range for + // DestinationStyleOpInterface + std::pair getDpsInitsPositionRange() { + std::pair outputsIndexAndLength = + getODSOperandIndexAndLength(1); + return std::make_pair( + outputsIndexAndLength.first, + outputsIndexAndLength.first + outputsIndexAndLength.second); + } + }]; +} + //===----------------------------------------------------------------------===// // Pure ops //===----------------------------------------------------------------------===// diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5af6873d8..90e497117 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12426,6 +12426,34 @@ def Torch_AtenCol2imOp : Torch_Op<"aten.col2im", [ }]; } +def Torch_AtenKthvalueOp : Torch_Op<"aten.kthvalue", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$k, + Torch_IntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchOptionalTensorType:$values, + AnyTorchOptionalTensorType:$indices + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenKthvalueOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 2); + } + void AtenKthvalueOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 2); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 684f7f681..9d0a764c1 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -254,6 +254,44 @@ static Value createTMTensorScanOp( return scanOp->getResult(0); } +static FailureOr createIntOrFloatCompareOp(PatternRewriter &rewriter, + Location loc, + Type elementType, Value lhs, + Value rhs, bool isDescending, + bool isEqual) { + + Value compareOp; + if (auto intType = dyn_cast(elementType)) { + // Case for using arith::CmpIOp. + arith::CmpIPredicate g = + isEqual ? arith::CmpIPredicate::sge : arith::CmpIPredicate::sgt; + arith::CmpIPredicate l = + isEqual ? arith::CmpIPredicate::sle : arith::CmpIPredicate::slt; + if (intType.isUnsignedInteger()) { + g = isEqual ? arith::CmpIPredicate::uge : arith::CmpIPredicate::ugt; + l = isEqual ? arith::CmpIPredicate::ule : arith::CmpIPredicate::ult; + } + arith::CmpIPredicate predicate = isDescending ? g : l; + compareOp = rewriter.create(loc, predicate, lhs, rhs); + return compareOp; + } + + if (isa(elementType)) { + // Case for using arith::CmpFOp. + arith::CmpFPredicate g = + isEqual ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OGT; + arith::CmpFPredicate l = + isEqual ? arith::CmpFPredicate::OLE : arith::CmpFPredicate::OLT; + + arith::CmpFPredicate predicate = isDescending ? g : l; + compareOp = rewriter.create(loc, predicate, lhs, rhs); + return compareOp; + } + + return rewriter.notifyMatchFailure( + loc, "Only Integer and Floating element type expected."); +} + // Utility function to create a TMTensor::SortOp. static FailureOr> createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc, @@ -280,34 +318,60 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc, } // Step 3. Create comparison op which will be used as the sorting predicate. - Value compareOp; - if (auto intType = dyn_cast(elementTypes[0])) { - // Case for using arith::CmpIOp. - arith::CmpIPredicate ge = arith::CmpIPredicate::sge; - arith::CmpIPredicate le = arith::CmpIPredicate::sle; - if (intType.isUnsignedInteger()) { - ge = arith::CmpIPredicate::uge; - le = arith::CmpIPredicate::ule; - } - arith::CmpIPredicate predicate = isDescending ? ge : le; - compareOp = rewriter.create( - loc, predicate, block->getArgument(0), block->getArgument(1)); - } else if (isa(elementTypes[0])) { - // Case for using arith::CmpFOp. - arith::CmpFPredicate predicate = - isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE; - compareOp = rewriter.create( - loc, predicate, block->getArgument(0), block->getArgument(1)); - } else { + auto compareOpRetVal = createIntOrFloatCompareOp( + rewriter, loc, elementTypes[0], block->getArgument(0), + block->getArgument(1), isDescending, true); + + if (failed(compareOpRetVal)) return rewriter.notifyMatchFailure( - sortOpLoc, "Only Integer and Floating element type expected."); - } + loc, "Only Integer and Floating element type expected."); // Step 4. Create yield op for yielding the sorting predicate. - rewriter.create(loc, compareOp); + rewriter.create(loc, compareOpRetVal.value()); return SmallVector(sortOp.getResults()); } +static FailureOr> createTMTensorTopkOp( + PatternRewriter &rewriter, Location topkOpLoc, llvm::ArrayRef inputs, + llvm::ArrayRef outputs, llvm::ArrayRef elementTypes, + int64_t dimension, bool isMinK) { + + // Generate output types. + SmallVector topkResultTypes; + for (Value val : outputs) { + topkResultTypes.push_back(val.getType()); + } + + // Create empty TopkOp, add body later. + auto topkOp = rewriter.create( + topkOpLoc, topkResultTypes, inputs, outputs, + rewriter.getI64IntegerAttr(dimension)); + + Region *body = &topkOp.getRegion(); + Block *block = rewriter.createBlock(body); + Location loc = body->getLoc(); + // Add arguments for each passed body region element type. + for (Type elementType : elementTypes) { + block->addArgument({elementType}, {loc}); + } + + // Generate compare operator. If minK is chosen, isDescending should be false. + // Is equal should be false, because we do not want equality to cause element + // swap. + auto compareOpRetVal = createIntOrFloatCompareOp( + rewriter, loc, elementTypes[0], block->getArgument(0), + block->getArgument(1), /*isDescending=*/!isMinK, /*isEqual=*/false); + + // Check if correct element types are passed. + if (failed(compareOpRetVal)) + return rewriter.notifyMatchFailure( + loc, "Only Integer and Floating element type expected."); + + // Yield the comparison result. + rewriter.create(loc, compareOpRetVal.value()); + return SmallVector(topkOp.getResults()); +} + namespace { class ConvertAtenScatterSrcOp : public OpConversionPattern { public: @@ -1570,6 +1634,456 @@ public: }; } // namespace +namespace { +class ConvertAtenKthvalueOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenKthvalueOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const llvm::StringRef opName = op->getName().getStringRef(); + + Location loc = op.getLoc(); + auto typec = this->getTypeConverter(); + + Value input = adaptor.getSelf(); + auto inputType = cast(input.getType()); + unsigned inputRank = inputType.getRank(); + Type inputElementType = inputType.getElementType(); + + auto valResultType = + cast(typec->convertType(op.getResult(0).getType())); + auto valResultElementType = + getElementTypeOrSelf(typec->convertType(valResultType)); + + auto idxResultType = + cast(typec->convertType(op.getResult(1).getType())); + auto idxResultElementType = + getElementTypeOrSelf(typec->convertType(idxResultType)); + + // get keepdim and check it is bool + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) + return rewriter.notifyMatchFailure( + op, opName + " requires boolean value for keepdim"); + + // get dim, check it is constant int + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant dim value is supported"); + + // turn dim into positive if negative, and check it is in the valid range + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) { + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + } + + // get k, check it is a constant int + int64_t k; + if (!matchPattern(op.getK(), m_TorchConstantInt(&k))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant k value is supported"); + + // check if element type is float, int, or unsigned + bool isUnsigned = false; + if (!isa(inputElementType)) { + if (!isa(inputElementType)) { + return rewriter.notifyMatchFailure( + op, opName + " to linalg.* requires Float or Integer " + "input element type"); + } + + auto integerTy = dyn_cast( + cast(op.getSelf().getType()).getDtype()); + isUnsigned = integerTy.isUnsigned(); + } + + // Create the values to fill initial output tensors for + // topk op and linalg generic op for finding max value. + Value fillValLinalgFindMax; + Value fillValTopK; + if (isa(inputElementType)) { + // max float for topk tensor + fillValTopK = rewriter.create( + loc, + rewriter.getFloatAttr( + inputElementType, + APFloat::getInf( + cast(inputElementType).getFloatSemantics(), + /*Negative=*/false))); + // min float for linalg generic op tensor + fillValLinalgFindMax = rewriter.create( + loc, + rewriter.getFloatAttr( + inputElementType, + APFloat::getInf( + cast(inputElementType).getFloatSemantics(), + /*Negative=*/true))); + } else if (!isUnsigned) { + auto width = cast(inputElementType).getWidth(); + // max signed int for topk op tensor + auto init = APSInt::getSignedMaxValue(width); + fillValTopK = rewriter.create( + loc, rewriter.getIntegerAttr(inputElementType, init)); + // min signed int for linalg generic op tensor + init = APSInt::getSignedMinValue(width); + fillValLinalgFindMax = rewriter.create( + loc, rewriter.getIntegerAttr(inputElementType, init)); + } else if (isUnsigned) { + auto width = cast(inputElementType).getWidth(); + // max unsigned int for topk op tensor + auto init = APInt::getMaxValue(width); + fillValTopK = rewriter.create( + loc, rewriter.getIntegerAttr(inputElementType, init)); + // min unsigned int for linalg generic op tensor + init = APInt::getMinValue(width); + fillValLinalgFindMax = rewriter.create( + loc, rewriter.getIntegerAttr(inputElementType, init)); + } + + auto i32Type = rewriter.getI32Type(); + + // ======== BEGIN: Topk op section ======== + // Based on iree docs: + // https://iree.dev/reference/mlir-dialects/LinalgExt/#iree_linalg_extsort-linalgextsortop + + // Create the output shape of topk op. + // For every dimension, topkShape[dimension] = inputShape[dimension], + // except topkShape[dim] = k. + SmallVector topkShape; + for (unsigned i = 0; i < inputRank; i++) { + auto currentDimSize = rewriter.create(loc, input, i); + topkShape.push_back(currentDimSize); + } + auto dimSize = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), k)); + topkShape[dim] = dimSize; + + // Fill the initial topk op output tensor. + Value topkOutputVal = createInitTensor(rewriter, loc, topkShape, + valResultElementType, fillValTopK); + + // Create the initial value to fill the topk output indices tensor. + // It is equal to the max 32-bit signless integer. + auto signlessType = mlir::IntegerType::get(op.getContext(), 32, + mlir::IntegerType::Signless); + auto initIdx = getNumericLimit(rewriter, signlessType, /*getMin=*/false); + auto fillValTopkIdx = rewriter.create(loc, initIdx); + // Fill the initial topk op output indices tensor. + Value topkOutputIdx = + createInitTensor(rewriter, loc, topkShape, i32Type, fillValTopkIdx); + + // Input arguments for topk op contain only the input tensor. + // Input indices will be inferred based on input shape. + // (See docs link above). + SmallVector topkInputs; + topkInputs.push_back(input); + + // Outputs contain both the values and the indices tensors. + SmallVector topkOutputs; + topkOutputs.push_back(topkOutputVal); + topkOutputs.push_back(topkOutputIdx); + + // Element types of the arguments passed to the topk op region. + // The region accepts the next value N, and the current output + // candidate K (see docs link above). + // Both N and K are values from the input tensors, thus the + // element types are the same and are taken from inputType. + SmallVector topkElementTypes; + topkElementTypes.push_back(inputType.getElementType()); + topkElementTypes.push_back(inputType.getElementType()); + + // Create the TMTensor TopkOp. + FailureOr> topkOp; + { + OpBuilder::InsertionGuard guard(rewriter); + topkOp = createTMTensorTopkOp(rewriter, loc, topkInputs, topkOutputs, + topkElementTypes, dim, /*isMinK=*/true); + } + // Topk op creation fails with invalid element types. + if (failed(topkOp)) + return rewriter.notifyMatchFailure( + loc, "Only Integer and Floating element type expected."); + + auto topkOpVal = topkOp.value(); + // ======== END: Topk op section ======== + + // ======== BEGIN: Linalg generic to find max in topk result ======== + + // Create result shape as both a vector of Value and of int64_t types. + // We assume that keepdim is false, and fix the result later if true. + // Result shape is equal to inputShape, with dim dimension removed. + SmallVector resultShape; + SmallVector resultShapeInt; + for (int64_t i = 0; i < inputType.getRank(); i++) { + if (dim != i) { + auto currentDimSize = rewriter.create(loc, input, i); + resultShape.push_back(currentDimSize); + resultShapeInt.push_back(inputType.getShape()[i]); + } + } + + // Fill the initial output tensor for linalg op for finding max value. + Value findMaxOutputVal = createInitTensor( + rewriter, loc, resultShape, inputElementType, fillValLinalgFindMax); + + // Fill the initial output indices tensor for linalg op for finding max + // value with zeros. + Value findMaxOutputIdx = + createZeroInitTensor(rewriter, loc, resultShape, idxResultElementType); + + // Reduce along dim. + SmallVector findMaxIteratorTypes( + inputType.getRank(), utils::IteratorType::parallel); + findMaxIteratorTypes[dim] = utils::IteratorType::reduction; + + SmallVector findMaxMapExprs; + SmallVector findMaxMapResultExprs; + for (auto size : + llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) { + findMaxMapExprs.push_back(rewriter.getAffineDimExpr(size.index())); + if (unsigned(dim) != size.index()) + findMaxMapResultExprs.push_back( + rewriter.getAffineDimExpr(size.index())); + } + + auto findMaxMaps = AffineMap::inferFromExprList( + {findMaxMapExprs, findMaxMapResultExprs, findMaxMapResultExprs}, + rewriter.getContext()); + + // Create linalg op for finding the max value in the extracted topk values. + auto findMaxLinalg = rewriter.create( + loc, + ArrayRef( + {findMaxOutputVal.getType(), findMaxOutputIdx.getType()}), + topkOpVal.front(), ValueRange({findMaxOutputVal, findMaxOutputIdx}), + findMaxMaps, findMaxIteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + // Linalg generic body is the same as the decomposition for + // AtenMinDim: lib/Conversion/TorchToLinalg/Reduction.cpp + + Value newValue = blockArgs[0]; + Value oldValue = blockArgs[1]; + Value oldIndex = blockArgs[2]; + + Value newIndex = rewriter.create( + nestedLoc, oldIndex.getType(), + rewriter.create(nestedLoc, dim)); + + Value resultVal, predicate; + if (isa(inputElementType)) { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + predicate = rewriter.create( + nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); + } else { + arith::CmpIPredicate predType; + predType = isUnsigned ? arith::CmpIPredicate::ugt + : arith::CmpIPredicate::sgt; + if (isUnsigned) { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } else { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } + predicate = rewriter.create(nestedLoc, predType, + newValue, oldValue); + } + auto resultIndex = rewriter.create( + nestedLoc, predicate, newIndex, oldIndex); + nestedBuilder.create( + nestedLoc, ValueRange{resultVal, resultIndex}); + }); + + auto findMaxVal = findMaxLinalg.getResult(0); + auto findMaxIdx = findMaxLinalg.getResult(1); + auto findMaxIdxType = cast(findMaxIdx.getType()); + + // ======== END: Linalg generic to find max in topk result ======== + + // ======== BEGIN: Linalg generic for index extraction ======== + // The linalg op for finding max returned idx of max elements in the + // tensor returned by the topk op. We need the idx of those elements + // in the original input. The topk op returned the idx of the top k + // extracted elements in the original input. Using the linalg idx + // results to index the topk idx results returns the idx of kth + // max value in the original input. Example: + // input = [1, 7, 3, 6, 2, 8, 9, 5], k = 4 + // topk_val = [1, 3, 2, 5], topk_idx = [0, 2, 4, 7] + // linalg_max_val = [5], linalg_max_idx = [3] (5 is at idx 3 in topk_val) + // index the topk_idx using linalg_max_idx -> topk_idx[3] = 7 + // kth_val = [5], kth_idx = [7] + + // Create a tensor for the resulting idx. + Value filledTensorExtractedIdx = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, findMaxIdx), i32Type); + + // We iterate through the idx tensor returned by the linalg generic op for + // finding max. + SmallVector extractedIdxIteratorTypes( + findMaxIdxType.getRank(), utils::IteratorType::parallel); + + SmallVector extractedIdxMapExprs; + for (auto size : + llvm::enumerate(makeShapeTorchCompatible(findMaxIdxType.getShape()))) { + extractedIdxMapExprs.push_back(rewriter.getAffineDimExpr(size.index())); + } + + auto extractedIdxMaps = AffineMap::inferFromExprList( + {extractedIdxMapExprs, extractedIdxMapExprs}, rewriter.getContext()); + + // Linalg generic op for indexing the topk output idx tensor using + // the idx tensor returned by the linalg generic op for finding max. + // Only the idx tensor from the linalg generic op is sent as input. + auto extractedIdxLinalg = rewriter.create( + loc, ArrayRef({filledTensorExtractedIdx.getType()}), findMaxIdx, + filledTensorExtractedIdx, extractedIdxMaps, extractedIdxIteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + // Get the current input idx. + Value index = rewriter.create( + loc, rewriter.getIndexType(), blockArgs[0]); + + // Create idx to index the topk idx tensor. + // Index the dim dimension using the current input idx. + SmallVector indexTarget; + for (unsigned i = 0; i < dim; i++) + indexTarget.push_back(rewriter.create(loc, i)); + indexTarget.push_back(index); + for (unsigned i = dim; i < findMaxIdxType.getRank(); i++) + indexTarget.push_back(rewriter.create(loc, i)); + + // Extract the element from the topk idx tensor. + Value extractedElement = rewriter.create( + loc, topkOpVal.back(), indexTarget); + rewriter.create(loc, extractedElement); + }); + + auto extractedIdx = extractedIdxLinalg.getResult(0); + auto extractedIdxType = cast(extractedIdx.getType()); + + // ======== END: Linalg generic for index extraction ======== + + // ======== BEGIN: Linalg generic for topk idx cast ======== + // Casts from i32 to idx result type of the Kthvalue op. + + // Create the initial tensor for the cast result. + Value filledTensorCastedIdx = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, extractedIdx), + idxResultElementType); + + SmallVector castedIdxIteratorTypes( + extractedIdxType.getRank(), utils::IteratorType::parallel); + + SmallVector castedIdxMapExprs; + for (auto size : llvm::enumerate( + makeShapeTorchCompatible(extractedIdxType.getShape()))) { + castedIdxMapExprs.push_back(rewriter.getAffineDimExpr(size.index())); + } + + auto castedIdxMaps = AffineMap::inferFromExprList( + {castedIdxMapExprs, castedIdxMapExprs}, rewriter.getContext()); + + // Linalg generic op for casting topk idx output tensor elements from i32 to + // result idx tensor element type. + auto castedIdxLinalg = rewriter.create( + loc, ArrayRef({filledTensorCastedIdx.getType()}), extractedIdx, + filledTensorCastedIdx, castedIdxMaps, castedIdxIteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value oldIdx = blockArgs[0]; + + // Cast from i32 to index. + Value oldIdxToIndexType = rewriter.create( + nestedLoc, rewriter.getIndexType(), oldIdx); + // Cast from index to result idx element type. + Value resultIdx = rewriter.create( + nestedLoc, idxResultElementType, oldIdxToIndexType); + + nestedBuilder.create(nestedLoc, resultIdx); + }); + + auto castedIdx = castedIdxLinalg.getResult(0); + + // ======== END: Linalg generic for topk idx cast ======== + + // Create output value type ("squeezed" since we assume keepdim=False). + auto topkValResultType = + cast(topkOpVal.front().getType()); + auto squeezedValType = topkValResultType.cloneWith( + resultShapeInt, + cast(findMaxVal.getType()).getElementType()); + + // Create output idx type ("squeezed" since we assume keepdim=False). + auto castedIdxType = cast(castedIdx.getType()); + auto squeezedIdxType = castedIdxType.cloneWith( + resultShapeInt, findMaxIdxType.getElementType()); + + if (!keepDim) { + // If keepdim=false, cast the the outputs to appropriate type and return. + Value retVal = + rewriter.create(loc, squeezedValType, findMaxVal); + Value retIdx = + rewriter.create(loc, squeezedIdxType, castedIdx); + llvm::SmallVector res{retVal, retIdx}; + rewriter.replaceOp(op, res); + return success(); + } + + // If keepdim is false, unsqueeze. + // Unsqueezing implementation taken from AteMinMaxDimOp lowering: + // lib/Conversion/TorchToLinalg/Reduction.cpp + llvm::SmallVector valShape(valResultType.getShape()); + llvm::SmallVector idxShape(idxResultType.getShape()); + for (int i = dim, s = valShape.size() - 1; i < s; ++i) { + valShape[i] = valShape[i + 1]; + idxShape[i] = idxShape[i + 1]; + } + + valShape.resize(valShape.size() - 1); + idxShape.resize(idxShape.size() - 1); + + Value retVal = rewriter.create( + loc, squeezedValType.clone(valShape), findMaxLinalg.getResult(0)); + Value retIdx = rewriter.create( + loc, squeezedIdxType.clone(idxShape), castedIdx); + + SmallVector reassociation(valShape.size()); + if (reassociation.size() > 0) { + for (int i = 0; i < dim; ++i) + reassociation[i].push_back(i); + reassociation[std::max(0, dim - 1)].push_back(dim); + for (int i = dim, s = reassociation.size(); i < s; ++i) + reassociation[i].push_back(i + 1); + } + + valShape.push_back(0); + idxShape.push_back(0); + for (int i = dim, s = valShape.size() - 1; i < s; ++i) { + valShape[i + 1] = valShape[i]; + idxShape[i + 1] = idxShape[i]; + } + + valShape[dim] = 1; + idxShape[dim] = 1; + + Value unsqueezeVal = rewriter.create( + loc, valResultType, retVal, reassociation); + + Value unsqueezeIdx = rewriter.create( + loc, idxResultType, retIdx, reassociation); + + // Return unsqueezed. + llvm::SmallVector unsqueezes = {unsqueezeVal, unsqueezeIdx}; + rewriter.replaceOp(op, unsqueezes); + return success(); + } +}; +} // namespace + // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- @@ -1619,6 +2133,8 @@ public: target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 218ecad33..05258f506 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -910,6 +910,213 @@ bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) { return true; } +//===----------------------------------------------------------------------===// +// TopkOp +//===----------------------------------------------------------------------===// + +LogicalResult TopkOp::verify() { + Operation *op = getOperation(); + if (getNumInputs() != 1 && getNumInputs() != 2) { + return op->emitOpError("expected one or two input operands"); + } + if (getNumOutputs() != 2) { + return op->emitOpError("expected two output operands"); + } + // First check added to eliminate comparison of different int types + if (getInputRank() < 0 || + (getDimension() >= static_cast(getInputRank()))) { + return op->emitOpError("dimension exceeds rank"); + } + // Ensure input/output element types match + auto inputValuesType = cast(values().getType()); + auto outputValuesType = cast(outputValues().getType()); + if (inputValuesType.getElementType() != outputValuesType.getElementType()) { + return op->emitOpError("expected input/output value types to be identical"); + } + // Indices must be int if provided + auto outputIndicesType = cast(outputIndices().getType()); + if (auto inputIndices = indices()) { + auto inputIndicesType = cast(inputIndices->getType()); + if (!inputIndicesType.getElementType().isInteger(32) || + !outputIndicesType.getElementType().isInteger(32)) { + return op->emitOpError("expected input/output indices types to be int32"); + } + } + + // Ranks must match + if (inputValuesType.getRank() != outputValuesType.getRank()) { + return op->emitOpError("expected input/output to have the same rank"); + } + if (auto inputIndices = indices()) { + auto inputIndicesType = cast(inputIndices->getType()); + if (inputIndicesType.getRank() != outputIndicesType.getRank()) { + return op->emitOpError("expected input/output to have the same rank"); + } + } + // Input indicies and values must have the same shape. + if (auto inputIndices = indices()) { + auto inputIndicesType = cast(inputIndices->getType()); + if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType))) + return op->emitOpError("input indices/values shape must match"); + } + // Output indicies and values must have the same shape. + if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType))) + return op->emitOpError("output indices/values shape must match"); + // Input shape must match the output shape except for the dimension() + uint64_t dim = getDimension(); + if (!llvm::all_of(llvm::enumerate(llvm::zip(inputValuesType.getShape(), + outputValuesType.getShape())), + [dim](auto e) { + if (e.index() == dim) { + return true; + } + std::tuple s = e.value(); + return succeeded(verifyCompatibleShape(std::get<0>(s), + + std::get<1>(s))); + })) { + return op->emitOpError("incompatible input/output shapes"); + } + // Check region compatibility + Block &block = getRegion().front(); + if (block.getNumArguments() != 2) { + return op->emitOpError("region block should have 2 arguments"); + } + if (block.getArgument(0).getType() != inputValuesType.getElementType() || + block.getArgument(1).getType() != inputValuesType.getElementType()) { + return op->emitOpError("region block types must match input"); + } + auto terminatorOp = llvm::dyn_cast(block.getTerminator()); + if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) { + return op->emitOpError("region block must end with a linalg_ext.yield i1!"); + } + return success(); +} + +SmallVector TopkOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(getInputRank(), + utils::IteratorType::parallel); + iteratorTypes[getDimension()] = utils::IteratorType::reduction; + return iteratorTypes; +} + +SmallVector TopkOp::getIterationDomain(OpBuilder &builder) { + int64_t operandRank = getInputRank(); + SmallVector loopBounds(operandRank); + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value source = values(); + for (auto dim : llvm::enumerate(getInputType().getShape())) { + loopBounds[dim.index()].offset = zero; + loopBounds[dim.index()].size = + getDimValue(builder, loc, source, dim.index()); + loopBounds[dim.index()].stride = one; + } + return loopBounds; +} + +LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc, + ValueRange ivs) { + uint64_t kDim = getDimension(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + Value initialValue = b.create(loc, values(), ivs); + + // If the indices tensor is not provided, the value index is derived from the + // loop induction variables. + Value initialIndex; + if (indices()) { + initialIndex = b.create(loc, *indices(), ivs); + } else { + Value rawInitialIndex = ivs[kDim]; + initialIndex = + b.create(loc, b.getI32Type(), rawInitialIndex); + } + + // Compute K (ub) from the selected dim of the output + Value ub = b.create(loc, outputValues(), getDimension()); + + // Inner K loop functions: + // Load current K value and index + // Compare N/K using inserted block compare + // Check if N == K using strict weak ordering, select which index came first + // Select new K value from N/K comparison + // Select new K index from N/K comparison or which index came first + // Store new k value and index + // Yield loop carry values after K selection + Value kValue, kIndex; + auto scfFor = b.create( + loc, zero, ub, one, ValueRange{initialValue, initialIndex}, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) { + SmallVector indices(ivs); + indices[kDim] = iv; + kValue = b.create(loc, outputValues(), indices); + kIndex = b.create(loc, outputIndices(), indices); + }); + + SmallVector indices(ivs); + indices[kDim] = scfFor.getInductionVar(); + auto loopCarryValues = scfFor.getRegionIterArgs(); + + // Retrieve region as black box comparision function f(x,y). Plug into op. + auto &srcBlock = getRegion().front(); + IRMapping bvmF; // f(x,y) + IRMapping bvmR; // f(y,x) + { + // Save previous insertion point. Continue within loop body. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToEnd(&scfFor.getRegion().front()); + SmallVector forwardValues{loopCarryValues[0], kValue}; + SmallVector reverseValues{kValue, loopCarryValues[0]}; + for (auto it : llvm::zip(srcBlock.getArguments(), forwardValues)) { + bvmF.map(std::get<0>(it), std::get<1>(it)); + } + for (auto it : llvm::zip(srcBlock.getArguments(), reverseValues)) { + bvmR.map(std::get<0>(it), std::get<1>(it)); + } + for (auto &blockOp : srcBlock.without_terminator()) { + b.clone(blockOp, bvmF); + b.clone(blockOp, bvmR); + } + Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0)); + Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0)); + + // Check value equality using strictly weak ordering from the region: + // f(x,y) --> forwardCmpRes + // f(y,x) --> reverseCmpRes + // if forwardCmpRes == reverseCmpRes then select which came first + Value cmpValuesEqual = b.create( + loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes); + Value cmpFirstIndex = b.create( + loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex); + Value combinedCmpEqRes = + b.create(loc, cmpValuesEqual, cmpFirstIndex); + // True if N > K or N came before K + Value indexCmpRes = + b.create(loc, forwardCmpRes, combinedCmpEqRes); + // Select results for K based on comparisons + Value resultKValue = b.create(loc, forwardCmpRes, + loopCarryValues[0], kValue); + Value resultKIndex = + b.create(loc, indexCmpRes, loopCarryValues[1], kIndex); + b.create(loc, resultKValue, outputValues(), indices); + b.create(loc, resultKIndex, outputIndices(), indices); + // Select loop carry, opposite of K results + Value resultCarryValue = b.create( + loc, forwardCmpRes, kValue, loopCarryValues[0]); + Value resultCarryIndex = + b.create(loc, indexCmpRes, kIndex, loopCarryValues[1]); + b.create(loc, ValueRange{resultCarryValue, resultCarryIndex}); + } + return success(); +} + +bool TopkOp::payloadUsesValueFromOperand(OpOperand *opOperand) { + // Set to true so that output operands are always initialized. + return true; +} + #define DEFINE_OP_GET_EFFECTS(OP_NAME) \ void OP_NAME::getEffects( \ SmallVectorImpl> \ @@ -924,6 +1131,7 @@ DEFINE_OP_GET_EFFECTS(AttentionOp) DEFINE_OP_GET_EFFECTS(ScanOp) DEFINE_OP_GET_EFFECTS(ScatterOp) DEFINE_OP_GET_EFFECTS(SortOp) +DEFINE_OP_GET_EFFECTS(TopkOp) namespace { /// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 140549ed5..500a861da 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4877,6 +4877,42 @@ LogicalResult AtenLinalgCrossOp::verify() { return success(); } +LogicalResult AtenKthvalueOp::verify() { + + auto selfType = cast(getSelf().getType()); + + if (!selfType.hasDtype() || !selfType.hasSizes()) + return success(); + + Type selfDtype = selfType.getDtype(); + if (selfDtype.isSignlessInteger(1)) + return emitOpError("input tensors must not have bool dtype"); + + int64_t dim; + if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) + return success(); + + ArrayRef selfShape = selfType.getSizes(); + int64_t selfRank = selfShape.size(); + + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return emitOpError("dim expected to be in range of [") + << -selfRank << ", " << selfRank - 1 << "], but got " << dim; + + // convert k to an integer type + int64_t k; + if (!matchPattern(getK(), m_TorchConstantInt(&k))) + return success(); + + // check if k is in the correct range + if (selfShape[dim] != kUnknownSize && (k < 1 || k > selfShape[dim])) + return emitOpError("k expected to be in range of [") + << 1 << ", " << selfShape[dim] << "], but got " << k; + + return success(); +} + //===----------------------------------------------------------------------===// // DtypeCalculateYieldDtypesOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 2eca3ab44..c587fd9f9 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6962,6 +6962,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.kthvalue\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.tuple, list> {\n" +" %0 = torch.derefine %arg2 : !torch.int to !torch.optional\n" +" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg3) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10897,6 +10903,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" " return %6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.kthvalue\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index e1759ceb0..bc3ba0c07 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1618,6 +1618,7 @@ public: auto idxTy = rewriter.getType( reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true)); llvm::SmallVector types{reductionTy, idxTy}; + reduction = rewriter .create(loc, types, reduction, dimValue, op.getKeepdim()) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7eb3d5e4e..058ada5b4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2274,6 +2274,11 @@ ONNX_XFAIL_SET = { "AtenIntTensorCharDtypeModule_basic", "AtenItemFpOpModule_basic", "AtenItemIntOpModule_basic", + "AtenKthvalueModule_basic", + "AtenKthvalueKeepDimModule_basic", + "AtenKthvalueDynamicDimsModule_basic", + "AtenKthvalueFloat64Module_basic", + "AtenKthvalueFloat64DynamicDimsModule_basic", "AtenLinalgCrossDynamic_basic", "AtenMatmulQMixedSigni8Transpose_basic", "AtenMatmulQMixedSigni8_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 3aa1a5ef2..da2681e76 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -468,6 +468,14 @@ def aten〇linalg_cross〡shape(self: List[int], other: List[int], dim: int = -1 assert (self[i] == other[i]) or self[i] == 1 or other[i] == 1, f"the size of first tensor ({self[i]}) must match the size of second tensor ({other[i]}) at dimension {i}" return upstream_shape_functions.broadcast(self, other) +@check_shape_function([ + Invocation(TensorOfShape(2, 4, 3, device="cpu"), k=2, dim=1, keepdim=True), # keep dim, + Invocation(TensorOfShape(2, 4, 3, device="cpu"), k=2, dim=1, keepdim=False), # don't keep dim +]) +def aten〇kthvalue〡shape(self: List[int], k: int, dim: int = -1, keepdim: bool = False) -> Tuple[List[int], List[int]]: + new_shape = upstream_shape_functions.argmax(self, dim, keepdim) + return (new_shape, new_shape) + def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -2705,6 +2713,13 @@ def aten〇linalg_cross〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dty dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function([ + Invocation(TensorOfShape(2, 4, 3, dtype=torch.int32, device="cpu"), k=2, dim=-1, keepdim=False) +]) +def aten〇kthvalue〡dtype(self_rank_dtype: Tuple[int, int], k: int, dim: int = -1, keepdim: bool = False) -> Tuple[int, int]: + _, self_dtype = self_rank_dtype + return (self_dtype, torch.int64) + @check_dtype_function( _check_two_tensor_op(dim=0, input_dtype=torch.float32) + _check_two_tensor_op(dim=0, input_dtype=torch.float64)) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 17c706f25..5a0632bed 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -912,6 +912,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): ) emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True) emit("aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)") + emit( + "aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)", + has_verifier=True, + ) # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index b483f9d3c..552f51af1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5547,3 +5547,102 @@ class CloneModule(torch.nn.Module): @register_test_case(module_factory=lambda: CloneModule()) def CloneModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 5)) + + +# ============================================================================== + + +class AtenKthvalueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([2, 6, 3], torch.int32, True)]) + def forward(self, x): + return torch.ops.aten.kthvalue(x, k=4, dim=1, keepdim=False) + + +@register_test_case(module_factory=lambda: AtenKthvalueModule()) +def AtenKthvalueModule_basic(module, tu: TestUtils): + module.forward(torch.randperm(2 * 6 * 3, dtype=torch.int32).reshape(2, 6, 3)) + + +# ============================================================================== + + +class AtenKthvalueKeepDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([2, 6, 3], torch.int32, True)]) + def forward(self, x): + return torch.ops.aten.kthvalue(x, k=4, dim=1, keepdim=True) + + +@register_test_case(module_factory=lambda: AtenKthvalueKeepDimModule()) +def AtenKthvalueKeepDimModule_basic(module, tu: TestUtils): + module.forward(torch.randperm(2 * 6 * 3, dtype=torch.int32).reshape(2, 6, 3)) + + +# ============================================================================== + + +class AtenKthvalueDynamicDimsModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, -1], torch.int32, True)]) + def forward(self, x): + return torch.ops.aten.kthvalue(x, k=6, dim=2, keepdim=True) + + +@register_test_case(module_factory=lambda: AtenKthvalueDynamicDimsModule()) +def AtenKthvalueDynamicDimsModule_basic(module, tu: TestUtils): + module.forward(torch.randperm(4 * 2 * 8 * 3, dtype=torch.int32).reshape(4, 2, 8, 3)) + + +# ============================================================================== + + +class AtenKthvalueFloat64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([4, 2, 8, 3], torch.float64, True)]) + def forward(self, x): + return torch.ops.aten.kthvalue(x, k=3, dim=0, keepdim=True) + + +@register_test_case(module_factory=lambda: AtenKthvalueFloat64Module()) +def AtenKthvalueFloat64Module_basic(module, tu: TestUtils): + module.forward( + torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3) + ) + + +# ============================================================================== + + +class AtenKthvalueFloat64DynamicDimsModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, -1], torch.float64, True)]) + def forward(self, x): + return torch.ops.aten.kthvalue(x, k=3, dim=3, keepdim=True) + + +@register_test_case(module_factory=lambda: AtenKthvalueFloat64DynamicDimsModule()) +def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils): + module.forward( + torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3) + )