[torch] Fix DecomposeAtenInstanceNorm decomposition (#2960)

The decomposition only suports a NCHW lowering however the operation can
support arbitrary spatial dimensions. Updated the lowering to better
support spatial dimensions.
pull/2958/head
Rob Suderman 2024-02-28 10:27:19 -08:00 committed by GitHub
parent dd673cfa8d
commit 73b6df9007
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 28 additions and 36 deletions

View File

@ -4025,25 +4025,20 @@ class DecomposeAtenInstanceNormOp
auto inputTy = op.getInput().getType().cast<BaseTensorType>();
int64_t inputRank = inputTy.getSizes().size();
auto reduceDimInts =
llvm::SmallVector<int64_t>({inputRank - 2, inputRank - 1});
SmallVector<int64_t> reducedShape(inputTy.getSizes());
reducedShape[inputRank - 1] = 1;
reducedShape[inputRank - 2] = 1;
SmallVector<int64_t> reduceDimInts;
SmallVector<Value> reduceDimVals;
for (int i = 2; i < inputRank; ++i) {
reducedShape[i] = 1;
reduceDimVals.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
}
Type dtype = inputTy.getOptionalDtype();
Type reducedTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(reducedShape), dtype);
auto sizeListType = ListType::get(IntType::get(context));
SmallVector<Value> reduceDimVals;
reduceDimVals.reserve(reduceDimInts.size());
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
std::back_inserter(reduceDimVals), [&](int64_t d) {
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(d));
});
Value reduceDimList =
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
@ -4069,9 +4064,12 @@ class DecomposeAtenInstanceNormOp
loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue,
/*dtype=*/none);
int64_t elemCount = 1;
for (int i = 2; i < inputRank; ++i)
elemCount *= inputTy.getSizes()[i];
Value hw = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] *
inputTy.getSizes()[inputRank - 2]));
loc, rewriter.getI64IntegerAttr(elemCount));
Value inputVar =
rewriter.create<AtenDivScalarOp>(loc, reducedTy, variancesum, hw);
@ -4104,19 +4102,14 @@ class DecomposeAtenInstanceNormOp
op.getContext(), llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, zero);
Value two = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
while (static_cast<int64_t>(newWeightShape.size()) < inputRank) {
Value i = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(newWeightShape.size()));
newWeightShape.push_back(1);
newWeightTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, two);
Value three = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
newWeightShape.push_back(1);
newWeightTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, three);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, i);
}
Value weightExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());
@ -4134,15 +4127,14 @@ class DecomposeAtenInstanceNormOp
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, zero);
while (static_cast<int64_t>(newBiasShape.size()) < inputRank) {
Value i = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(newBiasShape.size()));
newBiasShape.push_back(1);
newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, two);
newBiasShape.push_back(1);
newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, three);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, i);
}
Value biasExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, bias, op.getInput());