Added transpose lowering

pull/313/head
George Petterson 2021-09-17 02:49:04 -04:00 committed by Yi Zhang
parent c24ca5d639
commit ecc334123c
3 changed files with 98 additions and 1 deletions

View File

@ -180,3 +180,20 @@ class MaxPool2dModule(torch.nn.Module):
@register_test_case(module_factory=lambda: MaxPool2dModule()) @register_test_case(module_factory=lambda: MaxPool2dModule())
def MaxPool2dModule_basic(module, tu: TestUtils): def MaxPool2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20) - 0.5) module.forward(tu.rand(1, 1, 20, 20) - 0.5)
class TransposeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 2], torch.float32, True),
])
def forward(self, x):
return torch.transpose(x, 0, 1)
@register_test_case(module_factory=lambda: TransposeIntModule())
def TransposeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))

View File

@ -88,7 +88,8 @@ public:
Operation *op = workList.pop_back_val(); Operation *op = workList.pop_back_val();
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) { if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
copyToValueTensorOps.push_back(copyToValueTensor); copyToValueTensorOps.push_back(copyToValueTensor);
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp>(op)) { } else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
AtenTransposeIntOp>(op)) {
viewLikeOps.push_back(op); viewLikeOps.push_back(op);
llvm::append_range(workList, op->getResult(0).getUsers()); llvm::append_range(workList, op->getResult(0).getUsers());
} else { } else {

View File

@ -1275,6 +1275,83 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenTransposeIntOp
: public OpConversionPattern<AtenTransposeIntOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenTransposeIntOp op, llvm::ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
AtenTransposeIntOp::Adaptor adaptor(operands);
int64_t dim0;
if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0)))
return rewriter.notifyMatchFailure(op, "dim0 must be constant");
int64_t dim1;
if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1)))
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
auto inVector = adaptor.self();
auto inType = inVector.getType().cast<RankedTensorType>();
auto inputRank = inType.getRank();
auto outType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto elementType = inType.getElementType();
if (dim0 < 0)
dim0 += inputRank + 1;
if (dim0 < 0 || dim0 >= inputRank)
return rewriter.notifyMatchFailure(op, "dim0 out of range");
if (dim1 < 0)
dim1 += inputRank + 1;
if (dim1 < 0 || dim1 >= inputRank)
return rewriter.notifyMatchFailure(op, "dim1 out of range");
auto loc = op.getLoc();
llvm::SmallVector<Value> outputDims;
for (auto i = 0; i < inputRank; i++)
outputDims.push_back(getDimOp(rewriter, loc, adaptor.self(), i));
std::swap(outputDims[dim0], outputDims[dim1]);
Value outVector =
rewriter.create<linalg::InitTensorOp>(loc, outputDims, elementType);
SmallVector<AffineExpr> idExprs;
SmallVector<AffineExpr> swapExprs;
for (auto i = 0; i < inputRank; i++)
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
for (auto i = 0; i < inputRank; i++) {
if (i == dim0) {
swapExprs.push_back(idExprs[dim1]);
} else if (i == dim1) {
swapExprs.push_back(idExprs[dim0]);
} else {
swapExprs.push_back(idExprs[i]);
}
}
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(inputRank, 0, idExprs, op.getContext()),
AffineMap::get(inputRank, 0, swapExprs, op.getContext())};
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
auto transpose = rewriter
.create<linalg::GenericOp>(
loc, outVector.getType(), inVector, outVector,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
return success();
}
};
} // namespace
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// The pass // The pass
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -1325,6 +1402,8 @@ public:
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context); patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenSumOp>(); target.addIllegalOp<AtenSumOp>();
patterns.add<ConvertReductionOp>(typeConverter, context); patterns.add<ConvertReductionOp>(typeConverter, context);
target.addIllegalOp<AtenTransposeIntOp>();
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))