Refactor to share code in DecomposeComplexOps pass

Share code in `log_softmax_backward` and `softmax_backward` ops.
pull/432/head snapshot-20211119.94
Prashant Kumar 2021-11-19 12:18:41 +00:00
parent ea7a30f9b9
commit 1dc374014b
1 changed files with 41 additions and 42 deletions

View File

@ -33,7 +33,7 @@ static int getTensorRank(Value tensor) {
return tensorRank; return tensorRank;
} }
static Value createAtenSum(PatternRewriter &rewriter, Location loc, static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim, Operation *op, Value input, Value dim,
bool keepDim) { bool keepDim) {
BaseTensorType tensorType = input.getType().cast<BaseTensorType>(); BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
@ -69,7 +69,7 @@ static Value createAtenSum(PatternRewriter &rewriter, Location loc,
} }
// Helper for creating `aten::sub_tensor_op`. // Helper for creating `aten::sub_tensor_op`.
static Value createAtenSubTensorOp(PatternRewriter &rewriter, Location loc, static Value createTensorSub(PatternRewriter &rewriter, Location loc,
Type tensorType, Value lhs, Value rhs) { Type tensorType, Value lhs, Value rhs) {
Value alpha = Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1)); rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
@ -78,6 +78,27 @@ static Value createAtenSubTensorOp(PatternRewriter &rewriter, Location loc,
return sub; return sub;
} }
// Share code between `softmax_backward` and `log_softmax_backward` ops.
// Returns x - y * sum(z, dim).
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
Location loc, Operation *op,
Type tensorType, Value x,
Value y, Value z, Value dim) {
Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
if (!sum)
return nullptr;
auto broadcastSizeType =
Torch::ListType::get(Torch::IntType::get(op->getContext()));
Value broadcastSize = rewriter.create<AtenSizeOp>(loc, broadcastSizeType, z);
Value sumBroadcast =
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
Value temp =
rewriter.create<AtenMulTensorOp>(loc, tensorType, y, sumBroadcast);
Value sub = createTensorSub(rewriter, loc, tensorType, x, temp);
return sub;
}
namespace { namespace {
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> { class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
public: public:
@ -126,7 +147,7 @@ public:
// exp(x) // exp(x)
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self); Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
// sum(exp(x)) // sum(exp(x))
Value sum = createAtenSum(rewriter, loc, op, exp, dim, /*keepDim=*/true); Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
if (!sum) if (!sum)
return failure(); return failure();
// exp(x) / sum(exp(x)) // exp(x) / sum(exp(x))
@ -152,7 +173,6 @@ public:
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op, LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value gradOutput = op.grad_output(); Value gradOutput = op.grad_output();
Value output = op.output(); Value output = op.output();
Value dim = op.dim(); Value dim = op.dim();
@ -163,22 +183,13 @@ public:
Value newGrad = Value newGrad =
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output); rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
// temp = output * sum(newGrad, dim) Value result = createSoftmaxBackwardCommonKernel(
Value sum = rewriter, loc, op, tensorType, newGrad, output, newGrad, dim);
createAtenSum(rewriter, loc, op, newGrad, dim, /*keepDim=*/true); if (!result)
if (!sum) return rewriter.notifyMatchFailure(
return failure(); op,
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context)); "nullptr returned by createSoftmaxBackwardCommonKernel function.");
Value broadcastSize = rewriter.replaceOp(op, result);
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
Value sumBroadcast =
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
Value temp =
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, sumBroadcast);
// newGrad - temp
Value sub = createAtenSubTensorOp(rewriter, loc, tensorType, newGrad, temp);
rewriter.replaceOp(op, sub);
return success(); return success();
} }
}; };
@ -211,10 +222,8 @@ public:
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>( Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
loc, tensorType, tanhSquare, gradOutput); loc, tensorType, tanhSquare, gradOutput);
Value alpha = Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1)); gradMulTanhSquare);
Value newGrad = rewriter.create<AtenSubTensorOp>(
loc, tensorType, gradOutput, gradMulTanhSquare, alpha);
rewriter.replaceOp(op, newGrad); rewriter.replaceOp(op, newGrad);
return success(); return success();
} }
@ -231,7 +240,6 @@ public:
LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op, LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value gradOutput = op.grad_output(); Value gradOutput = op.grad_output();
Value output = op.output(); Value output = op.output();
Value dim = op.dim(); Value dim = op.dim();
@ -241,22 +249,13 @@ public:
return rewriter.notifyMatchFailure(op, "Only support floating type"); return rewriter.notifyMatchFailure(op, "Only support floating type");
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output); Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
Value sum = Value result = createSoftmaxBackwardCommonKernel(
createAtenSum(rewriter, loc, op, gradOutput, dim, /*keepDim=*/true); rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim);
if (!sum) if (!result)
return failure(); return rewriter.notifyMatchFailure(
op,
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context)); "nullptr returned by createSoftmaxBackwardCommonKernel function.");
Value broadcastSize = rewriter.replaceOp(op, result);
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
Value sumBroadcast =
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
Value temp =
rewriter.create<AtenMulTensorOp>(loc, tensorType, expOut, sumBroadcast);
Value sub =
createAtenSubTensorOp(rewriter, loc, tensorType, gradOutput, temp);
rewriter.replaceOp(op, sub);
return success(); return success();
} }
}; };