mirror of https://github.com/llvm/torch-mlir
Decompose AtenNormalFunctionalOp into AtenRandn* and other arithmetic. (#2737)
parent
f85e5c932b
commit
a8538e1e3f
|
@ -7655,6 +7655,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.normal_functional\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" %0 = torch.derefine %arg0 : !torch.float to !torch.union<float, int>\n"
|
||||
" %1 = torch.derefine %arg1 : !torch.float to !torch.union<float, int>\n"
|
||||
|
@ -11557,6 +11560,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
|
|
|
@ -3669,9 +3669,38 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// aten.normal_functional(mean, sigma) = randn() * sigma + mean.
|
||||
class DecomposeAtenNormalFunctionalOp
|
||||
: public OpRewritePattern<AtenNormalFunctionalOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Type resultType = op.getType();
|
||||
Value std = op.getStd();
|
||||
Value mean = op.getMean();
|
||||
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value one =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
Value randN = rewriter.create<AtenRandnLikeOp>(
|
||||
loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
|
||||
Value stdRandN =
|
||||
rewriter.create<AtenMulScalarOp>(loc, resultType, randN, std);
|
||||
rewriter.replaceOpWithNewOp<AtenAddScalarOp>(op, resultType, stdRandN,
|
||||
mean, /*alpha=*/one);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy, typename T1T2Op>
|
||||
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
@ -6591,6 +6620,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNormalFunctionalOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
|
||||
|
|
|
@ -494,6 +494,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenRandnOp>();
|
||||
target.addIllegalOp<AtenRandnGeneratorOp>();
|
||||
target.addIllegalOp<AtenRandnLikeOp>();
|
||||
target.addIllegalOp<AtenNormalFunctionalOp>();
|
||||
target.addIllegalOp<AtenVarMeanOp>();
|
||||
target.addIllegalOp<AtenCosineSimilarityOp>();
|
||||
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
||||
|
|
|
@ -1484,6 +1484,7 @@ LTC_XFAIL_SET = {
|
|||
"VarMeanUnbiasedModule_basic",
|
||||
"RandnLikeModule_basic",
|
||||
"RandnLikeDtypeModule_basic",
|
||||
"NormalFunctionalModule_basic",
|
||||
"BernoulliFloatModule_basic",
|
||||
"BernoulliModule_basic",
|
||||
"BernoulliPModule_basic",
|
||||
|
|
|
@ -902,6 +902,9 @@ def aten〇randn〡shape(size: List[int], dtype: Optional[int] = None, layout: O
|
|||
def aten〇randn〇generator〡shape(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return size
|
||||
|
||||
def aten〇normal_functional〡shape(self: List[int], mean: float = 0., std: float = 1., generator: Any = None) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇arange〇start_step〡shape(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)
|
||||
|
||||
|
@ -3822,6 +3825,16 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O
|
|||
assert not is_integer_dtype(dtype)
|
||||
return dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(
|
||||
num_of_tensors=1,
|
||||
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
|
||||
def aten〇normal_functional〡dtype(self_rank_dtype: Tuple[int, int], mean: float = 0., std: float = 1., generator: Any = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if self_dtype is None:
|
||||
return torch.float32
|
||||
assert not is_integer_dtype(self_dtype)
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function([Invocation(size=[1], generator=None),
|
||||
Invocation(size=[1], generator=None, dtype=torch.float32),
|
||||
ErrorInvocation(size=[1], generator=None, dtype=torch.int32),
|
||||
|
|
|
@ -605,3 +605,24 @@ class RandnLikeDtypeModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: RandnLikeDtypeModule())
|
||||
def RandnLikeDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(256, 1024).double())
|
||||
# ==============================================================================
|
||||
|
||||
class NormalFunctionalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0)
|
||||
mean = torch.mean(a)
|
||||
std = torch.std(a)
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NormalFunctionalModule())
|
||||
def NormalFunctionalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2048, 4096).double())
|
||||
|
|
Loading…
Reference in New Issue