mirror of https://github.com/llvm/torch-mlir
Implement lowering of torch.aten.exponential (#2680)
https://github.com/llvm/torch-mlir/issues/2646 Decompose aten.exponential() into: -exp(1-x)/lambdapull/2708/head snapshot-20231228.1066
parent
d560698e3d
commit
8e389ff2ff
|
@ -4739,6 +4739,31 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenExponentialOp : Torch_Op<"aten.exponential", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::exponential : (Tensor, float, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_FloatType:$lambd,
|
||||
AnyTorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenExponentialOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenExponentialOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -7580,6 +7580,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.exponential\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.rand\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -9382,6 +9385,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.exponential\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rand\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
|
|
|
@ -3562,6 +3562,51 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose exponential() to do inverse transform sampling.
|
||||
// - https://en.wikipedia.org/wiki/Inverse_transform_sampling
|
||||
// With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus,
|
||||
// exponential() = - ln(1 - uniform(0, 1)) / lambda.
|
||||
class DecomposeAtenExponentialOp : public OpRewritePattern<AtenExponentialOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenExponentialOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Type resultType = op.getType();
|
||||
|
||||
// Create a uniform random op with low and high set to 0.0 and 1.0,
|
||||
// respectively.
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value zero =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
Value one =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
||||
loc, resultType, op.getSelf(), zero, /*dtype=*/none, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
|
||||
Value x = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensor,
|
||||
/*from=*/zero, /*to=*/one,
|
||||
/*generator=*/none);
|
||||
|
||||
Value negX = rewriter.create<AtenNegOp>(loc, resultType, x);
|
||||
Value oneMinusX =
|
||||
rewriter.create<AtenAddScalarOp>(loc, resultType, negX, one,
|
||||
/*alpha=*/one);
|
||||
Value lnOneMinusX = rewriter.create<AtenLogOp>(loc, resultType, oneMinusX);
|
||||
Value negLambda = rewriter.create<AtenNegFloatOp>(loc, op.getLambd());
|
||||
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, resultType, lnOneMinusX,
|
||||
negLambda);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename OpTy, typename T1T2Op>
|
||||
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
||||
|
@ -6410,6 +6455,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenExponentialOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
|
||||
|
|
|
@ -427,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
|
||||
target.addIllegalOp<AtenBernoulliPOp>();
|
||||
target.addIllegalOp<AtenBernoulliTensorOp>();
|
||||
target.addIllegalOp<AtenExponentialOp>();
|
||||
target.addIllegalOp<AtenZeroOp>();
|
||||
target.addIllegalOp<AtenEyeOp>();
|
||||
target.addIllegalOp<AtenEyeMOp>();
|
||||
|
|
|
@ -1397,6 +1397,7 @@ LTC_XFAIL_SET = {
|
|||
"CeilFloatModule_basic",
|
||||
"DivFloatModule_basic",
|
||||
"EqIntModule_basic",
|
||||
"ExponentialModule_basic",
|
||||
"GeFloatIntModule_basic",
|
||||
"GeFloatModule_basic",
|
||||
"GeIntModule_basic",
|
||||
|
|
|
@ -831,6 +831,9 @@ def aten〇copy〡shape(self: List[int], src: List[int], non_blocking: bool = Fa
|
|||
def aten〇uniform〡shape(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇exponential〡shape(self: List[int], lambd: float = 1., generator: Any = None) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇rand〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return size
|
||||
|
||||
|
@ -2267,6 +2270,10 @@ def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0.,
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
def aten〇exponential〡dtype(self_rank_dtype: Tuple[int, int], lambd: float = 1., generator: Any = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function([Invocation([1]),
|
||||
Invocation([1], dtype=torch.float16),
|
||||
Invocation([1], dtype=torch.complex64)])
|
||||
|
|
|
@ -378,6 +378,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
|
||||
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
|
||||
emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)")
|
||||
emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)")
|
||||
emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)")
|
||||
emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
|
|
|
@ -157,6 +157,29 @@ def UniformNoCorrelationModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ExponentialModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.exponential(x, 3.0)
|
||||
mean = torch.mean(a)
|
||||
std = torch.std(a)
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ExponentialModule())
|
||||
def ExponentialModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 512, 16).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BernoulliModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue