[TORCH][MLIR] Fix the return types of `aten.native_layer_norm`.

This commit fixes the 2nd and 3rd return types of the `aten.native_layer_norm`.
Previously the mean and rSTD were returned with reduction dims removed.
This commit fixes this and keeps the reduction dims of the results.

Signed-Off-By: Prateek Gupta <prateek@nord-labs.com>
pull/679/head snapshot-20220317.329
Prateek Gupta 2022-03-16 12:51:57 +00:00
parent 3b66b4925a
commit 7256c9e395
5 changed files with 131 additions and 49 deletions

View File

@ -204,15 +204,35 @@ class NativeLayerNormModule(torch.nn.Module):
]) ])
def forward(self, x, weight, bias): def forward(self, x, weight, bias):
list = [2, 2, 3] list = [2, 2, 3]
# TODO: Fix the case of the other return values.
return torch.ops.aten.native_layer_norm( return torch.ops.aten.native_layer_norm(
x, list, weight, bias, eps=0.5)[0] x, list, weight, bias, eps=0.5)
@register_test_case(module_factory=lambda: NativeLayerNormModule()) @register_test_case(module_factory=lambda: NativeLayerNormModule())
def NativeLayerNormModule_basic(module, tu: TestUtils): def NativeLayerNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3)) module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
class NativeLayerNormDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, weight, bias):
list = [2, 2, 3]
return torch.ops.aten.native_layer_norm(
x, list, weight, bias, eps=0.5)
@register_test_case(module_factory=lambda: NativeLayerNormDynamicModule())
def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
# ============================================================================== # ==============================================================================
class NativeLayerNormModule4D(torch.nn.Module): class NativeLayerNormModule4D(torch.nn.Module):

View File

@ -1009,22 +1009,37 @@ public:
}; };
} // namespace } // namespace
// Normalization formula: /// Inverted STD: rSTD = 1 / sqrt(var + eps).
// ((input - mean) / sqrt(var + eps)) * weight + bias static Value calculateRSTD(OpBuilder &b, Location loc, Type elemTy, Value eps,
static Value createLinalgPayloadCalculationForNormOps( Value var) {
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
Value eps, Value weight, Value bias) {
Value inputSubMean = b.create<arith::SubFOp>(loc, input, mean);
// The eps is always f64. // The eps is always f64.
Value truncatedEps = b.create<arith::TruncFOp>(loc, elemTy, eps); Value truncatedEps = b.create<arith::TruncFOp>(loc, elemTy, eps);
Value varPlusEps = b.create<arith::AddFOp>(loc, var, truncatedEps); Value varPlusEps = b.create<arith::AddFOp>(loc, var, truncatedEps);
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps); Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
return rSTD;
}
// Normalization formula:
// ((input - mean) * rSTD * weight + bias
static Value createLinalgPayloadCalculationForNormOpsWithRSTD(
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean,
Value rSTD, Value eps, Value weight, Value bias) {
Value inputSubMean = b.create<arith::SubFOp>(loc, input, mean);
Value temp = b.create<arith::MulFOp>(loc, inputSubMean, rSTD); Value temp = b.create<arith::MulFOp>(loc, inputSubMean, rSTD);
Value timesWeight = b.create<arith::MulFOp>(loc, temp, weight); Value timesWeight = b.create<arith::MulFOp>(loc, temp, weight);
Value plusBias = b.create<arith::AddFOp>(loc, timesWeight, bias); Value plusBias = b.create<arith::AddFOp>(loc, timesWeight, bias);
return plusBias; return plusBias;
} }
static Value createLinalgPayloadCalculationForNormOpsWithVar(
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
Value eps, Value weight, Value bias) {
Value rSTD = calculateRSTD(b, loc, elemTy, eps, var);
Value result = createLinalgPayloadCalculationForNormOpsWithRSTD(
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
return result;
}
namespace { namespace {
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> { class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
public: public:
@ -1117,7 +1132,8 @@ public:
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], weight = args[1], bias = args[2], Value input = args[0], weight = args[1], bias = args[2],
mean = args[3], var = args[4]; mean = args[3], var = args[4];
Value result = createLinalgPayloadCalculationForNormOps( Value result =
createLinalgPayloadCalculationForNormOpsWithVar(
b, loc, var.getType(), input, mean, var, eps, weight, b, loc, var.getType(), input, mean, var, eps, weight,
bias); bias);
b.create<linalg::YieldOp>(loc, result); b.create<linalg::YieldOp>(loc, result);
@ -1139,13 +1155,12 @@ public:
// | meanAndVarShape | normalizedShape | // | meanAndVarShape | normalizedShape |
// +-------------------+--------------------- // +-------------------+---------------------
// <------------+ inputShape +--------------> // <------------+ inputShape +-------------->
// There are the following steps: // There are the following steps:
// Step 1. Check if all the arguments meet the requirements. // Step 1. Check if all the arguments meet the requirements.
// Step 2. Common parts to be used for getting mean and var. // Step 2. Common parts to be used for getting mean and var.
// This includes elements count, affineMap and iteratorTypes. // This includes elements count, affineMap and iteratorTypes.
// Step 3. Get mean. // Step 3. Get mean.
// Step 4. Get var. // Step 4. Get rSTD.
// Step 5. Get layernorm. // Step 5. Get layernorm.
namespace { namespace {
class ConvertAtenNativeLayerNormOp class ConvertAtenNativeLayerNormOp
@ -1283,7 +1298,7 @@ public:
.getResult(0); .getResult(0);
Value mean = genMeanOrVarCalculation(sum); Value mean = genMeanOrVarCalculation(sum);
// Step 4. Get var. // Step 4. Get rSTD.
// Calculate squareSum for the layer. // Calculate squareSum for the layer.
SmallVector<AffineMap> squareSumIndexingMaps{ SmallVector<AffineMap> squareSumIndexingMaps{
@ -1310,6 +1325,21 @@ public:
}) })
.getResult(0); .getResult(0);
Value var = genMeanOrVarCalculation(squareSum); Value var = genMeanOrVarCalculation(squareSum);
Value rSTDTensor = rewriter.create<linalg::InitTensorOp>(
loc, meanAndVarShapeSizes, elemTy);
SmallVector<AffineMap> rSTDIndexingMap(
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
Value rSTD = rewriter
.create<linalg::GenericOp>(
loc, rSTDTensor.getType(), var, rSTDTensor,
rSTDIndexingMap, meanAndVarIterationTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value result =
calculateRSTD(b, loc, elemTy, eps, args[0]);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
// Step 5. Get layernorm. // Step 5. Get layernorm.
@ -1320,7 +1350,6 @@ public:
auto normalizedShapeAffineMap = AffineMap::get( auto normalizedShapeAffineMap = AffineMap::get(
/*dimCount=*/inputRank, /*dimCount=*/inputRank,
/*symbolCount=*/0, normalizedShapeExprs, context); /*symbolCount=*/0, normalizedShapeExprs, context);
auto inputSizes = getTensorSizes(rewriter, loc, input); auto inputSizes = getTensorSizes(rewriter, loc, input);
Value initLayerNormTensor = Value initLayerNormTensor =
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy); rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
@ -1334,24 +1363,48 @@ public:
rewriter rewriter
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
loc, initLayerNormTensor.getType(), loc, initLayerNormTensor.getType(),
ValueRange{input, mean, var, weight, bias}, initLayerNormTensor, ValueRange{input, mean, rSTD, weight, bias},
initLayerNormTensor,
/*indexingMaps=*/indexingMaps, /*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/layerNormIterationTypes, /*iteratorTypes=*/layerNormIterationTypes,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], var = args[2], Value input = args[0], mean = args[1], rSTD = args[2],
weight = args[3], bias = args[4]; weight = args[3], bias = args[4];
Value result = createLinalgPayloadCalculationForNormOps( Value result =
b, loc, elemTy, input, mean, var, eps, weight, bias); createLinalgPayloadCalculationForNormOpsWithRSTD(
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
b.create<linalg::YieldOp>(loc, result); b.create<linalg::YieldOp>(loc, result);
}) })
.getResult(0); .getResult(0);
SmallVector<int64_t> expandShape(inputRank, 1);
for (int i = 0; i < meanAndVarShapeRank; i++) {
// `mean` and `rstd` are not yet casted, so they will be having dynamic
// shape. Hence to match them, for each dimension corresponding to `mean`
// or `rstd` assign -1.
expandShape[i] = -1;
}
auto expandShapeType = RankedTensorType::get(expandShape, elemTy);
SmallVector<ReassociationIndices> reassociation(meanAndVarShapeRank);
for (auto i : llvm::seq<int64_t>(0, meanAndVarShapeRank)) {
reassociation[i].push_back(i);
if (i == meanAndVarShapeRank - 1) {
for (auto j : llvm::seq<int64_t>(0, normalizedShapeRank))
reassociation[i].push_back(i + j + 1);
}
}
Value meanResult = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, mean, reassociation);
Value rSTDResult = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, rSTD, reassociation);
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0)); Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
Type meanResultType = getTypeConverter()->convertType(op.getType(1)); Type meanResultType = getTypeConverter()->convertType(op.getType(1));
Type varResultType = getTypeConverter()->convertType(op.getType(2)); Type rSTDResultType = getTypeConverter()->convertType(op.getType(2));
Value layerNorm_ = Value layerNorm_ =
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm); rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
Value mean_ = rewriter.create<tensor::CastOp>(loc, meanResultType, mean); Value mean_ =
Value var_ = rewriter.create<tensor::CastOp>(loc, varResultType, var); rewriter.create<tensor::CastOp>(loc, meanResultType, meanResult);
Value var_ =
rewriter.create<tensor::CastOp>(loc, rSTDResultType, rSTDResult);
rewriter.replaceOp(op, {layerNorm_, mean_, var_}); rewriter.replaceOp(op, {layerNorm_, mean_, var_});
return success(); return success();
} }

