mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add lowering for `aten._softmax` when `half_to_float=True`
-- This commit adds decompose logic for `aten._softmax` when `half_to_float` is `True`. -- An e2e test case will be added once support for half to float conversion for `aten._softmax` is added upstream. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>pull/1655/head
parent
5a27f826b8
commit
bb259f918a
|
@ -293,11 +293,10 @@ public:
|
||||||
// unnorm = aten.exp(input - x_max)
|
// unnorm = aten.exp(input - x_max)
|
||||||
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
|
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
static Value getSoftmaxResult(OpTy op, Type resultType,
|
static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value dim = op.dim();
|
Value dim = op.dim();
|
||||||
Value self = op.self();
|
|
||||||
Value xMax =
|
Value xMax =
|
||||||
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
||||||
if (!xMax)
|
if (!xMax)
|
||||||
|
@ -329,7 +328,7 @@ public:
|
||||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
|
|
||||||
Value result = getSoftmaxResult(op, tensorType, rewriter);
|
Value result = getSoftmaxResult(op, self, tensorType, rewriter);
|
||||||
if (!result)
|
if (!result)
|
||||||
return failure();
|
return failure();
|
||||||
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
||||||
|
@ -354,16 +353,25 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Expected a boolean value for half_to_float");
|
op, "Expected a boolean value for half_to_float");
|
||||||
|
|
||||||
// Currently, setting `halfToFloat` is not supported as the E2E testing for
|
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
|
||||||
// the same is not present on CPU.
|
// `torch.ops.aten._softmax`'s softmax with half to float conversion is not
|
||||||
if (halfToFloat)
|
// supported on CPU, but we go ahead with the decomposing.
|
||||||
return rewriter.notifyMatchFailure(
|
// TODO: Add an e2e test once upstream support is added.
|
||||||
op, "halfToFloat is currently not supported.");
|
// If `half_to_float` is set, we convert the input's elemental type to match
|
||||||
|
// that of output's.
|
||||||
Value result = getSoftmaxResult(op, tensorType, rewriter);
|
if (halfToFloat) {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
||||||
|
self = rewriter.create<AtenToDtypeOp>(
|
||||||
|
loc, resultTensorType, self,
|
||||||
|
getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()),
|
||||||
|
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
|
||||||
|
}
|
||||||
|
Value result = getSoftmaxResult(op, self, resultTensorType, rewriter);
|
||||||
if (!result)
|
if (!result)
|
||||||
return op.emitError("failed to get softmax result");
|
return op.emitError("failed to get softmax result");
|
||||||
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
|
||||||
result);
|
result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue