fix NativeBatchNorm decompose

tanyo/fix_upstream
TanyoKwok 2022-12-01 21:46:16 +08:00
parent e7edcc62fd
commit 05149869d4
1 changed files with 25 additions and 8 deletions

View File

@ -2167,6 +2167,7 @@ class DecomposeAtenNativeBatchNormOp
// to make it broadcast-compatible with (N, C, D?, H?, W?).
// 1. runningMean = runningMean.view(1, C, 1?, 1?, 1?)
// 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?)
SmallVector<Value> runningStatsShape(inputRank, one);
runningStatsShape[1] = numFeatures;
Value runningStatsSizeList = rewriter.create<PrimListConstructOp>(
@ -2178,11 +2179,29 @@ class DecomposeAtenNativeBatchNormOp
Type reshapeType = ValueTensorType::get(
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
runningStatsSizeList);
runningVar = rewriter.create<AtenViewOp>(loc, reshapeType, runningVar,
runningStatsSizeList);
auto convertRuningStat = [&](Value runningStat) -> Value {
Type runningStatDtype =
runningStat.getType().cast<ValueTensorType>().getDtype();
runningStat = rewriter.create<AtenViewOp>(
loc,
ValueTensorType::get(context,
llvm::makeArrayRef(runningStatsShapeInt),
runningStatDtype),
runningStat, runningStatsSizeList);
if (dtype != runningStatDtype) {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value none = rewriter.create<ConstantNoneOp>(loc);
return rewriter.create<AtenToDtypeOp>(
loc, reshapeType, runningStat,
getDtypeIntValueForType(rewriter, loc, dtype),
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
return runningStat;
};
runningMean = convertRuningStat(runningMean);
runningVar = convertRuningStat(runningVar);
// normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)).
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
loc, input.getType(), input, runningMean, /*alpha=*/one);
@ -2202,8 +2221,7 @@ class DecomposeAtenNativeBatchNormOp
// Rank of `weight` must be exactly 1.
if (getTensorRank(weight) != 1)
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
runningStatsSizeList);
weight = convertRuningStat(weight);
batchNormOutput = rewriter.create<AtenMulTensorOp>(
loc, batchNormOutput.getType(), batchNormOutput, weight);
}
@ -2211,8 +2229,7 @@ class DecomposeAtenNativeBatchNormOp
// Rank of `bias` must be exactly 1.
if (getTensorRank(bias) != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
runningStatsSizeList);
bias = convertRuningStat(bias);
batchNormOutput = rewriter.create<AtenAddTensorOp>(
loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one);
}