mirror of https://github.com/llvm/torch-mlir
Added transpose lowering
parent
c24ca5d639
commit
ecc334123c
|
@ -180,3 +180,20 @@ class MaxPool2dModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: MaxPool2dModule())
|
@register_test_case(module_factory=lambda: MaxPool2dModule())
|
||||||
def MaxPool2dModule_basic(module, tu: TestUtils):
|
def MaxPool2dModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
|
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
|
||||||
|
|
||||||
|
class TransposeIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4, 2], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.transpose(x, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TransposeIntModule())
|
||||||
|
def TransposeIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 2))
|
||||||
|
|
|
@ -88,7 +88,8 @@ public:
|
||||||
Operation *op = workList.pop_back_val();
|
Operation *op = workList.pop_back_val();
|
||||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||||
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp>(op)) {
|
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||||
|
AtenTransposeIntOp>(op)) {
|
||||||
viewLikeOps.push_back(op);
|
viewLikeOps.push_back(op);
|
||||||
llvm::append_range(workList, op->getResult(0).getUsers());
|
llvm::append_range(workList, op->getResult(0).getUsers());
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1275,6 +1275,83 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenTransposeIntOp
|
||||||
|
: public OpConversionPattern<AtenTransposeIntOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenTransposeIntOp op, llvm::ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
AtenTransposeIntOp::Adaptor adaptor(operands);
|
||||||
|
|
||||||
|
int64_t dim0;
|
||||||
|
if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim0 must be constant");
|
||||||
|
int64_t dim1;
|
||||||
|
if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
|
||||||
|
|
||||||
|
auto inVector = adaptor.self();
|
||||||
|
auto inType = inVector.getType().cast<RankedTensorType>();
|
||||||
|
auto inputRank = inType.getRank();
|
||||||
|
auto outType = getTypeConverter()
|
||||||
|
->convertType(op->getResult(0).getType())
|
||||||
|
.cast<RankedTensorType>();
|
||||||
|
auto elementType = inType.getElementType();
|
||||||
|
|
||||||
|
if (dim0 < 0)
|
||||||
|
dim0 += inputRank + 1;
|
||||||
|
if (dim0 < 0 || dim0 >= inputRank)
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim0 out of range");
|
||||||
|
if (dim1 < 0)
|
||||||
|
dim1 += inputRank + 1;
|
||||||
|
if (dim1 < 0 || dim1 >= inputRank)
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim1 out of range");
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
llvm::SmallVector<Value> outputDims;
|
||||||
|
for (auto i = 0; i < inputRank; i++)
|
||||||
|
outputDims.push_back(getDimOp(rewriter, loc, adaptor.self(), i));
|
||||||
|
std::swap(outputDims[dim0], outputDims[dim1]);
|
||||||
|
|
||||||
|
Value outVector =
|
||||||
|
rewriter.create<linalg::InitTensorOp>(loc, outputDims, elementType);
|
||||||
|
SmallVector<AffineExpr> idExprs;
|
||||||
|
SmallVector<AffineExpr> swapExprs;
|
||||||
|
for (auto i = 0; i < inputRank; i++)
|
||||||
|
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
|
||||||
|
for (auto i = 0; i < inputRank; i++) {
|
||||||
|
if (i == dim0) {
|
||||||
|
swapExprs.push_back(idExprs[dim1]);
|
||||||
|
} else if (i == dim1) {
|
||||||
|
swapExprs.push_back(idExprs[dim0]);
|
||||||
|
} else {
|
||||||
|
swapExprs.push_back(idExprs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<AffineMap> indexingMaps = {
|
||||||
|
AffineMap::get(inputRank, 0, idExprs, op.getContext()),
|
||||||
|
AffineMap::get(inputRank, 0, swapExprs, op.getContext())};
|
||||||
|
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
|
||||||
|
auto transpose = rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, outVector.getType(), inVector, outVector,
|
||||||
|
indexingMaps, iteratorTypes,
|
||||||
|
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
b.create<linalg::YieldOp>(loc, args[0]);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// The pass
|
// The pass
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -1325,6 +1402,8 @@ public:
|
||||||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenSumOp>();
|
target.addIllegalOp<AtenSumOp>();
|
||||||
patterns.add<ConvertReductionOp>(typeConverter, context);
|
patterns.add<ConvertReductionOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenTransposeIntOp>();
|
||||||
|
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
|
Loading…
Reference in New Issue