mirror of https://github.com/llvm/torch-mlir
[torch] Fix DecomposeAtenInstanceNorm decomposition (#2960)
The decomposition only suports a NCHW lowering however the operation can support arbitrary spatial dimensions. Updated the lowering to better support spatial dimensions.pull/2958/head
parent
dd673cfa8d
commit
73b6df9007
|
@ -4025,25 +4025,20 @@ class DecomposeAtenInstanceNormOp
|
|||
|
||||
auto inputTy = op.getInput().getType().cast<BaseTensorType>();
|
||||
int64_t inputRank = inputTy.getSizes().size();
|
||||
auto reduceDimInts =
|
||||
llvm::SmallVector<int64_t>({inputRank - 2, inputRank - 1});
|
||||
|
||||
SmallVector<int64_t> reducedShape(inputTy.getSizes());
|
||||
reducedShape[inputRank - 1] = 1;
|
||||
reducedShape[inputRank - 2] = 1;
|
||||
SmallVector<int64_t> reduceDimInts;
|
||||
SmallVector<Value> reduceDimVals;
|
||||
for (int i = 2; i < inputRank; ++i) {
|
||||
reducedShape[i] = 1;
|
||||
reduceDimVals.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
|
||||
Type dtype = inputTy.getOptionalDtype();
|
||||
Type reducedTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(reducedShape), dtype);
|
||||
|
||||
auto sizeListType = ListType::get(IntType::get(context));
|
||||
SmallVector<Value> reduceDimVals;
|
||||
reduceDimVals.reserve(reduceDimInts.size());
|
||||
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
|
||||
std::back_inserter(reduceDimVals), [&](int64_t d) {
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(d));
|
||||
});
|
||||
Value reduceDimList =
|
||||
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
|
||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
||||
|
@ -4069,9 +4064,12 @@ class DecomposeAtenInstanceNormOp
|
|||
loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue,
|
||||
/*dtype=*/none);
|
||||
|
||||
int64_t elemCount = 1;
|
||||
for (int i = 2; i < inputRank; ++i)
|
||||
elemCount *= inputTy.getSizes()[i];
|
||||
|
||||
Value hw = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] *
|
||||
inputTy.getSizes()[inputRank - 2]));
|
||||
loc, rewriter.getI64IntegerAttr(elemCount));
|
||||
Value inputVar =
|
||||
rewriter.create<AtenDivScalarOp>(loc, reducedTy, variancesum, hw);
|
||||
|
||||
|
@ -4104,19 +4102,14 @@ class DecomposeAtenInstanceNormOp
|
|||
op.getContext(), llvm::ArrayRef(newWeightShape), dtype);
|
||||
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, zero);
|
||||
|
||||
Value two = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(2));
|
||||
newWeightShape.push_back(1);
|
||||
newWeightTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newWeightShape), dtype);
|
||||
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, two);
|
||||
|
||||
Value three = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(3));
|
||||
newWeightShape.push_back(1);
|
||||
newWeightTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newWeightShape), dtype);
|
||||
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, three);
|
||||
while (static_cast<int64_t>(newWeightShape.size()) < inputRank) {
|
||||
Value i = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(newWeightShape.size()));
|
||||
newWeightShape.push_back(1);
|
||||
newWeightTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newWeightShape), dtype);
|
||||
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, i);
|
||||
}
|
||||
|
||||
Value weightExpanded =
|
||||
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());
|
||||
|
@ -4134,15 +4127,14 @@ class DecomposeAtenInstanceNormOp
|
|||
llvm::ArrayRef(newBiasShape), dtype);
|
||||
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, zero);
|
||||
|
||||
newBiasShape.push_back(1);
|
||||
newBiasTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newBiasShape), dtype);
|
||||
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, two);
|
||||
|
||||
newBiasShape.push_back(1);
|
||||
newBiasTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newBiasShape), dtype);
|
||||
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, three);
|
||||
while (static_cast<int64_t>(newBiasShape.size()) < inputRank) {
|
||||
Value i = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(newBiasShape.size()));
|
||||
newBiasShape.push_back(1);
|
||||
newBiasTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newBiasShape), dtype);
|
||||
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, i);
|
||||
}
|
||||
|
||||
Value biasExpanded =
|
||||
rewriter.create<AtenExpandAsOp>(loc, inputTy, bias, op.getInput());
|
||||
|
|
Loading…
Reference in New Issue