mirror of https://github.com/llvm/torch-mlir
Implement lowering of torch.aten.kthvalue (#3360)
Closes [nod-ai/SHARK-Turbine#620](https://github.com/nod-ai/SHARK-Turbine/issues/620)pull/3461/merge
parent
51902ec2dc
commit
4555629246
|
@ -326,6 +326,82 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TMTensor_TopkOp : TMTensor_Op<"topk",
|
||||||
|
[DeclareOpInterfaceMethods<TMTensorInterface,
|
||||||
|
["payloadUsesValueFromOperand"]>,
|
||||||
|
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
||||||
|
["generateScalarImplementation"]>]> {
|
||||||
|
let summary = "Top-K operator";
|
||||||
|
let description = [{
|
||||||
|
A Top-K operation for N-D tensors. Reduces the target dimension from the input
|
||||||
|
size N down to K elements based on the supplied binary region.
|
||||||
|
|
||||||
|
Accepts an N-D tensor input consisting of values and an optioanl N-D tensor
|
||||||
|
for indices of those values (i32 type). If input indices aren't provided, the
|
||||||
|
index mapping is inferred based on the k dim. Both input values/indices
|
||||||
|
tensors and output values/indicies tensors must have the same shape. Top-K is
|
||||||
|
computed along the target dimension (from dimension()). Returns two output
|
||||||
|
tensors of values and the indicies of Top-K results. The output dimensions
|
||||||
|
must match the input save for the dimension that is reduced to K results.
|
||||||
|
|
||||||
|
Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an
|
||||||
|
i1. If true, the two values are swapped:
|
||||||
|
- For Top-K compoarision: >
|
||||||
|
- For Min-K comparision: <
|
||||||
|
Note: when the two values are equal, the first occurence is always selected.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyShaped>:$inputs,
|
||||||
|
Variadic<AnyShaped>:$outputs,
|
||||||
|
I64Attr:$dimension
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||||
|
let regions = (region AnyRegion:$region);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
attr-dict
|
||||||
|
`dimension` `(` $dimension `)`
|
||||||
|
`ins` `(` $inputs `:` type($inputs) `)`
|
||||||
|
`outs` `(` $outputs `:` type($outputs) `)`
|
||||||
|
$region (`->` type($results)^)?
|
||||||
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||||
|
Value values() {
|
||||||
|
return getInputOperand(0)->get();
|
||||||
|
}
|
||||||
|
std::optional<Value> indices() {
|
||||||
|
if (getNumInputs() < 2) {
|
||||||
|
return {};
|
||||||
|
} else {
|
||||||
|
return getInputOperand(1)->get();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Value outputValues() {
|
||||||
|
return getOutputOperand(0)->get();
|
||||||
|
}
|
||||||
|
Value outputIndices() {
|
||||||
|
return getOutputOperand(1)->get();
|
||||||
|
}
|
||||||
|
ShapedType getInputType() {
|
||||||
|
return cast<ShapedType>(values().getType());
|
||||||
|
}
|
||||||
|
int64_t getInputRank() {
|
||||||
|
return getInputType().getRank();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method to implement for specifying output range for
|
||||||
|
// DestinationStyleOpInterface
|
||||||
|
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
|
||||||
|
std::pair<unsigned, unsigned> outputsIndexAndLength =
|
||||||
|
getODSOperandIndexAndLength(1);
|
||||||
|
return std::make_pair<int64_t, int64_t>(
|
||||||
|
outputsIndexAndLength.first,
|
||||||
|
outputsIndexAndLength.first + outputsIndexAndLength.second);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pure ops
|
// Pure ops
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -12426,6 +12426,34 @@ def Torch_AtenCol2imOp : Torch_Op<"aten.col2im", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenKthvalueOp : Torch_Op<"aten.kthvalue", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$k,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
Torch_BoolType:$keepdim
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$values,
|
||||||
|
AnyTorchOptionalTensorType:$indices
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenKthvalueOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 2);
|
||||||
|
}
|
||||||
|
void AtenKthvalueOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 2);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -254,6 +254,44 @@ static Value createTMTensorScanOp(
|
||||||
return scanOp->getResult(0);
|
return scanOp->getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value> createIntOrFloatCompareOp(PatternRewriter &rewriter,
|
||||||
|
Location loc,
|
||||||
|
Type elementType, Value lhs,
|
||||||
|
Value rhs, bool isDescending,
|
||||||
|
bool isEqual) {
|
||||||
|
|
||||||
|
Value compareOp;
|
||||||
|
if (auto intType = dyn_cast<mlir::IntegerType>(elementType)) {
|
||||||
|
// Case for using arith::CmpIOp.
|
||||||
|
arith::CmpIPredicate g =
|
||||||
|
isEqual ? arith::CmpIPredicate::sge : arith::CmpIPredicate::sgt;
|
||||||
|
arith::CmpIPredicate l =
|
||||||
|
isEqual ? arith::CmpIPredicate::sle : arith::CmpIPredicate::slt;
|
||||||
|
if (intType.isUnsignedInteger()) {
|
||||||
|
g = isEqual ? arith::CmpIPredicate::uge : arith::CmpIPredicate::ugt;
|
||||||
|
l = isEqual ? arith::CmpIPredicate::ule : arith::CmpIPredicate::ult;
|
||||||
|
}
|
||||||
|
arith::CmpIPredicate predicate = isDescending ? g : l;
|
||||||
|
compareOp = rewriter.create<arith::CmpIOp>(loc, predicate, lhs, rhs);
|
||||||
|
return compareOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<mlir::FloatType>(elementType)) {
|
||||||
|
// Case for using arith::CmpFOp.
|
||||||
|
arith::CmpFPredicate g =
|
||||||
|
isEqual ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OGT;
|
||||||
|
arith::CmpFPredicate l =
|
||||||
|
isEqual ? arith::CmpFPredicate::OLE : arith::CmpFPredicate::OLT;
|
||||||
|
|
||||||
|
arith::CmpFPredicate predicate = isDescending ? g : l;
|
||||||
|
compareOp = rewriter.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
|
||||||
|
return compareOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
loc, "Only Integer and Floating element type expected.");
|
||||||
|
}
|
||||||
|
|
||||||
// Utility function to create a TMTensor::SortOp.
|
// Utility function to create a TMTensor::SortOp.
|
||||||
static FailureOr<SmallVector<Value>>
|
static FailureOr<SmallVector<Value>>
|
||||||
createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
||||||
|
@ -280,34 +318,60 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 3. Create comparison op which will be used as the sorting predicate.
|
// Step 3. Create comparison op which will be used as the sorting predicate.
|
||||||
Value compareOp;
|
auto compareOpRetVal = createIntOrFloatCompareOp(
|
||||||
if (auto intType = dyn_cast<mlir::IntegerType>(elementTypes[0])) {
|
rewriter, loc, elementTypes[0], block->getArgument(0),
|
||||||
// Case for using arith::CmpIOp.
|
block->getArgument(1), isDescending, true);
|
||||||
arith::CmpIPredicate ge = arith::CmpIPredicate::sge;
|
|
||||||
arith::CmpIPredicate le = arith::CmpIPredicate::sle;
|
if (failed(compareOpRetVal))
|
||||||
if (intType.isUnsignedInteger()) {
|
|
||||||
ge = arith::CmpIPredicate::uge;
|
|
||||||
le = arith::CmpIPredicate::ule;
|
|
||||||
}
|
|
||||||
arith::CmpIPredicate predicate = isDescending ? ge : le;
|
|
||||||
compareOp = rewriter.create<arith::CmpIOp>(
|
|
||||||
loc, predicate, block->getArgument(0), block->getArgument(1));
|
|
||||||
} else if (isa<mlir::FloatType>(elementTypes[0])) {
|
|
||||||
// Case for using arith::CmpFOp.
|
|
||||||
arith::CmpFPredicate predicate =
|
|
||||||
isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE;
|
|
||||||
compareOp = rewriter.create<arith::CmpFOp>(
|
|
||||||
loc, predicate, block->getArgument(0), block->getArgument(1));
|
|
||||||
} else {
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
sortOpLoc, "Only Integer and Floating element type expected.");
|
loc, "Only Integer and Floating element type expected.");
|
||||||
}
|
|
||||||
|
|
||||||
// Step 4. Create yield op for yielding the sorting predicate.
|
// Step 4. Create yield op for yielding the sorting predicate.
|
||||||
rewriter.create<TMTensor::YieldOp>(loc, compareOp);
|
rewriter.create<TMTensor::YieldOp>(loc, compareOpRetVal.value());
|
||||||
return SmallVector<Value>(sortOp.getResults());
|
return SmallVector<Value>(sortOp.getResults());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<Value>> createTMTensorTopkOp(
|
||||||
|
PatternRewriter &rewriter, Location topkOpLoc, llvm::ArrayRef<Value> inputs,
|
||||||
|
llvm::ArrayRef<Value> outputs, llvm::ArrayRef<Type> elementTypes,
|
||||||
|
int64_t dimension, bool isMinK) {
|
||||||
|
|
||||||
|
// Generate output types.
|
||||||
|
SmallVector<Type> topkResultTypes;
|
||||||
|
for (Value val : outputs) {
|
||||||
|
topkResultTypes.push_back(val.getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create empty TopkOp, add body later.
|
||||||
|
auto topkOp = rewriter.create<TMTensor::TopkOp>(
|
||||||
|
topkOpLoc, topkResultTypes, inputs, outputs,
|
||||||
|
rewriter.getI64IntegerAttr(dimension));
|
||||||
|
|
||||||
|
Region *body = &topkOp.getRegion();
|
||||||
|
Block *block = rewriter.createBlock(body);
|
||||||
|
Location loc = body->getLoc();
|
||||||
|
// Add arguments for each passed body region element type.
|
||||||
|
for (Type elementType : elementTypes) {
|
||||||
|
block->addArgument({elementType}, {loc});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate compare operator. If minK is chosen, isDescending should be false.
|
||||||
|
// Is equal should be false, because we do not want equality to cause element
|
||||||
|
// swap.
|
||||||
|
auto compareOpRetVal = createIntOrFloatCompareOp(
|
||||||
|
rewriter, loc, elementTypes[0], block->getArgument(0),
|
||||||
|
block->getArgument(1), /*isDescending=*/!isMinK, /*isEqual=*/false);
|
||||||
|
|
||||||
|
// Check if correct element types are passed.
|
||||||
|
if (failed(compareOpRetVal))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
loc, "Only Integer and Floating element type expected.");
|
||||||
|
|
||||||
|
// Yield the comparison result.
|
||||||
|
rewriter.create<TMTensor::YieldOp>(loc, compareOpRetVal.value());
|
||||||
|
return SmallVector<Value>(topkOp.getResults());
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
|
class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -1570,6 +1634,456 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenKthvalueOp : public OpConversionPattern<AtenKthvalueOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenKthvalueOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
const llvm::StringRef opName = op->getName().getStringRef();
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto typec = this->getTypeConverter();
|
||||||
|
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
|
unsigned inputRank = inputType.getRank();
|
||||||
|
Type inputElementType = inputType.getElementType();
|
||||||
|
|
||||||
|
auto valResultType =
|
||||||
|
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
|
||||||
|
auto valResultElementType =
|
||||||
|
getElementTypeOrSelf(typec->convertType(valResultType));
|
||||||
|
|
||||||
|
auto idxResultType =
|
||||||
|
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
|
||||||
|
auto idxResultElementType =
|
||||||
|
getElementTypeOrSelf(typec->convertType(idxResultType));
|
||||||
|
|
||||||
|
// get keepdim and check it is bool
|
||||||
|
bool keepDim = false;
|
||||||
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, opName + " requires boolean value for keepdim");
|
||||||
|
|
||||||
|
// get dim, check it is constant int
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only constant dim value is supported");
|
||||||
|
|
||||||
|
// turn dim into positive if negative, and check it is in the valid range
|
||||||
|
dim = toPositiveDim(dim, inputRank);
|
||||||
|
if (!isValidDim(dim, inputRank)) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
}
|
||||||
|
|
||||||
|
// get k, check it is a constant int
|
||||||
|
int64_t k;
|
||||||
|
if (!matchPattern(op.getK(), m_TorchConstantInt(&k)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only constant k value is supported");
|
||||||
|
|
||||||
|
// check if element type is float, int, or unsigned
|
||||||
|
bool isUnsigned = false;
|
||||||
|
if (!isa<mlir::FloatType>(inputElementType)) {
|
||||||
|
if (!isa<mlir::IntegerType>(inputElementType)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, opName + " to linalg.* requires Float or Integer "
|
||||||
|
"input element type");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto integerTy = dyn_cast<mlir::IntegerType>(
|
||||||
|
cast<BaseTensorType>(op.getSelf().getType()).getDtype());
|
||||||
|
isUnsigned = integerTy.isUnsigned();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the values to fill initial output tensors for
|
||||||
|
// topk op and linalg generic op for finding max value.
|
||||||
|
Value fillValLinalgFindMax;
|
||||||
|
Value fillValTopK;
|
||||||
|
if (isa<mlir::FloatType>(inputElementType)) {
|
||||||
|
// max float for topk tensor
|
||||||
|
fillValTopK = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.getFloatAttr(
|
||||||
|
inputElementType,
|
||||||
|
APFloat::getInf(
|
||||||
|
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
|
||||||
|
/*Negative=*/false)));
|
||||||
|
// min float for linalg generic op tensor
|
||||||
|
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.getFloatAttr(
|
||||||
|
inputElementType,
|
||||||
|
APFloat::getInf(
|
||||||
|
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
|
||||||
|
/*Negative=*/true)));
|
||||||
|
} else if (!isUnsigned) {
|
||||||
|
auto width = cast<mlir::IntegerType>(inputElementType).getWidth();
|
||||||
|
// max signed int for topk op tensor
|
||||||
|
auto init = APSInt::getSignedMaxValue(width);
|
||||||
|
fillValTopK = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(inputElementType, init));
|
||||||
|
// min signed int for linalg generic op tensor
|
||||||
|
init = APSInt::getSignedMinValue(width);
|
||||||
|
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(inputElementType, init));
|
||||||
|
} else if (isUnsigned) {
|
||||||
|
auto width = cast<mlir::IntegerType>(inputElementType).getWidth();
|
||||||
|
// max unsigned int for topk op tensor
|
||||||
|
auto init = APInt::getMaxValue(width);
|
||||||
|
fillValTopK = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(inputElementType, init));
|
||||||
|
// min unsigned int for linalg generic op tensor
|
||||||
|
init = APInt::getMinValue(width);
|
||||||
|
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(inputElementType, init));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto i32Type = rewriter.getI32Type();
|
||||||
|
|
||||||
|
// ======== BEGIN: Topk op section ========
|
||||||
|
// Based on iree docs:
|
||||||
|
// https://iree.dev/reference/mlir-dialects/LinalgExt/#iree_linalg_extsort-linalgextsortop
|
||||||
|
|
||||||
|
// Create the output shape of topk op.
|
||||||
|
// For every dimension, topkShape[dimension] = inputShape[dimension],
|
||||||
|
// except topkShape[dim] = k.
|
||||||
|
SmallVector<Value> topkShape;
|
||||||
|
for (unsigned i = 0; i < inputRank; i++) {
|
||||||
|
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
|
||||||
|
topkShape.push_back(currentDimSize);
|
||||||
|
}
|
||||||
|
auto dimSize = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), k));
|
||||||
|
topkShape[dim] = dimSize;
|
||||||
|
|
||||||
|
// Fill the initial topk op output tensor.
|
||||||
|
Value topkOutputVal = createInitTensor(rewriter, loc, topkShape,
|
||||||
|
valResultElementType, fillValTopK);
|
||||||
|
|
||||||
|
// Create the initial value to fill the topk output indices tensor.
|
||||||
|
// It is equal to the max 32-bit signless integer.
|
||||||
|
auto signlessType = mlir::IntegerType::get(op.getContext(), 32,
|
||||||
|
mlir::IntegerType::Signless);
|
||||||
|
auto initIdx = getNumericLimit(rewriter, signlessType, /*getMin=*/false);
|
||||||
|
auto fillValTopkIdx = rewriter.create<arith::ConstantOp>(loc, initIdx);
|
||||||
|
// Fill the initial topk op output indices tensor.
|
||||||
|
Value topkOutputIdx =
|
||||||
|
createInitTensor(rewriter, loc, topkShape, i32Type, fillValTopkIdx);
|
||||||
|
|
||||||
|
// Input arguments for topk op contain only the input tensor.
|
||||||
|
// Input indices will be inferred based on input shape.
|
||||||
|
// (See docs link above).
|
||||||
|
SmallVector<Value> topkInputs;
|
||||||
|
topkInputs.push_back(input);
|
||||||
|
|
||||||
|
// Outputs contain both the values and the indices tensors.
|
||||||
|
SmallVector<Value> topkOutputs;
|
||||||
|
topkOutputs.push_back(topkOutputVal);
|
||||||
|
topkOutputs.push_back(topkOutputIdx);
|
||||||
|
|
||||||
|
// Element types of the arguments passed to the topk op region.
|
||||||
|
// The region accepts the next value N, and the current output
|
||||||
|
// candidate K (see docs link above).
|
||||||
|
// Both N and K are values from the input tensors, thus the
|
||||||
|
// element types are the same and are taken from inputType.
|
||||||
|
SmallVector<Type> topkElementTypes;
|
||||||
|
topkElementTypes.push_back(inputType.getElementType());
|
||||||
|
topkElementTypes.push_back(inputType.getElementType());
|
||||||
|
|
||||||
|
// Create the TMTensor TopkOp.
|
||||||
|
FailureOr<SmallVector<Value>> topkOp;
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
topkOp = createTMTensorTopkOp(rewriter, loc, topkInputs, topkOutputs,
|
||||||
|
topkElementTypes, dim, /*isMinK=*/true);
|
||||||
|
}
|
||||||
|
// Topk op creation fails with invalid element types.
|
||||||
|
if (failed(topkOp))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
loc, "Only Integer and Floating element type expected.");
|
||||||
|
|
||||||
|
auto topkOpVal = topkOp.value();
|
||||||
|
// ======== END: Topk op section ========
|
||||||
|
|
||||||
|
// ======== BEGIN: Linalg generic to find max in topk result ========
|
||||||
|
|
||||||
|
// Create result shape as both a vector of Value and of int64_t types.
|
||||||
|
// We assume that keepdim is false, and fix the result later if true.
|
||||||
|
// Result shape is equal to inputShape, with dim dimension removed.
|
||||||
|
SmallVector<Value> resultShape;
|
||||||
|
SmallVector<int64_t> resultShapeInt;
|
||||||
|
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
||||||
|
if (dim != i) {
|
||||||
|
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
|
||||||
|
resultShape.push_back(currentDimSize);
|
||||||
|
resultShapeInt.push_back(inputType.getShape()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill the initial output tensor for linalg op for finding max value.
|
||||||
|
Value findMaxOutputVal = createInitTensor(
|
||||||
|
rewriter, loc, resultShape, inputElementType, fillValLinalgFindMax);
|
||||||
|
|
||||||
|
// Fill the initial output indices tensor for linalg op for finding max
|
||||||
|
// value with zeros.
|
||||||
|
Value findMaxOutputIdx =
|
||||||
|
createZeroInitTensor(rewriter, loc, resultShape, idxResultElementType);
|
||||||
|
|
||||||
|
// Reduce along dim.
|
||||||
|
SmallVector<utils::IteratorType> findMaxIteratorTypes(
|
||||||
|
inputType.getRank(), utils::IteratorType::parallel);
|
||||||
|
findMaxIteratorTypes[dim] = utils::IteratorType::reduction;
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> findMaxMapExprs;
|
||||||
|
SmallVector<AffineExpr> findMaxMapResultExprs;
|
||||||
|
for (auto size :
|
||||||
|
llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
|
||||||
|
findMaxMapExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||||
|
if (unsigned(dim) != size.index())
|
||||||
|
findMaxMapResultExprs.push_back(
|
||||||
|
rewriter.getAffineDimExpr(size.index()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto findMaxMaps = AffineMap::inferFromExprList(
|
||||||
|
{findMaxMapExprs, findMaxMapResultExprs, findMaxMapResultExprs},
|
||||||
|
rewriter.getContext());
|
||||||
|
|
||||||
|
// Create linalg op for finding the max value in the extracted topk values.
|
||||||
|
auto findMaxLinalg = rewriter.create<linalg::GenericOp>(
|
||||||
|
loc,
|
||||||
|
ArrayRef<Type>(
|
||||||
|
{findMaxOutputVal.getType(), findMaxOutputIdx.getType()}),
|
||||||
|
topkOpVal.front(), ValueRange({findMaxOutputVal, findMaxOutputIdx}),
|
||||||
|
findMaxMaps, findMaxIteratorTypes,
|
||||||
|
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||||
|
ValueRange blockArgs) {
|
||||||
|
// Linalg generic body is the same as the decomposition for
|
||||||
|
// AtenMinDim: lib/Conversion/TorchToLinalg/Reduction.cpp
|
||||||
|
|
||||||
|
Value newValue = blockArgs[0];
|
||||||
|
Value oldValue = blockArgs[1];
|
||||||
|
Value oldIndex = blockArgs[2];
|
||||||
|
|
||||||
|
Value newIndex = rewriter.create<arith::IndexCastOp>(
|
||||||
|
nestedLoc, oldIndex.getType(),
|
||||||
|
rewriter.create<linalg::IndexOp>(nestedLoc, dim));
|
||||||
|
|
||||||
|
Value resultVal, predicate;
|
||||||
|
if (isa<mlir::FloatType>(inputElementType)) {
|
||||||
|
resultVal = rewriter.create<arith::MaximumFOp>(nestedLoc, newValue,
|
||||||
|
oldValue);
|
||||||
|
predicate = rewriter.create<arith::CmpFOp>(
|
||||||
|
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||||
|
} else {
|
||||||
|
arith::CmpIPredicate predType;
|
||||||
|
predType = isUnsigned ? arith::CmpIPredicate::ugt
|
||||||
|
: arith::CmpIPredicate::sgt;
|
||||||
|
if (isUnsigned) {
|
||||||
|
resultVal = rewriter.create<arith::MaxUIOp>(nestedLoc, newValue,
|
||||||
|
oldValue);
|
||||||
|
} else {
|
||||||
|
resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
|
||||||
|
oldValue);
|
||||||
|
}
|
||||||
|
predicate = rewriter.create<arith::CmpIOp>(nestedLoc, predType,
|
||||||
|
newValue, oldValue);
|
||||||
|
}
|
||||||
|
auto resultIndex = rewriter.create<arith::SelectOp>(
|
||||||
|
nestedLoc, predicate, newIndex, oldIndex);
|
||||||
|
nestedBuilder.create<linalg::YieldOp>(
|
||||||
|
nestedLoc, ValueRange{resultVal, resultIndex});
|
||||||
|
});
|
||||||
|
|
||||||
|
auto findMaxVal = findMaxLinalg.getResult(0);
|
||||||
|
auto findMaxIdx = findMaxLinalg.getResult(1);
|
||||||
|
auto findMaxIdxType = cast<RankedTensorType>(findMaxIdx.getType());
|
||||||
|
|
||||||
|
// ======== END: Linalg generic to find max in topk result ========
|
||||||
|
|
||||||
|
// ======== BEGIN: Linalg generic for index extraction ========
|
||||||
|
// The linalg op for finding max returned idx of max elements in the
|
||||||
|
// tensor returned by the topk op. We need the idx of those elements
|
||||||
|
// in the original input. The topk op returned the idx of the top k
|
||||||
|
// extracted elements in the original input. Using the linalg idx
|
||||||
|
// results to index the topk idx results returns the idx of kth
|
||||||
|
// max value in the original input. Example:
|
||||||
|
// input = [1, 7, 3, 6, 2, 8, 9, 5], k = 4
|
||||||
|
// topk_val = [1, 3, 2, 5], topk_idx = [0, 2, 4, 7]
|
||||||
|
// linalg_max_val = [5], linalg_max_idx = [3] (5 is at idx 3 in topk_val)
|
||||||
|
// index the topk_idx using linalg_max_idx -> topk_idx[3] = 7
|
||||||
|
// kth_val = [5], kth_idx = [7]
|
||||||
|
|
||||||
|
// Create a tensor for the resulting idx.
|
||||||
|
Value filledTensorExtractedIdx = createZeroInitTensor(
|
||||||
|
rewriter, loc, getTensorSizes(rewriter, loc, findMaxIdx), i32Type);
|
||||||
|
|
||||||
|
// We iterate through the idx tensor returned by the linalg generic op for
|
||||||
|
// finding max.
|
||||||
|
SmallVector<utils::IteratorType> extractedIdxIteratorTypes(
|
||||||
|
findMaxIdxType.getRank(), utils::IteratorType::parallel);
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> extractedIdxMapExprs;
|
||||||
|
for (auto size :
|
||||||
|
llvm::enumerate(makeShapeTorchCompatible(findMaxIdxType.getShape()))) {
|
||||||
|
extractedIdxMapExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto extractedIdxMaps = AffineMap::inferFromExprList(
|
||||||
|
{extractedIdxMapExprs, extractedIdxMapExprs}, rewriter.getContext());
|
||||||
|
|
||||||
|
// Linalg generic op for indexing the topk output idx tensor using
|
||||||
|
// the idx tensor returned by the linalg generic op for finding max.
|
||||||
|
// Only the idx tensor from the linalg generic op is sent as input.
|
||||||
|
auto extractedIdxLinalg = rewriter.create<linalg::GenericOp>(
|
||||||
|
loc, ArrayRef<Type>({filledTensorExtractedIdx.getType()}), findMaxIdx,
|
||||||
|
filledTensorExtractedIdx, extractedIdxMaps, extractedIdxIteratorTypes,
|
||||||
|
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||||
|
ValueRange blockArgs) {
|
||||||
|
// Get the current input idx.
|
||||||
|
Value index = rewriter.create<arith::IndexCastOp>(
|
||||||
|
loc, rewriter.getIndexType(), blockArgs[0]);
|
||||||
|
|
||||||
|
// Create idx to index the topk idx tensor.
|
||||||
|
// Index the dim dimension using the current input idx.
|
||||||
|
SmallVector<Value> indexTarget;
|
||||||
|
for (unsigned i = 0; i < dim; i++)
|
||||||
|
indexTarget.push_back(rewriter.create<linalg::IndexOp>(loc, i));
|
||||||
|
indexTarget.push_back(index);
|
||||||
|
for (unsigned i = dim; i < findMaxIdxType.getRank(); i++)
|
||||||
|
indexTarget.push_back(rewriter.create<linalg::IndexOp>(loc, i));
|
||||||
|
|
||||||
|
// Extract the element from the topk idx tensor.
|
||||||
|
Value extractedElement = rewriter.create<tensor::ExtractOp>(
|
||||||
|
loc, topkOpVal.back(), indexTarget);
|
||||||
|
rewriter.create<linalg::YieldOp>(loc, extractedElement);
|
||||||
|
});
|
||||||
|
|
||||||
|
auto extractedIdx = extractedIdxLinalg.getResult(0);
|
||||||
|
auto extractedIdxType = cast<RankedTensorType>(extractedIdx.getType());
|
||||||
|
|
||||||
|
// ======== END: Linalg generic for index extraction ========
|
||||||
|
|
||||||
|
// ======== BEGIN: Linalg generic for topk idx cast ========
|
||||||
|
// Casts from i32 to idx result type of the Kthvalue op.
|
||||||
|
|
||||||
|
// Create the initial tensor for the cast result.
|
||||||
|
Value filledTensorCastedIdx = createZeroInitTensor(
|
||||||
|
rewriter, loc, getTensorSizes(rewriter, loc, extractedIdx),
|
||||||
|
idxResultElementType);
|
||||||
|
|
||||||
|
SmallVector<utils::IteratorType> castedIdxIteratorTypes(
|
||||||
|
extractedIdxType.getRank(), utils::IteratorType::parallel);
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> castedIdxMapExprs;
|
||||||
|
for (auto size : llvm::enumerate(
|
||||||
|
makeShapeTorchCompatible(extractedIdxType.getShape()))) {
|
||||||
|
castedIdxMapExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto castedIdxMaps = AffineMap::inferFromExprList(
|
||||||
|
{castedIdxMapExprs, castedIdxMapExprs}, rewriter.getContext());
|
||||||
|
|
||||||
|
// Linalg generic op for casting topk idx output tensor elements from i32 to
|
||||||
|
// result idx tensor element type.
|
||||||
|
auto castedIdxLinalg = rewriter.create<linalg::GenericOp>(
|
||||||
|
loc, ArrayRef<Type>({filledTensorCastedIdx.getType()}), extractedIdx,
|
||||||
|
filledTensorCastedIdx, castedIdxMaps, castedIdxIteratorTypes,
|
||||||
|
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||||
|
ValueRange blockArgs) {
|
||||||
|
Value oldIdx = blockArgs[0];
|
||||||
|
|
||||||
|
// Cast from i32 to index.
|
||||||
|
Value oldIdxToIndexType = rewriter.create<arith::IndexCastOp>(
|
||||||
|
nestedLoc, rewriter.getIndexType(), oldIdx);
|
||||||
|
// Cast from index to result idx element type.
|
||||||
|
Value resultIdx = rewriter.create<arith::IndexCastOp>(
|
||||||
|
nestedLoc, idxResultElementType, oldIdxToIndexType);
|
||||||
|
|
||||||
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, resultIdx);
|
||||||
|
});
|
||||||
|
|
||||||
|
auto castedIdx = castedIdxLinalg.getResult(0);
|
||||||
|
|
||||||
|
// ======== END: Linalg generic for topk idx cast ========
|
||||||
|
|
||||||
|
// Create output value type ("squeezed" since we assume keepdim=False).
|
||||||
|
auto topkValResultType =
|
||||||
|
cast<RankedTensorType>(topkOpVal.front().getType());
|
||||||
|
auto squeezedValType = topkValResultType.cloneWith(
|
||||||
|
resultShapeInt,
|
||||||
|
cast<RankedTensorType>(findMaxVal.getType()).getElementType());
|
||||||
|
|
||||||
|
// Create output idx type ("squeezed" since we assume keepdim=False).
|
||||||
|
auto castedIdxType = cast<RankedTensorType>(castedIdx.getType());
|
||||||
|
auto squeezedIdxType = castedIdxType.cloneWith(
|
||||||
|
resultShapeInt, findMaxIdxType.getElementType());
|
||||||
|
|
||||||
|
if (!keepDim) {
|
||||||
|
// If keepdim=false, cast the the outputs to appropriate type and return.
|
||||||
|
Value retVal =
|
||||||
|
rewriter.create<tensor::CastOp>(loc, squeezedValType, findMaxVal);
|
||||||
|
Value retIdx =
|
||||||
|
rewriter.create<tensor::CastOp>(loc, squeezedIdxType, castedIdx);
|
||||||
|
llvm::SmallVector<Value> res{retVal, retIdx};
|
||||||
|
rewriter.replaceOp(op, res);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If keepdim is false, unsqueeze.
|
||||||
|
// Unsqueezing implementation taken from AteMinMaxDimOp lowering:
|
||||||
|
// lib/Conversion/TorchToLinalg/Reduction.cpp
|
||||||
|
llvm::SmallVector<int64_t> valShape(valResultType.getShape());
|
||||||
|
llvm::SmallVector<int64_t> idxShape(idxResultType.getShape());
|
||||||
|
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
|
||||||
|
valShape[i] = valShape[i + 1];
|
||||||
|
idxShape[i] = idxShape[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
valShape.resize(valShape.size() - 1);
|
||||||
|
idxShape.resize(idxShape.size() - 1);
|
||||||
|
|
||||||
|
Value retVal = rewriter.create<tensor::CastOp>(
|
||||||
|
loc, squeezedValType.clone(valShape), findMaxLinalg.getResult(0));
|
||||||
|
Value retIdx = rewriter.create<tensor::CastOp>(
|
||||||
|
loc, squeezedIdxType.clone(idxShape), castedIdx);
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> reassociation(valShape.size());
|
||||||
|
if (reassociation.size() > 0) {
|
||||||
|
for (int i = 0; i < dim; ++i)
|
||||||
|
reassociation[i].push_back(i);
|
||||||
|
reassociation[std::max<int64_t>(0, dim - 1)].push_back(dim);
|
||||||
|
for (int i = dim, s = reassociation.size(); i < s; ++i)
|
||||||
|
reassociation[i].push_back(i + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
valShape.push_back(0);
|
||||||
|
idxShape.push_back(0);
|
||||||
|
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
|
||||||
|
valShape[i + 1] = valShape[i];
|
||||||
|
idxShape[i + 1] = idxShape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
valShape[dim] = 1;
|
||||||
|
idxShape[dim] = 1;
|
||||||
|
|
||||||
|
Value unsqueezeVal = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
|
loc, valResultType, retVal, reassociation);
|
||||||
|
|
||||||
|
Value unsqueezeIdx = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
|
loc, idxResultType, retIdx, reassociation);
|
||||||
|
|
||||||
|
// Return unsqueezed.
|
||||||
|
llvm::SmallVector<Value> unsqueezes = {unsqueezeVal, unsqueezeIdx};
|
||||||
|
rewriter.replaceOp(op, unsqueezes);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// The pass
|
// The pass
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -1619,6 +2133,8 @@ public:
|
||||||
|
|
||||||
target.addIllegalOp<AtenScatterSrcOp>();
|
target.addIllegalOp<AtenScatterSrcOp>();
|
||||||
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);
|
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenKthvalueOp>();
|
||||||
|
patterns.add<ConvertAtenKthvalueOp>(typeConverter, context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
|
|
@ -910,6 +910,213 @@ bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TopkOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult TopkOp::verify() {
|
||||||
|
Operation *op = getOperation();
|
||||||
|
if (getNumInputs() != 1 && getNumInputs() != 2) {
|
||||||
|
return op->emitOpError("expected one or two input operands");
|
||||||
|
}
|
||||||
|
if (getNumOutputs() != 2) {
|
||||||
|
return op->emitOpError("expected two output operands");
|
||||||
|
}
|
||||||
|
// First check added to eliminate comparison of different int types
|
||||||
|
if (getInputRank() < 0 ||
|
||||||
|
(getDimension() >= static_cast<uint64_t>(getInputRank()))) {
|
||||||
|
return op->emitOpError("dimension exceeds rank");
|
||||||
|
}
|
||||||
|
// Ensure input/output element types match
|
||||||
|
auto inputValuesType = cast<ShapedType>(values().getType());
|
||||||
|
auto outputValuesType = cast<ShapedType>(outputValues().getType());
|
||||||
|
if (inputValuesType.getElementType() != outputValuesType.getElementType()) {
|
||||||
|
return op->emitOpError("expected input/output value types to be identical");
|
||||||
|
}
|
||||||
|
// Indices must be int if provided
|
||||||
|
auto outputIndicesType = cast<ShapedType>(outputIndices().getType());
|
||||||
|
if (auto inputIndices = indices()) {
|
||||||
|
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
|
||||||
|
if (!inputIndicesType.getElementType().isInteger(32) ||
|
||||||
|
!outputIndicesType.getElementType().isInteger(32)) {
|
||||||
|
return op->emitOpError("expected input/output indices types to be int32");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ranks must match
|
||||||
|
if (inputValuesType.getRank() != outputValuesType.getRank()) {
|
||||||
|
return op->emitOpError("expected input/output to have the same rank");
|
||||||
|
}
|
||||||
|
if (auto inputIndices = indices()) {
|
||||||
|
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
|
||||||
|
if (inputIndicesType.getRank() != outputIndicesType.getRank()) {
|
||||||
|
return op->emitOpError("expected input/output to have the same rank");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Input indicies and values must have the same shape.
|
||||||
|
if (auto inputIndices = indices()) {
|
||||||
|
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
|
||||||
|
if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType)))
|
||||||
|
return op->emitOpError("input indices/values shape must match");
|
||||||
|
}
|
||||||
|
// Output indicies and values must have the same shape.
|
||||||
|
if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType)))
|
||||||
|
return op->emitOpError("output indices/values shape must match");
|
||||||
|
// Input shape must match the output shape except for the dimension()
|
||||||
|
uint64_t dim = getDimension();
|
||||||
|
if (!llvm::all_of(llvm::enumerate(llvm::zip(inputValuesType.getShape(),
|
||||||
|
outputValuesType.getShape())),
|
||||||
|
[dim](auto e) {
|
||||||
|
if (e.index() == dim) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::tuple<int64_t, int64_t> s = e.value();
|
||||||
|
return succeeded(verifyCompatibleShape(std::get<0>(s),
|
||||||
|
|
||||||
|
std::get<1>(s)));
|
||||||
|
})) {
|
||||||
|
return op->emitOpError("incompatible input/output shapes");
|
||||||
|
}
|
||||||
|
// Check region compatibility
|
||||||
|
Block &block = getRegion().front();
|
||||||
|
if (block.getNumArguments() != 2) {
|
||||||
|
return op->emitOpError("region block should have 2 arguments");
|
||||||
|
}
|
||||||
|
if (block.getArgument(0).getType() != inputValuesType.getElementType() ||
|
||||||
|
block.getArgument(1).getType() != inputValuesType.getElementType()) {
|
||||||
|
return op->emitOpError("region block types must match input");
|
||||||
|
}
|
||||||
|
auto terminatorOp = llvm::dyn_cast<YieldOp>(block.getTerminator());
|
||||||
|
if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) {
|
||||||
|
return op->emitOpError("region block must end with a linalg_ext.yield i1!");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<utils::IteratorType> TopkOp::getLoopIteratorTypes() {
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(getInputRank(),
|
||||||
|
utils::IteratorType::parallel);
|
||||||
|
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
|
||||||
|
return iteratorTypes;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Range> TopkOp::getIterationDomain(OpBuilder &builder) {
|
||||||
|
int64_t operandRank = getInputRank();
|
||||||
|
SmallVector<Range> loopBounds(operandRank);
|
||||||
|
Location loc = getLoc();
|
||||||
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
Value source = values();
|
||||||
|
for (auto dim : llvm::enumerate(getInputType().getShape())) {
|
||||||
|
loopBounds[dim.index()].offset = zero;
|
||||||
|
loopBounds[dim.index()].size =
|
||||||
|
getDimValue(builder, loc, source, dim.index());
|
||||||
|
loopBounds[dim.index()].stride = one;
|
||||||
|
}
|
||||||
|
return loopBounds;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
||||||
|
ValueRange ivs) {
|
||||||
|
uint64_t kDim = getDimension();
|
||||||
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
Value initialValue = b.create<memref::LoadOp>(loc, values(), ivs);
|
||||||
|
|
||||||
|
// If the indices tensor is not provided, the value index is derived from the
|
||||||
|
// loop induction variables.
|
||||||
|
Value initialIndex;
|
||||||
|
if (indices()) {
|
||||||
|
initialIndex = b.create<memref::LoadOp>(loc, *indices(), ivs);
|
||||||
|
} else {
|
||||||
|
Value rawInitialIndex = ivs[kDim];
|
||||||
|
initialIndex =
|
||||||
|
b.create<arith::IndexCastOp>(loc, b.getI32Type(), rawInitialIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute K (ub) from the selected dim of the output
|
||||||
|
Value ub = b.create<memref::DimOp>(loc, outputValues(), getDimension());
|
||||||
|
|
||||||
|
// Inner K loop functions:
|
||||||
|
// Load current K value and index
|
||||||
|
// Compare N/K using inserted block compare
|
||||||
|
// Check if N == K using strict weak ordering, select which index came first
|
||||||
|
// Select new K value from N/K comparison
|
||||||
|
// Select new K index from N/K comparison or which index came first
|
||||||
|
// Store new k value and index
|
||||||
|
// Yield loop carry values after K selection
|
||||||
|
Value kValue, kIndex;
|
||||||
|
auto scfFor = b.create<scf::ForOp>(
|
||||||
|
loc, zero, ub, one, ValueRange{initialValue, initialIndex},
|
||||||
|
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) {
|
||||||
|
SmallVector<Value> indices(ivs);
|
||||||
|
indices[kDim] = iv;
|
||||||
|
kValue = b.create<memref::LoadOp>(loc, outputValues(), indices);
|
||||||
|
kIndex = b.create<memref::LoadOp>(loc, outputIndices(), indices);
|
||||||
|
});
|
||||||
|
|
||||||
|
SmallVector<Value> indices(ivs);
|
||||||
|
indices[kDim] = scfFor.getInductionVar();
|
||||||
|
auto loopCarryValues = scfFor.getRegionIterArgs();
|
||||||
|
|
||||||
|
// Retrieve region as black box comparision function f(x,y). Plug into op.
|
||||||
|
auto &srcBlock = getRegion().front();
|
||||||
|
IRMapping bvmF; // f(x,y)
|
||||||
|
IRMapping bvmR; // f(y,x)
|
||||||
|
{
|
||||||
|
// Save previous insertion point. Continue within loop body.
|
||||||
|
OpBuilder::InsertionGuard guard(b);
|
||||||
|
b.setInsertionPointToEnd(&scfFor.getRegion().front());
|
||||||
|
SmallVector<Value> forwardValues{loopCarryValues[0], kValue};
|
||||||
|
SmallVector<Value> reverseValues{kValue, loopCarryValues[0]};
|
||||||
|
for (auto it : llvm::zip(srcBlock.getArguments(), forwardValues)) {
|
||||||
|
bvmF.map(std::get<0>(it), std::get<1>(it));
|
||||||
|
}
|
||||||
|
for (auto it : llvm::zip(srcBlock.getArguments(), reverseValues)) {
|
||||||
|
bvmR.map(std::get<0>(it), std::get<1>(it));
|
||||||
|
}
|
||||||
|
for (auto &blockOp : srcBlock.without_terminator()) {
|
||||||
|
b.clone(blockOp, bvmF);
|
||||||
|
b.clone(blockOp, bvmR);
|
||||||
|
}
|
||||||
|
Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0));
|
||||||
|
Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0));
|
||||||
|
|
||||||
|
// Check value equality using strictly weak ordering from the region:
|
||||||
|
// f(x,y) --> forwardCmpRes
|
||||||
|
// f(y,x) --> reverseCmpRes
|
||||||
|
// if forwardCmpRes == reverseCmpRes then select which came first
|
||||||
|
Value cmpValuesEqual = b.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes);
|
||||||
|
Value cmpFirstIndex = b.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex);
|
||||||
|
Value combinedCmpEqRes =
|
||||||
|
b.create<arith::AndIOp>(loc, cmpValuesEqual, cmpFirstIndex);
|
||||||
|
// True if N > K or N came before K
|
||||||
|
Value indexCmpRes =
|
||||||
|
b.create<arith::OrIOp>(loc, forwardCmpRes, combinedCmpEqRes);
|
||||||
|
// Select results for K based on comparisons
|
||||||
|
Value resultKValue = b.create<arith::SelectOp>(loc, forwardCmpRes,
|
||||||
|
loopCarryValues[0], kValue);
|
||||||
|
Value resultKIndex =
|
||||||
|
b.create<arith::SelectOp>(loc, indexCmpRes, loopCarryValues[1], kIndex);
|
||||||
|
b.create<memref::StoreOp>(loc, resultKValue, outputValues(), indices);
|
||||||
|
b.create<memref::StoreOp>(loc, resultKIndex, outputIndices(), indices);
|
||||||
|
// Select loop carry, opposite of K results
|
||||||
|
Value resultCarryValue = b.create<arith::SelectOp>(
|
||||||
|
loc, forwardCmpRes, kValue, loopCarryValues[0]);
|
||||||
|
Value resultCarryIndex =
|
||||||
|
b.create<arith::SelectOp>(loc, indexCmpRes, kIndex, loopCarryValues[1]);
|
||||||
|
b.create<scf::YieldOp>(loc, ValueRange{resultCarryValue, resultCarryIndex});
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TopkOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
||||||
|
// Set to true so that output operands are always initialized.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
|
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
|
||||||
void OP_NAME::getEffects( \
|
void OP_NAME::getEffects( \
|
||||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
|
||||||
|
@ -924,6 +1131,7 @@ DEFINE_OP_GET_EFFECTS(AttentionOp)
|
||||||
DEFINE_OP_GET_EFFECTS(ScanOp)
|
DEFINE_OP_GET_EFFECTS(ScanOp)
|
||||||
DEFINE_OP_GET_EFFECTS(ScatterOp)
|
DEFINE_OP_GET_EFFECTS(ScatterOp)
|
||||||
DEFINE_OP_GET_EFFECTS(SortOp)
|
DEFINE_OP_GET_EFFECTS(SortOp)
|
||||||
|
DEFINE_OP_GET_EFFECTS(TopkOp)
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any
|
/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any
|
||||||
|
|
|
@ -4877,6 +4877,42 @@ LogicalResult AtenLinalgCrossOp::verify() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult AtenKthvalueOp::verify() {
|
||||||
|
|
||||||
|
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
||||||
|
|
||||||
|
if (!selfType.hasDtype() || !selfType.hasSizes())
|
||||||
|
return success();
|
||||||
|
|
||||||
|
Type selfDtype = selfType.getDtype();
|
||||||
|
if (selfDtype.isSignlessInteger(1))
|
||||||
|
return emitOpError("input tensors must not have bool dtype");
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> selfShape = selfType.getSizes();
|
||||||
|
int64_t selfRank = selfShape.size();
|
||||||
|
|
||||||
|
dim = toPositiveDim(dim, selfRank);
|
||||||
|
if (!isValidDim(dim, selfRank))
|
||||||
|
return emitOpError("dim expected to be in range of [")
|
||||||
|
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
|
||||||
|
|
||||||
|
// convert k to an integer type
|
||||||
|
int64_t k;
|
||||||
|
if (!matchPattern(getK(), m_TorchConstantInt(&k)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
// check if k is in the correct range
|
||||||
|
if (selfShape[dim] != kUnknownSize && (k < 1 || k > selfShape[dim]))
|
||||||
|
return emitOpError("k expected to be in range of [")
|
||||||
|
<< 1 << ", " << selfShape[dim] << "], but got " << k;
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DtypeCalculateYieldDtypesOp
|
// DtypeCalculateYieldDtypesOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -6962,6 +6962,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %4 : !torch.list<int>\n"
|
" return %4 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.kthvalue\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||||
|
" %0 = torch.derefine %arg2 : !torch.int to !torch.optional<int>\n"
|
||||||
|
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg3) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||||
|
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -10897,6 +10903,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
" return %6 : !torch.int\n"
|
" return %6 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.kthvalue\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.tuple<int, int> {\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
|
" return %1 : !torch.tuple<int, int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||||
" return %arg3 : !torch.int\n"
|
" return %arg3 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
|
|
@ -1618,6 +1618,7 @@ public:
|
||||||
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
|
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
|
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
|
||||||
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};
|
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};
|
||||||
|
|
||||||
reduction = rewriter
|
reduction = rewriter
|
||||||
.create<Torch::AtenMinDimOp>(loc, types, reduction,
|
.create<Torch::AtenMinDimOp>(loc, types, reduction,
|
||||||
dimValue, op.getKeepdim())
|
dimValue, op.getKeepdim())
|
||||||
|
|
|
@ -2274,6 +2274,11 @@ ONNX_XFAIL_SET = {
|
||||||
"AtenIntTensorCharDtypeModule_basic",
|
"AtenIntTensorCharDtypeModule_basic",
|
||||||
"AtenItemFpOpModule_basic",
|
"AtenItemFpOpModule_basic",
|
||||||
"AtenItemIntOpModule_basic",
|
"AtenItemIntOpModule_basic",
|
||||||
|
"AtenKthvalueModule_basic",
|
||||||
|
"AtenKthvalueKeepDimModule_basic",
|
||||||
|
"AtenKthvalueDynamicDimsModule_basic",
|
||||||
|
"AtenKthvalueFloat64Module_basic",
|
||||||
|
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||||
"AtenLinalgCrossDynamic_basic",
|
"AtenLinalgCrossDynamic_basic",
|
||||||
"AtenMatmulQMixedSigni8Transpose_basic",
|
"AtenMatmulQMixedSigni8Transpose_basic",
|
||||||
"AtenMatmulQMixedSigni8_basic",
|
"AtenMatmulQMixedSigni8_basic",
|
||||||
|
|
|
@ -468,6 +468,14 @@ def aten〇linalg_cross〡shape(self: List[int], other: List[int], dim: int = -1
|
||||||
assert (self[i] == other[i]) or self[i] == 1 or other[i] == 1, f"the size of first tensor ({self[i]}) must match the size of second tensor ({other[i]}) at dimension {i}"
|
assert (self[i] == other[i]) or self[i] == 1 or other[i] == 1, f"the size of first tensor ({self[i]}) must match the size of second tensor ({other[i]}) at dimension {i}"
|
||||||
return upstream_shape_functions.broadcast(self, other)
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
|
@check_shape_function([
|
||||||
|
Invocation(TensorOfShape(2, 4, 3, device="cpu"), k=2, dim=1, keepdim=True), # keep dim,
|
||||||
|
Invocation(TensorOfShape(2, 4, 3, device="cpu"), k=2, dim=1, keepdim=False), # don't keep dim
|
||||||
|
])
|
||||||
|
def aten〇kthvalue〡shape(self: List[int], k: int, dim: int = -1, keepdim: bool = False) -> Tuple[List[int], List[int]]:
|
||||||
|
new_shape = upstream_shape_functions.argmax(self, dim, keepdim)
|
||||||
|
return (new_shape, new_shape)
|
||||||
|
|
||||||
def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]:
|
def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]:
|
||||||
return upstream_shape_functions.unary(grad_output)
|
return upstream_shape_functions.unary(grad_output)
|
||||||
|
|
||||||
|
@ -2705,6 +2713,13 @@ def aten〇linalg_cross〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dty
|
||||||
dtypes = [self_dtype, other_dtype]
|
dtypes = [self_dtype, other_dtype]
|
||||||
return promote_dtypes(ranks, dtypes)
|
return promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
|
@check_dtype_function([
|
||||||
|
Invocation(TensorOfShape(2, 4, 3, dtype=torch.int32, device="cpu"), k=2, dim=-1, keepdim=False)
|
||||||
|
])
|
||||||
|
def aten〇kthvalue〡dtype(self_rank_dtype: Tuple[int, int], k: int, dim: int = -1, keepdim: bool = False) -> Tuple[int, int]:
|
||||||
|
_, self_dtype = self_rank_dtype
|
||||||
|
return (self_dtype, torch.int64)
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
_check_two_tensor_op(dim=0, input_dtype=torch.float32) +
|
_check_two_tensor_op(dim=0, input_dtype=torch.float32) +
|
||||||
_check_two_tensor_op(dim=0, input_dtype=torch.float64))
|
_check_two_tensor_op(dim=0, input_dtype=torch.float64))
|
||||||
|
|
|
@ -912,6 +912,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
)
|
)
|
||||||
emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True)
|
emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True)
|
||||||
emit("aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)")
|
emit("aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)")
|
||||||
|
emit(
|
||||||
|
"aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)",
|
||||||
|
has_verifier=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Functionalization ops
|
# Functionalization ops
|
||||||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||||
|
|
|
@ -5547,3 +5547,102 @@ class CloneModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: CloneModule())
|
@register_test_case(module_factory=lambda: CloneModule())
|
||||||
def CloneModule_basic(module, tu: TestUtils):
|
def CloneModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(5, 5))
|
module.forward(tu.rand(5, 5))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AtenKthvalueModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([2, 6, 3], torch.int32, True)])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.kthvalue(x, k=4, dim=1, keepdim=False)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenKthvalueModule())
|
||||||
|
def AtenKthvalueModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randperm(2 * 6 * 3, dtype=torch.int32).reshape(2, 6, 3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AtenKthvalueKeepDimModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([2, 6, 3], torch.int32, True)])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.kthvalue(x, k=4, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenKthvalueKeepDimModule())
|
||||||
|
def AtenKthvalueKeepDimModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randperm(2 * 6 * 3, dtype=torch.int32).reshape(2, 6, 3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AtenKthvalueDynamicDimsModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([-1, -1, -1, -1], torch.int32, True)])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.kthvalue(x, k=6, dim=2, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenKthvalueDynamicDimsModule())
|
||||||
|
def AtenKthvalueDynamicDimsModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randperm(4 * 2 * 8 * 3, dtype=torch.int32).reshape(4, 2, 8, 3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AtenKthvalueFloat64Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([4, 2, 8, 3], torch.float64, True)])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.kthvalue(x, k=3, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenKthvalueFloat64Module())
|
||||||
|
def AtenKthvalueFloat64Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AtenKthvalueFloat64DynamicDimsModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([-1, -1, -1, -1], torch.float64, True)])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.kthvalue(x, k=3, dim=3, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenKthvalueFloat64DynamicDimsModule())
|
||||||
|
def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue