[Stablehlo] support aten.any.dim, aten.min.dim (#3500)

* refactor `TorchToStablehlo/Reduction.cpp`
* add `ConvertAtenReduceWithIndicesOp` patterns
pull/3513/head
Yuanqiang Liu 2024-06-29 16:53:33 +08:00 committed by GitHub
parent 73ba09c587
commit f9fc741eef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 328 additions and 426 deletions

View File

@ -30,6 +30,18 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo;
static SmallVector<int64_t> getReduceOutputShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> dims) {
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (size_t i = 0; i < inputShape.size(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputShape[i]);
}
}
return reduceResultShape;
}
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementTy);
@ -42,8 +54,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
/*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
} else if (isa<mlir::IntegerType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
@ -59,8 +70,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
/*negative=*/true)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
} else if (isa<mlir::IntegerType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
@ -69,7 +79,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}
if (isa<AtenAminOp, AtenMinOp>(op)) {
if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp, AtenArgminOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
@ -77,8 +87,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
/*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
} else if (isa<mlir::IntegerType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
{APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())});
@ -93,8 +102,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
auto constAttr = DenseElementsAttr::get(constType, one);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
} else if (isa<mlir::IntegerType>(elementTy)) {
APInt one(elementTy.getIntOrFloatBitWidth(), 1);
auto constAttr = DenseElementsAttr::get(constType, one);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
@ -103,13 +111,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
if (isa<AtenAllOp>(op)) {
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)});
auto constAttr =
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
if (isa<AtenAnyOp>(op)) {
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)});
if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
auto constAttr =
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
@ -149,16 +159,17 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
result = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAminOp, AtenMinOp>(op)) {
} else if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp>(op)) {
result = rewriter.create<stablehlo::MinOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenSumOp>(op)) {
} else if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
AtenLinalgVectorNormOp>(op)) {
result = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAllOp>(op)) {
result = rewriter.create<stablehlo::AndOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAnyOp>(op)) {
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
result = rewriter.create<stablehlo::OrOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenProdOp>(op)) {
@ -174,11 +185,11 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
return reduce.getResults()[0];
}
// Util for converting AtenArgmaxOp and AtenMaxDimOp
// Util for converting AtenMaxDimOp/AtenMinDimOp
static std::optional<ValueRange>
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
ArrayRef<Value> inputShapeVec, int64_t dim,
size_t dimSizeIndexBits) {
createReduceOpReturnIndices(ConversionPatternRewriter &rewriter, Operation *op,
Value &input, ArrayRef<Value> inputShapeVec,
int64_t dim, size_t dimSizeIndexBits) {
auto inputTy = cast<RankedTensorType>(input.getType());
if (!inputTy) {
return std::nullopt;
@ -199,8 +210,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
}
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
outputShape.erase(outputShape.begin() + dim);
auto outputShape = getReduceOutputShape(inputShape, {dim});
auto outputTy = RankedTensorType::get(outputShape, inputElemTy);
auto outputIndexTy =
RankedTensorType::get(outputShape, rewriter.getIntegerType(64));
@ -252,6 +262,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
stablehlo::ComparisonDirectionAttr compareLeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::LE);
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
@ -260,11 +273,21 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
Value compareResult;
if (isa<AtenMaxDimOp>(op)) {
compareResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
} else if (isa<AtenMinDimOp>(op)) {
compareResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareLeDirectionAttr, compareTypeAttr);
} else {
op->emitError("unimplement lowering of createReduceOpReturnIndices");
return std::nullopt;
}
Value retValResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
op->getLoc(), compareResult, *firstValArg, *secondValArg);
// get smaller index value if compared nums are equal.
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
@ -273,16 +296,35 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
*secondIdxArg);
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
op->getLoc(), compareResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<stablehlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
op->getLoc(), ValueRange{retValResult, retIdxResult});
}
return stablehloReduceOp.getResults();
}
static Value reshapeReduceResultWhenKeepDim(ConversionPatternRewriter &rewriter,
Location loc, Value reduceResult,
ArrayRef<Value> inputShapeVec,
Type outType,
ArrayRef<int64_t> dims,
size_t dimSizeIndexBits) {
SmallVector<Value> outShapeVec(inputShapeVec);
Value one = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
for (auto dim : dims) {
outShapeVec[dim] = one;
}
auto outShapeTensor =
rewriter.create<tensor::FromElementsOp>(loc, outShapeVec);
return rewriter.create<stablehlo::DynamicReshapeOp>(
loc, outType, reduceResult, outShapeTensor);
}
namespace {
template <typename AtenOpT>
class ConvertAtenReductionOp : public ConvertAtenOp<AtenOpT> {
@ -320,14 +362,6 @@ public:
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op,
"IntegerType with bitwidth 8 unsupported in convertion to StableHLO");
}
if (inputElemTy != outTy.getElementType()) {
// use output type as computation type
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
@ -347,7 +381,7 @@ public:
};
template <typename AtenOpT>
class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp<AtenOpT> {
class ConvertAtenReduceOneDimOp : public ConvertAtenReductionOp<AtenOpT> {
public:
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
@ -356,7 +390,10 @@ public:
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
auto outTy = dyn_cast<RankedTensorType>(
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (!inputTy || !outTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
@ -366,12 +403,78 @@ public:
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
if (inputElemTy != outTy.getElementType()) {
// use output type as computation type
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
outTy.getElementType());
}
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(
op,
"IntegerType with bitwidth 8 unsupported in convertion to StableHLO");
op, "non-const integer `dim` is not supported");
}
dim = toPositiveDim(dim, inputTy.getRank());
SmallVector<int64_t> reduceResultShape =
getReduceOutputShape(inputTy.getShape(), {dim});
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input,
RankedTensorType::get(reduceResultShape, outTy.getElementType()), {dim},
rewriter);
if (!reduceResult) {
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
}
if (keepDim) {
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input,
options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
reduceResult = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim},
options.dimSizeIndexBits);
}
rewriter.replaceOp(op, reduceResult);
return success();
}
};
template <typename AtenOpT>
class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp<AtenOpT> {
public:
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto outTy = dyn_cast<RankedTensorType>(
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (!inputTy || !outTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
if (inputElemTy != outTy.getElementType()) {
// use output type as computation type
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
outTy.getElementType());
}
bool keepDim = false;
@ -393,19 +496,16 @@ public:
}
}
llvm::sort(dims.begin(), dims.end());
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputTy.getDimSize(i));
}
}
SmallVector<int64_t> reduceResultShape =
getReduceOutputShape(inputTy.getShape(), dims);
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
op, input,
RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims,
rewriter);
if (!reduceResult)
if (!reduceResult) {
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
}
if (keepDim) {
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
@ -415,215 +515,104 @@ public:
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op,
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
reduceResult, outShapeTensor);
return success();
reduceResult = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims,
options.dimSizeIndexBits);
}
rewriter.replaceOp(op, reduceResult);
return success();
}
};
} // namespace
// AtenArgmaxOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
AtenArgmaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported!
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenArgmaxOp to StableHLO");
}
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(op, "non-int dim unsupported");
}
dim = toPositiveDim(dim, inputTy.getRank());
if (!isValidDim(dim, inputTy.getRank())) {
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
}
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}
const auto &options = getOptions();
auto inputShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
dim, options.dimSizeIndexBits)
.value();
if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, typeConverter->convertType(op.getType()), stablehloReduceResults[1],
outShapeTensor);
return success();
}
rewriter.replaceOp(op, stablehloReduceResults[1]);
return success();
}
} // namespace
// AtenMaxDimOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
AtenMaxDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"Only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxDimOp to StableHLO");
}
RankedTensorType valResultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult(0).getType()));
RankedTensorType idxResultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult(1).getType()));
Type idxElementType = idxResultType.getElementType();
if (!isa<mlir::IntegerType>(idxElementType)) {
return op.emitError("Aten.max.dim needs integer-like result");
}
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(op, "non-int dim unsupported");
}
dim = toPositiveDim(dim, inputTy.getRank());
if (!isValidDim(dim, inputTy.getRank())) {
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
}
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}
const auto &options = getOptions();
auto inputShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
if (op.getResult(1).use_empty()) {
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
outputShape.erase(outputShape.begin() + dim);
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(outputShape, inputElemTy),
ArrayRef<int64_t>{dim}, rewriter);
if (!reduceResult)
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, reduceResult, outShapeTensor);
rewriter.replaceOp(op, {stablehloReduceValueResult, Value()});
return success();
template <typename AtenOpT>
class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp<AtenOpT> {
public:
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"Only floating-point or integer datatype legalization supported");
}
rewriter.replaceOp(op, {reduceResult, Value()});
return success();
} else {
auto stablehloReduceResults =
getMaxInDim(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();
if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
RankedTensorType valResultType = cast<RankedTensorType>(
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
op.getResult(0).getType()));
RankedTensorType idxResultType = cast<RankedTensorType>(
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
op.getResult(1).getType()));
Type idxElementType = idxResultType.getElementType();
if (!isa<mlir::IntegerType>(idxElementType)) {
return op.emitError("indices result should to be integer tyep");
}
auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, stablehloReduceResults[0],
outShapeTensor);
auto stablehloReduceIndexResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(op, "non-int dim unsupported");
}
dim = toPositiveDim(dim, inputTy.getRank());
if (!isValidDim(dim, inputTy.getRank())) {
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
}
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
auto inputShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
if (op.getResult(1).use_empty()) {
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
outputShape.erase(outputShape.begin() + dim);
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(outputShape, inputElemTy),
ArrayRef<int64_t>{dim}, rewriter);
if (!reduceResult) {
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
}
if (keepDim) {
reduceResult = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), reduceResult, inputShapeVec, valResultType,
{dim}, options.dimSizeIndexBits);
}
rewriter.replaceOp(op, {reduceResult, Value()});
return success();
} else {
ValueRange stablehloReduceResults =
createReduceOpReturnIndices(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();
if (keepDim) {
stablehloReduceResults[0] = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), stablehloReduceResults[0], inputShapeVec,
valResultType, {dim}, options.dimSizeIndexBits);
stablehloReduceResults[1] = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), stablehloReduceResults[1], inputShapeVec,
idxResultType, {dim}, options.dimSizeIndexBits);
}
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
op, {stablehloReduceResults[0], stablehloReduceResults[1]});
return success();
}
rewriter.replaceOp(op,
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success();
}
}
};
};
} // namespace
// AtenSumDimIntListOp
@ -653,17 +642,8 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
"Only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenSumDimIntListOp to StableHLO");
}
SmallVector<int64_t> inputDims;
SmallVector<int64_t> dims;
if (failed(checkNotNone(rewriter, op, op.getDim()))) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
} else {
@ -675,7 +655,6 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
}
}
for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
// Drop invalid dims
@ -683,46 +662,22 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
dims.push_back(d);
}
}
llvm::sort(dims.begin(), dims.end());
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputTy.getDimSize(i));
}
}
SmallVector<int64_t> reduceResultShape =
getReduceOutputShape(inputTy.getShape(), dims);
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(),
RankedTensorType::get(reduceResultShape, outTy.getElementType()), input,
initValue, rewriter.getDenseI64ArrayAttr(dims));
Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input,
RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims,
rewriter);
if (!reduceResult) {
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
}
if (keepDim) {
@ -733,23 +688,11 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResult(0), outShapeTensor);
return success();
reduceResult = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims,
options.dimSizeIndexBits);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
stablehloReduceOp.getResults());
rewriter.replaceOp(op, reduceResult);
return success();
}
} // namespace
@ -789,18 +732,12 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
"invalid dimension detected in `dim`");
}
}
// Sort the dims in ascending order, making the conversion
// stable with unordered dims.
std::sort(dims.begin(), dims.end());
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputRank; i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputType.getDimSize(i));
}
}
SmallVector<int64_t> reduceResultShape =
getReduceOutputShape(inputType.getShape(), dims);
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
@ -810,36 +747,14 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
auto squareOp = rewriter.create<stablehlo::MulOp>(op->getLoc(), input, input);
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
if (!initValue) {
return failure();
Value reduceResult = createReduceOpWithSingleRegionOp(
op, squareOp.getResult(),
RankedTensorType::get(reduceResultShape, inputElemType), dims, rewriter);
if (!reduceResult) {
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
}
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), RankedTensorType::get(reduceResultShape, inputElemType),
squareOp.getResult(), initValue, rewriter.getDenseI64ArrayAttr(dims));
Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputElemType);
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto firstArgument = *block.args_begin();
auto secondArgument = *block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
auto addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), firstArgument, secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult());
}
auto output =
rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
Value output = rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceResult);
if (keepDim) {
auto outShapeInfo =
@ -848,22 +763,12 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), output,
outShapeTensor);
return success();
output = reshapeReduceResultWhenKeepDim(
rewriter, op->getLoc(), output, *outShapeInfo,
getTypeConverter()->convertType(op.getType()), dims,
options.dimSizeIndexBits);
}
rewriter.replaceOp(op, output.getResult());
rewriter.replaceOp(op, output);
return success();
}
} // namespace
@ -920,13 +825,8 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
std::sort(dims.begin(), dims.end());
}
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputType.getDimSize(i));
}
}
SmallVector<int64_t> reduceResultShape =
getReduceOutputShape(inputType.getShape(), dims);
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
@ -934,46 +834,27 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
op, "non-const bool `keepdim` is not supported");
}
auto initValue = createInitialValueForReduceOp(op, outElemType, rewriter);
if (!initValue) {
return failure();
}
Value absValue = rewriter.create<stablehlo::AbsOp>(op->getLoc(), input);
Value powValue = rewriter.create<chlo::BroadcastPowOp>(op->getLoc(), absValue,
ord, nullptr);
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), RankedTensorType::get(reduceResultShape, outElemType),
powValue, initValue, rewriter.getDenseI64ArrayAttr(dims));
Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, outElemType);
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto firstArgument = *block.args_begin();
auto secondArgument = *block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
auto addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), firstArgument, secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult());
Value reduceResult = createReduceOpWithSingleRegionOp(
op, powValue, RankedTensorType::get(reduceResultShape, outElemType), dims,
rewriter);
if (!reduceResult) {
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
}
auto scalarType = RankedTensorType::get({}, outElemType);
auto constantOne = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), blockArgumentTy,
op->getLoc(), scalarType,
DenseElementsAttr::get(
blockArgumentTy,
scalarType,
APFloat(cast<mlir::FloatType>(outElemType).getFloatSemantics(), 1)));
auto reciprocalOrd = rewriter.create<stablehlo::DivOp>(
op->getLoc(), blockArgumentTy, constantOne, ord);
auto output = rewriter.create<chlo::BroadcastPowOp>(
op->getLoc(), reduceOp.getResult(0), reciprocalOrd, nullptr);
op->getLoc(), scalarType, constantOne, ord);
Value output = rewriter.create<chlo::BroadcastPowOp>(
op->getLoc(), reduceResult, reciprocalOrd, nullptr);
if (keepDim) {
auto outShapeInfo =
@ -982,23 +863,11 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), output,
outShapeTensor);
return success();
output = reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), output,
*outShapeInfo, outType, dims,
options.dimSizeIndexBits);
}
rewriter.replaceOp(op, output.getResult());
rewriter.replaceOp(op, output);
return success();
}
} // namespace
@ -1010,9 +879,6 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp);
@ -1022,7 +888,6 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenReduceAllDimsOp<AtenOp>>(typeConverter, context, \
options)
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMaxOp);
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMinOp);
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenSumOp);
@ -1031,12 +896,25 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAnyOp);
#undef INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN
#define INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenOp) \
#define INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenReduceKeepDimOp<AtenOp>>(typeConverter, context, \
options)
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
options)
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN
INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAmaxOp);
INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAminOp);
#undef INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN
#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenReduceDimsOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAmaxOp);
INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAminOp);
#undef INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN
#define INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenReduceWithIndicesOp<AtenOp>>(typeConverter, context, \
options)
INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMaxDimOp);
INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMinDimOp);
#undef INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN
}

View File

@ -32,6 +32,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# unimplemented lowering torch -> linalg for torchvision.deform_conv2d
# this is added to check the torch.onnx.export -> import_onnx -> torch path
"DeformConv2D_basic",
"ReduceAnyDimFloatModule_basic",
}
LINALG_CRASHING_SET = {
@ -340,6 +341,7 @@ TORCHDYNAMO_CRASHING_SET = {
}
FX_IMPORTER_XFAIL_SET = {
"ReduceAnyDimFloatModule_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
"AnyBoolFalseModule_basic",
@ -502,7 +504,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ArgminIntModule_multiple_mins",
"ArgminModule_basic",
"ArgminModule_keepDim",
"ArgminModule_with_dim",
"AtenComplexImagModule_basic",
"AtenComplexRealModule_basic",
"AtenComplexViewModule_basic",
@ -716,10 +717,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ReduceAllDimFloat_basic",
"ReduceAllDimInt_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimNegative_basic",
"ReduceMinAlongDimSignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ReduceMinAlongDim_basic",
"ReduceMinKeepDimReturnBoth_basic",
"ReduceMinKeepDim_basic",
"ReduceProdDimIntFloatModule_basic",
@ -832,6 +830,11 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
}
STABLEHLO_PASS_SET = {
"ReduceMinAlongDimNegative_basic",
"ReduceMinAlongDim_basic",
"ArgminModule_with_dim",
"ReduceMinAlongDimSignedInt_basic",
"ReduceAnyDimFloatModule_basic",
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
@ -2198,6 +2201,7 @@ ONNX_XFAIL_SET = {
# Failure - cast error
"PermuteNegativeIndexModule_basic",
# Failure - incorrect numerics
"ReduceAnyDimFloatModule_basic",
"AvgPool2dDivisorOverrideModule_basic",
"BroadcastDynamicDimModule_basic",
"ElementwiseAtan2TensorIntModule_basic",

View File

@ -239,6 +239,26 @@ def ReduceAnyFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
class ReduceAnyDimFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.any(a, dim=0)
@register_test_case(module_factory=lambda: ReduceAnyDimFloatModule())
def ReduceAnyDimFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================