mirror of https://github.com/llvm/torch-mlir
[Torch Dialect]Add Support for AtenGroupNormOp and AtenNativeGroupNormOp (#2591)
Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>pull/2637/head
parent
74f7a0c9d6
commit
7cf52ae73f
|
@ -5640,6 +5640,34 @@ def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$input,
|
||||||
|
Torch_IntType:$num_groups,
|
||||||
|
AnyTorchOptionalTensorType:$weight,
|
||||||
|
AnyTorchOptionalTensorType:$bias,
|
||||||
|
Torch_FloatType:$eps,
|
||||||
|
Torch_BoolType:$cudnn_enabled
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||||
|
}
|
||||||
|
void AtenGroupNormOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 6, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -8074,6 +8074,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.group_norm\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.native_group_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" %1 = torch.prim.ListConstruct %arg3, %arg6 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %2 = torch.prim.ListConstruct %arg3, %arg6 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||||
|
" return %3 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -8748,6 +8759,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.group_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.int {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||||
|
" torch.prim.If %2 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.native_group_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple<int, int, int> {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||||
|
" torch.prim.If %2 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
|
||||||
|
" return %3 : !torch.tuple<int, int, int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
|
|
|
@ -3753,6 +3753,165 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
|
||||||
|
using OpRewritePattern<AtenGroupNormOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenGroupNormOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
MLIRContext *context = op.getContext();
|
||||||
|
|
||||||
|
Value input = op.getInput();
|
||||||
|
Value weight = op.getWeight();
|
||||||
|
Value bias = op.getBias();
|
||||||
|
Value numGroups = op.getNumGroups();
|
||||||
|
Value eps = op.getEps();
|
||||||
|
|
||||||
|
Value cstZero =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value cstOne =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||||
|
|
||||||
|
Value N = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
|
||||||
|
Value C = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
|
||||||
|
Value numElements = rewriter.create<AtenNumelOp>(loc, input);
|
||||||
|
Value numElementsDivN =
|
||||||
|
rewriter.create<AtenFloordivIntOp>(loc, numElements, N);
|
||||||
|
Value HxW = rewriter.create<AtenFloordivIntOp>(loc, numElementsDivN, C);
|
||||||
|
|
||||||
|
AtenNativeGroupNormOp newOp = rewriter.create<AtenNativeGroupNormOp>(
|
||||||
|
loc, ArrayRef<Type>{op.getResult().getType(), baseType, baseType},
|
||||||
|
input, weight, bias, N, C, HxW, numGroups, eps);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, newOp.getResult0());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenNativeGroupNormOp
|
||||||
|
: public OpRewritePattern<AtenNativeGroupNormOp> {
|
||||||
|
using OpRewritePattern<AtenNativeGroupNormOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenNativeGroupNormOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
MLIRContext *context = op.getContext();
|
||||||
|
|
||||||
|
Value input = op.getInput();
|
||||||
|
Value weight = op.getWeight();
|
||||||
|
Value bias = op.getBias();
|
||||||
|
Value numGroups = op.getGroup();
|
||||||
|
Value eps = op.getEps();
|
||||||
|
|
||||||
|
// Check the rank of the input/outputs tensor.
|
||||||
|
auto inputType = input.getType().cast<BaseTensorType>();
|
||||||
|
auto outputType = op.getResult0().getType().cast<BaseTensorType>();
|
||||||
|
auto meanType = op.getResult1().getType().cast<BaseTensorType>();
|
||||||
|
auto rsqrtVarType = op.getResult2().getType().cast<BaseTensorType>();
|
||||||
|
if (!inputType.hasSizes() || !outputType.hasSizes() ||
|
||||||
|
!meanType.hasSizes() || !rsqrtVarType.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "input/outputs tensor should have known sizes.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value cstZero =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value cstOne =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
Value cstNegtiveOne =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||||
|
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
|
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||||
|
|
||||||
|
// GroupNorm requires the channel dimension (C) to be exactly divisible by
|
||||||
|
// the number of groups.
|
||||||
|
Value channel = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
|
||||||
|
Value remainder =
|
||||||
|
rewriter.create<AtenRemainderIntOp>(loc, channel, numGroups);
|
||||||
|
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, remainder, cstZero);
|
||||||
|
rewriter.create<RuntimeAssertOp>(
|
||||||
|
loc, eqOrNot,
|
||||||
|
rewriter.getStringAttr("the number of channels must be divisible by "
|
||||||
|
"the number of groups"));
|
||||||
|
|
||||||
|
// Reshape the input tensor to (N, numGroups, -1) to apply normalization.
|
||||||
|
SmallVector<Value> newShape;
|
||||||
|
newShape.push_back(rewriter.create<AtenSizeIntOp>(loc, input, cstZero));
|
||||||
|
newShape.push_back(numGroups);
|
||||||
|
newShape.push_back(cstNegtiveOne);
|
||||||
|
Value reshapedInput = rewriter.create<AtenViewOp>(
|
||||||
|
loc, baseType, input,
|
||||||
|
rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(IntType::get(context)), newShape));
|
||||||
|
|
||||||
|
// Now we proceed with the normalization steps across the 'groupSize'
|
||||||
|
// Compute the mean and variance for each group
|
||||||
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||||
|
ArrayRef<Value>{cstNegtiveOne});
|
||||||
|
auto mean = rewriter.create<AtenMeanDimOp>(
|
||||||
|
loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue,
|
||||||
|
/*dtype=*/none);
|
||||||
|
auto var = rewriter.create<AtenVarDimOp>(
|
||||||
|
loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse,
|
||||||
|
/*keepdim=*/cstTrue);
|
||||||
|
|
||||||
|
// Compute the normalized output: (input - mean) * rsqrt(var + eps)
|
||||||
|
auto varPlusEps = rewriter.create<AtenAddScalarOp>(loc, baseType, var, eps,
|
||||||
|
/*alpha=*/cstOne);
|
||||||
|
auto invStd = rewriter.create<AtenRsqrtOp>(loc, baseType, varPlusEps);
|
||||||
|
auto inputSubMean = rewriter.create<AtenSubTensorOp>(
|
||||||
|
loc, baseType, reshapedInput, mean, /*alpha=*/cstOne);
|
||||||
|
auto normalizedOutput =
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, baseType, inputSubMean, invStd);
|
||||||
|
|
||||||
|
// Reshape normalized output back to the original input shape
|
||||||
|
auto inputShape = rewriter.create<AtenSizeOp>(
|
||||||
|
loc, Torch::ListType::get(IntType::get(context)), input);
|
||||||
|
auto reshapedOutput = rewriter.create<AtenViewOp>(
|
||||||
|
loc, inputType, normalizedOutput, /*shape=*/inputShape);
|
||||||
|
|
||||||
|
// Apply weight and bias if they are not None
|
||||||
|
// Reshape weight and bias to C,1,1,...
|
||||||
|
SmallVector<Value> viewShape = {channel};
|
||||||
|
for (unsigned i = 2; i < inputType.getSizes().size(); i++) {
|
||||||
|
viewShape.push_back(cstOne);
|
||||||
|
}
|
||||||
|
Value viewShapeSizeList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, ListType::get(IntType::get(context)), viewShape);
|
||||||
|
|
||||||
|
Value groupNormOutput = reshapedOutput;
|
||||||
|
if (!weight.getType().isa<Torch::NoneType>()) {
|
||||||
|
auto weightReshaped = rewriter.create<AtenViewOp>(
|
||||||
|
loc, baseType, weight, /*shape=*/viewShapeSizeList);
|
||||||
|
groupNormOutput = rewriter.create<AtenMulTensorOp>(
|
||||||
|
loc, inputType, groupNormOutput, weightReshaped);
|
||||||
|
}
|
||||||
|
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||||
|
auto biasReshaped = rewriter.create<AtenViewOp>(
|
||||||
|
loc, baseType, bias, /*shape=*/viewShapeSizeList);
|
||||||
|
groupNormOutput = rewriter.create<AtenAddTensorOp>(
|
||||||
|
loc, inputType, groupNormOutput, biasReshaped,
|
||||||
|
/*alpha=*/cstOne);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value squeezedMean =
|
||||||
|
rewriter.create<AtenSqueezeDimOp>(loc, meanType, mean, cstNegtiveOne);
|
||||||
|
Value squeezedRsqrtVar = rewriter.create<AtenSqueezeDimOp>(
|
||||||
|
loc, rsqrtVarType, invStd, cstNegtiveOne);
|
||||||
|
|
||||||
|
rewriter.replaceOp(
|
||||||
|
op, ArrayRef<Value>{groupNormOutput, squeezedMean, squeezedRsqrtVar});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenNativeBatchNormOp
|
class DecomposeAtenNativeBatchNormOp
|
||||||
: public OpRewritePattern<AtenNativeBatchNormOp> {
|
: public OpRewritePattern<AtenNativeBatchNormOp> {
|
||||||
|
@ -6204,6 +6363,8 @@ public:
|
||||||
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
|
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<
|
addPatternIfTargetOpIsIllegal<
|
||||||
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
|
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
|
||||||
|
|
|
@ -407,6 +407,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenAddcdivOp>();
|
target.addIllegalOp<AtenAddcdivOp>();
|
||||||
target.addIllegalOp<AtenLayerNormOp>();
|
target.addIllegalOp<AtenLayerNormOp>();
|
||||||
target.addIllegalOp<AtenNativeLayerNormOp>();
|
target.addIllegalOp<AtenNativeLayerNormOp>();
|
||||||
|
target.addIllegalOp<AtenGroupNormOp>();
|
||||||
|
target.addIllegalOp<AtenNativeGroupNormOp>();
|
||||||
target.addIllegalOp<AtenNativeBatchNormOp>();
|
target.addIllegalOp<AtenNativeBatchNormOp>();
|
||||||
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
|
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
|
||||||
target.addIllegalOp<AtenConvolutionBackwardOp>();
|
target.addIllegalOp<AtenConvolutionBackwardOp>();
|
||||||
|
|
|
@ -306,6 +306,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
|
|
||||||
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
|
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
|
||||||
"ArangeStartOutViewModule_basic",
|
"ArangeStartOutViewModule_basic",
|
||||||
|
|
||||||
|
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
|
"GroupNormModule_basic",
|
||||||
|
"GroupNormNoWeightAndBiasModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCHDYNAMO_CRASHING_SET = {
|
TORCHDYNAMO_CRASHING_SET = {
|
||||||
|
@ -586,6 +590,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"NewFullModuleInt2DStatic_basic",
|
"NewFullModuleInt2DStatic_basic",
|
||||||
"NewFullModuleInt2D_basic",
|
"NewFullModuleInt2D_basic",
|
||||||
"NewFullModuleInt3D_basic",
|
"NewFullModuleInt3D_basic",
|
||||||
|
"GroupNormModule_basic",
|
||||||
"GatherStaticModule_basic",
|
"GatherStaticModule_basic",
|
||||||
"GatherModule_basic",
|
"GatherModule_basic",
|
||||||
"Gather2DInputModdule_basic",
|
"Gather2DInputModdule_basic",
|
||||||
|
|
|
@ -1130,6 +1130,12 @@ def aten〇convolution_backward〡shape(grad_output: List[int], input: List[int]
|
||||||
def aten〇batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
|
def aten〇batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
|
||||||
return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)
|
return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)
|
||||||
|
|
||||||
|
def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optional[List[int]] = None, bias: Optional[List[int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enabled: bool = True) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(input)
|
||||||
|
|
||||||
|
def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]:
|
||||||
|
return upstream_shape_functions.unary(input), [N, group], [N, group]
|
||||||
|
|
||||||
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||||
return upstream_shape_functions.slice(self, dim, start, end, step)
|
return upstream_shape_functions.slice(self, dim, start, end, step)
|
||||||
|
|
||||||
|
@ -1671,6 +1677,18 @@ def aten〇batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dty
|
||||||
input_rank, input_dtype = input_rank_dtype
|
input_rank, input_dtype = input_rank_dtype
|
||||||
return input_dtype
|
return input_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], error_types={*all_integer_dtypes()}, num_groups=1))
|
||||||
|
def aten〇group_norm〡dtype(input_rank_dtype: Tuple[int, int], num_groups: int, weight_rank_dtype: Optional[Tuple[int, int]] = None, bias_rank_dtype: Optional[Tuple[int, int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enabled: bool = True) -> int:
|
||||||
|
input_rank, input_dtype = input_rank_dtype
|
||||||
|
assert not is_integer_dtype(input_dtype)
|
||||||
|
return input_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7), (3,), (3,)], error_types={*all_integer_dtypes()}, N=2, C=3, HxW=35, group=1, eps=0.000001))
|
||||||
|
def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[int, int, int]:
|
||||||
|
input_rank, input_dtype = input_rank_dtype
|
||||||
|
assert not is_integer_dtype(input_dtype)
|
||||||
|
return input_dtype, input_dtype, input_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int:
|
def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
|
|
@ -421,6 +421,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
|
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
|
||||||
)
|
)
|
||||||
|
emit(
|
||||||
|
'aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)'
|
||||||
|
)
|
||||||
emit(
|
emit(
|
||||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,7 +10,6 @@
|
||||||
from torch_mlir._version import torch_version_for_comparison, version
|
from torch_mlir._version import torch_version_for_comparison, version
|
||||||
|
|
||||||
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||||
"NativeGroupNormModule_basic",
|
|
||||||
"NativeGroupNormBackwardModule_basic",
|
"NativeGroupNormBackwardModule_basic",
|
||||||
"QuantizedMLP_basic",
|
"QuantizedMLP_basic",
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
|
|
|
@ -243,6 +243,42 @@ def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class GroupNormModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 4, 6, 7], torch.float32, True),
|
||||||
|
([4], torch.float32, True),
|
||||||
|
([4], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, weight, bias):
|
||||||
|
return torch.ops.aten.group_norm(x, 2, weight, bias, 1.0000000000000001e-05, False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: GroupNormModule())
|
||||||
|
def GroupNormModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4, 6, 7), tu.rand(4), tu.rand(4))
|
||||||
|
|
||||||
|
class GroupNormNoWeightAndBiasModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 4, 6, 7], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.group_norm(x, 2, None, None, 1.0000000000000001e-05, False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: GroupNormNoWeightAndBiasModule())
|
||||||
|
def GroupNormNoWeightAndBiasModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4, 6, 7))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class NativeGroupNormModule(torch.nn.Module):
|
class NativeGroupNormModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -257,13 +293,15 @@ class NativeGroupNormModule(torch.nn.Module):
|
||||||
def forward(self, x, weight, bias):
|
def forward(self, x, weight, bias):
|
||||||
return torch.ops.aten.native_group_norm(
|
return torch.ops.aten.native_group_norm(
|
||||||
x, weight, bias,
|
x, weight, bias,
|
||||||
2, 6, 4, 3, 0.000001);
|
2, 6, 4, 3, 0.000001)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeGroupNormModule())
|
@register_test_case(module_factory=lambda: NativeGroupNormModule())
|
||||||
def NativeGroupNormModule_basic(module, tu: TestUtils):
|
def NativeGroupNormModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6))
|
module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class NativeGroupNormBackwardModule(torch.nn.Module):
|
class NativeGroupNormBackwardModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -280,7 +318,7 @@ class NativeGroupNormBackwardModule(torch.nn.Module):
|
||||||
def forward(self, grad_out, x, mean, rstd, weight):
|
def forward(self, grad_out, x, mean, rstd, weight):
|
||||||
return torch.ops.aten.native_group_norm_backward(
|
return torch.ops.aten.native_group_norm_backward(
|
||||||
grad_out, x, mean, rstd, weight,
|
grad_out, x, mean, rstd, weight,
|
||||||
2, 6, 4, 3, [True, True, True]);
|
2, 6, 4, 3, [True, True, True])
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeGroupNormBackwardModule())
|
@register_test_case(module_factory=lambda: NativeGroupNormBackwardModule())
|
||||||
|
@ -450,3 +488,4 @@ class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
|
@register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
|
||||||
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
|
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 2, 3))
|
module.forward(tu.rand(2, 2, 3))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue