mirror of https://github.com/llvm/torch-mlir
Refactor to share code in DecomposeComplexOps pass
Share code in `log_softmax_backward` and `softmax_backward` ops.pull/432/head snapshot-20211119.94
parent
ea7a30f9b9
commit
1dc374014b
|
@ -33,7 +33,7 @@ static int getTensorRank(Value tensor) {
|
||||||
return tensorRank;
|
return tensorRank;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createAtenSum(PatternRewriter &rewriter, Location loc,
|
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||||
Operation *op, Value input, Value dim,
|
Operation *op, Value input, Value dim,
|
||||||
bool keepDim) {
|
bool keepDim) {
|
||||||
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
|
||||||
|
@ -69,7 +69,7 @@ static Value createAtenSum(PatternRewriter &rewriter, Location loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper for creating `aten::sub_tensor_op`.
|
// Helper for creating `aten::sub_tensor_op`.
|
||||||
static Value createAtenSubTensorOp(PatternRewriter &rewriter, Location loc,
|
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
||||||
Type tensorType, Value lhs, Value rhs) {
|
Type tensorType, Value lhs, Value rhs) {
|
||||||
Value alpha =
|
Value alpha =
|
||||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||||||
|
@ -78,6 +78,27 @@ static Value createAtenSubTensorOp(PatternRewriter &rewriter, Location loc,
|
||||||
return sub;
|
return sub;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Share code between `softmax_backward` and `log_softmax_backward` ops.
|
||||||
|
// Returns x - y * sum(z, dim).
|
||||||
|
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
||||||
|
Location loc, Operation *op,
|
||||||
|
Type tensorType, Value x,
|
||||||
|
Value y, Value z, Value dim) {
|
||||||
|
Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
|
||||||
|
if (!sum)
|
||||||
|
return nullptr;
|
||||||
|
auto broadcastSizeType =
|
||||||
|
Torch::ListType::get(Torch::IntType::get(op->getContext()));
|
||||||
|
Value broadcastSize = rewriter.create<AtenSizeOp>(loc, broadcastSizeType, z);
|
||||||
|
Value sumBroadcast =
|
||||||
|
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
||||||
|
Value temp =
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, y, sumBroadcast);
|
||||||
|
|
||||||
|
Value sub = createTensorSub(rewriter, loc, tensorType, x, temp);
|
||||||
|
return sub;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -126,7 +147,7 @@ public:
|
||||||
// exp(x)
|
// exp(x)
|
||||||
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
|
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
|
||||||
// sum(exp(x))
|
// sum(exp(x))
|
||||||
Value sum = createAtenSum(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
||||||
if (!sum)
|
if (!sum)
|
||||||
return failure();
|
return failure();
|
||||||
// exp(x) / sum(exp(x))
|
// exp(x) / sum(exp(x))
|
||||||
|
@ -152,7 +173,6 @@ public:
|
||||||
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
|
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
MLIRContext *context = op.getContext();
|
|
||||||
Value gradOutput = op.grad_output();
|
Value gradOutput = op.grad_output();
|
||||||
Value output = op.output();
|
Value output = op.output();
|
||||||
Value dim = op.dim();
|
Value dim = op.dim();
|
||||||
|
@ -163,22 +183,13 @@ public:
|
||||||
|
|
||||||
Value newGrad =
|
Value newGrad =
|
||||||
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
|
||||||
// temp = output * sum(newGrad, dim)
|
Value result = createSoftmaxBackwardCommonKernel(
|
||||||
Value sum =
|
rewriter, loc, op, tensorType, newGrad, output, newGrad, dim);
|
||||||
createAtenSum(rewriter, loc, op, newGrad, dim, /*keepDim=*/true);
|
if (!result)
|
||||||
if (!sum)
|
return rewriter.notifyMatchFailure(
|
||||||
return failure();
|
op,
|
||||||
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context));
|
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
||||||
Value broadcastSize =
|
rewriter.replaceOp(op, result);
|
||||||
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
|
|
||||||
Value sumBroadcast =
|
|
||||||
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
|
||||||
Value temp =
|
|
||||||
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, sumBroadcast);
|
|
||||||
|
|
||||||
// newGrad - temp
|
|
||||||
Value sub = createAtenSubTensorOp(rewriter, loc, tensorType, newGrad, temp);
|
|
||||||
rewriter.replaceOp(op, sub);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -211,10 +222,8 @@ public:
|
||||||
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
|
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
|
||||||
loc, tensorType, tanhSquare, gradOutput);
|
loc, tensorType, tanhSquare, gradOutput);
|
||||||
|
|
||||||
Value alpha =
|
Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
|
||||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
gradMulTanhSquare);
|
||||||
Value newGrad = rewriter.create<AtenSubTensorOp>(
|
|
||||||
loc, tensorType, gradOutput, gradMulTanhSquare, alpha);
|
|
||||||
rewriter.replaceOp(op, newGrad);
|
rewriter.replaceOp(op, newGrad);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -231,7 +240,6 @@ public:
|
||||||
LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op,
|
LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
MLIRContext *context = op.getContext();
|
|
||||||
Value gradOutput = op.grad_output();
|
Value gradOutput = op.grad_output();
|
||||||
Value output = op.output();
|
Value output = op.output();
|
||||||
Value dim = op.dim();
|
Value dim = op.dim();
|
||||||
|
@ -241,22 +249,13 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
|
|
||||||
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
|
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
|
||||||
Value sum =
|
Value result = createSoftmaxBackwardCommonKernel(
|
||||||
createAtenSum(rewriter, loc, op, gradOutput, dim, /*keepDim=*/true);
|
rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim);
|
||||||
if (!sum)
|
if (!result)
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context));
|
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
||||||
Value broadcastSize =
|
rewriter.replaceOp(op, result);
|
||||||
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
|
|
||||||
Value sumBroadcast =
|
|
||||||
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
|
||||||
Value temp =
|
|
||||||
rewriter.create<AtenMulTensorOp>(loc, tensorType, expOut, sumBroadcast);
|
|
||||||
|
|
||||||
Value sub =
|
|
||||||
createAtenSubTensorOp(rewriter, loc, tensorType, gradOutput, temp);
|
|
||||||
rewriter.replaceOp(op, sub);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue