diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index de68e6146..e6a1ae292 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -293,11 +293,10 @@ public: // unnorm = aten.exp(input - x_max) // softmax = unnorm / sum(unnorm, dim, keepdim = True) template -static Value getSoftmaxResult(OpTy op, Type resultType, +static Value getSoftmaxResult(OpTy op, Value self, Type resultType, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value dim = op.dim(); - Value self = op.self(); Value xMax = createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true); if (!xMax) @@ -329,7 +328,7 @@ public: if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); - Value result = getSoftmaxResult(op, tensorType, rewriter); + Value result = getSoftmaxResult(op, self, tensorType, rewriter); if (!result) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), @@ -354,16 +353,25 @@ public: return rewriter.notifyMatchFailure( op, "Expected a boolean value for half_to_float"); - // Currently, setting `halfToFloat` is not supported as the E2E testing for - // the same is not present on CPU. - if (halfToFloat) - return rewriter.notifyMatchFailure( - op, "halfToFloat is currently not supported."); - - Value result = getSoftmaxResult(op, tensorType, rewriter); + BaseTensorType resultTensorType = op.getType().cast(); + // `torch.ops.aten._softmax`'s softmax with half to float conversion is not + // supported on CPU, but we go ahead with the decomposing. + // TODO: Add an e2e test once upstream support is added. + // If `half_to_float` is set, we convert the input's elemental type to match + // that of output's. + if (halfToFloat) { + Location loc = op.getLoc(); + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + self = rewriter.create( + loc, resultTensorType, self, + getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()), + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); + } + Value result = getSoftmaxResult(op, self, resultTensorType, rewriter); if (!result) return op.emitError("failed to get softmax result"); - rewriter.replaceOpWithNewOp(op, op.getType(), + rewriter.replaceOpWithNewOp(op, resultTensorType, result); return success(); }