mirror of https://github.com/llvm/torch-mlir
parent
bbd3094c2f
commit
c23a61f4b6
|
@ -2825,7 +2825,7 @@ class DecomposeAtenNativeBatchNormOp
|
||||||
loc, ListType::get(IntType::get(context)), runningStatsShape);
|
loc, ListType::get(IntType::get(context)), runningStatsShape);
|
||||||
|
|
||||||
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
||||||
runningStatsShapeInt[1] = kUnknownSize;
|
runningStatsShapeInt[1] = runningMean.getType().cast<BaseTensorType>().getSizes()[0];
|
||||||
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||||
Type reshapeType = ValueTensorType::get(
|
Type reshapeType = ValueTensorType::get(
|
||||||
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
|
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
|
||||||
|
|
Loading…
Reference in New Issue