View File

@ -1118,9 +1118,10 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
Value normalizedShape = op.normalized_shape(); Value normalizedShape = op.normalized_shape();
SmallVector<Value> normalizedShapeSizesTorchInt; SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
std::vector<int64_t> meanVarSizes; int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++) std::vector<int64_t> meanVarSizes(inputRank, 1);
meanVarSizes.push_back(input.getSizes()[i]); for (int i = 0; i < axis; i++)
meanVarSizes[i] = input.getSizes()[i];
auto meanVarType = input.getWithSizesAndDtype( auto meanVarType = input.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarSizes), input.getDtype()); llvm::makeArrayRef(meanVarSizes), input.getDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>( auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(

View File

@ -2513,20 +2513,36 @@ module {
} }
func @"__torch_mlir_shape_fn.aten.native_layer_norm"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> { func @"__torch_mlir_shape_fn.aten.native_layer_norm"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%str = torch.constant.str "AssertionError: "
%none = torch.constant.none
%true = torch.constant.bool true %true = torch.constant.bool true
%0 = torch.prim.ListConstruct : () -> !torch.list<int> %0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%3 = torch.aten.__range_length %1, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int
%4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %4 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop %3, %true, init() { torch.prim.Loop %3, %true, init() {
^bb0(%arg5: !torch.int): ^bb0(%arg5: !torch.int):
%5 = torch.aten.__derive_index %arg5, %1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int %8 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%6 = torch.aten.__getitem__.t %arg0, %5 : !torch.list<int>, !torch.int -> !torch.int %9 = torch.aten.append.t %0, %8 : !torch.list<int>, !torch.int -> !torch.list<int>
%7 = torch.aten.append.t %0, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter() torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> () } : (!torch.int, !torch.bool) -> ()
%4 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>> %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
return %4 : !torch.tuple<list<int>, list<int>, list<int>> %6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %6, %true, init() {
^bb0(%arg5: !torch.int):
%8 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%7 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
return %7 : !torch.tuple<list<int>, list<int>, list<int>>
} }
func @"__torch_mlir_shape_fn.aten.native_batch_norm"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> { func @"__torch_mlir_shape_fn.aten.native_batch_norm"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0

View File

@ -773,25 +773,17 @@ def atennll_loss_forward(self: List[int], target: List[int], weight: Optional
def atennll_loss_backward(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]: def atennll_loss_backward(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
return upstream_shape_helpers.unary(self) return upstream_shape_helpers.unary(self)
# TODO: Fix shape function (see body). @check_shape_function([
# @check_shape_function([ Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
# Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case. ])
# ])
def atennative_layer_norm(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]: def atennative_layer_norm(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]:
reduction_shape: List[int] = [] reduction_shape: List[int] = []
# TODO: Fix buggy behavior. TorchToLinalg needs to properly handle the num_unreduced_dimensions = len(input) - len(normalized_shape)
# correctly inferred shapes. assert num_unreduced_dimensions >= 0
# With input=[2, 5, 2, 2, 3] and normalized_shape=[2, 2, 3], we should get for i in range(num_unreduced_dimensions):
# [[2, 5, 2, 2, 3], [2, 5, 1, 1, 1], [2, 5, 1, 1, 1]]
for i in range(len(normalized_shape), len(input)):
reduction_shape.append(input[i]) reduction_shape.append(input[i])
# Correct code: for i in range(num_unreduced_dimensions, len(input)):
# num_unreduced_dimensions = len(input) - len(normalized_shape) reduction_shape.append(1)
# assert num_unreduced_dimensions >= 0
# for i in range(num_unreduced_dimensions):
# reduction_shape.append(input[i])
# for i in range(num_unreduced_dimensions, len(input)):
# reduction_shape.append(1)
return input, reduction_shape, reduction_shape return input, reduction_shape, reduction_shape
@check_shape_function([ @check_shape_function([