Refactor to share code in DecomposeComplexOps pass

Share code in `log_softmax_backward` and `softmax_backward` ops.
pull/432/head snapshot-20211119.94
Prashant Kumar 2021-11-19 12:18:41 +00:00
parent ea7a30f9b9
commit 1dc374014b
1 changed files with 41 additions and 42 deletions

View File

@ -33,7 +33,7 @@ static int getTensorRank(Value tensor) {
return tensorRank;
}
static Value createAtenSum(PatternRewriter &rewriter, Location loc,
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
@ -69,7 +69,7 @@ static Value createAtenSum(PatternRewriter &rewriter, Location loc,
}
// Helper for creating `aten::sub_tensor_op`.
static Value createAtenSubTensorOp(PatternRewriter &rewriter, Location loc,
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
Type tensorType, Value lhs, Value rhs) {
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
@ -78,6 +78,27 @@ static Value createAtenSubTensorOp(PatternRewriter &rewriter, Location loc,
return sub;
}
// Share code between `softmax_backward` and `log_softmax_backward` ops.
// Returns x - y * sum(z, dim).
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
Location loc, Operation *op,
Type tensorType, Value x,
Value y, Value z, Value dim) {
Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
if (!sum)
return nullptr;
auto broadcastSizeType =
Torch::ListType::get(Torch::IntType::get(op->getContext()));
Value broadcastSize = rewriter.create<AtenSizeOp>(loc, broadcastSizeType, z);
Value sumBroadcast =
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
Value temp =
rewriter.create<AtenMulTensorOp>(loc, tensorType, y, sumBroadcast);
Value sub = createTensorSub(rewriter, loc, tensorType, x, temp);
return sub;
}
namespace {
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
public:
@ -126,7 +147,7 @@ public:
// exp(x)
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
// sum(exp(x))
Value sum = createAtenSum(rewriter, loc, op, exp, dim, /*keepDim=*/true);
Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
if (!sum)
return failure();
// exp(x) / sum(exp(x))
@ -152,7 +173,6 @@ public:
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value gradOutput = op.grad_output();
Value output = op.output();
Value dim = op.dim();
@ -163,22 +183,13 @@ public:
Value newGrad =
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
// temp = output * sum(newGrad, dim)
Value sum =
createAtenSum(rewriter, loc, op, newGrad, dim, /*keepDim=*/true);
if (!sum)
return failure();
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context));
Value broadcastSize =
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
Value sumBroadcast =
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
Value temp =
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, sumBroadcast);
// newGrad - temp
Value sub = createAtenSubTensorOp(rewriter, loc, tensorType, newGrad, temp);
rewriter.replaceOp(op, sub);
Value result = createSoftmaxBackwardCommonKernel(
rewriter, loc, op, tensorType, newGrad, output, newGrad, dim);
if (!result)
return rewriter.notifyMatchFailure(
op,
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
rewriter.replaceOp(op, result);
return success();
}
};
@ -199,7 +210,7 @@ public:
Value gradOutput = op.grad_output();
// `output` is the value flowing out from tanh. Hence, tanh(x) = output.
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
Value output = op.output();
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
@ -211,10 +222,8 @@ public:
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
loc, tensorType, tanhSquare, gradOutput);
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
Value newGrad = rewriter.create<AtenSubTensorOp>(
loc, tensorType, gradOutput, gradMulTanhSquare, alpha);
Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
gradMulTanhSquare);
rewriter.replaceOp(op, newGrad);
return success();
}
@ -231,7 +240,6 @@ public:
LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value gradOutput = op.grad_output();
Value output = op.output();
Value dim = op.dim();
@ -241,22 +249,13 @@ public:
return rewriter.notifyMatchFailure(op, "Only support floating type");
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
Value sum =
createAtenSum(rewriter, loc, op, gradOutput, dim, /*keepDim=*/true);
if (!sum)
return failure();
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context));
Value broadcastSize =
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
Value sumBroadcast =
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
Value temp =
rewriter.create<AtenMulTensorOp>(loc, tensorType, expOut, sumBroadcast);
Value sub =
createAtenSubTensorOp(rewriter, loc, tensorType, gradOutput, temp);
rewriter.replaceOp(op, sub);
Value result = createSoftmaxBackwardCommonKernel(
rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim);
if (!result)
return rewriter.notifyMatchFailure(
op,
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
rewriter.replaceOp(op, result);
return success();
}
};