mirror of https://github.com/llvm/torch-mlir
Refactor to share code in DecomposeComplexOps pass
Share code in `log_softmax_backward` and `softmax_backward` ops.pull/432/head snapshot-20211119.94
parent
ea7a30f9b9
commit
1dc374014b
|
@ -33,7 +33,7 @@ static int getTensorRank(Value tensor) {
|
|||
return tensorRank;
|
||||
}
|
||||
|
||||
static Value createAtenSum(PatternRewriter &rewriter, Location loc,
|
||||
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value input, Value dim,
|
||||
bool keepDim) {
|
||||
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
|
||||
|
@ -69,7 +69,7 @@ static Value createAtenSum(PatternRewriter &rewriter, Location loc,
|
|||
}
|
||||
|
||||
// Helper for creating `aten::sub_tensor_op`.
|
||||
static Value createAtenSubTensorOp(PatternRewriter &rewriter, Location loc,
|
||||
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
||||
Type tensorType, Value lhs, Value rhs) {
|
||||
Value alpha =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||||
|
@ -78,6 +78,27 @@ static Value createAtenSubTensorOp(PatternRewriter &rewriter, Location loc,
|
|||
return sub;
|
||||
}
|
||||
|
||||
// Share code between `softmax_backward` and `log_softmax_backward` ops.
|
||||
// Returns x - y * sum(z, dim).
|
||||
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
||||
Location loc, Operation *op,
|
||||
Type tensorType, Value x,
|
||||
Value y, Value z, Value dim) {
|
||||
Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
|
||||
if (!sum)
|
||||
return nullptr;
|
||||
auto broadcastSizeType =
|
||||
Torch::ListType::get(Torch::IntType::get(op->getContext()));
|
||||
Value broadcastSize = rewriter.create<AtenSizeOp>(loc, broadcastSizeType, z);
|
||||
Value sumBroadcast =
|
||||
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
||||
Value temp =
|
||||
rewriter.create<AtenMulTensorOp>(loc, tensorType, y, sumBroadcast);
|
||||
|
||||
Value sub = createTensorSub(rewriter, loc, tensorType, x, temp);
|
||||
return sub;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
||||
public:
|
||||
|
@ -126,7 +147,7 @@ public:
|
|||
// exp(x)
|
||||
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
|
||||
// sum(exp(x))
|
||||
Value sum = createAtenSum(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
||||
Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
||||
if (!sum)
|
||||
return failure();
|
||||
// exp(x) / sum(exp(x))
|
||||
|
@ -152,7 +173,6 @@ public:
|
|||
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
Value gradOutput = op.grad_output();
|
||||
Value output = op.output();
|
||||
Value dim = op.dim();
|
||||
|
@ -163,22 +183,13 @@ public:
|
|||
|
||||
Value newGrad =
|
||||
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
|
||||
// temp = output * sum(newGrad, dim)
|
||||
Value sum =
|
||||
createAtenSum(rewriter, loc, op, newGrad, dim, /*keepDim=*/true);
|
||||
if (!sum)
|
||||
return failure();
|
||||
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context));
|
||||
Value broadcastSize =
|
||||
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
|
||||
Value sumBroadcast =
|
||||
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
||||
Value temp =
|
||||
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, sumBroadcast);
|
||||
|
||||
// newGrad - temp
|
||||
Value sub = createAtenSubTensorOp(rewriter, loc, tensorType, newGrad, temp);
|
||||
rewriter.replaceOp(op, sub);
|
||||
Value result = createSoftmaxBackwardCommonKernel(
|
||||
rewriter, loc, op, tensorType, newGrad, output, newGrad, dim);
|
||||
if (!result)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -199,7 +210,7 @@ public:
|
|||
Value gradOutput = op.grad_output();
|
||||
|
||||
// `output` is the value flowing out from tanh. Hence, tanh(x) = output.
|
||||
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
|
||||
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
|
||||
Value output = op.output();
|
||||
|
||||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
||||
|
@ -211,10 +222,8 @@ public:
|
|||
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
|
||||
loc, tensorType, tanhSquare, gradOutput);
|
||||
|
||||
Value alpha =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||||
Value newGrad = rewriter.create<AtenSubTensorOp>(
|
||||
loc, tensorType, gradOutput, gradMulTanhSquare, alpha);
|
||||
Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
|
||||
gradMulTanhSquare);
|
||||
rewriter.replaceOp(op, newGrad);
|
||||
return success();
|
||||
}
|
||||
|
@ -231,7 +240,6 @@ public:
|
|||
LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
Value gradOutput = op.grad_output();
|
||||
Value output = op.output();
|
||||
Value dim = op.dim();
|
||||
|
@ -241,22 +249,13 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
|
||||
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
|
||||
Value sum =
|
||||
createAtenSum(rewriter, loc, op, gradOutput, dim, /*keepDim=*/true);
|
||||
if (!sum)
|
||||
return failure();
|
||||
|
||||
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context));
|
||||
Value broadcastSize =
|
||||
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
|
||||
Value sumBroadcast =
|
||||
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
||||
Value temp =
|
||||
rewriter.create<AtenMulTensorOp>(loc, tensorType, expOut, sumBroadcast);
|
||||
|
||||
Value sub =
|
||||
createAtenSubTensorOp(rewriter, loc, tensorType, gradOutput, temp);
|
||||
rewriter.replaceOp(op, sub);
|
||||
Value result = createSoftmaxBackwardCommonKernel(
|
||||
rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim);
|
||||
if (!result)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue