mirror of https://github.com/llvm/torch-mlir
fix NativeBatchNorm decompose
parent
e7edcc62fd
commit
05149869d4
|
@ -2167,6 +2167,7 @@ class DecomposeAtenNativeBatchNormOp
|
|||
// to make it broadcast-compatible with (N, C, D?, H?, W?).
|
||||
// 1. runningMean = runningMean.view(1, C, 1?, 1?, 1?)
|
||||
// 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?)
|
||||
|
||||
SmallVector<Value> runningStatsShape(inputRank, one);
|
||||
runningStatsShape[1] = numFeatures;
|
||||
Value runningStatsSizeList = rewriter.create<PrimListConstructOp>(
|
||||
|
@ -2178,11 +2179,29 @@ class DecomposeAtenNativeBatchNormOp
|
|||
Type reshapeType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
|
||||
|
||||
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
|
||||
runningStatsSizeList);
|
||||
runningVar = rewriter.create<AtenViewOp>(loc, reshapeType, runningVar,
|
||||
runningStatsSizeList);
|
||||
auto convertRuningStat = [&](Value runningStat) -> Value {
|
||||
Type runningStatDtype =
|
||||
runningStat.getType().cast<ValueTensorType>().getDtype();
|
||||
runningStat = rewriter.create<AtenViewOp>(
|
||||
loc,
|
||||
ValueTensorType::get(context,
|
||||
llvm::makeArrayRef(runningStatsShapeInt),
|
||||
runningStatDtype),
|
||||
runningStat, runningStatsSizeList);
|
||||
|
||||
if (dtype != runningStatDtype) {
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
return rewriter.create<AtenToDtypeOp>(
|
||||
loc, reshapeType, runningStat,
|
||||
getDtypeIntValueForType(rewriter, loc, dtype),
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/none);
|
||||
}
|
||||
return runningStat;
|
||||
};
|
||||
runningMean = convertRuningStat(runningMean);
|
||||
runningVar = convertRuningStat(runningVar);
|
||||
// normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)).
|
||||
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
|
||||
loc, input.getType(), input, runningMean, /*alpha=*/one);
|
||||
|
@ -2202,8 +2221,7 @@ class DecomposeAtenNativeBatchNormOp
|
|||
// Rank of `weight` must be exactly 1.
|
||||
if (getTensorRank(weight) != 1)
|
||||
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
|
||||
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
|
||||
runningStatsSizeList);
|
||||
weight = convertRuningStat(weight);
|
||||
batchNormOutput = rewriter.create<AtenMulTensorOp>(
|
||||
loc, batchNormOutput.getType(), batchNormOutput, weight);
|
||||
}
|
||||
|
@ -2211,8 +2229,7 @@ class DecomposeAtenNativeBatchNormOp
|
|||
// Rank of `bias` must be exactly 1.
|
||||
if (getTensorRank(bias) != 1)
|
||||
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
||||
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
|
||||
runningStatsSizeList);
|
||||
bias = convertRuningStat(bias);
|
||||
batchNormOutput = rewriter.create<AtenAddTensorOp>(
|
||||
loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue