Decompose AtenNormalFunctionalOp into AtenRandn* and other arithmetic. (#2737)

pull/2592/merge snapshot-20240116.1085
Sungsoon Cho 2024-01-15 22:49:29 -08:00 committed by GitHub
parent f85e5c932b
commit a8538e1e3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 85 additions and 2 deletions

View File

@ -7655,6 +7655,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.normal_functional\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg0 : !torch.float to !torch.union<float, int>\n"
" %1 = torch.derefine %arg1 : !torch.float to !torch.union<float, int>\n"
@ -11557,6 +11560,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int6 = torch.constant.int 6\n"

View File

@ -3669,9 +3669,38 @@ public:
return success();
}
};
} // namespace
namespace {
// aten.normal_functional(mean, sigma) = randn() * sigma + mean.
class DecomposeAtenNormalFunctionalOp
: public OpRewritePattern<AtenNormalFunctionalOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
PatternRewriter &rewriter) const override {
if (!op.getGenerator().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
Location loc = op.getLoc();
Type resultType = op.getType();
Value std = op.getStd();
Value mean = op.getMean();
Value none = rewriter.create<ConstantNoneOp>(loc);
Value one =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value randN = rewriter.create<AtenRandnLikeOp>(
loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
Value stdRandN =
rewriter.create<AtenMulScalarOp>(loc, resultType, randN, std);
rewriter.replaceOpWithNewOp<AtenAddScalarOp>(op, resultType, stdRandN,
mean, /*alpha=*/one);
return success();
}
};
template <typename OpTy, typename T1T2Op>
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
@ -6591,6 +6620,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNormalFunctionalOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);

View File

@ -494,6 +494,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRandnOp>();
target.addIllegalOp<AtenRandnGeneratorOp>();
target.addIllegalOp<AtenRandnLikeOp>();
target.addIllegalOp<AtenNormalFunctionalOp>();
target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenCosineSimilarityOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();

View File

@ -1484,6 +1484,7 @@ LTC_XFAIL_SET = {
"VarMeanUnbiasedModule_basic",
"RandnLikeModule_basic",
"RandnLikeDtypeModule_basic",
"NormalFunctionalModule_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliPModule_basic",

View File

@ -902,6 +902,9 @@ def atenrandn〡shape(size: List[int], dtype: Optional[int] = None, layout: O
def atenrandngenerator〡shape(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
def atennormal_functional〡shape(self: List[int], mean: float = 0., std: float = 1., generator: Any = None) -> List[int]:
return self
def atenarangestart_step〡shape(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)
@ -3822,6 +3825,16 @@ def atenrandn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O
assert not is_integer_dtype(dtype)
return dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(
num_of_tensors=1,
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
def atennormal_functional〡dtype(self_rank_dtype: Tuple[int, int], mean: float = 0., std: float = 1., generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
if self_dtype is None:
return torch.float32
assert not is_integer_dtype(self_dtype)
return self_dtype
@check_dtype_function([Invocation(size=[1], generator=None),
Invocation(size=[1], generator=None, dtype=torch.float32),
ErrorInvocation(size=[1], generator=None, dtype=torch.int32),

View File

@ -605,3 +605,24 @@ class RandnLikeDtypeModule(torch.nn.Module):
@register_test_case(module_factory=lambda: RandnLikeDtypeModule())
def RandnLikeDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(256, 1024).double())
# ==============================================================================
class NormalFunctionalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float64, True),
])
def forward(self, x):
a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0)
mean = torch.mean(a)
std = torch.std(a)
return mean, std
@register_test_case(module_factory=lambda: NormalFunctionalModule())
def NormalFunctionalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2048, 4096).double())