DecomposeComplexOps: Use static shape if available (#2289)

pull/2302/head snapshot-20230712.897
Matthias Gehre 2023-07-12 10:07:30 +02:00 committed by GitHub
parent bbd3094c2f
commit c23a61f4b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -2825,7 +2825,7 @@ class DecomposeAtenNativeBatchNormOp
loc, ListType::get(IntType::get(context)), runningStatsShape); loc, ListType::get(IntType::get(context)), runningStatsShape);
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1); SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
runningStatsShapeInt[1] = kUnknownSize; runningStatsShapeInt[1] = runningMean.getType().cast<BaseTensorType>().getSizes()[0];
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype(); Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
Type reshapeType = ValueTensorType::get( Type reshapeType = ValueTensorType::get(
context, llvm::ArrayRef(runningStatsShapeInt), dtype); context, llvm::ArrayRef(runningStatsShapeInt), dtype);