mirror of https://github.com/llvm/torch-mlir
[LINALG] Add value tensor variant to `bernoulli_.float`
This commit adds the op `PseudoAtenBernoulliFloatOp` that represents `AtenBernoulli_FloatOp` without the underscore. This is needed to make sure that the `ReduceOpVariants` pass turns the in-place op into an op that takes value tensors as inputs, otherwise the `MaximizeValueSemantics` pass will not be able to add value semantics correctly.bert-staging
parent
69d872b298
commit
54357ea378
|
@ -931,6 +931,24 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
|
|||
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)";
|
||||
}
|
||||
|
||||
// The corresponding without underscore variant for `torch.aten.bernoulli_.float`
|
||||
// doesn't exist in the pytorch ops registry. Add it here.
|
||||
def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
]> {
|
||||
let summary = "`bernoulli.float op : (Tensor, float, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_FloatType:$p,
|
||||
TorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)";
|
||||
}
|
||||
|
||||
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
|
||||
// But TS compiler introduces control flow for `torch._assert` operation. The
|
||||
// `torch._assert` would introduce control flow like:
|
||||
|
|
|
@ -770,11 +770,11 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenBernoulli_FloatOp
|
||||
: public OpRewritePattern<AtenBernoulli_FloatOp> {
|
||||
class DecomposePseudoAtenBernoulliFloatOp
|
||||
: public OpRewritePattern<PseudoAtenBernoulliFloatOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenBernoulli_FloatOp op,
|
||||
LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.self();
|
||||
|
@ -1148,8 +1148,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<Aten_UnsafeViewOp>();
|
||||
patterns.add<DecomposeAtenBernoulliOp>(context);
|
||||
target.addIllegalOp<AtenBernoulliOp>();
|
||||
patterns.add<DecomposeAtenBernoulli_FloatOp>(context);
|
||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
|
||||
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -126,14 +126,21 @@ public:
|
|||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenUniform_Op>(op))
|
||||
Location loc = op->getLoc();
|
||||
Operation *newOp;
|
||||
if (isa<AtenUniform_Op>(op)) {
|
||||
newOp = rewriter.create<PseudoAtenUniformOp>(loc, op->getResultTypes(),
|
||||
op->getOperands());
|
||||
} else if (isa<AtenBernoulli_FloatOp>(op)) {
|
||||
newOp = rewriter.create<PseudoAtenBernoulliFloatOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Operation *newOp = rewriter.create<PseudoAtenUniformOp>(
|
||||
op->getLoc(), op->getResultTypes(), op->getOperands());
|
||||
auto tensor =
|
||||
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
||||
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
|
||||
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
|
||||
rewriter.create<OverwriteTensorOp>(loc, tensor, op->getOperand(0));
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
return success();
|
||||
}
|
||||
|
@ -202,6 +209,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||
target.addIllegalOp<AtenUniform_Op>();
|
||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||
auto hasValueSemantics = [](Type t) {
|
||||
|
|
|
@ -228,7 +228,8 @@ public:
|
|||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
|
||||
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp>(op)) {
|
||||
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
|
||||
PseudoAtenBernoulliFloatOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -144,3 +144,20 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
|
|||
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.bernoulli_.float(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: %[[GENERATOR:.*]] = torch.constant.none
|
||||
// CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||
// CHECK: return %[[T]] : !torch.tensor
|
||||
func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
|
||||
%generator = torch.constant.none
|
||||
%p = torch.constant.float 5.000000e-01
|
||||
%ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue