[torch] Make torch.aten.unflatten lower directly to linalg (#2971)

Existing lowering via aten.view does not work as well for dynamic shapes
as the lowering to tensor.expand must re-infer dynamic shape matching.
Better to directly lower.
pull/2978/head
Rob Suderman 2024-03-04 10:17:42 -08:00 committed by GitHub
parent d51e80b648
commit 19d4888278
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 68 additions and 3 deletions

View File

@ -638,6 +638,68 @@ public:
};
} // namespace
// Lower aten.unflatten.int into tensor.expand_shape
namespace {
class ConvertAtenUnflattenIntOp
: public OpConversionPattern<AtenUnflattenIntOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenUnflattenIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
BaseTensorType outputTensorType = op.getType().cast<BaseTensorType>();
if (!outputTensorType.hasSizes())
return rewriter.notifyMatchFailure(
op, "unimplemented: output must have known sizes");
std::optional<unsigned> maybeRank = getTensorRank(self);
if (!maybeRank)
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>();
if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"Expected input type having sizes");
}
int inputRank = inputTensorType.getSizes().size();
int outputRank = outputTensorType.getSizes().size();
int64_t dimInt;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: requires dim to be constants");
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
auto sizesOp = op.getSizes().getDefiningOp<Torch::PrimListConstructOp>();
int numSizes = sizesOp.getNumOperands();
SmallVector<ReassociationIndices> reassociations(inputRank);
if (inputRank > 0) {
for (int i = 0; i < dimInt; ++i)
reassociations[i].push_back(i);
for (int i = 0; i < numSizes; ++i)
reassociations[dimInt].push_back(i + dimInt);
for (int i = dimInt + numSizes; i < outputRank; ++i)
reassociations[i - numSizes + 1].push_back(i);
}
auto expandTy = getTypeConverter()->convertType(outputTensorType);
auto expand = rewriter
.create<tensor::ExpandShapeOp>(
loc, expandTy, adaptor.getSelf(), reassociations)
.getResult();
rewriter.replaceOp(op, expand);
return success();
}
};
} // namespace
namespace {
/// The `ConvertAtenViewOp` conversion pattern converts `aten.View` op to
/// one `linalg.TensorExpandShape` op for all expanded dimensions and one
@ -2043,6 +2105,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
target.addIllegalOp<AtenFlattenUsingIntsOp>();
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
target.addIllegalOp<AtenViewOp>();
patterns.add<ConvertAtenUnflattenIntOp>(typeConverter, context);
target.addIllegalOp<AtenUnflattenIntOp>();
patterns.add<ConvertAtenViewOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);

View File

@ -379,7 +379,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenExpandOp>();
target.addIllegalOp<AtenFlattenUsingIntsOp>();
target.addIllegalOp<AtenUnflattenIntOp>();
target.addIllegalOp<AtenWhereScalarOp>();
target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>();

View File

@ -248,7 +248,7 @@ class ExampleArgs:
# compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'],
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d'],
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int'],
OutputType.STABLEHLO: [],
}

View File

@ -50,9 +50,11 @@ class LinalgOnTensorsOnnxBackend(OnnxBackend):
f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract")
backend_legal_ops = ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int']
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
run_pipeline_with_repro_report(
imported_module,
f"builtin.module(torch-lower-to-backend-contract)",
f"builtin.module(torch-lower-to-backend-contract{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
)