[Torch Dialect]Add Support for AtenGroupNormOp and AtenNativeGroupNormOp (#2591)

Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>
pull/2637/head
JianzheXiao 2023-12-12 19:05:12 -08:00 committed by GitHub
parent 74f7a0c9d6
commit 7cf52ae73f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 298 additions and 3 deletions

View File

@ -5640,6 +5640,34 @@ def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
}];
}
def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
Torch_IntType:$num_groups,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenGroupNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -8074,6 +8074,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.group_norm\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.native_group_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %1 = torch.prim.ListConstruct %arg3, %arg6 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %2 = torch.prim.ListConstruct %arg3, %arg6 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %3 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -8748,6 +8759,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.group_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.native_group_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple<int, int, int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %3 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -3753,6 +3753,165 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
};
} // namespace
namespace {
class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
using OpRewritePattern<AtenGroupNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenGroupNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value input = op.getInput();
Value weight = op.getWeight();
Value bias = op.getBias();
Value numGroups = op.getNumGroups();
Value eps = op.getEps();
Value cstZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
Value N = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
Value C = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
Value numElements = rewriter.create<AtenNumelOp>(loc, input);
Value numElementsDivN =
rewriter.create<AtenFloordivIntOp>(loc, numElements, N);
Value HxW = rewriter.create<AtenFloordivIntOp>(loc, numElementsDivN, C);
AtenNativeGroupNormOp newOp = rewriter.create<AtenNativeGroupNormOp>(
loc, ArrayRef<Type>{op.getResult().getType(), baseType, baseType},
input, weight, bias, N, C, HxW, numGroups, eps);
rewriter.replaceOp(op, newOp.getResult0());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenNativeGroupNormOp
: public OpRewritePattern<AtenNativeGroupNormOp> {
using OpRewritePattern<AtenNativeGroupNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeGroupNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value input = op.getInput();
Value weight = op.getWeight();
Value bias = op.getBias();
Value numGroups = op.getGroup();
Value eps = op.getEps();
// Check the rank of the input/outputs tensor.
auto inputType = input.getType().cast<BaseTensorType>();
auto outputType = op.getResult0().getType().cast<BaseTensorType>();
auto meanType = op.getResult1().getType().cast<BaseTensorType>();
auto rsqrtVarType = op.getResult2().getType().cast<BaseTensorType>();
if (!inputType.hasSizes() || !outputType.hasSizes() ||
!meanType.hasSizes() || !rsqrtVarType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "input/outputs tensor should have known sizes.");
}
Value none = rewriter.create<ConstantNoneOp>(loc);
Value cstZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value cstNegtiveOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
// GroupNorm requires the channel dimension (C) to be exactly divisible by
// the number of groups.
Value channel = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
Value remainder =
rewriter.create<AtenRemainderIntOp>(loc, channel, numGroups);
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, remainder, cstZero);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("the number of channels must be divisible by "
"the number of groups"));
// Reshape the input tensor to (N, numGroups, -1) to apply normalization.
SmallVector<Value> newShape;
newShape.push_back(rewriter.create<AtenSizeIntOp>(loc, input, cstZero));
newShape.push_back(numGroups);
newShape.push_back(cstNegtiveOne);
Value reshapedInput = rewriter.create<AtenViewOp>(
loc, baseType, input,
rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(IntType::get(context)), newShape));
// Now we proceed with the normalization steps across the 'groupSize'
// Compute the mean and variance for each group
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
ArrayRef<Value>{cstNegtiveOne});
auto mean = rewriter.create<AtenMeanDimOp>(
loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue,
/*dtype=*/none);
auto var = rewriter.create<AtenVarDimOp>(
loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse,
/*keepdim=*/cstTrue);
// Compute the normalized output: (input - mean) * rsqrt(var + eps)
auto varPlusEps = rewriter.create<AtenAddScalarOp>(loc, baseType, var, eps,
/*alpha=*/cstOne);
auto invStd = rewriter.create<AtenRsqrtOp>(loc, baseType, varPlusEps);
auto inputSubMean = rewriter.create<AtenSubTensorOp>(
loc, baseType, reshapedInput, mean, /*alpha=*/cstOne);
auto normalizedOutput =
rewriter.create<AtenMulTensorOp>(loc, baseType, inputSubMean, invStd);
// Reshape normalized output back to the original input shape
auto inputShape = rewriter.create<AtenSizeOp>(
loc, Torch::ListType::get(IntType::get(context)), input);
auto reshapedOutput = rewriter.create<AtenViewOp>(
loc, inputType, normalizedOutput, /*shape=*/inputShape);
// Apply weight and bias if they are not None
// Reshape weight and bias to C,1,1,...
SmallVector<Value> viewShape = {channel};
for (unsigned i = 2; i < inputType.getSizes().size(); i++) {
viewShape.push_back(cstOne);
}
Value viewShapeSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), viewShape);
Value groupNormOutput = reshapedOutput;
if (!weight.getType().isa<Torch::NoneType>()) {
auto weightReshaped = rewriter.create<AtenViewOp>(
loc, baseType, weight, /*shape=*/viewShapeSizeList);
groupNormOutput = rewriter.create<AtenMulTensorOp>(
loc, inputType, groupNormOutput, weightReshaped);
}
if (!bias.getType().isa<Torch::NoneType>()) {
auto biasReshaped = rewriter.create<AtenViewOp>(
loc, baseType, bias, /*shape=*/viewShapeSizeList);
groupNormOutput = rewriter.create<AtenAddTensorOp>(
loc, inputType, groupNormOutput, biasReshaped,
/*alpha=*/cstOne);
}
Value squeezedMean =
rewriter.create<AtenSqueezeDimOp>(loc, meanType, mean, cstNegtiveOne);
Value squeezedRsqrtVar = rewriter.create<AtenSqueezeDimOp>(
loc, rsqrtVarType, invStd, cstNegtiveOne);
rewriter.replaceOp(
op, ArrayRef<Value>{groupNormOutput, squeezedMean, squeezedRsqrtVar});
return success();
}
};
} // namespace
namespace {
class DecomposeAtenNativeBatchNormOp
: public OpRewritePattern<AtenNativeBatchNormOp> {
@ -6204,6 +6363,8 @@ public:
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);

View File

@ -407,6 +407,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
target.addIllegalOp<AtenNativeLayerNormOp>();
target.addIllegalOp<AtenGroupNormOp>();
target.addIllegalOp<AtenNativeGroupNormOp>();
target.addIllegalOp<AtenNativeBatchNormOp>();
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
target.addIllegalOp<AtenConvolutionBackwardOp>();

View File

@ -306,6 +306,10 @@ TORCHDYNAMO_XFAIL_SET = {
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
"ArangeStartOutViewModule_basic",
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"GroupNormModule_basic",
"GroupNormNoWeightAndBiasModule_basic",
}
TORCHDYNAMO_CRASHING_SET = {
@ -586,6 +590,7 @@ STABLEHLO_PASS_SET = {
"NewFullModuleInt2DStatic_basic",
"NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic",
"GroupNormModule_basic",
"GatherStaticModule_basic",
"GatherModule_basic",
"Gather2DInputModdule_basic",

View File

@ -1130,6 +1130,12 @@ def atenconvolution_backward〡shape(grad_output: List[int], input: List[int]
def atenbatch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)
def atengroup_norm〡shape(input: List[int], num_groups: int, weight: Optional[List[int]] = None, bias: Optional[List[int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enabled: bool = True) -> List[int]:
return upstream_shape_functions.unary(input)
def atennative_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]:
return upstream_shape_functions.unary(input), [N, group], [N, group]
def atensliceTensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return upstream_shape_functions.slice(self, dim, start, end, step)
@ -1671,6 +1677,18 @@ def atenbatch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dty
input_rank, input_dtype = input_rank_dtype
return input_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], error_types={*all_integer_dtypes()}, num_groups=1))
def atengroup_norm〡dtype(input_rank_dtype: Tuple[int, int], num_groups: int, weight_rank_dtype: Optional[Tuple[int, int]] = None, bias_rank_dtype: Optional[Tuple[int, int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enabled: bool = True) -> int:
input_rank, input_dtype = input_rank_dtype
assert not is_integer_dtype(input_dtype)
return input_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7), (3,), (3,)], error_types={*all_integer_dtypes()}, N=2, C=3, HxW=35, group=1, eps=0.000001))
def atennative_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[int, int, int]:
input_rank, input_dtype = input_rank_dtype
assert not is_integer_dtype(input_dtype)
return input_dtype, input_dtype, input_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenbernoulli_float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -421,6 +421,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
)
emit(
'aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)'
)
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)

