mirror of https://github.com/llvm/torch-mlir
[Stablehlo] support aten.any.dim, aten.min.dim (#3500)
* refactor `TorchToStablehlo/Reduction.cpp` * add `ConvertAtenReduceWithIndicesOp` patternspull/3513/head
parent
73ba09c587
commit
f9fc741eef
|
@ -30,6 +30,18 @@ using namespace mlir::torch;
|
|||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
static SmallVector<int64_t> getReduceOutputShape(ArrayRef<int64_t> inputShape,
|
||||
ArrayRef<int64_t> dims) {
|
||||
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
|
||||
SmallVector<int64_t> reduceResultShape;
|
||||
for (size_t i = 0; i < inputShape.size(); i++) {
|
||||
if (dimsSet.find(i) == dimsSet.end()) {
|
||||
reduceResultShape.push_back(inputShape[i]);
|
||||
}
|
||||
}
|
||||
return reduceResultShape;
|
||||
}
|
||||
|
||||
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||
PatternRewriter &rewriter) {
|
||||
auto constType = RankedTensorType::get({}, elementTy);
|
||||
|
@ -42,8 +54,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
/*negative=*/false)});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (isa<mlir::IntegerType>(elementTy) &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
} else if (isa<mlir::IntegerType>(elementTy)) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
|
@ -59,8 +70,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
/*negative=*/true)});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (isa<mlir::IntegerType>(elementTy) &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
} else if (isa<mlir::IntegerType>(elementTy)) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
||||
|
@ -69,7 +79,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
}
|
||||
}
|
||||
|
||||
if (isa<AtenAminOp, AtenMinOp>(op)) {
|
||||
if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp, AtenArgminOp>(op)) {
|
||||
if (isa<mlir::FloatType>(elementTy)) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
|
@ -77,8 +87,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
/*negative=*/false)});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (isa<mlir::IntegerType>(elementTy) &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
} else if (isa<mlir::IntegerType>(elementTy)) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
{APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())});
|
||||
|
@ -93,8 +102,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
auto constAttr = DenseElementsAttr::get(constType, one);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (isa<mlir::IntegerType>(elementTy) &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
} else if (isa<mlir::IntegerType>(elementTy)) {
|
||||
APInt one(elementTy.getIntOrFloatBitWidth(), 1);
|
||||
auto constAttr = DenseElementsAttr::get(constType, one);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
|
@ -103,13 +111,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
}
|
||||
|
||||
if (isa<AtenAllOp>(op)) {
|
||||
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)});
|
||||
auto constAttr =
|
||||
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
|
||||
if (isa<AtenAnyOp>(op)) {
|
||||
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)});
|
||||
if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
|
||||
auto constAttr =
|
||||
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
|
@ -149,16 +159,17 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
|
|||
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
|
||||
result = rewriter.create<stablehlo::MaxOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
} else if (isa<AtenAminOp, AtenMinOp>(op)) {
|
||||
} else if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp>(op)) {
|
||||
result = rewriter.create<stablehlo::MinOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
} else if (isa<AtenSumOp>(op)) {
|
||||
} else if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
|
||||
AtenLinalgVectorNormOp>(op)) {
|
||||
result = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
} else if (isa<AtenAllOp>(op)) {
|
||||
result = rewriter.create<stablehlo::AndOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
} else if (isa<AtenAnyOp>(op)) {
|
||||
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
|
||||
result = rewriter.create<stablehlo::OrOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
} else if (isa<AtenProdOp>(op)) {
|
||||
|
@ -174,11 +185,11 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
|
|||
return reduce.getResults()[0];
|
||||
}
|
||||
|
||||
// Util for converting AtenArgmaxOp and AtenMaxDimOp
|
||||
// Util for converting AtenMaxDimOp/AtenMinDimOp
|
||||
static std::optional<ValueRange>
|
||||
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||
ArrayRef<Value> inputShapeVec, int64_t dim,
|
||||
size_t dimSizeIndexBits) {
|
||||
createReduceOpReturnIndices(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value &input, ArrayRef<Value> inputShapeVec,
|
||||
int64_t dim, size_t dimSizeIndexBits) {
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return std::nullopt;
|
||||
|
@ -199,8 +210,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
}
|
||||
|
||||
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
|
||||
outputShape.erase(outputShape.begin() + dim);
|
||||
auto outputShape = getReduceOutputShape(inputShape, {dim});
|
||||
auto outputTy = RankedTensorType::get(outputShape, inputElemTy);
|
||||
auto outputIndexTy =
|
||||
RankedTensorType::get(outputShape, rewriter.getIntegerType(64));
|
||||
|
@ -252,6 +262,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
||||
stablehlo::ComparisonDirectionAttr compareLeDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::LE);
|
||||
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
||||
|
@ -260,11 +273,21 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareGeDirectionAttr, compareTypeAttr);
|
||||
Value compareResult;
|
||||
if (isa<AtenMaxDimOp>(op)) {
|
||||
compareResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareGeDirectionAttr, compareTypeAttr);
|
||||
} else if (isa<AtenMinDimOp>(op)) {
|
||||
compareResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareLeDirectionAttr, compareTypeAttr);
|
||||
} else {
|
||||
op->emitError("unimplement lowering of createReduceOpReturnIndices");
|
||||
return std::nullopt;
|
||||
}
|
||||
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||
op->getLoc(), compareResult, *firstValArg, *secondValArg);
|
||||
|
||||
// get smaller index value if compared nums are equal.
|
||||
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
||||
|
@ -273,16 +296,35 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
||||
*secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||
op->getLoc(), compareResult, *firstIdxArg, *secondIdxArg);
|
||||
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||
|
||||
rewriter.create<stablehlo::ReturnOp>(
|
||||
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||
op->getLoc(), ValueRange{retValResult, retIdxResult});
|
||||
}
|
||||
return stablehloReduceOp.getResults();
|
||||
}
|
||||
|
||||
static Value reshapeReduceResultWhenKeepDim(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value reduceResult,
|
||||
ArrayRef<Value> inputShapeVec,
|
||||
Type outType,
|
||||
ArrayRef<int64_t> dims,
|
||||
size_t dimSizeIndexBits) {
|
||||
SmallVector<Value> outShapeVec(inputShapeVec);
|
||||
Value one = rewriter.create<arith::ConstantOp>(
|
||||
loc,
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
|
||||
for (auto dim : dims) {
|
||||
outShapeVec[dim] = one;
|
||||
}
|
||||
auto outShapeTensor =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, outShapeVec);
|
||||
return rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
loc, outType, reduceResult, outShapeTensor);
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenReductionOp : public ConvertAtenOp<AtenOpT> {
|
||||
|
@ -320,14 +362,6 @@ public:
|
|||
return op.emitError(
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
// Currently, (u)int8 dtype is not supported
|
||||
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"IntegerType with bitwidth 8 unsupported in convertion to StableHLO");
|
||||
}
|
||||
|
||||
if (inputElemTy != outTy.getElementType()) {
|
||||
// use output type as computation type
|
||||
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
|
||||
|
@ -347,7 +381,7 @@ public:
|
|||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp<AtenOpT> {
|
||||
class ConvertAtenReduceOneDimOp : public ConvertAtenReductionOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
@ -356,7 +390,10 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
auto outTy = dyn_cast<RankedTensorType>(
|
||||
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
if (!inputTy || !outTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
|
@ -366,12 +403,78 @@ public:
|
|||
return op.emitError(
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
// Currently, (u)int8 dtype is not supported
|
||||
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
if (inputElemTy != outTy.getElementType()) {
|
||||
// use output type as computation type
|
||||
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
|
||||
outTy.getElementType());
|
||||
}
|
||||
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||
}
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"IntegerType with bitwidth 8 unsupported in convertion to StableHLO");
|
||||
op, "non-const integer `dim` is not supported");
|
||||
}
|
||||
dim = toPositiveDim(dim, inputTy.getRank());
|
||||
SmallVector<int64_t> reduceResultShape =
|
||||
getReduceOutputShape(inputTy.getShape(), {dim});
|
||||
|
||||
Value reduceResult = createReduceOpWithSingleRegionOp(
|
||||
op, input,
|
||||
RankedTensorType::get(reduceResultShape, outTy.getElementType()), {dim},
|
||||
rewriter);
|
||||
if (!reduceResult) {
|
||||
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
||||
}
|
||||
|
||||
if (keepDim) {
|
||||
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
|
||||
auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input,
|
||||
options.dimSizeIndexBits);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
reduceResult = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim},
|
||||
options.dimSizeIndexBits);
|
||||
}
|
||||
rewriter.replaceOp(op, reduceResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
auto outTy = dyn_cast<RankedTensorType>(
|
||||
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
if (!inputTy || !outTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
if (inputElemTy != outTy.getElementType()) {
|
||||
// use output type as computation type
|
||||
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
|
||||
outTy.getElementType());
|
||||
}
|
||||
|
||||
bool keepDim = false;
|
||||
|
@ -393,19 +496,16 @@ public:
|
|||
}
|
||||
}
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
|
||||
SmallVector<int64_t> reduceResultShape;
|
||||
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
||||
if (dimsSet.find(i) == dimsSet.end()) {
|
||||
reduceResultShape.push_back(inputTy.getDimSize(i));
|
||||
}
|
||||
}
|
||||
SmallVector<int64_t> reduceResultShape =
|
||||
getReduceOutputShape(inputTy.getShape(), dims);
|
||||
|
||||
Value reduceResult = createReduceOpWithSingleRegionOp(
|
||||
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
|
||||
op, input,
|
||||
RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims,
|
||||
rewriter);
|
||||
if (!reduceResult)
|
||||
if (!reduceResult) {
|
||||
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
||||
}
|
||||
|
||||
if (keepDim) {
|
||||
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
|
||||
|
@ -415,215 +515,104 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto outShapeVec = *outShapeInfo;
|
||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
for (int64_t i : dims) {
|
||||
outShapeVec[i] = one;
|
||||
}
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op,
|
||||
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
reduceResult, outShapeTensor);
|
||||
return success();
|
||||
reduceResult = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims,
|
||||
options.dimSizeIndexBits);
|
||||
}
|
||||
rewriter.replaceOp(op, reduceResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// AtenArgmaxOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
||||
AtenArgmaxOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
// Currently, (u)int8 dtype is not supported!
|
||||
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenArgmaxOp to StableHLO");
|
||||
}
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-int dim unsupported");
|
||||
}
|
||||
dim = toPositiveDim(dim, inputTy.getRank());
|
||||
if (!isValidDim(dim, inputTy.getRank())) {
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
}
|
||||
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||
}
|
||||
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
|
||||
dim, options.dimSizeIndexBits)
|
||||
.value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, typeConverter->convertType(op.getType()), stablehloReduceResults[1],
|
||||
outShapeTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, stablehloReduceResults[1]);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenMaxDimOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||
AtenMaxDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
// Currently, (u)int8 dtype is not supported
|
||||
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenMaxDimOp to StableHLO");
|
||||
}
|
||||
|
||||
RankedTensorType valResultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op.getResult(0).getType()));
|
||||
RankedTensorType idxResultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op.getResult(1).getType()));
|
||||
Type idxElementType = idxResultType.getElementType();
|
||||
if (!isa<mlir::IntegerType>(idxElementType)) {
|
||||
return op.emitError("Aten.max.dim needs integer-like result");
|
||||
}
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-int dim unsupported");
|
||||
}
|
||||
dim = toPositiveDim(dim, inputTy.getRank());
|
||||
if (!isValidDim(dim, inputTy.getRank())) {
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
}
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||
}
|
||||
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
|
||||
if (op.getResult(1).use_empty()) {
|
||||
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
|
||||
outputShape.erase(outputShape.begin() + dim);
|
||||
Value reduceResult = createReduceOpWithSingleRegionOp(
|
||||
op, input, RankedTensorType::get(outputShape, inputElemTy),
|
||||
ArrayRef<int64_t>{dim}, rewriter);
|
||||
if (!reduceResult)
|
||||
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
|
||||
auto stablehloReduceValueResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), valResultType, reduceResult, outShapeTensor);
|
||||
rewriter.replaceOp(op, {stablehloReduceValueResult, Value()});
|
||||
return success();
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
rewriter.replaceOp(op, {reduceResult, Value()});
|
||||
return success();
|
||||
} else {
|
||||
auto stablehloReduceResults =
|
||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
||||
options.dimSizeIndexBits)
|
||||
.value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
RankedTensorType valResultType = cast<RankedTensorType>(
|
||||
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getResult(0).getType()));
|
||||
RankedTensorType idxResultType = cast<RankedTensorType>(
|
||||
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getResult(1).getType()));
|
||||
Type idxElementType = idxResultType.getElementType();
|
||||
if (!isa<mlir::IntegerType>(idxElementType)) {
|
||||
return op.emitError("indices result should to be integer tyep");
|
||||
}
|
||||
|
||||
auto stablehloReduceValueResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), valResultType, stablehloReduceResults[0],
|
||||
outShapeTensor);
|
||||
auto stablehloReduceIndexResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), idxResultType, stablehloReduceResults[1],
|
||||
outShapeTensor);
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-int dim unsupported");
|
||||
}
|
||||
dim = toPositiveDim(dim, inputTy.getRank());
|
||||
if (!isValidDim(dim, inputTy.getRank())) {
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
}
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||
}
|
||||
|
||||
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
|
||||
auto inputShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
|
||||
if (op.getResult(1).use_empty()) {
|
||||
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
|
||||
outputShape.erase(outputShape.begin() + dim);
|
||||
Value reduceResult = createReduceOpWithSingleRegionOp(
|
||||
op, input, RankedTensorType::get(outputShape, inputElemTy),
|
||||
ArrayRef<int64_t>{dim}, rewriter);
|
||||
if (!reduceResult) {
|
||||
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
||||
}
|
||||
|
||||
if (keepDim) {
|
||||
reduceResult = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), reduceResult, inputShapeVec, valResultType,
|
||||
{dim}, options.dimSizeIndexBits);
|
||||
}
|
||||
rewriter.replaceOp(op, {reduceResult, Value()});
|
||||
return success();
|
||||
} else {
|
||||
ValueRange stablehloReduceResults =
|
||||
createReduceOpReturnIndices(rewriter, op, input, inputShapeVec, dim,
|
||||
options.dimSizeIndexBits)
|
||||
.value();
|
||||
if (keepDim) {
|
||||
stablehloReduceResults[0] = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), stablehloReduceResults[0], inputShapeVec,
|
||||
valResultType, {dim}, options.dimSizeIndexBits);
|
||||
stablehloReduceResults[1] = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), stablehloReduceResults[1], inputShapeVec,
|
||||
idxResultType, {dim}, options.dimSizeIndexBits);
|
||||
}
|
||||
rewriter.replaceOp(
|
||||
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
|
||||
op, {stablehloReduceResults[0], stablehloReduceResults[1]});
|
||||
return success();
|
||||
}
|
||||
rewriter.replaceOp(op,
|
||||
{stablehloReduceResults[0], stablehloReduceResults[1]});
|
||||
return success();
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// AtenSumDimIntListOp
|
||||
|
@ -653,17 +642,8 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
"Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
// Currently, (u)int8 dtype is not supported
|
||||
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenSumDimIntListOp to StableHLO");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> inputDims;
|
||||
SmallVector<int64_t> dims;
|
||||
|
||||
if (failed(checkNotNone(rewriter, op, op.getDim()))) {
|
||||
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||
} else {
|
||||
|
@ -675,7 +655,6 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto d : inputDims) {
|
||||
d = toPositiveDim(d, inputTy.getRank());
|
||||
// Drop invalid dims
|
||||
|
@ -683,46 +662,22 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
dims.push_back(d);
|
||||
}
|
||||
}
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
|
||||
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
|
||||
SmallVector<int64_t> reduceResultShape;
|
||||
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
||||
if (dimsSet.find(i) == dimsSet.end()) {
|
||||
reduceResultShape.push_back(inputTy.getDimSize(i));
|
||||
}
|
||||
}
|
||||
SmallVector<int64_t> reduceResultShape =
|
||||
getReduceOutputShape(inputTy.getShape(), dims);
|
||||
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||
}
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue)
|
||||
return failure();
|
||||
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(reduceResultShape, outTy.getElementType()), input,
|
||||
initValue, rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Region ®ion = stablehloReduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto *firstArgument = block.args_begin();
|
||||
auto secondArgument = block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value addResult = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
|
||||
Value reduceResult = createReduceOpWithSingleRegionOp(
|
||||
op, input,
|
||||
RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims,
|
||||
rewriter);
|
||||
if (!reduceResult) {
|
||||
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
||||
}
|
||||
|
||||
if (keepDim) {
|
||||
|
@ -733,23 +688,11 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto outShapeVec = *outShapeInfo;
|
||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
for (int64_t i : dims) {
|
||||
outShapeVec[i] = one;
|
||||
}
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
stablehloReduceOp.getResult(0), outShapeTensor);
|
||||
return success();
|
||||
reduceResult = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims,
|
||||
options.dimSizeIndexBits);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
||||
stablehloReduceOp.getResults());
|
||||
rewriter.replaceOp(op, reduceResult);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -789,18 +732,12 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
"invalid dimension detected in `dim`");
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the dims in ascending order, making the conversion
|
||||
// stable with unordered dims.
|
||||
std::sort(dims.begin(), dims.end());
|
||||
|
||||
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
|
||||
SmallVector<int64_t> reduceResultShape;
|
||||
for (int64_t i = 0; i < inputRank; i++) {
|
||||
if (dimsSet.find(i) == dimsSet.end()) {
|
||||
reduceResultShape.push_back(inputType.getDimSize(i));
|
||||
}
|
||||
}
|
||||
SmallVector<int64_t> reduceResultShape =
|
||||
getReduceOutputShape(inputType.getShape(), dims);
|
||||
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
|
@ -810,36 +747,14 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
|
||||
auto squareOp = rewriter.create<stablehlo::MulOp>(op->getLoc(), input, input);
|
||||
|
||||
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
|
||||
if (!initValue) {
|
||||
return failure();
|
||||
Value reduceResult = createReduceOpWithSingleRegionOp(
|
||||
op, squareOp.getResult(),
|
||||
RankedTensorType::get(reduceResultShape, inputElemType), dims, rewriter);
|
||||
if (!reduceResult) {
|
||||
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
||||
}
|
||||
|
||||
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op->getLoc(), RankedTensorType::get(reduceResultShape, inputElemType),
|
||||
squareOp.getResult(), initValue, rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Region ®ion = reduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputElemType);
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto firstArgument = *block.args_begin();
|
||||
auto secondArgument = *block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
auto addResult = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), firstArgument, secondArgument);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult());
|
||||
}
|
||||
|
||||
auto output =
|
||||
rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
|
||||
Value output = rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceResult);
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeInfo =
|
||||
|
@ -848,22 +763,12 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto outShapeVec = *outShapeInfo;
|
||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
for (int64_t i : dims) {
|
||||
outShapeVec[i] = one;
|
||||
}
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), output,
|
||||
outShapeTensor);
|
||||
return success();
|
||||
output = reshapeReduceResultWhenKeepDim(
|
||||
rewriter, op->getLoc(), output, *outShapeInfo,
|
||||
getTypeConverter()->convertType(op.getType()), dims,
|
||||
options.dimSizeIndexBits);
|
||||
}
|
||||
rewriter.replaceOp(op, output.getResult());
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -920,13 +825,8 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
|||
std::sort(dims.begin(), dims.end());
|
||||
}
|
||||
|
||||
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
|
||||
SmallVector<int64_t> reduceResultShape;
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
||||
if (dimsSet.find(i) == dimsSet.end()) {
|
||||
reduceResultShape.push_back(inputType.getDimSize(i));
|
||||
}
|
||||
}
|
||||
SmallVector<int64_t> reduceResultShape =
|
||||
getReduceOutputShape(inputType.getShape(), dims);
|
||||
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
|
@ -934,46 +834,27 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
|||
op, "non-const bool `keepdim` is not supported");
|
||||
}
|
||||
|
||||
auto initValue = createInitialValueForReduceOp(op, outElemType, rewriter);
|
||||
if (!initValue) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value absValue = rewriter.create<stablehlo::AbsOp>(op->getLoc(), input);
|
||||
Value powValue = rewriter.create<chlo::BroadcastPowOp>(op->getLoc(), absValue,
|
||||
ord, nullptr);
|
||||
|
||||
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op->getLoc(), RankedTensorType::get(reduceResultShape, outElemType),
|
||||
powValue, initValue, rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Region ®ion = reduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, outElemType);
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto firstArgument = *block.args_begin();
|
||||
auto secondArgument = *block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
auto addResult = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), firstArgument, secondArgument);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult());
|
||||
Value reduceResult = createReduceOpWithSingleRegionOp(
|
||||
op, powValue, RankedTensorType::get(reduceResultShape, outElemType), dims,
|
||||
rewriter);
|
||||
if (!reduceResult) {
|
||||
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
||||
}
|
||||
|
||||
auto scalarType = RankedTensorType::get({}, outElemType);
|
||||
auto constantOne = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), blockArgumentTy,
|
||||
op->getLoc(), scalarType,
|
||||
DenseElementsAttr::get(
|
||||
blockArgumentTy,
|
||||
scalarType,
|
||||
APFloat(cast<mlir::FloatType>(outElemType).getFloatSemantics(), 1)));
|
||||
auto reciprocalOrd = rewriter.create<stablehlo::DivOp>(
|
||||
op->getLoc(), blockArgumentTy, constantOne, ord);
|
||||
auto output = rewriter.create<chlo::BroadcastPowOp>(
|
||||
op->getLoc(), reduceOp.getResult(0), reciprocalOrd, nullptr);
|
||||
op->getLoc(), scalarType, constantOne, ord);
|
||||
Value output = rewriter.create<chlo::BroadcastPowOp>(
|
||||
op->getLoc(), reduceResult, reciprocalOrd, nullptr);
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeInfo =
|
||||
|
@ -982,23 +863,11 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto outShapeVec = *outShapeInfo;
|
||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
for (int64_t i : dims) {
|
||||
outShapeVec[i] = one;
|
||||
}
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), output,
|
||||
outShapeTensor);
|
||||
return success();
|
||||
output = reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), output,
|
||||
*outShapeInfo, outType, dims,
|
||||
options.dimSizeIndexBits);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, output.getResult());
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -1010,9 +879,6 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
|||
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
|
||||
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp);
|
||||
|
@ -1022,7 +888,6 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
|||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenReduceAllDimsOp<AtenOp>>(typeConverter, context, \
|
||||
options)
|
||||
|
||||
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMaxOp);
|
||||
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMinOp);
|
||||
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenSumOp);
|
||||
|
@ -1031,12 +896,25 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
|||
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAnyOp);
|
||||
#undef INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN
|
||||
|
||||
#define INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenOp) \
|
||||
#define INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenReduceKeepDimOp<AtenOp>>(typeConverter, context, \
|
||||
options)
|
||||
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
|
||||
options)
|
||||
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
|
||||
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN
|
||||
|
||||
INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAmaxOp);
|
||||
INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAminOp);
|
||||
#undef INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN
|
||||
#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenReduceDimsOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAmaxOp);
|
||||
INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAminOp);
|
||||
#undef INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN
|
||||
|
||||
#define INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenReduceWithIndicesOp<AtenOp>>(typeConverter, context, \
|
||||
options)
|
||||
INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMaxDimOp);
|
||||
INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMinDimOp);
|
||||
#undef INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
# unimplemented lowering torch -> linalg for torchvision.deform_conv2d
|
||||
# this is added to check the torch.onnx.export -> import_onnx -> torch path
|
||||
"DeformConv2D_basic",
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
}
|
||||
|
||||
LINALG_CRASHING_SET = {
|
||||
|
@ -340,6 +341,7 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
}
|
||||
|
||||
FX_IMPORTER_XFAIL_SET = {
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
"AnyBoolFalseModule_basic",
|
||||
|
@ -502,7 +504,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ArgminIntModule_multiple_mins",
|
||||
"ArgminModule_basic",
|
||||
"ArgminModule_keepDim",
|
||||
"ArgminModule_with_dim",
|
||||
"AtenComplexImagModule_basic",
|
||||
"AtenComplexRealModule_basic",
|
||||
"AtenComplexViewModule_basic",
|
||||
|
@ -716,10 +717,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ReduceAllDimFloat_basic",
|
||||
"ReduceAllDimInt_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimNegative_basic",
|
||||
"ReduceMinAlongDimSignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDim_basic",
|
||||
"ReduceMinKeepDimReturnBoth_basic",
|
||||
"ReduceMinKeepDim_basic",
|
||||
"ReduceProdDimIntFloatModule_basic",
|
||||
|
@ -832,6 +830,11 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
|||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
"ReduceMinAlongDimNegative_basic",
|
||||
"ReduceMinAlongDim_basic",
|
||||
"ArgminModule_with_dim",
|
||||
"ReduceMinAlongDimSignedInt_basic",
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"MeshgridIndexingIJ_basic",
|
||||
"MeshgridIndexingXY_basic",
|
||||
"Meshgrid_basic",
|
||||
|
@ -2198,6 +2201,7 @@ ONNX_XFAIL_SET = {
|
|||
# Failure - cast error
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
# Failure - incorrect numerics
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"AvgPool2dDivisorOverrideModule_basic",
|
||||
"BroadcastDynamicDimModule_basic",
|
||||
"ElementwiseAtan2TensorIntModule_basic",
|
||||
|
|
|
@ -239,6 +239,26 @@ def ReduceAnyFloatModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
class ReduceAnyDimFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.any(a, dim=0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAnyDimFloatModule())
|
||||
def ReduceAnyDimFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue