diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index e1b0e6fd4..3dec9be4c 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -33,7 +33,7 @@ static int getTensorRank(Value tensor) { return tensorRank; } -static Value createAtenSum(PatternRewriter &rewriter, Location loc, +static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { BaseTensorType tensorType = input.getType().cast(); @@ -69,7 +69,7 @@ static Value createAtenSum(PatternRewriter &rewriter, Location loc, } // Helper for creating `aten::sub_tensor_op`. -static Value createAtenSubTensorOp(PatternRewriter &rewriter, Location loc, +static Value createTensorSub(PatternRewriter &rewriter, Location loc, Type tensorType, Value lhs, Value rhs) { Value alpha = rewriter.create(loc, rewriter.getF64FloatAttr(1)); @@ -78,6 +78,27 @@ static Value createAtenSubTensorOp(PatternRewriter &rewriter, Location loc, return sub; } +// Share code between `softmax_backward` and `log_softmax_backward` ops. +// Returns x - y * sum(z, dim). +static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter, + Location loc, Operation *op, + Type tensorType, Value x, + Value y, Value z, Value dim) { + Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true); + if (!sum) + return nullptr; + auto broadcastSizeType = + Torch::ListType::get(Torch::IntType::get(op->getContext())); + Value broadcastSize = rewriter.create(loc, broadcastSizeType, z); + Value sumBroadcast = + rewriter.create(loc, tensorType, sum, broadcastSize); + Value temp = + rewriter.create(loc, tensorType, y, sumBroadcast); + + Value sub = createTensorSub(rewriter, loc, tensorType, x, temp); + return sub; +} + namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -126,7 +147,7 @@ public: // exp(x) Value exp = rewriter.create(loc, tensorType, self); // sum(exp(x)) - Value sum = createAtenSum(rewriter, loc, op, exp, dim, /*keepDim=*/true); + Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true); if (!sum) return failure(); // exp(x) / sum(exp(x)) @@ -152,7 +173,6 @@ public: LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - MLIRContext *context = op.getContext(); Value gradOutput = op.grad_output(); Value output = op.output(); Value dim = op.dim(); @@ -163,22 +183,13 @@ public: Value newGrad = rewriter.create(loc, tensorType, gradOutput, output); - // temp = output * sum(newGrad, dim) - Value sum = - createAtenSum(rewriter, loc, op, newGrad, dim, /*keepDim=*/true); - if (!sum) - return failure(); - auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context)); - Value broadcastSize = - rewriter.create(loc, broadcastSizeType, output); - Value sumBroadcast = - rewriter.create(loc, tensorType, sum, broadcastSize); - Value temp = - rewriter.create(loc, tensorType, output, sumBroadcast); - - // newGrad - temp - Value sub = createAtenSubTensorOp(rewriter, loc, tensorType, newGrad, temp); - rewriter.replaceOp(op, sub); + Value result = createSoftmaxBackwardCommonKernel( + rewriter, loc, op, tensorType, newGrad, output, newGrad, dim); + if (!result) + return rewriter.notifyMatchFailure( + op, + "nullptr returned by createSoftmaxBackwardCommonKernel function."); + rewriter.replaceOp(op, result); return success(); } }; @@ -199,7 +210,7 @@ public: Value gradOutput = op.grad_output(); // `output` is the value flowing out from tanh. Hence, tanh(x) = output. - // Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2). + // Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2). Value output = op.output(); BaseTensorType tensorType = gradOutput.getType().cast(); @@ -211,10 +222,8 @@ public: Value gradMulTanhSquare = rewriter.create( loc, tensorType, tanhSquare, gradOutput); - Value alpha = - rewriter.create(loc, rewriter.getF64FloatAttr(1)); - Value newGrad = rewriter.create( - loc, tensorType, gradOutput, gradMulTanhSquare, alpha); + Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput, + gradMulTanhSquare); rewriter.replaceOp(op, newGrad); return success(); } @@ -231,7 +240,6 @@ public: LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - MLIRContext *context = op.getContext(); Value gradOutput = op.grad_output(); Value output = op.output(); Value dim = op.dim(); @@ -241,22 +249,13 @@ public: return rewriter.notifyMatchFailure(op, "Only support floating type"); Value expOut = rewriter.create(loc, tensorType, output); - Value sum = - createAtenSum(rewriter, loc, op, gradOutput, dim, /*keepDim=*/true); - if (!sum) - return failure(); - - auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context)); - Value broadcastSize = - rewriter.create(loc, broadcastSizeType, output); - Value sumBroadcast = - rewriter.create(loc, tensorType, sum, broadcastSize); - Value temp = - rewriter.create(loc, tensorType, expOut, sumBroadcast); - - Value sub = - createAtenSubTensorOp(rewriter, loc, tensorType, gradOutput, temp); - rewriter.replaceOp(op, sub); + Value result = createSoftmaxBackwardCommonKernel( + rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim); + if (!result) + return rewriter.notifyMatchFailure( + op, + "nullptr returned by createSoftmaxBackwardCommonKernel function."); + rewriter.replaceOp(op, result); return success(); } };