mirror of https://github.com/llvm/torch-mlir
Decompose AtenNormalFunctionalOp into AtenRandn* and other arithmetic. (#2737)
parent
f85e5c932b
commit
a8538e1e3f
|
@ -7655,6 +7655,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||||
" return %arg0 : !torch.list<int>\n"
|
" return %arg0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.normal_functional\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {\n"
|
||||||
|
" return %arg0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||||
" %0 = torch.derefine %arg0 : !torch.float to !torch.union<float, int>\n"
|
" %0 = torch.derefine %arg0 : !torch.float to !torch.union<float, int>\n"
|
||||||
" %1 = torch.derefine %arg1 : !torch.float to !torch.union<float, int>\n"
|
" %1 = torch.derefine %arg1 : !torch.float to !torch.union<float, int>\n"
|
||||||
|
@ -11557,6 +11560,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %1 : !torch.int\n"
|
" return %1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||||
|
" torch.prim.If %2 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
" %int6 = torch.constant.int 6\n"
|
" %int6 = torch.constant.int 6\n"
|
||||||
|
|
|
@ -3669,9 +3669,38 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
// aten.normal_functional(mean, sigma) = randn() * sigma + mean.
|
||||||
|
class DecomposeAtenNormalFunctionalOp
|
||||||
|
: public OpRewritePattern<AtenNormalFunctionalOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "The generator has to be None because only global default "
|
||||||
|
"generator is supported");
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Type resultType = op.getType();
|
||||||
|
Value std = op.getStd();
|
||||||
|
Value mean = op.getMean();
|
||||||
|
|
||||||
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value one =
|
||||||
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||||
|
Value randN = rewriter.create<AtenRandnLikeOp>(
|
||||||
|
loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none,
|
||||||
|
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
|
||||||
|
Value stdRandN =
|
||||||
|
rewriter.create<AtenMulScalarOp>(loc, resultType, randN, std);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenAddScalarOp>(op, resultType, stdRandN,
|
||||||
|
mean, /*alpha=*/one);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename OpTy, typename T1T2Op>
|
template <typename OpTy, typename T1T2Op>
|
||||||
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
||||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||||
|
@ -6591,6 +6620,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNormalFunctionalOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
|
||||||
|
|
|
@ -494,6 +494,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenRandnOp>();
|
target.addIllegalOp<AtenRandnOp>();
|
||||||
target.addIllegalOp<AtenRandnGeneratorOp>();
|
target.addIllegalOp<AtenRandnGeneratorOp>();
|
||||||
target.addIllegalOp<AtenRandnLikeOp>();
|
target.addIllegalOp<AtenRandnLikeOp>();
|
||||||
|
target.addIllegalOp<AtenNormalFunctionalOp>();
|
||||||
target.addIllegalOp<AtenVarMeanOp>();
|
target.addIllegalOp<AtenVarMeanOp>();
|
||||||
target.addIllegalOp<AtenCosineSimilarityOp>();
|
target.addIllegalOp<AtenCosineSimilarityOp>();
|
||||||
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
||||||
|
|
|
@ -1484,6 +1484,7 @@ LTC_XFAIL_SET = {
|
||||||
"VarMeanUnbiasedModule_basic",
|
"VarMeanUnbiasedModule_basic",
|
||||||
"RandnLikeModule_basic",
|
"RandnLikeModule_basic",
|
||||||
"RandnLikeDtypeModule_basic",
|
"RandnLikeDtypeModule_basic",
|
||||||
|
"NormalFunctionalModule_basic",
|
||||||
"BernoulliFloatModule_basic",
|
"BernoulliFloatModule_basic",
|
||||||
"BernoulliModule_basic",
|
"BernoulliModule_basic",
|
||||||
"BernoulliPModule_basic",
|
"BernoulliPModule_basic",
|
||||||
|
|
|
@ -902,6 +902,9 @@ def aten〇randn〡shape(size: List[int], dtype: Optional[int] = None, layout: O
|
||||||
def aten〇randn〇generator〡shape(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
def aten〇randn〇generator〡shape(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||||
return size
|
return size
|
||||||
|
|
||||||
|
def aten〇normal_functional〡shape(self: List[int], mean: float = 0., std: float = 1., generator: Any = None) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
def aten〇arange〇start_step〡shape(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
def aten〇arange〇start_step〡shape(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||||
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)
|
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)
|
||||||
|
|
||||||
|
@ -3822,6 +3825,16 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O
|
||||||
assert not is_integer_dtype(dtype)
|
assert not is_integer_dtype(dtype)
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(
|
||||||
|
num_of_tensors=1,
|
||||||
|
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
|
||||||
|
def aten〇normal_functional〡dtype(self_rank_dtype: Tuple[int, int], mean: float = 0., std: float = 1., generator: Any = None) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
if self_dtype is None:
|
||||||
|
return torch.float32
|
||||||
|
assert not is_integer_dtype(self_dtype)
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function([Invocation(size=[1], generator=None),
|
@check_dtype_function([Invocation(size=[1], generator=None),
|
||||||
Invocation(size=[1], generator=None, dtype=torch.float32),
|
Invocation(size=[1], generator=None, dtype=torch.float32),
|
||||||
ErrorInvocation(size=[1], generator=None, dtype=torch.int32),
|
ErrorInvocation(size=[1], generator=None, dtype=torch.int32),
|
||||||
|
|
|
@ -605,3 +605,24 @@ class RandnLikeDtypeModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: RandnLikeDtypeModule())
|
@register_test_case(module_factory=lambda: RandnLikeDtypeModule())
|
||||||
def RandnLikeDtypeModule_basic(module, tu: TestUtils):
|
def RandnLikeDtypeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(256, 1024).double())
|
module.forward(tu.rand(256, 1024).double())
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class NormalFunctionalModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0)
|
||||||
|
mean = torch.mean(a)
|
||||||
|
std = torch.std(a)
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NormalFunctionalModule())
|
||||||
|
def NormalFunctionalModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2048, 4096).double())
|
||||||
|
|
Loading…
Reference in New Issue