mirror of https://github.com/llvm/torch-mlir
[torch] Make torch.aten.unflatten lower directly to linalg (#2971)
Existing lowering via aten.view does not work as well for dynamic shapes as the lowering to tensor.expand must re-infer dynamic shape matching. Better to directly lower.pull/2978/head
parent
d51e80b648
commit
19d4888278
|
@ -638,6 +638,68 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Lower aten.unflatten.int into tensor.expand_shape
|
||||
namespace {
|
||||
class ConvertAtenUnflattenIntOp
|
||||
: public OpConversionPattern<AtenUnflattenIntOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenUnflattenIntOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
BaseTensorType outputTensorType = op.getType().cast<BaseTensorType>();
|
||||
if (!outputTensorType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: output must have known sizes");
|
||||
|
||||
std::optional<unsigned> maybeRank = getTensorRank(self);
|
||||
if (!maybeRank)
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
||||
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>();
|
||||
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Expected input type having sizes");
|
||||
}
|
||||
int inputRank = inputTensorType.getSizes().size();
|
||||
int outputRank = outputTensorType.getSizes().size();
|
||||
|
||||
int64_t dimInt;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: requires dim to be constants");
|
||||
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
|
||||
auto sizesOp = op.getSizes().getDefiningOp<Torch::PrimListConstructOp>();
|
||||
int numSizes = sizesOp.getNumOperands();
|
||||
|
||||
SmallVector<ReassociationIndices> reassociations(inputRank);
|
||||
if (inputRank > 0) {
|
||||
for (int i = 0; i < dimInt; ++i)
|
||||
reassociations[i].push_back(i);
|
||||
|
||||
for (int i = 0; i < numSizes; ++i)
|
||||
reassociations[dimInt].push_back(i + dimInt);
|
||||
|
||||
for (int i = dimInt + numSizes; i < outputRank; ++i)
|
||||
reassociations[i - numSizes + 1].push_back(i);
|
||||
}
|
||||
|
||||
auto expandTy = getTypeConverter()->convertType(outputTensorType);
|
||||
auto expand = rewriter
|
||||
.create<tensor::ExpandShapeOp>(
|
||||
loc, expandTy, adaptor.getSelf(), reassociations)
|
||||
.getResult();
|
||||
rewriter.replaceOp(op, expand);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
/// The `ConvertAtenViewOp` conversion pattern converts `aten.View` op to
|
||||
/// one `linalg.TensorExpandShape` op for all expanded dimensions and one
|
||||
|
@ -2043,6 +2105,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
target.addIllegalOp<AtenFlattenUsingIntsOp>();
|
||||
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenViewOp>();
|
||||
patterns.add<ConvertAtenUnflattenIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUnflattenIntOp>();
|
||||
patterns.add<ConvertAtenViewOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeOp>();
|
||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||
|
|
|
@ -379,7 +379,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenRepeatOp>();
|
||||
target.addIllegalOp<AtenExpandOp>();
|
||||
target.addIllegalOp<AtenFlattenUsingIntsOp>();
|
||||
target.addIllegalOp<AtenUnflattenIntOp>();
|
||||
target.addIllegalOp<AtenWhereScalarOp>();
|
||||
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
||||
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
||||
|
|
|
@ -248,7 +248,7 @@ class ExampleArgs:
|
|||
# compiler where each backend can "own" its set of legal ops.
|
||||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d'],
|
||||
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int'],
|
||||
OutputType.STABLEHLO: [],
|
||||
}
|
||||
|
||||
|
|
|
@ -50,9 +50,11 @@ class LinalgOnTensorsOnnxBackend(OnnxBackend):
|
|||
f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
|
||||
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract")
|
||||
|
||||
backend_legal_ops = ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int']
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
f"builtin.module(torch-lower-to-backend-contract)",
|
||||
f"builtin.module(torch-lower-to-backend-contract{option_string})",
|
||||
"Lowering TorchFX IR -> Torch Backend IR",
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue