mirror of https://github.com/llvm/torch-mlir
[MHLO] Init MHLO reduce-like op conversion (#1133)
* [MHLO] init reduce-like op conversion from Torch to MHLO Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>pull/1142/head snapshot-20220803.553
parent
0b23af27d3
commit
636f5acb10
|
@ -4,6 +4,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
||||||
BasicOp.cpp
|
BasicOp.cpp
|
||||||
GatherOp.cpp
|
GatherOp.cpp
|
||||||
ViewLikeOps.cpp
|
ViewLikeOps.cpp
|
||||||
|
ReductionOp.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
|
||||||
|
|
|
@ -25,7 +25,9 @@ void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
|
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns,
|
RewritePatternSet &patterns,
|
||||||
ConversionTarget &target);
|
ConversionTarget &target);
|
||||||
|
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
|
RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target);
|
||||||
} // namespace torch_to_mhlo
|
} // namespace torch_to_mhlo
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -0,0 +1,566 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "./MhloLegalizeUtils.h"
|
||||||
|
#include "./PopulatePatterns.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
auto constType = RankedTensorType::get({}, elementTy);
|
||||||
|
if (isa<AtenSumOp, AtenSumDimIntListOp>(op)) {
|
||||||
|
if (elementTy.isa<mlir::FloatType>()) {
|
||||||
|
auto constAttr = DenseElementsAttr::get(
|
||||||
|
constType, {APFloat::getZero(
|
||||||
|
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
|
/*negative=*/false)});
|
||||||
|
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||||
|
constAttr);
|
||||||
|
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||||
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
|
auto constAttr = DenseElementsAttr::get(
|
||||||
|
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||||
|
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||||
|
constAttr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
|
||||||
|
if (elementTy.isa<mlir::FloatType>()) {
|
||||||
|
auto constAttr = DenseElementsAttr::get(
|
||||||
|
constType, {APFloat::getLargest(
|
||||||
|
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
|
/*negative=*/true)});
|
||||||
|
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||||
|
constAttr);
|
||||||
|
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||||
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
|
auto constAttr = DenseElementsAttr::get(
|
||||||
|
constType,
|
||||||
|
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
||||||
|
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||||
|
constAttr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
op->emitError("unimplemented lowering in "
|
||||||
|
"createInitialValueForReduceOp");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Util for converting AtenArgmaxOp and AtenMaxDimOp
|
||||||
|
static llvm::Optional<ValueRange>
|
||||||
|
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
|
ArrayRef<Value> inputShapeVec, int64_t dim) {
|
||||||
|
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||||
|
if (!inputTy) {
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
if (!inputTy.getElementType().isIntOrFloat()) {
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
auto inputShape = inputTy.getShape();
|
||||||
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
|
||||||
|
Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter);
|
||||||
|
if (!initValue) return llvm::None;
|
||||||
|
|
||||||
|
Value initIndex =
|
||||||
|
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
|
||||||
|
|
||||||
|
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get({}, rewriter.getI64Type()), dim);
|
||||||
|
|
||||||
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
|
op->getLoc(), inputShapeVec);
|
||||||
|
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
|
||||||
|
op->getLoc(), RankedTensorType::get(inputShape, rewriter.getI64Type()),
|
||||||
|
inputShapeTensor, static_cast<uint64_t>(dim));
|
||||||
|
|
||||||
|
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||||
|
op->getLoc(), ValueRange{input, indexTensor},
|
||||||
|
ValueRange{
|
||||||
|
initValue,
|
||||||
|
initIndex,
|
||||||
|
},
|
||||||
|
dimensions);
|
||||||
|
|
||||||
|
Block &block = mhloReduceOp.body().emplaceBlock();
|
||||||
|
|
||||||
|
// Add block arguments
|
||||||
|
auto blockValArgumentType =
|
||||||
|
RankedTensorType::get({}, inputTy.getElementType());
|
||||||
|
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
|
||||||
|
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
|
||||||
|
block.addArgument(blockValArgumentType, op->getLoc());
|
||||||
|
block.addArgument(blockIdxArgumentType, op->getLoc());
|
||||||
|
|
||||||
|
block.addArgument(blockValArgumentType, op->getLoc());
|
||||||
|
block.addArgument(blockIdxArgumentType, op->getLoc());
|
||||||
|
|
||||||
|
auto *firstValArg = block.args_begin();
|
||||||
|
auto *firstIdxArg = std::next(firstValArg);
|
||||||
|
auto *secondValArg = std::next(firstIdxArg);
|
||||||
|
auto *secondIdxArg = std::next(secondValArg);
|
||||||
|
|
||||||
|
mhlo::ComparisonTypeAttr compareTypeAttr;
|
||||||
|
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||||
|
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||||
|
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
|
||||||
|
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
||||||
|
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||||
|
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
|
||||||
|
}
|
||||||
|
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||||
|
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||||
|
mhlo::ComparisonDirection::GE);
|
||||||
|
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||||
|
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||||
|
mhlo::ComparisonDirection::EQ);
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
|
||||||
|
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
|
||||||
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||||
|
compareGeDirectionAttr, compareTypeAttr);
|
||||||
|
Value retValResult = rewriter.create<mhlo::SelectOp>(
|
||||||
|
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||||
|
|
||||||
|
// get smaller index value if compared nums are equal.
|
||||||
|
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
|
||||||
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||||
|
compareEqDirectionAttr, compareTypeAttr);
|
||||||
|
Value minIdx =
|
||||||
|
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
|
||||||
|
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
|
||||||
|
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||||
|
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
|
||||||
|
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||||
|
|
||||||
|
rewriter.create<mhlo::ReturnOp>(
|
||||||
|
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||||
|
}
|
||||||
|
return mhloReduceOp.getResults();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename AtenOpT>
|
||||||
|
class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// AtenArgmaxOp
|
||||||
|
namespace {
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
|
AtenArgmaxOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value input = adaptor.self();
|
||||||
|
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||||
|
if (!inputTy) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
return op.emitError(
|
||||||
|
"only floating-point or integer datatype legalization supported");
|
||||||
|
}
|
||||||
|
// Currently, (u)int8 dtype is not supported!
|
||||||
|
if (inputElemTy.isa<mlir::IntegerType>() &&
|
||||||
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
"AtenArgmaxOp to MHLO");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "non-int dim unsupported");
|
||||||
|
}
|
||||||
|
dim = toPositiveDim(dim, inputTy.getRank());
|
||||||
|
if (!isValidDim(dim, inputTy.getRank())) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool keepDim = false;
|
||||||
|
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||||
|
if (failed(inputShapeInfo)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "failed to get dimension sizes of the input");
|
||||||
|
}
|
||||||
|
auto inputShapeVec = *inputShapeInfo;
|
||||||
|
auto mhloReduceResults =
|
||||||
|
getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue();
|
||||||
|
|
||||||
|
if (keepDim) {
|
||||||
|
auto outShapeVec = inputShapeVec;
|
||||||
|
|
||||||
|
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||||
|
op->getLoc(), rewriter.getIntegerAttr(
|
||||||
|
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1));
|
||||||
|
|
||||||
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
|
op->getLoc(), outShapeVec);
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||||
|
op, typeConverter->convertType(op.getType()), mhloReduceResults[1],
|
||||||
|
outShapeTensor);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, mhloReduceResults[1]);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// AtenMaxDimOp
|
||||||
|
namespace {
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
|
AtenMaxDimOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value input = adaptor.self();
|
||||||
|
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||||
|
if (!inputTy) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||||
|
}
|
||||||
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
return op.emitError(
|
||||||
|
"Only floating-point or integer datatype legalization supported");
|
||||||
|
}
|
||||||
|
// Currently, (u)int8 dtype is not supported
|
||||||
|
if (inputElemTy.isa<mlir::IntegerType>() &&
|
||||||
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
"AtenMaxDimOp to MHLO");
|
||||||
|
}
|
||||||
|
|
||||||
|
RankedTensorType valResultType = getTypeConverter()
|
||||||
|
->convertType(op.getResult(0).getType())
|
||||||
|
.template cast<RankedTensorType>();
|
||||||
|
RankedTensorType idxResultType = getTypeConverter()
|
||||||
|
->convertType(op.getResult(1).getType())
|
||||||
|
.template cast<RankedTensorType>();
|
||||||
|
Type idxElementType = idxResultType.getElementType();
|
||||||
|
if (!idxElementType.isa<mlir::IntegerType>()) {
|
||||||
|
return op.emitError("Aten.max.dim needs integer-like result");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "non-int dim unsupported");
|
||||||
|
}
|
||||||
|
dim = toPositiveDim(dim, inputTy.getRank());
|
||||||
|
if (!isValidDim(dim, inputTy.getRank())) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||||
|
}
|
||||||
|
bool keepDim = false;
|
||||||
|
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||||
|
if (failed(inputShapeInfo)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "failed to get dimension sizes of the input");
|
||||||
|
}
|
||||||
|
auto inputShapeVec = *inputShapeInfo;
|
||||||
|
auto mhloReduceResults =
|
||||||
|
getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue();
|
||||||
|
|
||||||
|
if (keepDim) {
|
||||||
|
auto outShapeVec = inputShapeVec;
|
||||||
|
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||||
|
op->getLoc(), rewriter.getIntegerAttr(
|
||||||
|
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1));
|
||||||
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
|
op->getLoc(), outShapeVec);
|
||||||
|
|
||||||
|
auto mhloReduceValueResult = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
|
op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor);
|
||||||
|
auto mhloReduceIndexResult = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
|
op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor);
|
||||||
|
rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// AtenSumOp
|
||||||
|
namespace {
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
||||||
|
AtenSumOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value input = adaptor.self();
|
||||||
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!inputTy) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||||
|
}
|
||||||
|
auto dtype = adaptor.dtype();
|
||||||
|
if (!dtype.getType().isa<Torch::NoneType>()) {
|
||||||
|
auto dstElemTy = getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template dyn_cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||||
|
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
|
}
|
||||||
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
return op.emitError(
|
||||||
|
"only floating-point or integer datatype legalization supported");
|
||||||
|
}
|
||||||
|
// Currently, (u)int8 dtype is not supported
|
||||||
|
if (inputElemTy.isa<mlir::IntegerType>() &&
|
||||||
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
"AtenSumOp to MHLO");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> dims;
|
||||||
|
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
||||||
|
dims.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value initValue =
|
||||||
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||||
|
if (!initValue) return failure();
|
||||||
|
|
||||||
|
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||||
|
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||||
|
|
||||||
|
Block &block = mhloReduceOp.body().emplaceBlock();
|
||||||
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||||
|
|
||||||
|
block.addArgument(blockArgumentTy, op->getLoc());
|
||||||
|
block.addArgument(blockArgumentTy, op->getLoc());
|
||||||
|
|
||||||
|
auto *firstArgument = block.args_begin();
|
||||||
|
auto secondArgument = block.args_rbegin();
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
Value addResult = rewriter.create<mhlo::AddOp>(
|
||||||
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
|
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, mhloReduceOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// AtenMaxOp
|
||||||
|
namespace {
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
||||||
|
AtenMaxOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value input = adaptor.self();
|
||||||
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!inputTy) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||||
|
}
|
||||||
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
return op.emitError(
|
||||||
|
"only floating-point or integer datatype legalization supported");
|
||||||
|
}
|
||||||
|
// Currently, (u)int8 dtype is not supported
|
||||||
|
if (inputElemTy.isa<mlir::IntegerType>() &&
|
||||||
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
"AtenMaxOp to MHLO");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> dims;
|
||||||
|
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
||||||
|
dims.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value initValue =
|
||||||
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||||
|
if (!initValue) return failure();
|
||||||
|
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||||
|
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||||
|
|
||||||
|
Block &block = mhloReduceOp.body().emplaceBlock();
|
||||||
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||||
|
|
||||||
|
block.addArgument(blockArgumentTy, op->getLoc());
|
||||||
|
block.addArgument(blockArgumentTy, op->getLoc());
|
||||||
|
|
||||||
|
auto *firstArgument = block.args_begin();
|
||||||
|
auto secondArgument = block.args_rbegin();
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
Value maxResult = rewriter.create<mhlo::MaxOp>(
|
||||||
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
|
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, mhloReduceOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// AtenSumDimIntListOp
|
||||||
|
namespace {
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
|
AtenSumDimIntListOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value input = adaptor.self();
|
||||||
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!inputTy) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||||
|
}
|
||||||
|
auto dtype = adaptor.dtype();
|
||||||
|
if (!dtype.getType().isa<Torch::NoneType>()) {
|
||||||
|
auto dstElemTy = getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template dyn_cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||||
|
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
|
}
|
||||||
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
return op.emitError(
|
||||||
|
"Only floating-point or integer datatype legalization supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Currently, (u)int8 dtype is not supported
|
||||||
|
if (inputElemTy.isa<mlir::IntegerType>() &&
|
||||||
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
"AtenSumDimIntListOp to MHLO");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> inputDims;
|
||||||
|
SmallVector<int64_t> dims;
|
||||||
|
if (!matchPattern(op.dim(), m_TorchConstantIntList(inputDims))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto d : inputDims) {
|
||||||
|
d = toPositiveDim(d, inputTy.getRank());
|
||||||
|
// Drop invalid dims
|
||||||
|
if (isValidDim(d, inputTy.getRank())) {
|
||||||
|
dims.push_back(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool keepDim = false;
|
||||||
|
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||||
|
}
|
||||||
|
Value initValue =
|
||||||
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||||
|
if (!initValue) return failure();
|
||||||
|
|
||||||
|
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||||
|
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||||
|
|
||||||
|
Region ®ion = mhloReduceOp.body();
|
||||||
|
Block &block = region.emplaceBlock();
|
||||||
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||||
|
|
||||||
|
block.addArgument(blockArgumentTy, op->getLoc());
|
||||||
|
block.addArgument(blockArgumentTy, op->getLoc());
|
||||||
|
|
||||||
|
auto *firstArgument = block.args_begin();
|
||||||
|
auto secondArgument = block.args_rbegin();
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
Value addResult = rewriter.create<mhlo::AddOp>(
|
||||||
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
|
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (keepDim) {
|
||||||
|
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||||
|
if (failed(outShapeInfo)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "failed to get dimension sizes of the input");
|
||||||
|
}
|
||||||
|
auto outShapeVec = *outShapeInfo;
|
||||||
|
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||||
|
op->getLoc(), rewriter.getIntegerAttr(
|
||||||
|
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1));
|
||||||
|
for (int64_t i : dims) {
|
||||||
|
outShapeVec[i] = one;
|
||||||
|
}
|
||||||
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
|
op->getLoc(), outShapeVec);
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||||
|
op, getTypeConverter()->convertType(op.getType()),
|
||||||
|
mhloReduceOp.getResult(0), outShapeTensor);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, mhloReduceOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
||||||
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target) {
|
||||||
|
MLIRContext *context = patterns.getContext();
|
||||||
|
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context);
|
||||||
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
|
||||||
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
|
||||||
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
||||||
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
|
||||||
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
|
||||||
|
#undef INSERT_ATEN_REDUCTION_OP_PATTERN
|
||||||
|
}
|
|
@ -58,6 +58,8 @@ public:
|
||||||
target);
|
target);
|
||||||
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
|
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
|
||||||
target);
|
target);
|
||||||
|
torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter,
|
||||||
|
patterns, target);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
|
|
@ -0,0 +1,243 @@
|
||||||
|
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.max.dim$keepdim(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>) {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %true = torch.constant.bool true
|
||||||
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
|
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_2]] : index to i64
|
||||||
|
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_5]] : tensor<2xi64>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||||
|
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||||
|
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
// CHECK: "mhlo.return"(%[[VAL_16]], %[[VAL_20]]) : (tensor<f32>, tensor<i64>) -> ()
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_21:.*]] = arith.constant 1 : i64
|
||||||
|
// CHECK: %[[VAL_22:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_21]] : tensor<2xi64>
|
||||||
|
// CHECK: %[[VAL_23:.*]] = "mhlo.dynamic_reshape"(%[[VAL_10]]#0, %[[VAL_22]]) : (tensor<?xf32>, tensor<2xi64>) -> tensor<?x1xf32>
|
||||||
|
// CHECK: %[[VAL_24:.*]] = "mhlo.dynamic_reshape"(%[[VAL_10]]#1, %[[VAL_22]]) : (tensor<?xi64>, tensor<2xi64>) -> tensor<?x1xi64>
|
||||||
|
// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<?x1xf32> -> !torch.vtensor<[?,1],f32>
|
||||||
|
// CHECK: %[[VAL_26:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<?x1xi64> -> !torch.vtensor<[?,1],si64>
|
||||||
|
// CHECK: return %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>
|
||||||
|
|
||||||
|
func.func @torch.aten.max.dim$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>) {
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%values, %indices = torch.aten.max.dim %arg0, %int1, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>
|
||||||
|
return %values, %indices : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.max.dim(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>) {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %false = torch.constant.bool false
|
||||||
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
|
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_2]] : index to i64
|
||||||
|
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_5]] : tensor<2xi64>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||||
|
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||||
|
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
// CHECK: "mhlo.return"(%[[VAL_16]], %[[VAL_20]]) : (tensor<f32>, tensor<i64>) -> ()
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_10]]#0 : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||||
|
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_10]]#1 : tensor<?xi64> -> !torch.vtensor<[?],si64>
|
||||||
|
// CHECK: return %[[VAL_21]], %[[VAL_22]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>
|
||||||
|
func.func @torch.aten.max.dim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>) {
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%values, %indices = torch.aten.max.dim %arg0, %int1, %false : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>
|
||||||
|
return %values, %indices : !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.argmax$keepdim(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],si64> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
|
// CHECK: %true = torch.constant.bool true
|
||||||
|
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_2]] : index to i64
|
||||||
|
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_5]] : tensor<2xi64>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||||
|
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||||
|
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
// CHECK: "mhlo.return"(%[[VAL_16]], %[[VAL_20]]) : (tensor<f32>, tensor<i64>) -> ()
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_21:.*]] = arith.constant 1 : i64
|
||||||
|
// CHECK: %[[VAL_22:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_21]] : tensor<2xi64>
|
||||||
|
// CHECK: %[[VAL_23:.*]] = "mhlo.dynamic_reshape"(%[[VAL_10]]#1, %[[VAL_22]]) : (tensor<?xi64>, tensor<2xi64>) -> tensor<?x1xi64>
|
||||||
|
// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<?x1xi64> -> !torch.vtensor<[?,1],si64>
|
||||||
|
// CHECK: return %[[VAL_24]] : !torch.vtensor<[?,1],si64>
|
||||||
|
func.func @torch.aten.argmax$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],si64> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%indices = torch.aten.argmax %arg0, %int1, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,1],si64>
|
||||||
|
return %indices : !torch.vtensor<[?,1],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.argmax(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?],si64> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
|
// CHECK: %false = torch.constant.bool false
|
||||||
|
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_2]] : index to i64
|
||||||
|
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_5]] : tensor<2xi64>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||||
|
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||||
|
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
// CHECK: "mhlo.return"(%[[VAL_16]], %[[VAL_20]]) : (tensor<f32>, tensor<i64>) -> ()
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]]#1 : tensor<?xi64> -> !torch.vtensor<[?],si64>
|
||||||
|
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?],si64>
|
||||||
|
func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?],si64> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%indices = torch.aten.argmax %arg0, %int1, %false : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],si64>
|
||||||
|
return %indices : !torch.vtensor<[?],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.sum.dim_Intlist$keepdim(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[1,1,?],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||||
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
|
// CHECK: %int0 = torch.constant.int 0
|
||||||
|
// CHECK: %true = torch.constant.bool true
|
||||||
|
// CHECK: %none = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.reduce(%[[VAL_1:.*]] init: %[[VAL_3:.*]]) applies mhlo.add across dimensions = [0, 1] : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i64
|
||||||
|
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?x?xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_7]] : index to i64
|
||||||
|
// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor<?x?x?xf32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : index to i64
|
||||||
|
// CHECK: %[[ONE_0:.*]] = arith.constant 1 : i64
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tensor.from_elements %[[ONE_0]], %[[ONE_0]], %[[VAL_10]] : tensor<3xi64>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = "mhlo.dynamic_reshape"(%[[VAL_4]], %[[VAL_11]]) : (tensor<?xf32>, tensor<3xi64>) -> tensor<1x1x?xf32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x1x?xf32> -> !torch.vtensor<[1,1,?],f32>
|
||||||
|
// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,1,?],f32>
|
||||||
|
func.func @torch.aten.sum.dim_Intlist$keepdim(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[1,1,?],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.aten.sum.dim_IntList %arg0, %0, %true, %none : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,?],f32>
|
||||||
|
return %1 : !torch.vtensor<[1,1,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.sum.dim_Intlist(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
|
// CHECK: %[[VAL_01:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||||
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
|
// CHECK: %int0 = torch.constant.int 0
|
||||||
|
// CHECK: %false = torch.constant.bool false
|
||||||
|
// CHECK: %none = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.reduce(%[[VAL_1]] init: %[[VAL_3]]) applies mhlo.add across dimensions = [0, 1] : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32>
|
||||||
|
func.func @torch.aten.sum.dim_Intlist(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[?],f32>
|
||||||
|
return %1 : !torch.vtensor<[?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.sum(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||||
|
// CHECK: %none = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = mhlo.reduce(%[[VAL_1]] init: %[[VAL_2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[],f32>
|
||||||
|
func.func @torch.aten.sum(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
|
||||||
|
return %0 : !torch.vtensor<[],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.max(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = mhlo.reduce(%[[VAL_1]] init: %[[VAL_2]]) applies mhlo.maximum across dimensions = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[],f32>
|
||||||
|
func.func @torch.aten.max(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||||
|
%0 = torch.aten.max %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[],f32>
|
||||||
|
return %0 : !torch.vtensor<[],f32>
|
||||||
|
}
|
Loading…
Reference in New Issue