mirror of https://github.com/llvm/torch-mlir
parent
6040d7ce00
commit
f5b689e12f
|
@ -107,6 +107,8 @@ MHLO_PASS_SET = {
|
|||
"BroadcastToModule_basic",
|
||||
"BroadcastToSameRankStaticModule_basic",
|
||||
"BroadcastZeroRankInputStaticModule_basic",
|
||||
"CumsumStaticModule_basic",
|
||||
"CumsumStaticNegativeDimModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||
|
|
|
@ -34,7 +34,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
PatternRewriter &rewriter) {
|
||||
auto constType = RankedTensorType::get({}, elementTy);
|
||||
// Avg pooling
|
||||
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp>(op)) {
|
||||
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
|
||||
if (elementTy.isa<mlir::FloatType>()) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APFloat::getZero(
|
||||
|
@ -531,6 +531,87 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenCumsumOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
||||
AtenCumsumOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dim must be a constant int");
|
||||
}
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank)) {
|
||||
return rewriter.notifyMatchFailure(op, "dim is out of range");
|
||||
}
|
||||
if (inputTy.isDynamicDim(dim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: cumsum dim must be static");
|
||||
}
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
mhloKernelSize[dim] = inputShape[dim];
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
mhloPadding[dim * 2] = inputShape[dim] - 1;
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
|
||||
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
Block &sumBlock = reduceWindowSum.getBody().emplaceBlock();
|
||||
|
||||
// Add bb argument
|
||||
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
auto *firstArg = sumBlock.args_begin();
|
||||
auto *secondArg = std::next(firstArg);
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&sumBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowSum.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
|
@ -542,4 +623,6 @@ void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
|
|||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
||||
context, options);
|
||||
target.addIllegalOp<AtenCumsumOp>();
|
||||
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
||||
}
|
||||
|
|
|
@ -3035,6 +3035,23 @@ class CumsumStaticModule(torch.nn.Module):
|
|||
def CumsumStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 7, 4))
|
||||
|
||||
class CumsumStaticNegativeDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 7, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, val):
|
||||
return torch.ops.aten.cumsum(val, dim=-1)
|
||||
|
||||
@register_test_case(module_factory=lambda: CumsumStaticNegativeDimModule())
|
||||
def CumsumStaticNegativeDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 7, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenToDeviceModule(torch.nn.Module):
|
||||
|
|
Loading…
Reference in New Issue