|
|
|
@ -71,7 +71,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (isa<AtenMinOp>(op)) {
|
|
|
|
|
if (isa<AtenAminOp, AtenMinOp>(op)) {
|
|
|
|
|
if (isa<mlir::FloatType>(elementTy)) {
|
|
|
|
|
auto constAttr = DenseElementsAttr::get(
|
|
|
|
|
constType,
|
|
|
|
@ -151,6 +151,21 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
|
|
|
|
|
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
|
|
|
|
|
result = rewriter.create<stablehlo::MaxOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
} else if (isa<AtenAminOp, AtenMinOp>(op)) {
|
|
|
|
|
result = rewriter.create<stablehlo::MinOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
} else if (isa<AtenSumOp>(op)) {
|
|
|
|
|
result = rewriter.create<stablehlo::AddOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
} else if (isa<AtenAllOp>(op)) {
|
|
|
|
|
result = rewriter.create<stablehlo::AndOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
} else if (isa<AtenAnyOp>(op)) {
|
|
|
|
|
result = rewriter.create<stablehlo::OrOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
} else if (isa<AtenProdOp>(op)) {
|
|
|
|
|
result = rewriter.create<stablehlo::MulOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
} else {
|
|
|
|
|
op->emitError("unimplemented lowering in "
|
|
|
|
|
"createReduceOpWithSingleRegionOp");
|
|
|
|
@ -278,7 +293,150 @@ public:
|
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
assert(false && "Unimplemented");
|
|
|
|
|
return failure();
|
|
|
|
|
};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename AtenOpT>
|
|
|
|
|
class ConvertAtenReduceAllDimsOp : public ConvertAtenReductionOp<AtenOpT> {
|
|
|
|
|
public:
|
|
|
|
|
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
|
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
|
auto outTy = dyn_cast<RankedTensorType>(
|
|
|
|
|
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
|
op.getType()));
|
|
|
|
|
if (!inputTy || !outTy) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
|
return op.emitError(
|
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
|
}
|
|
|
|
|
// Currently, (u)int8 dtype is not supported
|
|
|
|
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
|
|
|
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op,
|
|
|
|
|
"IntegerType with bitwidth 8 unsupported in convertion to StableHLO");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (inputElemTy != outTy.getElementType()) {
|
|
|
|
|
// use output type as computation type
|
|
|
|
|
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
|
|
|
|
|
outTy.getElementType());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> dims =
|
|
|
|
|
llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
|
|
|
|
|
Value result =
|
|
|
|
|
createReduceOpWithSingleRegionOp(op, input, outTy, dims, rewriter);
|
|
|
|
|
if (!result) {
|
|
|
|
|
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename AtenOpT>
|
|
|
|
|
class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp<AtenOpT> {
|
|
|
|
|
public:
|
|
|
|
|
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
|
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
|
if (!inputTy) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
|
return op.emitError(
|
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
|
}
|
|
|
|
|
// Currently, (u)int8 dtype is not supported
|
|
|
|
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
|
|
|
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op,
|
|
|
|
|
"IntegerType with bitwidth 8 unsupported in convertion to StableHLO");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool keepDim = false;
|
|
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> inputDims;
|
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "non-const integer `dim` is not supported");
|
|
|
|
|
}
|
|
|
|
|
for (auto d : inputDims) {
|
|
|
|
|
d = toPositiveDim(d, inputTy.getRank());
|
|
|
|
|
// Drop invalid dims
|
|
|
|
|
if (isValidDim(d, inputTy.getRank())) {
|
|
|
|
|
dims.push_back(d);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
llvm::sort(dims.begin(), dims.end());
|
|
|
|
|
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
|
|
|
|
|
SmallVector<int64_t> reduceResultShape;
|
|
|
|
|
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
|
|
|
|
if (dimsSet.find(i) == dimsSet.end()) {
|
|
|
|
|
reduceResultShape.push_back(inputTy.getDimSize(i));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value reduceResult = createReduceOpWithSingleRegionOp(
|
|
|
|
|
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
|
|
|
|
|
rewriter);
|
|
|
|
|
if (!reduceResult)
|
|
|
|
|
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
|
|
|
|
|
|
|
|
|
if (keepDim) {
|
|
|
|
|
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
|
|
|
|
|
auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input,
|
|
|
|
|
options.dimSizeIndexBits);
|
|
|
|
|
if (failed(outShapeInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
|
}
|
|
|
|
|
auto outShapeVec = *outShapeInfo;
|
|
|
|
|
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
|
|
|
|
op->getLoc(),
|
|
|
|
|
rewriter.getIntegerAttr(
|
|
|
|
|
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
|
|
|
|
for (int64_t i : dims) {
|
|
|
|
|
outShapeVec[i] = one;
|
|
|
|
|
}
|
|
|
|
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
|
|
|
|
op->getLoc(), outShapeVec);
|
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
|
|
|
|
op,
|
|
|
|
|
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
|
op.getType()),
|
|
|
|
|
reduceResult, outShapeTensor);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(op, reduceResult);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
@ -419,7 +577,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|
|
|
|
op, input, RankedTensorType::get(outputShape, inputElemTy),
|
|
|
|
|
ArrayRef<int64_t>{dim}, rewriter);
|
|
|
|
|
if (!reduceResult)
|
|
|
|
|
return failure();
|
|
|
|
|
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
|
|
|
|
|
|
|
|
|
if (keepDim) {
|
|
|
|
|
auto outShapeVec = inputShapeVec;
|
|
|
|
@ -472,483 +630,6 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// AtenSumOp
|
|
|
|
|
namespace {
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|
|
|
|
AtenSumOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
|
auto outTy = getTypeConverter()
|
|
|
|
|
->convertType(op.getType())
|
|
|
|
|
.template dyn_cast<RankedTensorType>();
|
|
|
|
|
if (!inputTy) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
|
}
|
|
|
|
|
if (inputTy.getElementType() != outTy.getElementType()) {
|
|
|
|
|
// Use output element type as computation type.
|
|
|
|
|
auto dstElemTy = outTy.getElementType();
|
|
|
|
|
input =
|
|
|
|
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
|
|
|
|
inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
|
}
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
|
return op.emitError(
|
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
|
}
|
|
|
|
|
// Currently, (u)int8 dtype is not supported
|
|
|
|
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
|
|
|
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
|
|
|
|
"AtenSumOp to StableHLO");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
|
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
|
|
|
|
dims.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
Value initValue =
|
|
|
|
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
|
|
|
|
if (!initValue)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
llvm::sort(dims.begin(), dims.end());
|
|
|
|
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
|
|
|
|
op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input,
|
|
|
|
|
initValue, rewriter.getDenseI64ArrayAttr(dims));
|
|
|
|
|
|
|
|
|
|
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
|
|
|
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
|
|
|
|
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
|
|
|
|
|
auto *firstArgument = block.args_begin();
|
|
|
|
|
auto secondArgument = block.args_rbegin();
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
Value addResult = rewriter.create<stablehlo::AddOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
|
|
|
|
stablehloReduceOp.getResults());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// AtenAllOp
|
|
|
|
|
namespace {
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenReductionOp<AtenAllOp>::matchAndRewrite(
|
|
|
|
|
AtenAllOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
if (!inputTy) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
|
}
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
|
|
|
|
|
|
// Currently, (u)int8 dtype is not supported
|
|
|
|
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
|
|
|
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
|
|
|
|
"AtenAllOp to StableHLO");
|
|
|
|
|
}
|
|
|
|
|
auto outTy = getTypeConverter()
|
|
|
|
|
->convertType(op.getType())
|
|
|
|
|
.template dyn_cast<RankedTensorType>();
|
|
|
|
|
|
|
|
|
|
if (inputElemTy != outTy.getElementType()) {
|
|
|
|
|
// Use output bool type as computation type.
|
|
|
|
|
auto dstElemTy = outTy.getElementType();
|
|
|
|
|
input =
|
|
|
|
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
|
|
|
|
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
inputElemTy = inputTy.getElementType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
|
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
|
|
|
|
dims.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value initValue =
|
|
|
|
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
|
|
|
|
if (!initValue)
|
|
|
|
|
return failure();
|
|
|
|
|
llvm::sort(dims.begin(), dims.end());
|
|
|
|
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
|
|
|
|
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
|
|
|
|
|
rewriter.getDenseI64ArrayAttr(dims));
|
|
|
|
|
|
|
|
|
|
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
|
|
|
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
|
|
|
|
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
|
|
|
|
|
auto *firstArgument = block.args_begin();
|
|
|
|
|
auto secondArgument = block.args_rbegin();
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
Value allResult = rewriter.create<stablehlo::AndOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), allResult);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
|
|
|
|
op, getTypeConverter()->convertType(op.getType()),
|
|
|
|
|
stablehloReduceOp.getResults());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// AtenAnyOp
|
|
|
|
|
namespace {
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenReductionOp<AtenAnyOp>::matchAndRewrite(
|
|
|
|
|
AtenAnyOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
if (!inputTy) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
|
}
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
|
|
|
|
|
|
// Currently, (u)int8 dtype is not supported
|
|
|
|
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
|
|
|
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
|
|
|
|
"AtenAllOp to StableHLO");
|
|
|
|
|
}
|
|
|
|
|
auto outTy = getTypeConverter()
|
|
|
|
|
->convertType(op.getType())
|
|
|
|
|
.template dyn_cast<RankedTensorType>();
|
|
|
|
|
|
|
|
|
|
if (inputElemTy != outTy.getElementType()) {
|
|
|
|
|
// Use output bool type as computation type.
|
|
|
|
|
auto dstElemTy = outTy.getElementType();
|
|
|
|
|
input =
|
|
|
|
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
|
|
|
|
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
inputElemTy = inputTy.getElementType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
|
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
|
|
|
|
dims.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value initValue =
|
|
|
|
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
|
|
|
|
if (!initValue)
|
|
|
|
|
return failure();
|
|
|
|
|
llvm::sort(dims.begin(), dims.end());
|
|
|
|
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
|
|
|
|
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
|
|
|
|
|
rewriter.getDenseI64ArrayAttr(dims));
|
|
|
|
|
|
|
|
|
|
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
|
|
|
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
|
|
|
|
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
|
|
|
|
|
auto *firstArgument = block.args_begin();
|
|
|
|
|
auto secondArgument = block.args_rbegin();
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
Value anyResult = rewriter.create<stablehlo::OrOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), anyResult);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
|
|
|
|
op, getTypeConverter()->convertType(op.getType()),
|
|
|
|
|
stablehloReduceOp.getResults());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// AtenProdOp
|
|
|
|
|
namespace {
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
|
|
|
|
|
AtenProdOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
|
auto outTy = getTypeConverter()
|
|
|
|
|
->convertType(op.getType())
|
|
|
|
|
.template dyn_cast<RankedTensorType>();
|
|
|
|
|
if (!inputTy) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
|
}
|
|
|
|
|
if (inputTy.getElementType() != outTy.getElementType()) {
|
|
|
|
|
// Use output element type as computation type.
|
|
|
|
|
auto dstElemTy = outTy.getElementType();
|
|
|
|
|
input =
|
|
|
|
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
|
|
|
|
inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
|
}
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
|
return op.emitError(
|
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
|
}
|
|
|
|
|
// Currently, (u)int8 dtype is not supported
|
|
|
|
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
|
|
|
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
|
|
|
|
"AtenProdOp to StableHLO");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
|
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
|
|
|
|
dims.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
Value initValue =
|
|
|
|
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
|
|
|
|
if (!initValue)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
llvm::sort(dims.begin(), dims.end());
|
|
|
|
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
|
|
|
|
op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input,
|
|
|
|
|
initValue, rewriter.getDenseI64ArrayAttr(dims));
|
|
|
|
|
|
|
|
|
|
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
|
|
|
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
|
|
|
|
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
|
|
|
|
|
auto *firstArgument = block.args_begin();
|
|
|
|
|
auto secondArgument = block.args_rbegin();
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
Value mulResult = rewriter.create<stablehlo::MulOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), mulResult);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
|
|
|
|
stablehloReduceOp.getResults());
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// AtenAmaxOp
|
|
|
|
|
namespace {
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenReductionOp<AtenAmaxOp>::matchAndRewrite(
|
|
|
|
|
AtenAmaxOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
|
if (!inputTy) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
|
}
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
|
return op.emitError(
|
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
|
}
|
|
|
|
|
// Currently, (u)int8 dtype is not supported
|
|
|
|
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
|
|
|
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
|
|
|
|
"AtenMaxOp to StableHLO");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool keepDim = false;
|
|
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> inputDims;
|
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "non-const integer `dim` is not supported");
|
|
|
|
|
}
|
|
|
|
|
for (auto d : inputDims) {
|
|
|
|
|
d = toPositiveDim(d, inputTy.getRank());
|
|
|
|
|
// Drop invalid dims
|
|
|
|
|
if (isValidDim(d, inputTy.getRank())) {
|
|
|
|
|
dims.push_back(d);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
llvm::sort(dims.begin(), dims.end());
|
|
|
|
|
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
|
|
|
|
|
SmallVector<int64_t> reduceResultShape;
|
|
|
|
|
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
|
|
|
|
if (dimsSet.find(i) == dimsSet.end()) {
|
|
|
|
|
reduceResultShape.push_back(inputTy.getDimSize(i));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value reduceResult = createReduceOpWithSingleRegionOp(
|
|
|
|
|
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
|
|
|
|
|
rewriter);
|
|
|
|
|
if (!reduceResult)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
if (keepDim) {
|
|
|
|
|
const auto &options = getOptions();
|
|
|
|
|
auto outShapeInfo =
|
|
|
|
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
|
|
|
|
if (failed(outShapeInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
|
}
|
|
|
|
|
auto outShapeVec = *outShapeInfo;
|
|
|
|
|
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
|
|
|
|
op->getLoc(),
|
|
|
|
|
rewriter.getIntegerAttr(
|
|
|
|
|
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
|
|
|
|
for (int64_t i : dims) {
|
|
|
|
|
outShapeVec[i] = one;
|
|
|
|
|
}
|
|
|
|
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
|
|
|
|
op->getLoc(), outShapeVec);
|
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), reduceResult,
|
|
|
|
|
outShapeTensor);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(op, reduceResult);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// AtenMaxOp
|
|
|
|
|
namespace {
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|
|
|
|
AtenMaxOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
|
if (!inputTy) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
|
}
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
|
return op.emitError(
|
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
|
}
|
|
|
|
|
// Currently, (u)int8 dtype is not supported
|
|
|
|
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
|
|
|
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
|
|
|
|
"AtenMaxOp to StableHLO");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> dims =
|
|
|
|
|
llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
|
|
|
|
|
|
|
|
|
|
Value reduceResult = createReduceOpWithSingleRegionOp(
|
|
|
|
|
op, input, RankedTensorType::get({}, inputElemTy), dims, rewriter);
|
|
|
|
|
if (!reduceResult)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), reduceResult);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// AtenMinOp
|
|
|
|
|
namespace {
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
|
|
|
|
|
AtenMinOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
|
if (!inputTy) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
|
}
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
|
return op.emitError(
|
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
|
}
|
|
|
|
|
// Currently, (u)int8 dtype is not supported
|
|
|
|
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
|
|
|
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
|
|
|
|
"AtenMinOp to StableHLO");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
|
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
|
|
|
|
dims.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value initValue =
|
|
|
|
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
|
|
|
|
if (!initValue)
|
|
|
|
|
return failure();
|
|
|
|
|
llvm::sort(dims.begin(), dims.end());
|
|
|
|
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
|
|
|
|
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
|
|
|
|
|
rewriter.getDenseI64ArrayAttr(dims));
|
|
|
|
|
|
|
|
|
|
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
|
|
|
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
|
|
|
|
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
|
|
|
|
|
auto *firstArgument = block.args_begin();
|
|
|
|
|
auto secondArgument = block.args_rbegin();
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
Value minResult = rewriter.create<stablehlo::MinOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), minResult);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
|
|
|
|
op, getTypeConverter()->convertType(op.getType()),
|
|
|
|
|
stablehloReduceOp.getResults());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// AtenSumDimIntListOp
|
|
|
|
|
namespace {
|
|
|
|
|
template <>
|
|
|
|
@ -1334,17 +1015,33 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
|
|
|
|
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
|
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
|
|
|
|
|
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAmaxOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp);
|
|
|
|
|
#undef INSERT_ATEN_REDUCTION_OP_PATTERN
|
|
|
|
|
|
|
|
|
|
#define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenOp) \
|
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
|
patterns.add<ConvertAtenReduceAllDimsOp<AtenOp>>(typeConverter, context, \
|
|
|
|
|
options)
|
|
|
|
|
|
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMaxOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMinOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenSumOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenProdOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAllOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAnyOp);
|
|
|
|
|
#undef INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN
|
|
|
|
|
|
|
|
|
|
#define INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenOp) \
|
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
|
patterns.add<ConvertAtenReduceKeepDimOp<AtenOp>>(typeConverter, context, \
|
|
|
|
|
options)
|
|
|
|
|
|
|
|
|
|
INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAmaxOp);
|
|
|
|
|
INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAminOp);
|
|
|
|
|
#undef INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN
|
|
|
|
|
}
|
|
|
|
|