View File

@ -10,7 +10,6 @@
from torch_mlir._version import torch_version_for_comparison, version
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"NativeGroupNormModule_basic",
"NativeGroupNormBackwardModule_basic",
"QuantizedMLP_basic",
"ReduceMaxAlongDimUnsignedInt_basic",

View File

@ -243,6 +243,42 @@ def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
# ==============================================================================
class GroupNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 4, 6, 7], torch.float32, True),
([4], torch.float32, True),
([4], torch.float32, True),
])
def forward(self, x, weight, bias):
return torch.ops.aten.group_norm(x, 2, weight, bias, 1.0000000000000001e-05, False)
@register_test_case(module_factory=lambda: GroupNormModule())
def GroupNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 6, 7), tu.rand(4), tu.rand(4))
class GroupNormNoWeightAndBiasModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 4, 6, 7], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.group_norm(x, 2, None, None, 1.0000000000000001e-05, False)
@register_test_case(module_factory=lambda: GroupNormNoWeightAndBiasModule())
def GroupNormNoWeightAndBiasModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 6, 7))
# ==============================================================================
class NativeGroupNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -257,13 +293,15 @@ class NativeGroupNormModule(torch.nn.Module):
def forward(self, x, weight, bias):
return torch.ops.aten.native_group_norm(
x, weight, bias,
2, 6, 4, 3, 0.000001);
2, 6, 4, 3, 0.000001)
@register_test_case(module_factory=lambda: NativeGroupNormModule())
def NativeGroupNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6))
# ==============================================================================
class NativeGroupNormBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -280,7 +318,7 @@ class NativeGroupNormBackwardModule(torch.nn.Module):
def forward(self, grad_out, x, mean, rstd, weight):
return torch.ops.aten.native_group_norm_backward(
grad_out, x, mean, rstd, weight,
2, 6, 4, 3, [True, True, True]);
2, 6, 4, 3, [True, True, True])
@register_test_case(module_factory=lambda: NativeGroupNormBackwardModule())
@ -450,3 +488,4 @@ class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
@register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 3))