[LINALG] Add value tensor variant to `bernoulli_.float` (#597)

This commit adds the op `PseudoAtenBernoulliFloatOp` that represents
`AtenBernoulli_FloatOp` without the underscore. This is needed to make
sure that the `ReduceOpVariants` pass turns the in-place op into an op
that takes value tensors as inputs, otherwise the
`MaximizeValueSemantics` pass will not be able to add value semantics
correctly.
pull/595/head
Ramiro Leal-Cavazos 2022-02-14 18:58:48 -08:00 committed by GitHub
parent dfc07d11d7
commit 413e6000d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 55 additions and 11 deletions

View File

@ -931,6 +931,24 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)";
}
// The corresponding without underscore variant for `torch.aten.bernoulli_.float`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [
AllowsTypeRefinement,
HasValueSemantics,
]> {
let summary = "`bernoulli.float op : (Tensor, float, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$p,
TorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)";
}
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
// But TS compiler introduces control flow for `torch._assert` operation. The
// `torch._assert` would introduce control flow like:

View File

@ -778,11 +778,11 @@ public:
} // namespace
namespace {
class DecomposeAtenBernoulli_FloatOp
: public OpRewritePattern<AtenBernoulli_FloatOp> {
class DecomposePseudoAtenBernoulliFloatOp
: public OpRewritePattern<PseudoAtenBernoulliFloatOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBernoulli_FloatOp op,
LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
@ -1155,8 +1155,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<Aten_UnsafeViewOp>();
patterns.add<DecomposeAtenBernoulliOp>(context);
target.addIllegalOp<AtenBernoulliOp>();
patterns.add<DecomposeAtenBernoulli_FloatOp>(context);
target.addIllegalOp<AtenBernoulli_FloatOp>();
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -126,14 +126,21 @@ public:
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!isa<AtenUniform_Op>(op))
Location loc = op->getLoc();
Operation *newOp;
if (isa<AtenUniform_Op>(op)) {
newOp = rewriter.create<PseudoAtenUniformOp>(loc, op->getResultTypes(),
op->getOperands());
} else if (isa<AtenBernoulli_FloatOp>(op)) {
newOp = rewriter.create<PseudoAtenBernoulliFloatOp>(
loc, op->getResultTypes(), op->getOperands());
} else {
return failure();
}
Operation *newOp = rewriter.create<PseudoAtenUniformOp>(
op->getLoc(), op->getResultTypes(), op->getOperands());
auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(loc, tensor, op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
@ -202,6 +209,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
ConversionTarget target(*context);
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenUniform_Op>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
auto hasValueSemantics = [](Type t) {

View File

@ -228,7 +228,8 @@ public:
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp>(op)) {
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
PseudoAtenBernoulliFloatOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

View File

@ -144,3 +144,20 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}
// CHECK-LABEL: func @torch.aten.bernoulli_.float(
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[GENERATOR:.*]] = torch.constant.none
// CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
%generator = torch.constant.none
%p = torch.constant.float 5.000000e-01
%ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}