mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Fix the return types of `aten.native_layer_norm`.
This commit fixes the 2nd and 3rd return types of the `aten.native_layer_norm`. Previously the mean and rSTD were returned with reduction dims removed. This commit fixes this and keeps the reduction dims of the results. Signed-Off-By: Prateek Gupta <prateek@nord-labs.com>pull/679/head snapshot-20220317.329
parent
3b66b4925a
commit
7256c9e395
|
@ -204,15 +204,35 @@ class NativeLayerNormModule(torch.nn.Module):
|
||||||
])
|
])
|
||||||
def forward(self, x, weight, bias):
|
def forward(self, x, weight, bias):
|
||||||
list = [2, 2, 3]
|
list = [2, 2, 3]
|
||||||
# TODO: Fix the case of the other return values.
|
|
||||||
return torch.ops.aten.native_layer_norm(
|
return torch.ops.aten.native_layer_norm(
|
||||||
x, list, weight, bias, eps=0.5)[0]
|
x, list, weight, bias, eps=0.5)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeLayerNormModule())
|
@register_test_case(module_factory=lambda: NativeLayerNormModule())
|
||||||
def NativeLayerNormModule_basic(module, tu: TestUtils):
|
def NativeLayerNormModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||||
|
|
||||||
|
class NativeLayerNormDynamicModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, weight, bias):
|
||||||
|
list = [2, 2, 3]
|
||||||
|
return torch.ops.aten.native_layer_norm(
|
||||||
|
x, list, weight, bias, eps=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NativeLayerNormDynamicModule())
|
||||||
|
def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class NativeLayerNormModule4D(torch.nn.Module):
|
class NativeLayerNormModule4D(torch.nn.Module):
|
||||||
|
|
|
@ -1009,22 +1009,37 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Normalization formula:
|
/// Inverted STD: rSTD = 1 / sqrt(var + eps).
|
||||||
// ((input - mean) / sqrt(var + eps)) * weight + bias
|
static Value calculateRSTD(OpBuilder &b, Location loc, Type elemTy, Value eps,
|
||||||
static Value createLinalgPayloadCalculationForNormOps(
|
Value var) {
|
||||||
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
|
|
||||||
Value eps, Value weight, Value bias) {
|
|
||||||
Value inputSubMean = b.create<arith::SubFOp>(loc, input, mean);
|
|
||||||
// The eps is always f64.
|
// The eps is always f64.
|
||||||
Value truncatedEps = b.create<arith::TruncFOp>(loc, elemTy, eps);
|
Value truncatedEps = b.create<arith::TruncFOp>(loc, elemTy, eps);
|
||||||
Value varPlusEps = b.create<arith::AddFOp>(loc, var, truncatedEps);
|
Value varPlusEps = b.create<arith::AddFOp>(loc, var, truncatedEps);
|
||||||
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
|
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
|
||||||
|
return rSTD;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalization formula:
|
||||||
|
// ((input - mean) * rSTD * weight + bias
|
||||||
|
static Value createLinalgPayloadCalculationForNormOpsWithRSTD(
|
||||||
|
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean,
|
||||||
|
Value rSTD, Value eps, Value weight, Value bias) {
|
||||||
|
Value inputSubMean = b.create<arith::SubFOp>(loc, input, mean);
|
||||||
Value temp = b.create<arith::MulFOp>(loc, inputSubMean, rSTD);
|
Value temp = b.create<arith::MulFOp>(loc, inputSubMean, rSTD);
|
||||||
Value timesWeight = b.create<arith::MulFOp>(loc, temp, weight);
|
Value timesWeight = b.create<arith::MulFOp>(loc, temp, weight);
|
||||||
Value plusBias = b.create<arith::AddFOp>(loc, timesWeight, bias);
|
Value plusBias = b.create<arith::AddFOp>(loc, timesWeight, bias);
|
||||||
return plusBias;
|
return plusBias;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createLinalgPayloadCalculationForNormOpsWithVar(
|
||||||
|
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
|
||||||
|
Value eps, Value weight, Value bias) {
|
||||||
|
Value rSTD = calculateRSTD(b, loc, elemTy, eps, var);
|
||||||
|
Value result = createLinalgPayloadCalculationForNormOpsWithRSTD(
|
||||||
|
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
|
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -1117,9 +1132,10 @@ public:
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
Value input = args[0], weight = args[1], bias = args[2],
|
Value input = args[0], weight = args[1], bias = args[2],
|
||||||
mean = args[3], var = args[4];
|
mean = args[3], var = args[4];
|
||||||
Value result = createLinalgPayloadCalculationForNormOps(
|
Value result =
|
||||||
b, loc, var.getType(), input, mean, var, eps, weight,
|
createLinalgPayloadCalculationForNormOpsWithVar(
|
||||||
bias);
|
b, loc, var.getType(), input, mean, var, eps, weight,
|
||||||
|
bias);
|
||||||
b.create<linalg::YieldOp>(loc, result);
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
@ -1139,13 +1155,12 @@ public:
|
||||||
// | meanAndVarShape | normalizedShape |
|
// | meanAndVarShape | normalizedShape |
|
||||||
// +-------------------+---------------------
|
// +-------------------+---------------------
|
||||||
// <------------+ inputShape +-------------->
|
// <------------+ inputShape +-------------->
|
||||||
|
|
||||||
// There are the following steps:
|
// There are the following steps:
|
||||||
// Step 1. Check if all the arguments meet the requirements.
|
// Step 1. Check if all the arguments meet the requirements.
|
||||||
// Step 2. Common parts to be used for getting mean and var.
|
// Step 2. Common parts to be used for getting mean and var.
|
||||||
// This includes elements count, affineMap and iteratorTypes.
|
// This includes elements count, affineMap and iteratorTypes.
|
||||||
// Step 3. Get mean.
|
// Step 3. Get mean.
|
||||||
// Step 4. Get var.
|
// Step 4. Get rSTD.
|
||||||
// Step 5. Get layernorm.
|
// Step 5. Get layernorm.
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenNativeLayerNormOp
|
class ConvertAtenNativeLayerNormOp
|
||||||
|
@ -1283,7 +1298,7 @@ public:
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
Value mean = genMeanOrVarCalculation(sum);
|
Value mean = genMeanOrVarCalculation(sum);
|
||||||
|
|
||||||
// Step 4. Get var.
|
// Step 4. Get rSTD.
|
||||||
|
|
||||||
// Calculate squareSum for the layer.
|
// Calculate squareSum for the layer.
|
||||||
SmallVector<AffineMap> squareSumIndexingMaps{
|
SmallVector<AffineMap> squareSumIndexingMaps{
|
||||||
|
@ -1310,6 +1325,21 @@ public:
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
Value var = genMeanOrVarCalculation(squareSum);
|
Value var = genMeanOrVarCalculation(squareSum);
|
||||||
|
Value rSTDTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, meanAndVarShapeSizes, elemTy);
|
||||||
|
SmallVector<AffineMap> rSTDIndexingMap(
|
||||||
|
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
|
||||||
|
|
||||||
|
Value rSTD = rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, rSTDTensor.getType(), var, rSTDTensor,
|
||||||
|
rSTDIndexingMap, meanAndVarIterationTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value result =
|
||||||
|
calculateRSTD(b, loc, elemTy, eps, args[0]);
|
||||||
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
// Step 5. Get layernorm.
|
// Step 5. Get layernorm.
|
||||||
|
|
||||||
|
@ -1320,7 +1350,6 @@ public:
|
||||||
auto normalizedShapeAffineMap = AffineMap::get(
|
auto normalizedShapeAffineMap = AffineMap::get(
|
||||||
/*dimCount=*/inputRank,
|
/*dimCount=*/inputRank,
|
||||||
/*symbolCount=*/0, normalizedShapeExprs, context);
|
/*symbolCount=*/0, normalizedShapeExprs, context);
|
||||||
|
|
||||||
auto inputSizes = getTensorSizes(rewriter, loc, input);
|
auto inputSizes = getTensorSizes(rewriter, loc, input);
|
||||||
Value initLayerNormTensor =
|
Value initLayerNormTensor =
|
||||||
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
|
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
|
||||||
|
@ -1334,24 +1363,48 @@ public:
|
||||||
rewriter
|
rewriter
|
||||||
.create<linalg::GenericOp>(
|
.create<linalg::GenericOp>(
|
||||||
loc, initLayerNormTensor.getType(),
|
loc, initLayerNormTensor.getType(),
|
||||||
ValueRange{input, mean, var, weight, bias}, initLayerNormTensor,
|
ValueRange{input, mean, rSTD, weight, bias},
|
||||||
|
initLayerNormTensor,
|
||||||
/*indexingMaps=*/indexingMaps,
|
/*indexingMaps=*/indexingMaps,
|
||||||
/*iteratorTypes=*/layerNormIterationTypes,
|
/*iteratorTypes=*/layerNormIterationTypes,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
Value input = args[0], mean = args[1], var = args[2],
|
Value input = args[0], mean = args[1], rSTD = args[2],
|
||||||
weight = args[3], bias = args[4];
|
weight = args[3], bias = args[4];
|
||||||
Value result = createLinalgPayloadCalculationForNormOps(
|
Value result =
|
||||||
b, loc, elemTy, input, mean, var, eps, weight, bias);
|
createLinalgPayloadCalculationForNormOpsWithRSTD(
|
||||||
|
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
|
||||||
b.create<linalg::YieldOp>(loc, result);
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
SmallVector<int64_t> expandShape(inputRank, 1);
|
||||||
|
for (int i = 0; i < meanAndVarShapeRank; i++) {
|
||||||
|
// `mean` and `rstd` are not yet casted, so they will be having dynamic
|
||||||
|
// shape. Hence to match them, for each dimension corresponding to `mean`
|
||||||
|
// or `rstd` assign -1.
|
||||||
|
expandShape[i] = -1;
|
||||||
|
}
|
||||||
|
auto expandShapeType = RankedTensorType::get(expandShape, elemTy);
|
||||||
|
SmallVector<ReassociationIndices> reassociation(meanAndVarShapeRank);
|
||||||
|
for (auto i : llvm::seq<int64_t>(0, meanAndVarShapeRank)) {
|
||||||
|
reassociation[i].push_back(i);
|
||||||
|
if (i == meanAndVarShapeRank - 1) {
|
||||||
|
for (auto j : llvm::seq<int64_t>(0, normalizedShapeRank))
|
||||||
|
reassociation[i].push_back(i + j + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Value meanResult = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
|
loc, expandShapeType, mean, reassociation);
|
||||||
|
Value rSTDResult = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
|
loc, expandShapeType, rSTD, reassociation);
|
||||||
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
|
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
|
||||||
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
|
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
|
||||||
Type varResultType = getTypeConverter()->convertType(op.getType(2));
|
Type rSTDResultType = getTypeConverter()->convertType(op.getType(2));
|
||||||
Value layerNorm_ =
|
Value layerNorm_ =
|
||||||
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
|
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
|
||||||
Value mean_ = rewriter.create<tensor::CastOp>(loc, meanResultType, mean);
|
Value mean_ =
|
||||||
Value var_ = rewriter.create<tensor::CastOp>(loc, varResultType, var);
|
rewriter.create<tensor::CastOp>(loc, meanResultType, meanResult);
|
||||||
|
Value var_ =
|
||||||
|
rewriter.create<tensor::CastOp>(loc, rSTDResultType, rSTDResult);
|
||||||
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
|
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1118,9 +1118,10 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
||||||
Value normalizedShape = op.normalized_shape();
|
Value normalizedShape = op.normalized_shape();
|
||||||
SmallVector<Value> normalizedShapeSizesTorchInt;
|
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||||||
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
||||||
std::vector<int64_t> meanVarSizes;
|
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
||||||
for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++)
|
std::vector<int64_t> meanVarSizes(inputRank, 1);
|
||||||
meanVarSizes.push_back(input.getSizes()[i]);
|
for (int i = 0; i < axis; i++)
|
||||||
|
meanVarSizes[i] = input.getSizes()[i];
|
||||||
auto meanVarType = input.getWithSizesAndDtype(
|
auto meanVarType = input.getWithSizesAndDtype(
|
||||||
llvm::makeArrayRef(meanVarSizes), input.getDtype());
|
llvm::makeArrayRef(meanVarSizes), input.getDtype());
|
||||||
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
||||||
|
|
|
@ -2513,20 +2513,36 @@ module {
|
||||||
}
|
}
|
||||||
func @"__torch_mlir_shape_fn.aten.native_layer_norm"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
|
func @"__torch_mlir_shape_fn.aten.native_layer_norm"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
|
||||||
%int1 = torch.constant.int 1
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%str = torch.constant.str "AssertionError: "
|
||||||
|
%none = torch.constant.none
|
||||||
%true = torch.constant.bool true
|
%true = torch.constant.bool true
|
||||||
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
%2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||||
%3 = torch.aten.__range_length %1, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
%3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
torch.prim.If %4 -> () {
|
||||||
|
torch.prim.If.yield
|
||||||
|
} else {
|
||||||
|
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
||||||
|
torch.prim.If.yield
|
||||||
|
}
|
||||||
torch.prim.Loop %3, %true, init() {
|
torch.prim.Loop %3, %true, init() {
|
||||||
^bb0(%arg5: !torch.int):
|
^bb0(%arg5: !torch.int):
|
||||||
%5 = torch.aten.__derive_index %arg5, %1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
%8 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
%6 = torch.aten.__getitem__.t %arg0, %5 : !torch.list<int>, !torch.int -> !torch.int
|
%9 = torch.aten.append.t %0, %8 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
%7 = torch.aten.append.t %0, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
|
|
||||||
torch.prim.Loop.condition %true, iter()
|
torch.prim.Loop.condition %true, iter()
|
||||||
} : (!torch.int, !torch.bool) -> ()
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
%4 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
|
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
return %4 : !torch.tuple<list<int>, list<int>, list<int>>
|
%6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||||
|
torch.prim.Loop %6, %true, init() {
|
||||||
|
^bb0(%arg5: !torch.int):
|
||||||
|
%8 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
torch.prim.Loop.condition %true, iter()
|
||||||
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
|
%7 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
|
||||||
|
return %7 : !torch.tuple<list<int>, list<int>, list<int>>
|
||||||
}
|
}
|
||||||
func @"__torch_mlir_shape_fn.aten.native_batch_norm"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
|
func @"__torch_mlir_shape_fn.aten.native_batch_norm"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
|
||||||
%int0 = torch.constant.int 0
|
%int0 = torch.constant.int 0
|
||||||
|
|
|
@ -773,25 +773,17 @@ def aten〇nll_loss_forward(self: List[int], target: List[int], weight: Optional
|
||||||
def aten〇nll_loss_backward(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
|
def aten〇nll_loss_backward(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
|
||||||
return upstream_shape_helpers.unary(self)
|
return upstream_shape_helpers.unary(self)
|
||||||
|
|
||||||
# TODO: Fix shape function (see body).
|
@check_shape_function([
|
||||||
# @check_shape_function([
|
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
||||||
# Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
])
|
||||||
# ])
|
|
||||||
def aten〇native_layer_norm(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]:
|
def aten〇native_layer_norm(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]:
|
||||||
reduction_shape: List[int] = []
|
reduction_shape: List[int] = []
|
||||||
# TODO: Fix buggy behavior. TorchToLinalg needs to properly handle the
|
num_unreduced_dimensions = len(input) - len(normalized_shape)
|
||||||
# correctly inferred shapes.
|
assert num_unreduced_dimensions >= 0
|
||||||
# With input=[2, 5, 2, 2, 3] and normalized_shape=[2, 2, 3], we should get
|
for i in range(num_unreduced_dimensions):
|
||||||
# [[2, 5, 2, 2, 3], [2, 5, 1, 1, 1], [2, 5, 1, 1, 1]]
|
|
||||||
for i in range(len(normalized_shape), len(input)):
|
|
||||||
reduction_shape.append(input[i])
|
reduction_shape.append(input[i])
|
||||||
# Correct code:
|
for i in range(num_unreduced_dimensions, len(input)):
|
||||||
# num_unreduced_dimensions = len(input) - len(normalized_shape)
|
reduction_shape.append(1)
|
||||||
# assert num_unreduced_dimensions >= 0
|
|
||||||
# for i in range(num_unreduced_dimensions):
|
|
||||||
# reduction_shape.append(input[i])
|
|
||||||
# for i in range(num_unreduced_dimensions, len(input)):
|
|
||||||
# reduction_shape.append(1)
|
|
||||||
return input, reduction_shape, reduction_shape
|
return input, reduction_shape, reduction_shape
|
||||||
|
|
||||||
@check_shape_function([
|
@check_shape_function([
|
||||||
|
|
Loading…
Reference in New Issue