From ef897dbb1924d90272f1a6a8a0a31f2b30077f1b Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Tue, 2 Nov 2021 17:06:04 +0000 Subject: [PATCH] Add lowering of `aten.log_softmax` op. The `aten.log_softmax` is decomposed into `aten.softmax` and `aten.log` op. --- e2e_testing/torchscript/basic.py | 17 +++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 16 ++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 30 +++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 19 ++++++++---- .../jit_ir/build_tools/torch_ods_gen.py | 3 ++ 5 files changed, 79 insertions(+), 6 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 99d82585c..5200df6a2 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -512,3 +512,20 @@ class TensorToInt(torch.nn.Module): @register_test_case(module_factory=lambda: TensorToInt()) def TensorToInt_basic(module, tu: TestUtils): module.forward(torch.randint(10,[]), tu.rand()) + +class LogSoftmaxIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.log_softmax = torch.nn.LogSoftmax(2) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, tensor): + return self.log_softmax.forward(tensor) + +@register_test_case(module_factory=lambda: LogSoftmaxIntModule()) +def LogSoftmaxIntModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 2, 4).double()) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 102df7087..395e575f2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1088,6 +1088,22 @@ def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [ let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` type($self) `,` type($dim) `,` type($dtype) `->` type($result)"; } +def Torch_AtenLogSoftmaxIntOp : Torch_Op<"aten.log_softmax.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + TorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` type($self) `,` type($dim) `,` type($dtype) `->` type($result)"; +} + def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 5beb4c5cb..e5aa6503e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -88,6 +88,34 @@ public: }; } // namespace +// Decompose aten.log_softmax op into: log(softmax(x)) +namespace { +class DecomposeAtenLogSoftmaxIntOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.self(); + Value dim = op.dim(); + if (!op.dtype().getType().isa()) + return rewriter.notifyMatchFailure( + op, "Unimplemented non-None dtype for log_softmax"); + + BaseTensorType tensorType = self.getType().cast(); + if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) + return rewriter.notifyMatchFailure(op, "Only support floating type"); + + // softmax(x, dim) + Value softmax = rewriter.create(loc, tensorType, self, + dim, op.dtype()); + rewriter.replaceOpWithNewOp(op, op.getType(), softmax); + return success(); + } +}; +} // namespace + // Decompose torch.matmul into: torch.mm and torch.bmm according to ranks. namespace { class DecomposeAtenMatmulOp : public OpRewritePattern { @@ -125,6 +153,8 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addDynamicallyLegalOp([](AtenMatmulOp op) { int lhsRank = getTensorRank(op.self()); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 90fb6449b..0d30d606e 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -411,7 +411,9 @@ public: } else if (auto matmul = dyn_cast(op)) { return visitAtenMatmulOp(matmul, operands); } else if (auto softmaxIntOp = dyn_cast(op)) { - return visitAtenSoftmaxIntOp(softmaxIntOp, operands); + return visitAtenSoftmaxLikeOp(softmaxIntOp, operands); + } else if (auto logSoftmaxIntOp = dyn_cast(op)) { + return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands); } // Otherwise, this is an unknown operation. Just mark all results as @@ -511,11 +513,13 @@ private: visitAtenBmmOp(AtenBmmOp op, ArrayRef *> operands); ChangeResult - visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op, - ArrayRef *> operands); - ChangeResult visitAtenMatmulOp(AtenMatmulOp op, ArrayRef *> operands); + + template + ChangeResult + visitAtenSoftmaxLikeOp(OpTy op, + ArrayRef *> operands); }; } // namespace @@ -1259,8 +1263,11 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp( return getLatticeElement(op.getResult()).join(knowledge); } -ChangeResult TypeAnalyzer::visitAtenSoftmaxIntOp( - AtenSoftmaxIntOp op, ArrayRef *> operands) { + +// Common template for softmax like ops, eg., log_softmax. +template +ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp( + OpTy op, ArrayRef *> operands) { auto input = operands[0]->getValue(); auto dtype = op.dtype(); auto knowledge = diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index a2699e987..10f8b969d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -495,6 +495,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit( "aten::softmax.int : (Tensor, int, int?) -> (Tensor)" ) + emit( + "aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)" + ) emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")