From dd998fa4d4163af14519fed436f29e82e72673ae Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Tue, 15 Feb 2022 02:09:08 +0000 Subject: [PATCH] [LINALG] Add value tensor variant to `bernoulli_.float` This commit adds the op `PseudoAtenBernoulliFloatOp` that represents `AtenBernoulli_FloatOp` without the underscore. This is needed to make sure that the `ReduceOpVariants` pass turns the in-place op into an op that takes value tensors as inputs, otherwise the `MaximizeValueSemantics` pass will not be able to add value semantics correctly. --- .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 18 ++++++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 10 +++++----- .../Torch/Transforms/ReduceOpVariants.cpp | 18 +++++++++++++----- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 3 ++- test/Dialect/Torch/reduce-op-variants.mlir | 17 +++++++++++++++++ 5 files changed, 55 insertions(+), 11 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index a07b9e4ef..fd8957915 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -931,6 +931,24 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [ let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)"; } +// The corresponding without underscore variant for `torch.aten.bernoulli_.float` +// doesn't exist in the pytorch ops registry. Add it here. +def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [ + AllowsTypeRefinement, + HasValueSemantics, + ]> { + let summary = "`bernoulli.float op : (Tensor, float, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$p, + TorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)"; +} + // To handle runtime assertions, torchscript provides us `torch._assert` operation. // But TS compiler introduces control flow for `torch._assert` operation. The // `torch._assert` would introduce control flow like: diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 52f47b6bc..4b51c8ebb 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -778,11 +778,11 @@ public: } // namespace namespace { -class DecomposeAtenBernoulli_FloatOp - : public OpRewritePattern { +class DecomposePseudoAtenBernoulliFloatOp + : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenBernoulli_FloatOp op, + LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.self(); @@ -1155,8 +1155,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index c67bb5919..3f5ba2f91 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -126,14 +126,21 @@ public: : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!isa(op)) + Location loc = op->getLoc(); + Operation *newOp; + if (isa(op)) { + newOp = rewriter.create(loc, op->getResultTypes(), + op->getOperands()); + } else if (isa(op)) { + newOp = rewriter.create( + loc, op->getResultTypes(), op->getOperands()); + } else { return failure(); + } - Operation *newOp = rewriter.create( - op->getLoc(), op->getResultTypes(), op->getOperands()); auto tensor = - rewriter.create(op->getLoc(), newOp->getResult(0)); - rewriter.create(op->getLoc(), tensor, op->getOperand(0)); + rewriter.create(loc, newOp->getResult(0)); + rewriter.create(loc, tensor, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0)); return success(); } @@ -202,6 +209,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { ConversionTarget target(*context); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *op) { if (op->hasTrait()) { auto hasValueSemantics = [](Type t) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 56442ca61..aa1b8c45d 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -228,7 +228,8 @@ public: AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp, - AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp>(op)) { + AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp, + PseudoAtenBernoulliFloatOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } diff --git a/test/Dialect/Torch/reduce-op-variants.mlir b/test/Dialect/Torch/reduce-op-variants.mlir index d7ed784d3..49881fd2e 100644 --- a/test/Dialect/Torch/reduce-op-variants.mlir +++ b/test/Dialect/Torch/reduce-op-variants.mlir @@ -144,3 +144,20 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl %ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor return %ret : !torch.tensor } + +// CHECK-LABEL: func @torch.aten.bernoulli_.float( +// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor { +// CHECK: %[[GENERATOR:.*]] = torch.constant.none +// CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01 +// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor +// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor +// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor +// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor +// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor +// CHECK: return %[[T]] : !torch.tensor +func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor { + %generator = torch.constant.none + %p = torch.constant.float 5.000000e-01 + %ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor + return %ret : !torch.tensor +}