Implement lowering of torch.aten.exponential (#2680)

https://github.com/llvm/torch-mlir/issues/2646

Decompose aten.exponential() into: -exp(1-x)/lambda
pull/2708/head snapshot-20231228.1066
Sungsoon Cho 2023-12-27 20:33:18 -08:00 committed by GitHub
parent d560698e3d
commit 8e389ff2ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 111 additions and 0 deletions

View File

@ -4739,6 +4739,31 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [
}];
}
def Torch_AtenExponentialOp : Torch_Op<"aten.exponential", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::exponential : (Tensor, float, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$lambd,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenExponentialOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenExponentialOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -7580,6 +7580,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.exponential\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rand\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
@ -9382,6 +9385,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.exponential\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rand\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"

View File

@ -3562,6 +3562,51 @@ public:
};
} // namespace
namespace {
// Decompose exponential() to do inverse transform sampling.
// - https://en.wikipedia.org/wiki/Inverse_transform_sampling
// With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus,
// exponential() = - ln(1 - uniform(0, 1)) / lambda.
class DecomposeAtenExponentialOp : public OpRewritePattern<AtenExponentialOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenExponentialOp op,
PatternRewriter &rewriter) const override {
if (!op.getGenerator().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
Location loc = op.getLoc();
Type resultType = op.getType();
// Create a uniform random op with low and high set to 0.0 and 1.0,
// respectively.
Value none = rewriter.create<ConstantNoneOp>(loc);
Value zero =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value one =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
loc, resultType, op.getSelf(), zero, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
Value x = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensor,
/*from=*/zero, /*to=*/one,
/*generator=*/none);
Value negX = rewriter.create<AtenNegOp>(loc, resultType, x);
Value oneMinusX =
rewriter.create<AtenAddScalarOp>(loc, resultType, negX, one,
/*alpha=*/one);
Value lnOneMinusX = rewriter.create<AtenLogOp>(loc, resultType, oneMinusX);
Value negLambda = rewriter.create<AtenNegFloatOp>(loc, op.getLambd());
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, resultType, lnOneMinusX,
negLambda);
return success();
}
};
} // namespace
namespace {
template <typename OpTy, typename T1T2Op>
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
@ -6410,6 +6455,7 @@ public:
addPatternIfTargetOpIsIllegal<
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExponentialOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);

View File

@ -427,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
target.addIllegalOp<AtenBernoulliPOp>();
target.addIllegalOp<AtenBernoulliTensorOp>();
target.addIllegalOp<AtenExponentialOp>();
target.addIllegalOp<AtenZeroOp>();
target.addIllegalOp<AtenEyeOp>();
target.addIllegalOp<AtenEyeMOp>();

View File

@ -1397,6 +1397,7 @@ LTC_XFAIL_SET = {
"CeilFloatModule_basic",
"DivFloatModule_basic",
"EqIntModule_basic",
"ExponentialModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",

View File

@ -831,6 +831,9 @@ def atencopy〡shape(self: List[int], src: List[int], non_blocking: bool = Fa
def atenuniform〡shape(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]:
return self
def atenexponential〡shape(self: List[int], lambd: float = 1., generator: Any = None) -> List[int]:
return self
def atenrand〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
@ -2267,6 +2270,10 @@ def atenuniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0.,
self_rank, self_dtype = self_rank_dtype
return self_dtype
def atenexponential〡dtype(self_rank_dtype: Tuple[int, int], lambd: float = 1., generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function([Invocation([1]),
Invocation([1], dtype=torch.float16),
Invocation([1], dtype=torch.complex64)])

View File

@ -378,6 +378,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)")
emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)")

View File

@ -157,6 +157,29 @@ def UniformNoCorrelationModule_basic(module, tu: TestUtils):
# ==============================================================================
class ExponentialModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, x):
a = torch.ops.aten.exponential(x, 3.0)
mean = torch.mean(a)
std = torch.std(a)
return mean, std
@register_test_case(module_factory=lambda: ExponentialModule())
def ExponentialModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(512, 512, 16).double())
# ==============================================================================
class BernoulliModule(torch.nn.Module):
def __init__(self):
super().__init__()