From f461a7ebcef872b82bdf9ab8ab980a44ca32a4ec Mon Sep 17 00:00:00 2001 From: Prateek Gupta Date: Thu, 25 Nov 2021 11:19:02 +0530 Subject: [PATCH] [TORCH][MLIR] Add E2E support for aten._softmax operation. (#431) Signed-Off-By: Prateek Gupta --- e2e_testing/torchscript/basic.py | 17 +++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 16 +++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 64 ++++++++++++++++--- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 33 ++++++++-- .../jit_ir/build_tools/torch_ods_gen.py | 1 + 5 files changed, 118 insertions(+), 13 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index c6641b2fd..63fedc348 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -457,6 +457,23 @@ class SoftmaxIntModule(torch.nn.Module): def SoftmaxIntModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4)) +class _SoftmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, tensor): + return torch.ops.aten._softmax(tensor, 0, False) + + +@register_test_case(module_factory=lambda: _SoftmaxModule()) +def _SoftmaxModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 2, 4)) + class SoftmaxIntNegDimModule(torch.nn.Module): def __init__(self): diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 4ac931686..f3d4d0c68 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1392,6 +1392,22 @@ def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [ let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; } +def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::_softmax : (Tensor, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$half_to_float + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)"; +} + def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ AllowsTypeRefinement ]> { diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 8c1941780..be67a0b69 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -126,6 +126,26 @@ public: }; } // namespace +// Calculates the softmax function on the given `input` tensor. Softmax(x) = +// exp(x)/sum(exp(x)). +template +static Value getSoftmaxResult(OpTy op, Type resultType, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); + Value dim = op.dim(); + Value self = op.self(); + + // exp(x) + Value exp = rewriter.create(loc, resultType, self); + // sum(exp(x)) + Value sum = + createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true); + if (!sum) + return nullptr; + // exp(x) / sum(exp(x)) + return rewriter.create(loc, resultType, exp, sum); +} + // Decompose softmax into: exp(x) / sum(exp(x)) namespace { class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { @@ -133,9 +153,7 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSoftmaxIntOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); Value self = op.self(); - Value dim = op.dim(); if (!op.dtype().getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented non-None dtype for softmax"); @@ -144,14 +162,40 @@ public: if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); - // exp(x) - Value exp = rewriter.create(loc, tensorType, self); - // sum(exp(x)) - Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true); - if (!sum) + Value result = getSoftmaxResult(op, tensorType, rewriter); + if (!result) return failure(); - // exp(x) / sum(exp(x)) - Value result = rewriter.create(loc, tensorType, exp, sum); + rewriter.replaceOpWithNewOp(op, op.getType(), + result); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAten_SoftmaxOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_SoftmaxOp op, + PatternRewriter &rewriter) const override { + Value self = op.self(); + BaseTensorType tensorType = self.getType().cast(); + if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) + return rewriter.notifyMatchFailure(op, "Only support floating type"); + bool halfToFloat; + if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) + return rewriter.notifyMatchFailure( + op, "Expected a boolean value for half_to_float"); + + // Currently, setting `halfToFloat` is not supported as the E2E testing for + // the same is not present on CPU. + if (halfToFloat) + return rewriter.notifyMatchFailure( + op, "halfToFloat is currently not supported."); + + Value result = getSoftmaxResult(op, tensorType, rewriter); + if (!result) + return op.emitError("failed to get softmax result"); rewriter.replaceOpWithNewOp(op, op.getType(), result); return success(); @@ -406,6 +450,8 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 28405930f..e2079609e 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -418,6 +418,8 @@ public: return visitAtenMatmulOp(matmul, operands); } else if (auto softmaxIntOp = dyn_cast(op)) { return visitAtenSoftmaxLikeOp(softmaxIntOp, operands); + } else if (auto _softmaxOp = dyn_cast(op)) { + return visitAten_SoftmaxOp(_softmaxOp, operands); } else if (auto logSoftmaxIntOp = dyn_cast(op)) { return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands); } else if (auto numToTensorOp = dyn_cast(op)) { @@ -541,6 +543,10 @@ private: ChangeResult visitAtenAddCLikeOp(Operation *op, ArrayRef *> operands); + + ChangeResult + visitAten_SoftmaxOp(Aten_SoftmaxOp op, + ArrayRef *> operands); }; } // namespace @@ -1332,6 +1338,16 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp( return getLatticeElement(op.getResult()).join(knowledge); } +static ValueKnowledge +getSameSizeAsInput(Operation *op, + ArrayRef *> operands) { + auto input = operands[0]->getValue(); + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + knowledge.hasSizes = input.hasSizes; + knowledge.sizes = input.sizes; + return knowledge; +} // Common template for softmax like ops, eg., log_softmax. template @@ -1339,14 +1355,23 @@ ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp( OpTy op, ArrayRef *> operands) { auto input = operands[0]->getValue(); auto dtype = op.dtype(); - auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); - knowledge.hasSizes = input.hasSizes; - knowledge.sizes = input.sizes; + ValueKnowledge knowledge = getSameSizeAsInput(op, operands); fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype); return getLatticeElement(op.getResult()).join(knowledge); } +ChangeResult TypeAnalyzer::visitAten_SoftmaxOp( + Aten_SoftmaxOp op, ArrayRef *> operands) { + auto input = operands[0]->getValue(); + ValueKnowledge knowledge = getSameSizeAsInput(op, operands); + bool halfToFloat; + if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) { + knowledge.dtype = + halfToFloat ? Float32Type::get(op->getContext()) : input.dtype; + } + return getLatticeElement(op.getResult()).join(knowledge); +} + ChangeResult TypeAnalyzer::visitAtenBmmOp( AtenBmmOp op, ArrayRef *> operands) { auto knowledge = diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index a745f4305..b31299695 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -516,6 +516,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)") emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::sqrt : (Tensor) -> (Tensor)") + emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") # Misc tensor ops. emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")