diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index a97fa4d7e..0935c859d 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -180,3 +180,20 @@ class MaxPool2dModule(torch.nn.Module): @register_test_case(module_factory=lambda: MaxPool2dModule()) def MaxPool2dModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20) - 0.5) + +class TransposeIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 2], torch.float32, True), + ]) + def forward(self, x): + return torch.transpose(x, 0, 1) + + +@register_test_case(module_factory=lambda: TransposeIntModule()) +def TransposeIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 2)) diff --git a/external/torch-mlir/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/external/torch-mlir/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 29ea3d281..417643cc9 100644 --- a/external/torch-mlir/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/external/torch-mlir/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -88,7 +88,8 @@ public: Operation *op = workList.pop_back_val(); if (auto copyToValueTensor = dyn_cast(op)) { copyToValueTensorOps.push_back(copyToValueTensor); - } else if (isa(op)) { + } else if (isa(op)) { viewLikeOps.push_back(op); llvm::append_range(workList, op->getResult(0).getUsers()); } else { diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 0276496ae..002abf35c 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1275,6 +1275,83 @@ public: }; } // namespace +namespace { +class ConvertAtenTransposeIntOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenTransposeIntOp op, llvm::ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + AtenTransposeIntOp::Adaptor adaptor(operands); + + int64_t dim0; + if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0))) + return rewriter.notifyMatchFailure(op, "dim0 must be constant"); + int64_t dim1; + if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1))) + return rewriter.notifyMatchFailure(op, "dim1 must be constant"); + + auto inVector = adaptor.self(); + auto inType = inVector.getType().cast(); + auto inputRank = inType.getRank(); + auto outType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + auto elementType = inType.getElementType(); + + if (dim0 < 0) + dim0 += inputRank + 1; + if (dim0 < 0 || dim0 >= inputRank) + return rewriter.notifyMatchFailure(op, "dim0 out of range"); + if (dim1 < 0) + dim1 += inputRank + 1; + if (dim1 < 0 || dim1 >= inputRank) + return rewriter.notifyMatchFailure(op, "dim1 out of range"); + + auto loc = op.getLoc(); + + llvm::SmallVector outputDims; + for (auto i = 0; i < inputRank; i++) + outputDims.push_back(getDimOp(rewriter, loc, adaptor.self(), i)); + std::swap(outputDims[dim0], outputDims[dim1]); + + Value outVector = + rewriter.create(loc, outputDims, elementType); + SmallVector idExprs; + SmallVector swapExprs; + for (auto i = 0; i < inputRank; i++) + idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); + for (auto i = 0; i < inputRank; i++) { + if (i == dim0) { + swapExprs.push_back(idExprs[dim1]); + } else if (i == dim1) { + swapExprs.push_back(idExprs[dim0]); + } else { + swapExprs.push_back(idExprs[i]); + } + } + + SmallVector indexingMaps = { + AffineMap::get(inputRank, 0, idExprs, op.getContext()), + AffineMap::get(inputRank, 0, swapExprs, op.getContext())}; + SmallVector iteratorTypes(inputRank, "parallel"); + auto transpose = rewriter + .create( + loc, outVector.getType(), inVector, outVector, + indexingMaps, iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, outType, transpose); + return success(); + } +}; +} // namespace + // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- @@ -1325,6 +1402,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))