mirror of https://github.com/llvm/torch-mlir
Implement lowering of torch.aten.kthvalue (#3360)
Closes [nod-ai/SHARK-Turbine#620](https://github.com/nod-ai/SHARK-Turbine/issues/620)pull/3461/merge
parent
51902ec2dc
commit
4555629246
|
@ -326,6 +326,82 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
|||
}];
|
||||
}
|
||||
|
||||
def TMTensor_TopkOp : TMTensor_Op<"topk",
|
||||
[DeclareOpInterfaceMethods<TMTensorInterface,
|
||||
["payloadUsesValueFromOperand"]>,
|
||||
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
||||
["generateScalarImplementation"]>]> {
|
||||
let summary = "Top-K operator";
|
||||
let description = [{
|
||||
A Top-K operation for N-D tensors. Reduces the target dimension from the input
|
||||
size N down to K elements based on the supplied binary region.
|
||||
|
||||
Accepts an N-D tensor input consisting of values and an optioanl N-D tensor
|
||||
for indices of those values (i32 type). If input indices aren't provided, the
|
||||
index mapping is inferred based on the k dim. Both input values/indices
|
||||
tensors and output values/indicies tensors must have the same shape. Top-K is
|
||||
computed along the target dimension (from dimension()). Returns two output
|
||||
tensors of values and the indicies of Top-K results. The output dimensions
|
||||
must match the input save for the dimension that is reduced to K results.
|
||||
|
||||
Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an
|
||||
i1. If true, the two values are swapped:
|
||||
- For Top-K compoarision: >
|
||||
- For Min-K comparision: <
|
||||
Note: when the two values are equal, the first occurence is always selected.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyShaped>:$inputs,
|
||||
Variadic<AnyShaped>:$outputs,
|
||||
I64Attr:$dimension
|
||||
);
|
||||
|
||||
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||
let regions = (region AnyRegion:$region);
|
||||
let assemblyFormat = [{
|
||||
attr-dict
|
||||
`dimension` `(` $dimension `)`
|
||||
`ins` `(` $inputs `:` type($inputs) `)`
|
||||
`outs` `(` $outputs `:` type($outputs) `)`
|
||||
$region (`->` type($results)^)?
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||
Value values() {
|
||||
return getInputOperand(0)->get();
|
||||
}
|
||||
std::optional<Value> indices() {
|
||||
if (getNumInputs() < 2) {
|
||||
return {};
|
||||
} else {
|
||||
return getInputOperand(1)->get();
|
||||
}
|
||||
}
|
||||
Value outputValues() {
|
||||
return getOutputOperand(0)->get();
|
||||
}
|
||||
Value outputIndices() {
|
||||
return getOutputOperand(1)->get();
|
||||
}
|
||||
ShapedType getInputType() {
|
||||
return cast<ShapedType>(values().getType());
|
||||
}
|
||||
int64_t getInputRank() {
|
||||
return getInputType().getRank();
|
||||
}
|
||||
|
||||
// Method to implement for specifying output range for
|
||||
// DestinationStyleOpInterface
|
||||
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
|
||||
std::pair<unsigned, unsigned> outputsIndexAndLength =
|
||||
getODSOperandIndexAndLength(1);
|
||||
return std::make_pair<int64_t, int64_t>(
|
||||
outputsIndexAndLength.first,
|
||||
outputsIndexAndLength.first + outputsIndexAndLength.second);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pure ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -12426,6 +12426,34 @@ def Torch_AtenCol2imOp : Torch_Op<"aten.col2im", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenKthvalueOp : Torch_Op<"aten.kthvalue", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$k,
|
||||
Torch_IntType:$dim,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$values,
|
||||
AnyTorchOptionalTensorType:$indices
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenKthvalueOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 2);
|
||||
}
|
||||
void AtenKthvalueOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 2);
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -254,6 +254,44 @@ static Value createTMTensorScanOp(
|
|||
return scanOp->getResult(0);
|
||||
}
|
||||
|
||||
static FailureOr<Value> createIntOrFloatCompareOp(PatternRewriter &rewriter,
|
||||
Location loc,
|
||||
Type elementType, Value lhs,
|
||||
Value rhs, bool isDescending,
|
||||
bool isEqual) {
|
||||
|
||||
Value compareOp;
|
||||
if (auto intType = dyn_cast<mlir::IntegerType>(elementType)) {
|
||||
// Case for using arith::CmpIOp.
|
||||
arith::CmpIPredicate g =
|
||||
isEqual ? arith::CmpIPredicate::sge : arith::CmpIPredicate::sgt;
|
||||
arith::CmpIPredicate l =
|
||||
isEqual ? arith::CmpIPredicate::sle : arith::CmpIPredicate::slt;
|
||||
if (intType.isUnsignedInteger()) {
|
||||
g = isEqual ? arith::CmpIPredicate::uge : arith::CmpIPredicate::ugt;
|
||||
l = isEqual ? arith::CmpIPredicate::ule : arith::CmpIPredicate::ult;
|
||||
}
|
||||
arith::CmpIPredicate predicate = isDescending ? g : l;
|
||||
compareOp = rewriter.create<arith::CmpIOp>(loc, predicate, lhs, rhs);
|
||||
return compareOp;
|
||||
}
|
||||
|
||||
if (isa<mlir::FloatType>(elementType)) {
|
||||
// Case for using arith::CmpFOp.
|
||||
arith::CmpFPredicate g =
|
||||
isEqual ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OGT;
|
||||
arith::CmpFPredicate l =
|
||||
isEqual ? arith::CmpFPredicate::OLE : arith::CmpFPredicate::OLT;
|
||||
|
||||
arith::CmpFPredicate predicate = isDescending ? g : l;
|
||||
compareOp = rewriter.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
|
||||
return compareOp;
|
||||
}
|
||||
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "Only Integer and Floating element type expected.");
|
||||
}
|
||||
|
||||
// Utility function to create a TMTensor::SortOp.
|
||||
static FailureOr<SmallVector<Value>>
|
||||
createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
||||
|
@ -280,34 +318,60 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
|||
}
|
||||
|
||||
// Step 3. Create comparison op which will be used as the sorting predicate.
|
||||
Value compareOp;
|
||||
if (auto intType = dyn_cast<mlir::IntegerType>(elementTypes[0])) {
|
||||
// Case for using arith::CmpIOp.
|
||||
arith::CmpIPredicate ge = arith::CmpIPredicate::sge;
|
||||
arith::CmpIPredicate le = arith::CmpIPredicate::sle;
|
||||
if (intType.isUnsignedInteger()) {
|
||||
ge = arith::CmpIPredicate::uge;
|
||||
le = arith::CmpIPredicate::ule;
|
||||
}
|
||||
arith::CmpIPredicate predicate = isDescending ? ge : le;
|
||||
compareOp = rewriter.create<arith::CmpIOp>(
|
||||
loc, predicate, block->getArgument(0), block->getArgument(1));
|
||||
} else if (isa<mlir::FloatType>(elementTypes[0])) {
|
||||
// Case for using arith::CmpFOp.
|
||||
arith::CmpFPredicate predicate =
|
||||
isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE;
|
||||
compareOp = rewriter.create<arith::CmpFOp>(
|
||||
loc, predicate, block->getArgument(0), block->getArgument(1));
|
||||
} else {
|
||||
auto compareOpRetVal = createIntOrFloatCompareOp(
|
||||
rewriter, loc, elementTypes[0], block->getArgument(0),
|
||||
block->getArgument(1), isDescending, true);
|
||||
|
||||
if (failed(compareOpRetVal))
|
||||
return rewriter.notifyMatchFailure(
|
||||
sortOpLoc, "Only Integer and Floating element type expected.");
|
||||
}
|
||||
loc, "Only Integer and Floating element type expected.");
|
||||
|
||||
// Step 4. Create yield op for yielding the sorting predicate.
|
||||
rewriter.create<TMTensor::YieldOp>(loc, compareOp);
|
||||
rewriter.create<TMTensor::YieldOp>(loc, compareOpRetVal.value());
|
||||
return SmallVector<Value>(sortOp.getResults());
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<Value>> createTMTensorTopkOp(
|
||||
PatternRewriter &rewriter, Location topkOpLoc, llvm::ArrayRef<Value> inputs,
|
||||
llvm::ArrayRef<Value> outputs, llvm::ArrayRef<Type> elementTypes,
|
||||
int64_t dimension, bool isMinK) {
|
||||
|
||||
// Generate output types.
|
||||
SmallVector<Type> topkResultTypes;
|
||||
for (Value val : outputs) {
|
||||
topkResultTypes.push_back(val.getType());
|
||||
}
|
||||
|
||||
// Create empty TopkOp, add body later.
|
||||
auto topkOp = rewriter.create<TMTensor::TopkOp>(
|
||||
topkOpLoc, topkResultTypes, inputs, outputs,
|
||||
rewriter.getI64IntegerAttr(dimension));
|
||||
|
||||
Region *body = &topkOp.getRegion();
|
||||
Block *block = rewriter.createBlock(body);
|
||||
Location loc = body->getLoc();
|
||||
// Add arguments for each passed body region element type.
|
||||
for (Type elementType : elementTypes) {
|
||||
block->addArgument({elementType}, {loc});
|
||||
}
|
||||
|
||||
// Generate compare operator. If minK is chosen, isDescending should be false.
|
||||
// Is equal should be false, because we do not want equality to cause element
|
||||
// swap.
|
||||
auto compareOpRetVal = createIntOrFloatCompareOp(
|
||||
rewriter, loc, elementTypes[0], block->getArgument(0),
|
||||
block->getArgument(1), /*isDescending=*/!isMinK, /*isEqual=*/false);
|
||||
|
||||
// Check if correct element types are passed.
|
||||
if (failed(compareOpRetVal))
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "Only Integer and Floating element type expected.");
|
||||
|
||||
// Yield the comparison result.
|
||||
rewriter.create<TMTensor::YieldOp>(loc, compareOpRetVal.value());
|
||||
return SmallVector<Value>(topkOp.getResults());
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
|
||||
public:
|
||||
|
@ -1570,6 +1634,456 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenKthvalueOp : public OpConversionPattern<AtenKthvalueOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenKthvalueOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
const llvm::StringRef opName = op->getName().getStringRef();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto typec = this->getTypeConverter();
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
unsigned inputRank = inputType.getRank();
|
||||
Type inputElementType = inputType.getElementType();
|
||||
|
||||
auto valResultType =
|
||||
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
|
||||
auto valResultElementType =
|
||||
getElementTypeOrSelf(typec->convertType(valResultType));
|
||||
|
||||
auto idxResultType =
|
||||
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
|
||||
auto idxResultElementType =
|
||||
getElementTypeOrSelf(typec->convertType(idxResultType));
|
||||
|
||||
// get keepdim and check it is bool
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, opName + " requires boolean value for keepdim");
|
||||
|
||||
// get dim, check it is constant int
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only constant dim value is supported");
|
||||
|
||||
// turn dim into positive if negative, and check it is in the valid range
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank)) {
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
}
|
||||
|
||||
// get k, check it is a constant int
|
||||
int64_t k;
|
||||
if (!matchPattern(op.getK(), m_TorchConstantInt(&k)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only constant k value is supported");
|
||||
|
||||
// check if element type is float, int, or unsigned
|
||||
bool isUnsigned = false;
|
||||
if (!isa<mlir::FloatType>(inputElementType)) {
|
||||
if (!isa<mlir::IntegerType>(inputElementType)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, opName + " to linalg.* requires Float or Integer "
|
||||
"input element type");
|
||||
}
|
||||
|
||||
auto integerTy = dyn_cast<mlir::IntegerType>(
|
||||
cast<BaseTensorType>(op.getSelf().getType()).getDtype());
|
||||
isUnsigned = integerTy.isUnsigned();
|
||||
}
|
||||
|
||||
// Create the values to fill initial output tensors for
|
||||
// topk op and linalg generic op for finding max value.
|
||||
Value fillValLinalgFindMax;
|
||||
Value fillValTopK;
|
||||
if (isa<mlir::FloatType>(inputElementType)) {
|
||||
// max float for topk tensor
|
||||
fillValTopK = rewriter.create<arith::ConstantOp>(
|
||||
loc,
|
||||
rewriter.getFloatAttr(
|
||||
inputElementType,
|
||||
APFloat::getInf(
|
||||
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
|
||||
/*Negative=*/false)));
|
||||
// min float for linalg generic op tensor
|
||||
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
|
||||
loc,
|
||||
rewriter.getFloatAttr(
|
||||
inputElementType,
|
||||
APFloat::getInf(
|
||||
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
|
||||
/*Negative=*/true)));
|
||||
} else if (!isUnsigned) {
|
||||
auto width = cast<mlir::IntegerType>(inputElementType).getWidth();
|
||||
// max signed int for topk op tensor
|
||||
auto init = APSInt::getSignedMaxValue(width);
|
||||
fillValTopK = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(inputElementType, init));
|
||||
// min signed int for linalg generic op tensor
|
||||
init = APSInt::getSignedMinValue(width);
|
||||
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(inputElementType, init));
|
||||
} else if (isUnsigned) {
|
||||
auto width = cast<mlir::IntegerType>(inputElementType).getWidth();
|
||||
// max unsigned int for topk op tensor
|
||||
auto init = APInt::getMaxValue(width);
|
||||
fillValTopK = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(inputElementType, init));
|
||||
// min unsigned int for linalg generic op tensor
|
||||
init = APInt::getMinValue(width);
|
||||
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(inputElementType, init));
|
||||
}
|
||||
|
||||
auto i32Type = rewriter.getI32Type();
|
||||
|
||||
// ======== BEGIN: Topk op section ========
|
||||
// Based on iree docs:
|
||||
// https://iree.dev/reference/mlir-dialects/LinalgExt/#iree_linalg_extsort-linalgextsortop
|
||||
|
||||
// Create the output shape of topk op.
|
||||
// For every dimension, topkShape[dimension] = inputShape[dimension],
|
||||
// except topkShape[dim] = k.
|
||||
SmallVector<Value> topkShape;
|
||||
for (unsigned i = 0; i < inputRank; i++) {
|
||||
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
|
||||
topkShape.push_back(currentDimSize);
|
||||
}
|
||||
auto dimSize = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), k));
|
||||
topkShape[dim] = dimSize;
|
||||
|
||||
// Fill the initial topk op output tensor.
|
||||
Value topkOutputVal = createInitTensor(rewriter, loc, topkShape,
|
||||
valResultElementType, fillValTopK);
|
||||
|
||||
// Create the initial value to fill the topk output indices tensor.
|
||||
// It is equal to the max 32-bit signless integer.
|
||||
auto signlessType = mlir::IntegerType::get(op.getContext(), 32,
|
||||
mlir::IntegerType::Signless);
|
||||
auto initIdx = getNumericLimit(rewriter, signlessType, /*getMin=*/false);
|
||||
auto fillValTopkIdx = rewriter.create<arith::ConstantOp>(loc, initIdx);
|
||||
// Fill the initial topk op output indices tensor.
|
||||
Value topkOutputIdx =
|
||||
createInitTensor(rewriter, loc, topkShape, i32Type, fillValTopkIdx);
|
||||
|
||||
// Input arguments for topk op contain only the input tensor.
|
||||
// Input indices will be inferred based on input shape.
|
||||
// (See docs link above).
|
||||
SmallVector<Value> topkInputs;
|
||||
topkInputs.push_back(input);
|
||||
|
||||
// Outputs contain both the values and the indices tensors.
|
||||
SmallVector<Value> topkOutputs;
|
||||
topkOutputs.push_back(topkOutputVal);
|
||||
topkOutputs.push_back(topkOutputIdx);
|
||||
|
||||
// Element types of the arguments passed to the topk op region.
|
||||
// The region accepts the next value N, and the current output
|
||||
// candidate K (see docs link above).
|
||||
// Both N and K are values from the input tensors, thus the
|
||||
// element types are the same and are taken from inputType.
|
||||
SmallVector<Type> topkElementTypes;
|
||||
topkElementTypes.push_back(inputType.getElementType());
|
||||
topkElementTypes.push_back(inputType.getElementType());
|
||||
|
||||
// Create the TMTensor TopkOp.
|
||||
FailureOr<SmallVector<Value>> topkOp;
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
topkOp = createTMTensorTopkOp(rewriter, loc, topkInputs, topkOutputs,
|
||||
topkElementTypes, dim, /*isMinK=*/true);
|
||||
}
|
||||
// Topk op creation fails with invalid element types.
|
||||
if (failed(topkOp))
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "Only Integer and Floating element type expected.");
|
||||
|
||||
auto topkOpVal = topkOp.value();
|
||||
// ======== END: Topk op section ========
|
||||
|
||||
// ======== BEGIN: Linalg generic to find max in topk result ========
|
||||
|
||||
// Create result shape as both a vector of Value and of int64_t types.
|
||||
// We assume that keepdim is false, and fix the result later if true.
|
||||
// Result shape is equal to inputShape, with dim dimension removed.
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<int64_t> resultShapeInt;
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
||||
if (dim != i) {
|
||||
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
|
||||
resultShape.push_back(currentDimSize);
|
||||
resultShapeInt.push_back(inputType.getShape()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Fill the initial output tensor for linalg op for finding max value.
|
||||
Value findMaxOutputVal = createInitTensor(
|
||||
rewriter, loc, resultShape, inputElementType, fillValLinalgFindMax);
|
||||
|
||||
// Fill the initial output indices tensor for linalg op for finding max
|
||||
// value with zeros.
|
||||
Value findMaxOutputIdx =
|
||||
createZeroInitTensor(rewriter, loc, resultShape, idxResultElementType);
|
||||
|
||||
// Reduce along dim.
|
||||
SmallVector<utils::IteratorType> findMaxIteratorTypes(
|
||||
inputType.getRank(), utils::IteratorType::parallel);
|
||||
findMaxIteratorTypes[dim] = utils::IteratorType::reduction;
|
||||
|
||||
SmallVector<AffineExpr> findMaxMapExprs;
|
||||
SmallVector<AffineExpr> findMaxMapResultExprs;
|
||||
for (auto size :
|
||||
llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
|
||||
findMaxMapExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||
if (unsigned(dim) != size.index())
|
||||
findMaxMapResultExprs.push_back(
|
||||
rewriter.getAffineDimExpr(size.index()));
|
||||
}
|
||||
|
||||
auto findMaxMaps = AffineMap::inferFromExprList(
|
||||
{findMaxMapExprs, findMaxMapResultExprs, findMaxMapResultExprs},
|
||||
rewriter.getContext());
|
||||
|
||||
// Create linalg op for finding the max value in the extracted topk values.
|
||||
auto findMaxLinalg = rewriter.create<linalg::GenericOp>(
|
||||
loc,
|
||||
ArrayRef<Type>(
|
||||
{findMaxOutputVal.getType(), findMaxOutputIdx.getType()}),
|
||||
topkOpVal.front(), ValueRange({findMaxOutputVal, findMaxOutputIdx}),
|
||||
findMaxMaps, findMaxIteratorTypes,
|
||||
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||
ValueRange blockArgs) {
|
||||
// Linalg generic body is the same as the decomposition for
|
||||
// AtenMinDim: lib/Conversion/TorchToLinalg/Reduction.cpp
|
||||
|
||||
Value newValue = blockArgs[0];
|
||||
Value oldValue = blockArgs[1];
|
||||
Value oldIndex = blockArgs[2];
|
||||
|
||||
Value newIndex = rewriter.create<arith::IndexCastOp>(
|
||||
nestedLoc, oldIndex.getType(),
|
||||
rewriter.create<linalg::IndexOp>(nestedLoc, dim));
|
||||
|
||||
Value resultVal, predicate;
|
||||
if (isa<mlir::FloatType>(inputElementType)) {
|
||||
resultVal = rewriter.create<arith::MaximumFOp>(nestedLoc, newValue,
|
||||
oldValue);
|
||||
predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
} else {
|
||||
arith::CmpIPredicate predType;
|
||||
predType = isUnsigned ? arith::CmpIPredicate::ugt
|
||||
: arith::CmpIPredicate::sgt;
|
||||
if (isUnsigned) {
|
||||
resultVal = rewriter.create<arith::MaxUIOp>(nestedLoc, newValue,
|
||||
oldValue);
|
||||
} else {
|
||||
resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
|
||||
oldValue);
|
||||
}
|
||||
predicate = rewriter.create<arith::CmpIOp>(nestedLoc, predType,
|
||||
newValue, oldValue);
|
||||
}
|
||||
auto resultIndex = rewriter.create<arith::SelectOp>(
|
||||
nestedLoc, predicate, newIndex, oldIndex);
|
||||
nestedBuilder.create<linalg::YieldOp>(
|
||||
nestedLoc, ValueRange{resultVal, resultIndex});
|
||||
});
|
||||
|
||||
auto findMaxVal = findMaxLinalg.getResult(0);
|
||||
auto findMaxIdx = findMaxLinalg.getResult(1);
|
||||
auto findMaxIdxType = cast<RankedTensorType>(findMaxIdx.getType());
|
||||
|
||||
// ======== END: Linalg generic to find max in topk result ========
|
||||
|
||||
// ======== BEGIN: Linalg generic for index extraction ========
|
||||
// The linalg op for finding max returned idx of max elements in the
|
||||
// tensor returned by the topk op. We need the idx of those elements
|
||||
// in the original input. The topk op returned the idx of the top k
|
||||
// extracted elements in the original input. Using the linalg idx
|
||||
// results to index the topk idx results returns the idx of kth
|
||||
// max value in the original input. Example:
|
||||
// input = [1, 7, 3, 6, 2, 8, 9, 5], k = 4
|
||||
// topk_val = [1, 3, 2, 5], topk_idx = [0, 2, 4, 7]
|
||||
// linalg_max_val = [5], linalg_max_idx = [3] (5 is at idx 3 in topk_val)
|
||||
// index the topk_idx using linalg_max_idx -> topk_idx[3] = 7
|
||||
// kth_val = [5], kth_idx = [7]
|
||||
|
||||
// Create a tensor for the resulting idx.
|
||||
Value filledTensorExtractedIdx = createZeroInitTensor(
|
||||
rewriter, loc, getTensorSizes(rewriter, loc, findMaxIdx), i32Type);
|
||||
|
||||
// We iterate through the idx tensor returned by the linalg generic op for
|
||||
// finding max.
|
||||
SmallVector<utils::IteratorType> extractedIdxIteratorTypes(
|
||||
findMaxIdxType.getRank(), utils::IteratorType::parallel);
|
||||
|
||||
SmallVector<AffineExpr> extractedIdxMapExprs;
|
||||
for (auto size :
|
||||
llvm::enumerate(makeShapeTorchCompatible(findMaxIdxType.getShape()))) {
|
||||
extractedIdxMapExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||
}
|
||||
|
||||
auto extractedIdxMaps = AffineMap::inferFromExprList(
|
||||
{extractedIdxMapExprs, extractedIdxMapExprs}, rewriter.getContext());
|
||||
|
||||
// Linalg generic op for indexing the topk output idx tensor using
|
||||
// the idx tensor returned by the linalg generic op for finding max.
|
||||
// Only the idx tensor from the linalg generic op is sent as input.
|
||||
auto extractedIdxLinalg = rewriter.create<linalg::GenericOp>(
|
||||
loc, ArrayRef<Type>({filledTensorExtractedIdx.getType()}), findMaxIdx,
|
||||
filledTensorExtractedIdx, extractedIdxMaps, extractedIdxIteratorTypes,
|
||||
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||
ValueRange blockArgs) {
|
||||
// Get the current input idx.
|
||||
Value index = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), blockArgs[0]);
|
||||
|
||||
// Create idx to index the topk idx tensor.
|
||||
// Index the dim dimension using the current input idx.
|
||||
SmallVector<Value> indexTarget;
|
||||
for (unsigned i = 0; i < dim; i++)
|
||||
indexTarget.push_back(rewriter.create<linalg::IndexOp>(loc, i));
|
||||
indexTarget.push_back(index);
|
||||
for (unsigned i = dim; i < findMaxIdxType.getRank(); i++)
|
||||
indexTarget.push_back(rewriter.create<linalg::IndexOp>(loc, i));
|
||||
|
||||
// Extract the element from the topk idx tensor.
|
||||
Value extractedElement = rewriter.create<tensor::ExtractOp>(
|
||||
loc, topkOpVal.back(), indexTarget);
|
||||
rewriter.create<linalg::YieldOp>(loc, extractedElement);
|
||||
});
|
||||
|
||||
auto extractedIdx = extractedIdxLinalg.getResult(0);
|
||||
auto extractedIdxType = cast<RankedTensorType>(extractedIdx.getType());
|
||||
|
||||
// ======== END: Linalg generic for index extraction ========
|
||||
|
||||
// ======== BEGIN: Linalg generic for topk idx cast ========
|
||||
// Casts from i32 to idx result type of the Kthvalue op.
|
||||
|
||||
// Create the initial tensor for the cast result.
|
||||
Value filledTensorCastedIdx = createZeroInitTensor(
|
||||
rewriter, loc, getTensorSizes(rewriter, loc, extractedIdx),
|
||||
idxResultElementType);
|
||||
|
||||
SmallVector<utils::IteratorType> castedIdxIteratorTypes(
|
||||
extractedIdxType.getRank(), utils::IteratorType::parallel);
|
||||
|
||||
SmallVector<AffineExpr> castedIdxMapExprs;
|
||||
for (auto size : llvm::enumerate(
|
||||
makeShapeTorchCompatible(extractedIdxType.getShape()))) {
|
||||
castedIdxMapExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||
}
|
||||
|
||||
auto castedIdxMaps = AffineMap::inferFromExprList(
|
||||
{castedIdxMapExprs, castedIdxMapExprs}, rewriter.getContext());
|
||||
|
||||
// Linalg generic op for casting topk idx output tensor elements from i32 to
|
||||
// result idx tensor element type.
|
||||
auto castedIdxLinalg = rewriter.create<linalg::GenericOp>(
|
||||
loc, ArrayRef<Type>({filledTensorCastedIdx.getType()}), extractedIdx,
|
||||
filledTensorCastedIdx, castedIdxMaps, castedIdxIteratorTypes,
|
||||
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||
ValueRange blockArgs) {
|
||||
Value oldIdx = blockArgs[0];
|
||||
|
||||
// Cast from i32 to index.
|
||||
Value oldIdxToIndexType = rewriter.create<arith::IndexCastOp>(
|
||||
nestedLoc, rewriter.getIndexType(), oldIdx);
|
||||
// Cast from index to result idx element type.
|
||||
Value resultIdx = rewriter.create<arith::IndexCastOp>(
|
||||
nestedLoc, idxResultElementType, oldIdxToIndexType);
|
||||
|
||||
nestedBuilder.create<linalg::YieldOp>(nestedLoc, resultIdx);
|
||||
});
|
||||
|
||||
auto castedIdx = castedIdxLinalg.getResult(0);
|
||||
|
||||
// ======== END: Linalg generic for topk idx cast ========
|
||||
|
||||
// Create output value type ("squeezed" since we assume keepdim=False).
|
||||
auto topkValResultType =
|
||||
cast<RankedTensorType>(topkOpVal.front().getType());
|
||||
auto squeezedValType = topkValResultType.cloneWith(
|
||||
resultShapeInt,
|
||||
cast<RankedTensorType>(findMaxVal.getType()).getElementType());
|
||||
|
||||
// Create output idx type ("squeezed" since we assume keepdim=False).
|
||||
auto castedIdxType = cast<RankedTensorType>(castedIdx.getType());
|
||||
auto squeezedIdxType = castedIdxType.cloneWith(
|
||||
resultShapeInt, findMaxIdxType.getElementType());
|
||||
|
||||
if (!keepDim) {
|
||||
// If keepdim=false, cast the the outputs to appropriate type and return.
|
||||
Value retVal =
|
||||
rewriter.create<tensor::CastOp>(loc, squeezedValType, findMaxVal);
|
||||
Value retIdx =
|
||||
rewriter.create<tensor::CastOp>(loc, squeezedIdxType, castedIdx);
|
||||
llvm::SmallVector<Value> res{retVal, retIdx};
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
|
||||
// If keepdim is false, unsqueeze.
|
||||
// Unsqueezing implementation taken from AteMinMaxDimOp lowering:
|
||||
// lib/Conversion/TorchToLinalg/Reduction.cpp
|
||||
llvm::SmallVector<int64_t> valShape(valResultType.getShape());
|
||||
llvm::SmallVector<int64_t> idxShape(idxResultType.getShape());
|
||||
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
|
||||
valShape[i] = valShape[i + 1];
|
||||
idxShape[i] = idxShape[i + 1];
|
||||
}
|
||||
|
||||
valShape.resize(valShape.size() - 1);
|
||||
idxShape.resize(idxShape.size() - 1);
|
||||
|
||||
Value retVal = rewriter.create<tensor::CastOp>(
|
||||
loc, squeezedValType.clone(valShape), findMaxLinalg.getResult(0));
|
||||
Value retIdx = rewriter.create<tensor::CastOp>(
|
||||
loc, squeezedIdxType.clone(idxShape), castedIdx);
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation(valShape.size());
|
||||
if (reassociation.size() > 0) {
|
||||
for (int i = 0; i < dim; ++i)
|
||||
reassociation[i].push_back(i);
|
||||
reassociation[std::max<int64_t>(0, dim - 1)].push_back(dim);
|
||||
for (int i = dim, s = reassociation.size(); i < s; ++i)
|
||||
reassociation[i].push_back(i + 1);
|
||||
}
|
||||
|
||||
valShape.push_back(0);
|
||||
idxShape.push_back(0);
|
||||
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
|
||||
valShape[i + 1] = valShape[i];
|
||||
idxShape[i + 1] = idxShape[i];
|
||||
}
|
||||
|
||||
valShape[dim] = 1;
|
||||
idxShape[dim] = 1;
|
||||
|
||||
Value unsqueezeVal = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, valResultType, retVal, reassociation);
|
||||
|
||||
Value unsqueezeIdx = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, idxResultType, retIdx, reassociation);
|
||||
|
||||
// Return unsqueezed.
|
||||
llvm::SmallVector<Value> unsqueezes = {unsqueezeVal, unsqueezeIdx};
|
||||
rewriter.replaceOp(op, unsqueezes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -1619,6 +2133,8 @@ public:
|
|||
|
||||
target.addIllegalOp<AtenScatterSrcOp>();
|
||||
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenKthvalueOp>();
|
||||
patterns.add<ConvertAtenKthvalueOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -910,6 +910,213 @@ bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
|||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TopkOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult TopkOp::verify() {
|
||||
Operation *op = getOperation();
|
||||
if (getNumInputs() != 1 && getNumInputs() != 2) {
|
||||
return op->emitOpError("expected one or two input operands");
|
||||
}
|
||||
if (getNumOutputs() != 2) {
|
||||
return op->emitOpError("expected two output operands");
|
||||
}
|
||||
// First check added to eliminate comparison of different int types
|
||||
if (getInputRank() < 0 ||
|
||||
(getDimension() >= static_cast<uint64_t>(getInputRank()))) {
|
||||
return op->emitOpError("dimension exceeds rank");
|
||||
}
|
||||
// Ensure input/output element types match
|
||||
auto inputValuesType = cast<ShapedType>(values().getType());
|
||||
auto outputValuesType = cast<ShapedType>(outputValues().getType());
|
||||
if (inputValuesType.getElementType() != outputValuesType.getElementType()) {
|
||||
return op->emitOpError("expected input/output value types to be identical");
|
||||
}
|
||||
// Indices must be int if provided
|
||||
auto outputIndicesType = cast<ShapedType>(outputIndices().getType());
|
||||
if (auto inputIndices = indices()) {
|
||||
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
|
||||
if (!inputIndicesType.getElementType().isInteger(32) ||
|
||||
!outputIndicesType.getElementType().isInteger(32)) {
|
||||
return op->emitOpError("expected input/output indices types to be int32");
|
||||
}
|
||||
}
|
||||
|
||||
// Ranks must match
|
||||
if (inputValuesType.getRank() != outputValuesType.getRank()) {
|
||||
return op->emitOpError("expected input/output to have the same rank");
|
||||
}
|
||||
if (auto inputIndices = indices()) {
|
||||
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
|
||||
if (inputIndicesType.getRank() != outputIndicesType.getRank()) {
|
||||
return op->emitOpError("expected input/output to have the same rank");
|
||||
}
|
||||
}
|
||||
// Input indicies and values must have the same shape.
|
||||
if (auto inputIndices = indices()) {
|
||||
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
|
||||
if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType)))
|
||||
return op->emitOpError("input indices/values shape must match");
|
||||
}
|
||||
// Output indicies and values must have the same shape.
|
||||
if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType)))
|
||||
return op->emitOpError("output indices/values shape must match");
|
||||
// Input shape must match the output shape except for the dimension()
|
||||
uint64_t dim = getDimension();
|
||||
if (!llvm::all_of(llvm::enumerate(llvm::zip(inputValuesType.getShape(),
|
||||
outputValuesType.getShape())),
|
||||
[dim](auto e) {
|
||||
if (e.index() == dim) {
|
||||
return true;
|
||||
}
|
||||
std::tuple<int64_t, int64_t> s = e.value();
|
||||
return succeeded(verifyCompatibleShape(std::get<0>(s),
|
||||
|
||||
std::get<1>(s)));
|
||||
})) {
|
||||
return op->emitOpError("incompatible input/output shapes");
|
||||
}
|
||||
// Check region compatibility
|
||||
Block &block = getRegion().front();
|
||||
if (block.getNumArguments() != 2) {
|
||||
return op->emitOpError("region block should have 2 arguments");
|
||||
}
|
||||
if (block.getArgument(0).getType() != inputValuesType.getElementType() ||
|
||||
block.getArgument(1).getType() != inputValuesType.getElementType()) {
|
||||
return op->emitOpError("region block types must match input");
|
||||
}
|
||||
auto terminatorOp = llvm::dyn_cast<YieldOp>(block.getTerminator());
|
||||
if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) {
|
||||
return op->emitOpError("region block must end with a linalg_ext.yield i1!");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<utils::IteratorType> TopkOp::getLoopIteratorTypes() {
|
||||
SmallVector<utils::IteratorType> iteratorTypes(getInputRank(),
|
||||
utils::IteratorType::parallel);
|
||||
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
|
||||
return iteratorTypes;
|
||||
}
|
||||
|
||||
SmallVector<Range> TopkOp::getIterationDomain(OpBuilder &builder) {
|
||||
int64_t operandRank = getInputRank();
|
||||
SmallVector<Range> loopBounds(operandRank);
|
||||
Location loc = getLoc();
|
||||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value source = values();
|
||||
for (auto dim : llvm::enumerate(getInputType().getShape())) {
|
||||
loopBounds[dim.index()].offset = zero;
|
||||
loopBounds[dim.index()].size =
|
||||
getDimValue(builder, loc, source, dim.index());
|
||||
loopBounds[dim.index()].stride = one;
|
||||
}
|
||||
return loopBounds;
|
||||
}
|
||||
|
||||
LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
||||
ValueRange ivs) {
|
||||
uint64_t kDim = getDimension();
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value initialValue = b.create<memref::LoadOp>(loc, values(), ivs);
|
||||
|
||||
// If the indices tensor is not provided, the value index is derived from the
|
||||
// loop induction variables.
|
||||
Value initialIndex;
|
||||
if (indices()) {
|
||||
initialIndex = b.create<memref::LoadOp>(loc, *indices(), ivs);
|
||||
} else {
|
||||
Value rawInitialIndex = ivs[kDim];
|
||||
initialIndex =
|
||||
b.create<arith::IndexCastOp>(loc, b.getI32Type(), rawInitialIndex);
|
||||
}
|
||||
|
||||
// Compute K (ub) from the selected dim of the output
|
||||
Value ub = b.create<memref::DimOp>(loc, outputValues(), getDimension());
|
||||
|
||||
// Inner K loop functions:
|
||||
// Load current K value and index
|
||||
// Compare N/K using inserted block compare
|
||||
// Check if N == K using strict weak ordering, select which index came first
|
||||
// Select new K value from N/K comparison
|
||||
// Select new K index from N/K comparison or which index came first
|
||||
// Store new k value and index
|
||||
// Yield loop carry values after K selection
|
||||
Value kValue, kIndex;
|
||||
auto scfFor = b.create<scf::ForOp>(
|
||||
loc, zero, ub, one, ValueRange{initialValue, initialIndex},
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) {
|
||||
SmallVector<Value> indices(ivs);
|
||||
indices[kDim] = iv;
|
||||
kValue = b.create<memref::LoadOp>(loc, outputValues(), indices);
|
||||
kIndex = b.create<memref::LoadOp>(loc, outputIndices(), indices);
|
||||
});
|
||||
|
||||
SmallVector<Value> indices(ivs);
|
||||
indices[kDim] = scfFor.getInductionVar();
|
||||
auto loopCarryValues = scfFor.getRegionIterArgs();
|
||||
|
||||
// Retrieve region as black box comparision function f(x,y). Plug into op.
|
||||
auto &srcBlock = getRegion().front();
|
||||
IRMapping bvmF; // f(x,y)
|
||||
IRMapping bvmR; // f(y,x)
|
||||
{
|
||||
// Save previous insertion point. Continue within loop body.
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointToEnd(&scfFor.getRegion().front());
|
||||
SmallVector<Value> forwardValues{loopCarryValues[0], kValue};
|
||||
SmallVector<Value> reverseValues{kValue, loopCarryValues[0]};
|
||||
for (auto it : llvm::zip(srcBlock.getArguments(), forwardValues)) {
|
||||
bvmF.map(std::get<0>(it), std::get<1>(it));
|
||||
}
|
||||
for (auto it : llvm::zip(srcBlock.getArguments(), reverseValues)) {
|
||||
bvmR.map(std::get<0>(it), std::get<1>(it));
|
||||
}
|
||||
for (auto &blockOp : srcBlock.without_terminator()) {
|
||||
b.clone(blockOp, bvmF);
|
||||
b.clone(blockOp, bvmR);
|
||||
}
|
||||
Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0));
|
||||
Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0));
|
||||
|
||||
// Check value equality using strictly weak ordering from the region:
|
||||
// f(x,y) --> forwardCmpRes
|
||||
// f(y,x) --> reverseCmpRes
|
||||
// if forwardCmpRes == reverseCmpRes then select which came first
|
||||
Value cmpValuesEqual = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes);
|
||||
Value cmpFirstIndex = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex);
|
||||
Value combinedCmpEqRes =
|
||||
b.create<arith::AndIOp>(loc, cmpValuesEqual, cmpFirstIndex);
|
||||
// True if N > K or N came before K
|
||||
Value indexCmpRes =
|
||||
b.create<arith::OrIOp>(loc, forwardCmpRes, combinedCmpEqRes);
|
||||
// Select results for K based on comparisons
|
||||
Value resultKValue = b.create<arith::SelectOp>(loc, forwardCmpRes,
|
||||
loopCarryValues[0], kValue);
|
||||
Value resultKIndex =
|
||||
b.create<arith::SelectOp>(loc, indexCmpRes, loopCarryValues[1], kIndex);
|
||||
b.create<memref::StoreOp>(loc, resultKValue, outputValues(), indices);
|
||||
b.create<memref::StoreOp>(loc, resultKIndex, outputIndices(), indices);
|
||||
// Select loop carry, opposite of K results
|
||||
Value resultCarryValue = b.create<arith::SelectOp>(
|
||||
loc, forwardCmpRes, kValue, loopCarryValues[0]);
|
||||
Value resultCarryIndex =
|
||||
b.create<arith::SelectOp>(loc, indexCmpRes, kIndex, loopCarryValues[1]);
|
||||
b.create<scf::YieldOp>(loc, ValueRange{resultCarryValue, resultCarryIndex});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
bool TopkOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
||||
// Set to true so that output operands are always initialized.
|
||||
return true;
|
||||
}
|
||||
|
||||
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
|
||||
void OP_NAME::getEffects( \
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
|
||||
|
@ -924,6 +1131,7 @@ DEFINE_OP_GET_EFFECTS(AttentionOp)
|
|||
DEFINE_OP_GET_EFFECTS(ScanOp)
|
||||
DEFINE_OP_GET_EFFECTS(ScatterOp)
|
||||
DEFINE_OP_GET_EFFECTS(SortOp)
|
||||
DEFINE_OP_GET_EFFECTS(TopkOp)
|
||||
|
||||
namespace {
|
||||
/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any
|
||||
|
|
|
@ -4877,6 +4877,42 @@ LogicalResult AtenLinalgCrossOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult AtenKthvalueOp::verify() {
|
||||
|
||||
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
||||
|
||||
if (!selfType.hasDtype() || !selfType.hasSizes())
|
||||
return success();
|
||||
|
||||
Type selfDtype = selfType.getDtype();
|
||||
if (selfDtype.isSignlessInteger(1))
|
||||
return emitOpError("input tensors must not have bool dtype");
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
|
||||
return success();
|
||||
|
||||
ArrayRef<int64_t> selfShape = selfType.getSizes();
|
||||
int64_t selfRank = selfShape.size();
|
||||
|
||||
dim = toPositiveDim(dim, selfRank);
|
||||
if (!isValidDim(dim, selfRank))
|
||||
return emitOpError("dim expected to be in range of [")
|
||||
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
|
||||
|
||||
// convert k to an integer type
|
||||
int64_t k;
|
||||
if (!matchPattern(getK(), m_TorchConstantInt(&k)))
|
||||
return success();
|
||||
|
||||
// check if k is in the correct range
|
||||
if (selfShape[dim] != kUnknownSize && (k < 1 || k > selfShape[dim]))
|
||||
return emitOpError("k expected to be in range of [")
|
||||
<< 1 << ", " << selfShape[dim] << "], but got " << k;
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DtypeCalculateYieldDtypesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -6962,6 +6962,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %4 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.kthvalue\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %0 = torch.derefine %arg2 : !torch.int to !torch.optional<int>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg3) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10897,6 +10903,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.kthvalue\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.tuple<int, int> {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %1 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" return %arg3 : !torch.int\n"
|
||||
" }\n"
|
||||
|
|
|
@ -1618,6 +1618,7 @@ public:
|
|||
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
|
||||
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};
|
||||
|
||||
reduction = rewriter
|
||||
.create<Torch::AtenMinDimOp>(loc, types, reduction,
|
||||
dimValue, op.getKeepdim())
|
||||
|
|
|
@ -2274,6 +2274,11 @@ ONNX_XFAIL_SET = {
|
|||
"AtenIntTensorCharDtypeModule_basic",
|
||||
"AtenItemFpOpModule_basic",
|
||||
"AtenItemIntOpModule_basic",
|
||||
"AtenKthvalueModule_basic",
|
||||
"AtenKthvalueKeepDimModule_basic",
|
||||
"AtenKthvalueDynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64Module_basic",
|
||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||
"AtenLinalgCrossDynamic_basic",
|
||||
"AtenMatmulQMixedSigni8Transpose_basic",
|
||||
"AtenMatmulQMixedSigni8_basic",
|
||||
|
|
|
@ -468,6 +468,14 @@ def aten〇linalg_cross〡shape(self: List[int], other: List[int], dim: int = -1
|
|||
assert (self[i] == other[i]) or self[i] == 1 or other[i] == 1, f"the size of first tensor ({self[i]}) must match the size of second tensor ({other[i]}) at dimension {i}"
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 4, 3, device="cpu"), k=2, dim=1, keepdim=True), # keep dim,
|
||||
Invocation(TensorOfShape(2, 4, 3, device="cpu"), k=2, dim=1, keepdim=False), # don't keep dim
|
||||
])
|
||||
def aten〇kthvalue〡shape(self: List[int], k: int, dim: int = -1, keepdim: bool = False) -> Tuple[List[int], List[int]]:
|
||||
new_shape = upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
return (new_shape, new_shape)
|
||||
|
||||
def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(grad_output)
|
||||
|
||||
|
@ -2705,6 +2713,13 @@ def aten〇linalg_cross〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dty
|
|||
dtypes = [self_dtype, other_dtype]
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
@check_dtype_function([
|
||||
Invocation(TensorOfShape(2, 4, 3, dtype=torch.int32, device="cpu"), k=2, dim=-1, keepdim=False)
|
||||
])
|
||||
def aten〇kthvalue〡dtype(self_rank_dtype: Tuple[int, int], k: int, dim: int = -1, keepdim: bool = False) -> Tuple[int, int]:
|
||||
_, self_dtype = self_rank_dtype
|
||||
return (self_dtype, torch.int64)
|
||||
|
||||
@check_dtype_function(
|
||||
_check_two_tensor_op(dim=0, input_dtype=torch.float32) +
|
||||
_check_two_tensor_op(dim=0, input_dtype=torch.float64))
|
||||
|
|
|
@ -912,6 +912,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
)
|
||||
emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True)
|
||||
emit("aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)")
|
||||
emit(
|
||||
"aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)",
|
||||
has_verifier=True,
|
||||
)
|
||||
|
||||
# Functionalization ops
|
||||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||
|
|
|
@ -5547,3 +5547,102 @@ class CloneModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: CloneModule())
|
||||
def CloneModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenKthvalueModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([2, 6, 3], torch.int32, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.kthvalue(x, k=4, dim=1, keepdim=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenKthvalueModule())
|
||||
def AtenKthvalueModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randperm(2 * 6 * 3, dtype=torch.int32).reshape(2, 6, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenKthvalueKeepDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([2, 6, 3], torch.int32, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.kthvalue(x, k=4, dim=1, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenKthvalueKeepDimModule())
|
||||
def AtenKthvalueKeepDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randperm(2 * 6 * 3, dtype=torch.int32).reshape(2, 6, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenKthvalueDynamicDimsModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1, -1, -1], torch.int32, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.kthvalue(x, k=6, dim=2, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenKthvalueDynamicDimsModule())
|
||||
def AtenKthvalueDynamicDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randperm(4 * 2 * 8 * 3, dtype=torch.int32).reshape(4, 2, 8, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenKthvalueFloat64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([4, 2, 8, 3], torch.float64, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.kthvalue(x, k=3, dim=0, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenKthvalueFloat64Module())
|
||||
def AtenKthvalueFloat64Module_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenKthvalueFloat64DynamicDimsModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1, -1, -1], torch.float64, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.kthvalue(x, k=3, dim=3, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenKthvalueFloat64DynamicDimsModule())
|
||||
def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue