[MLIR][TORCH] Add lowering for `aten._softmax` when `half_to_float=True`

-- This commit adds decompose logic for `aten._softmax` when
   `half_to_float` is `True`.
-- An e2e test case will be added once support for half to float conversion for
   `aten._softmax` is added upstream.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
pull/1655/head
Abhishek Varma 2022-11-25 05:56:37 +00:00 committed by Prashant Kumar
parent 5a27f826b8
commit bb259f918a
1 changed files with 19 additions and 11 deletions

View File

@ -293,11 +293,10 @@ public:
// unnorm = aten.exp(input - x_max)
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
template <typename OpTy>
static Value getSoftmaxResult(OpTy op, Type resultType,
static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value dim = op.dim();
Value self = op.self();
Value xMax =
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
if (!xMax)
@ -329,7 +328,7 @@ public:
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
Value result = getSoftmaxResult(op, tensorType, rewriter);
Value result = getSoftmaxResult(op, self, tensorType, rewriter);
if (!result)
return failure();
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
@ -354,16 +353,25 @@ public:
return rewriter.notifyMatchFailure(
op, "Expected a boolean value for half_to_float");
// Currently, setting `halfToFloat` is not supported as the E2E testing for
// the same is not present on CPU.
if (halfToFloat)
return rewriter.notifyMatchFailure(
op, "halfToFloat is currently not supported.");
Value result = getSoftmaxResult(op, tensorType, rewriter);
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
// `torch.ops.aten._softmax`'s softmax with half to float conversion is not
// supported on CPU, but we go ahead with the decomposing.
// TODO: Add an e2e test once upstream support is added.
// If `half_to_float` is set, we convert the input's elemental type to match
// that of output's.
if (halfToFloat) {
Location loc = op.getLoc();
Value none = rewriter.create<ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
self = rewriter.create<AtenToDtypeOp>(
loc, resultTensorType, self,
getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()),
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
}
Value result = getSoftmaxResult(op, self, resultTensorType, rewriter);
if (!result)
return op.emitError("failed to get softmax result");
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
result);
return success();
}