mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Fix the return types of `aten.native_layer_norm`.
This commit fixes the 2nd and 3rd return types of the `aten.native_layer_norm`. Previously the mean and rSTD were returned with reduction dims removed. This commit fixes this and keeps the reduction dims of the results. Signed-Off-By: Prateek Gupta <prateek@nord-labs.com>pull/679/head snapshot-20220317.329
parent
3b66b4925a
commit
7256c9e395
|
@ -204,15 +204,35 @@ class NativeLayerNormModule(torch.nn.Module):
|
|||
])
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
# TODO: Fix the case of the other return values.
|
||||
return torch.ops.aten.native_layer_norm(
|
||||
x, list, weight, bias, eps=0.5)[0]
|
||||
x, list, weight, bias, eps=0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeLayerNormModule())
|
||||
def NativeLayerNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||
|
||||
class NativeLayerNormDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
return torch.ops.aten.native_layer_norm(
|
||||
x, list, weight, bias, eps=0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeLayerNormDynamicModule())
|
||||
def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NativeLayerNormModule4D(torch.nn.Module):
|
||||
|
|
|
@ -1009,22 +1009,37 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Normalization formula:
|
||||
// ((input - mean) / sqrt(var + eps)) * weight + bias
|
||||
static Value createLinalgPayloadCalculationForNormOps(
|
||||
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
|
||||
Value eps, Value weight, Value bias) {
|
||||
Value inputSubMean = b.create<arith::SubFOp>(loc, input, mean);
|
||||
/// Inverted STD: rSTD = 1 / sqrt(var + eps).
|
||||
static Value calculateRSTD(OpBuilder &b, Location loc, Type elemTy, Value eps,
|
||||
Value var) {
|
||||
// The eps is always f64.
|
||||
Value truncatedEps = b.create<arith::TruncFOp>(loc, elemTy, eps);
|
||||
Value varPlusEps = b.create<arith::AddFOp>(loc, var, truncatedEps);
|
||||
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
|
||||
return rSTD;
|
||||
}
|
||||
|
||||
// Normalization formula:
|
||||
// ((input - mean) * rSTD * weight + bias
|
||||
static Value createLinalgPayloadCalculationForNormOpsWithRSTD(
|
||||
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean,
|
||||
Value rSTD, Value eps, Value weight, Value bias) {
|
||||
Value inputSubMean = b.create<arith::SubFOp>(loc, input, mean);
|
||||
Value temp = b.create<arith::MulFOp>(loc, inputSubMean, rSTD);
|
||||
Value timesWeight = b.create<arith::MulFOp>(loc, temp, weight);
|
||||
Value plusBias = b.create<arith::AddFOp>(loc, timesWeight, bias);
|
||||
return plusBias;
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForNormOpsWithVar(
|
||||
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
|
||||
Value eps, Value weight, Value bias) {
|
||||
Value rSTD = calculateRSTD(b, loc, elemTy, eps, var);
|
||||
Value result = createLinalgPayloadCalculationForNormOpsWithRSTD(
|
||||
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
|
||||
public:
|
||||
|
@ -1117,7 +1132,8 @@ public:
|
|||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value input = args[0], weight = args[1], bias = args[2],
|
||||
mean = args[3], var = args[4];
|
||||
Value result = createLinalgPayloadCalculationForNormOps(
|
||||
Value result =
|
||||
createLinalgPayloadCalculationForNormOpsWithVar(
|
||||
b, loc, var.getType(), input, mean, var, eps, weight,
|
||||
bias);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
|
@ -1139,13 +1155,12 @@ public:
|
|||
// | meanAndVarShape | normalizedShape |
|
||||
// +-------------------+---------------------
|
||||
// <------------+ inputShape +-------------->
|
||||
|
||||
// There are the following steps:
|
||||
// Step 1. Check if all the arguments meet the requirements.
|
||||
// Step 2. Common parts to be used for getting mean and var.
|
||||
// This includes elements count, affineMap and iteratorTypes.
|
||||
// Step 3. Get mean.
|
||||
// Step 4. Get var.
|
||||
// Step 4. Get rSTD.
|
||||
// Step 5. Get layernorm.
|
||||
namespace {
|
||||
class ConvertAtenNativeLayerNormOp
|
||||
|
@ -1283,7 +1298,7 @@ public:
|
|||
.getResult(0);
|
||||
Value mean = genMeanOrVarCalculation(sum);
|
||||
|
||||
// Step 4. Get var.
|
||||
// Step 4. Get rSTD.
|
||||
|
||||
// Calculate squareSum for the layer.
|
||||
SmallVector<AffineMap> squareSumIndexingMaps{
|
||||
|
@ -1310,6 +1325,21 @@ public:
|
|||
})
|
||||
.getResult(0);
|
||||
Value var = genMeanOrVarCalculation(squareSum);
|
||||
Value rSTDTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, meanAndVarShapeSizes, elemTy);
|
||||
SmallVector<AffineMap> rSTDIndexingMap(
|
||||
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
|
||||
|
||||
Value rSTD = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, rSTDTensor.getType(), var, rSTDTensor,
|
||||
rSTDIndexingMap, meanAndVarIterationTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value result =
|
||||
calculateRSTD(b, loc, elemTy, eps, args[0]);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
// Step 5. Get layernorm.
|
||||
|
||||
|
@ -1320,7 +1350,6 @@ public:
|
|||
auto normalizedShapeAffineMap = AffineMap::get(
|
||||
/*dimCount=*/inputRank,
|
||||
/*symbolCount=*/0, normalizedShapeExprs, context);
|
||||
|
||||
auto inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
Value initLayerNormTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
|
||||
|
@ -1334,24 +1363,48 @@ public:
|
|||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initLayerNormTensor.getType(),
|
||||
ValueRange{input, mean, var, weight, bias}, initLayerNormTensor,
|
||||
ValueRange{input, mean, rSTD, weight, bias},
|
||||
initLayerNormTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/layerNormIterationTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value input = args[0], mean = args[1], var = args[2],
|
||||
Value input = args[0], mean = args[1], rSTD = args[2],
|
||||
weight = args[3], bias = args[4];
|
||||
Value result = createLinalgPayloadCalculationForNormOps(
|
||||
b, loc, elemTy, input, mean, var, eps, weight, bias);
|
||||
Value result =
|
||||
createLinalgPayloadCalculationForNormOpsWithRSTD(
|
||||
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
SmallVector<int64_t> expandShape(inputRank, 1);
|
||||
for (int i = 0; i < meanAndVarShapeRank; i++) {
|
||||
// `mean` and `rstd` are not yet casted, so they will be having dynamic
|
||||
// shape. Hence to match them, for each dimension corresponding to `mean`
|
||||
// or `rstd` assign -1.
|
||||
expandShape[i] = -1;
|
||||
}
|
||||
auto expandShapeType = RankedTensorType::get(expandShape, elemTy);
|
||||
SmallVector<ReassociationIndices> reassociation(meanAndVarShapeRank);
|
||||
for (auto i : llvm::seq<int64_t>(0, meanAndVarShapeRank)) {
|
||||
reassociation[i].push_back(i);
|
||||
if (i == meanAndVarShapeRank - 1) {
|
||||
for (auto j : llvm::seq<int64_t>(0, normalizedShapeRank))
|
||||
reassociation[i].push_back(i + j + 1);
|
||||
}
|
||||
}
|
||||
Value meanResult = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, expandShapeType, mean, reassociation);
|
||||
Value rSTDResult = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, expandShapeType, rSTD, reassociation);
|
||||
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
|
||||
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
|
||||
Type varResultType = getTypeConverter()->convertType(op.getType(2));
|
||||
Type rSTDResultType = getTypeConverter()->convertType(op.getType(2));
|
||||
Value layerNorm_ =
|
||||
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
|
||||
Value mean_ = rewriter.create<tensor::CastOp>(loc, meanResultType, mean);
|
||||
Value var_ = rewriter.create<tensor::CastOp>(loc, varResultType, var);
|
||||
Value mean_ =
|
||||
rewriter.create<tensor::CastOp>(loc, meanResultType, meanResult);
|
||||
Value var_ =
|
||||
rewriter.create<tensor::CastOp>(loc, rSTDResultType, rSTDResult);
|
||||
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -1118,9 +1118,10 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
|||
Value normalizedShape = op.normalized_shape();
|
||||
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||||
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
||||
std::vector<int64_t> meanVarSizes;
|
||||
for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++)
|
||||
meanVarSizes.push_back(input.getSizes()[i]);
|
||||
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
||||
std::vector<int64_t> meanVarSizes(inputRank, 1);
|
||||
for (int i = 0; i < axis; i++)
|
||||
meanVarSizes[i] = input.getSizes()[i];
|
||||
auto meanVarType = input.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(meanVarSizes), input.getDtype());
|
||||
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
||||
|
|
|
@ -2513,20 +2513,36 @@ module {
|
|||
}
|
||||
func @"__torch_mlir_shape_fn.aten.native_layer_norm"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%str = torch.constant.str "AssertionError: "
|
||||
%none = torch.constant.none
|
||||
%true = torch.constant.bool true
|
||||
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%3 = torch.aten.__range_length %1, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||
%3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int
|
||||
%4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If %4 -> () {
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
||||
torch.prim.If.yield
|
||||
}
|
||||
torch.prim.Loop %3, %true, init() {
|
||||
^bb0(%arg5: !torch.int):
|
||||
%5 = torch.aten.__derive_index %arg5, %1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||
%6 = torch.aten.__getitem__.t %arg0, %5 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%7 = torch.aten.append.t %0, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
%8 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%9 = torch.aten.append.t %0, %8 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
%4 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
|
||||
return %4 : !torch.tuple<list<int>, list<int>, list<int>>
|
||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||
torch.prim.Loop %6, %true, init() {
|
||||
^bb0(%arg5: !torch.int):
|
||||
%8 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
%7 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
|
||||
return %7 : !torch.tuple<list<int>, list<int>, list<int>>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.native_batch_norm"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
|
||||
%int0 = torch.constant.int 0
|
||||
|
|
|
@ -773,25 +773,17 @@ def aten〇nll_loss_forward(self: List[int], target: List[int], weight: Optional
|
|||
def aten〇nll_loss_backward(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.unary(self)
|
||||
|
||||
# TODO: Fix shape function (see body).
|
||||
# @check_shape_function([
|
||||
# Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
||||
# ])
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
||||
])
|
||||
def aten〇native_layer_norm(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]:
|
||||
reduction_shape: List[int] = []
|
||||
# TODO: Fix buggy behavior. TorchToLinalg needs to properly handle the
|
||||
# correctly inferred shapes.
|
||||
# With input=[2, 5, 2, 2, 3] and normalized_shape=[2, 2, 3], we should get
|
||||
# [[2, 5, 2, 2, 3], [2, 5, 1, 1, 1], [2, 5, 1, 1, 1]]
|
||||
for i in range(len(normalized_shape), len(input)):
|
||||
num_unreduced_dimensions = len(input) - len(normalized_shape)
|
||||
assert num_unreduced_dimensions >= 0
|
||||
for i in range(num_unreduced_dimensions):
|
||||
reduction_shape.append(input[i])
|
||||
# Correct code:
|
||||
# num_unreduced_dimensions = len(input) - len(normalized_shape)
|
||||
# assert num_unreduced_dimensions >= 0
|
||||
# for i in range(num_unreduced_dimensions):
|
||||
# reduction_shape.append(input[i])
|
||||
# for i in range(num_unreduced_dimensions, len(input)):
|
||||
# reduction_shape.append(1)
|
||||
for i in range(num_unreduced_dimensions, len(input)):
|
||||
reduction_shape.append(1)
|
||||
return input, reduction_shape, reduction_shape
|
||||
|
||||
@check_shape_function([
|
||||
|
|
Loading…
Reference in New Issue