mirror of https://github.com/llvm/torch-mlir
Added transpose lowering
parent
c24ca5d639
commit
ecc334123c
|
@ -180,3 +180,20 @@ class MaxPool2dModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: MaxPool2dModule())
|
||||
def MaxPool2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
|
||||
|
||||
class TransposeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4, 2], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.transpose(x, 0, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TransposeIntModule())
|
||||
def TransposeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 2))
|
||||
|
|
|
@ -88,7 +88,8 @@ public:
|
|||
Operation *op = workList.pop_back_val();
|
||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp>(op)) {
|
||||
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||
AtenTransposeIntOp>(op)) {
|
||||
viewLikeOps.push_back(op);
|
||||
llvm::append_range(workList, op->getResult(0).getUsers());
|
||||
} else {
|
||||
|
|
|
@ -1275,6 +1275,83 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenTransposeIntOp
|
||||
: public OpConversionPattern<AtenTransposeIntOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenTransposeIntOp op, llvm::ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
AtenTransposeIntOp::Adaptor adaptor(operands);
|
||||
|
||||
int64_t dim0;
|
||||
if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0)))
|
||||
return rewriter.notifyMatchFailure(op, "dim0 must be constant");
|
||||
int64_t dim1;
|
||||
if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1)))
|
||||
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
|
||||
|
||||
auto inVector = adaptor.self();
|
||||
auto inType = inVector.getType().cast<RankedTensorType>();
|
||||
auto inputRank = inType.getRank();
|
||||
auto outType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto elementType = inType.getElementType();
|
||||
|
||||
if (dim0 < 0)
|
||||
dim0 += inputRank + 1;
|
||||
if (dim0 < 0 || dim0 >= inputRank)
|
||||
return rewriter.notifyMatchFailure(op, "dim0 out of range");
|
||||
if (dim1 < 0)
|
||||
dim1 += inputRank + 1;
|
||||
if (dim1 < 0 || dim1 >= inputRank)
|
||||
return rewriter.notifyMatchFailure(op, "dim1 out of range");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
|
||||
llvm::SmallVector<Value> outputDims;
|
||||
for (auto i = 0; i < inputRank; i++)
|
||||
outputDims.push_back(getDimOp(rewriter, loc, adaptor.self(), i));
|
||||
std::swap(outputDims[dim0], outputDims[dim1]);
|
||||
|
||||
Value outVector =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, outputDims, elementType);
|
||||
SmallVector<AffineExpr> idExprs;
|
||||
SmallVector<AffineExpr> swapExprs;
|
||||
for (auto i = 0; i < inputRank; i++)
|
||||
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
|
||||
for (auto i = 0; i < inputRank; i++) {
|
||||
if (i == dim0) {
|
||||
swapExprs.push_back(idExprs[dim1]);
|
||||
} else if (i == dim1) {
|
||||
swapExprs.push_back(idExprs[dim0]);
|
||||
} else {
|
||||
swapExprs.push_back(idExprs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::get(inputRank, 0, idExprs, op.getContext()),
|
||||
AffineMap::get(inputRank, 0, swapExprs, op.getContext())};
|
||||
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
|
||||
auto transpose = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outVector.getType(), inVector, outVector,
|
||||
indexingMaps, iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -1325,6 +1402,8 @@ public:
|
|||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSumOp>();
|
||||
patterns.add<ConvertReductionOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenTransposeIntOp>();
|
||||
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
Loading…
Reference in New Issue