mirror of https://github.com/llvm/torch-mlir
[LINALG] Add value tensor variant to `bernoulli_.float` (#597)
This commit adds the op `PseudoAtenBernoulliFloatOp` that represents `AtenBernoulli_FloatOp` without the underscore. This is needed to make sure that the `ReduceOpVariants` pass turns the in-place op into an op that takes value tensors as inputs, otherwise the `MaximizeValueSemantics` pass will not be able to add value semantics correctly.pull/595/head
parent
dfc07d11d7
commit
413e6000d2
|
@ -931,6 +931,24 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
|
||||||
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)";
|
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The corresponding without underscore variant for `torch.aten.bernoulli_.float`
|
||||||
|
// doesn't exist in the pytorch ops registry. Add it here.
|
||||||
|
def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
]> {
|
||||||
|
let summary = "`bernoulli.float op : (Tensor, float, Generator?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_FloatType:$p,
|
||||||
|
TorchOptionalGeneratorType:$generator
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
|
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
|
||||||
// But TS compiler introduces control flow for `torch._assert` operation. The
|
// But TS compiler introduces control flow for `torch._assert` operation. The
|
||||||
// `torch._assert` would introduce control flow like:
|
// `torch._assert` would introduce control flow like:
|
||||||
|
|
|
@ -778,11 +778,11 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenBernoulli_FloatOp
|
class DecomposePseudoAtenBernoulliFloatOp
|
||||||
: public OpRewritePattern<AtenBernoulli_FloatOp> {
|
: public OpRewritePattern<PseudoAtenBernoulliFloatOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenBernoulli_FloatOp op,
|
LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value self = op.self();
|
Value self = op.self();
|
||||||
|
@ -1155,8 +1155,8 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<Aten_UnsafeViewOp>();
|
target.addIllegalOp<Aten_UnsafeViewOp>();
|
||||||
patterns.add<DecomposeAtenBernoulliOp>(context);
|
patterns.add<DecomposeAtenBernoulliOp>(context);
|
||||||
target.addIllegalOp<AtenBernoulliOp>();
|
target.addIllegalOp<AtenBernoulliOp>();
|
||||||
patterns.add<DecomposeAtenBernoulli_FloatOp>(context);
|
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
|
||||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
|
|
@ -126,14 +126,21 @@ public:
|
||||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
|
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
|
||||||
LogicalResult matchAndRewrite(Operation *op,
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (!isa<AtenUniform_Op>(op))
|
Location loc = op->getLoc();
|
||||||
|
Operation *newOp;
|
||||||
|
if (isa<AtenUniform_Op>(op)) {
|
||||||
|
newOp = rewriter.create<PseudoAtenUniformOp>(loc, op->getResultTypes(),
|
||||||
|
op->getOperands());
|
||||||
|
} else if (isa<AtenBernoulli_FloatOp>(op)) {
|
||||||
|
newOp = rewriter.create<PseudoAtenBernoulliFloatOp>(
|
||||||
|
loc, op->getResultTypes(), op->getOperands());
|
||||||
|
} else {
|
||||||
return failure();
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
Operation *newOp = rewriter.create<PseudoAtenUniformOp>(
|
|
||||||
op->getLoc(), op->getResultTypes(), op->getOperands());
|
|
||||||
auto tensor =
|
auto tensor =
|
||||||
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
|
||||||
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
|
rewriter.create<OverwriteTensorOp>(loc, tensor, op->getOperand(0));
|
||||||
rewriter.replaceOp(op, op->getOperand(0));
|
rewriter.replaceOp(op, op->getOperand(0));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -202,6 +209,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||||
target.addIllegalOp<AtenUniform_Op>();
|
target.addIllegalOp<AtenUniform_Op>();
|
||||||
|
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||||
auto hasValueSemantics = [](Type t) {
|
auto hasValueSemantics = [](Type t) {
|
||||||
|
|
|
@ -228,7 +228,8 @@ public:
|
||||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
|
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
|
||||||
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp>(op)) {
|
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
|
||||||
|
PseudoAtenBernoulliFloatOp>(op)) {
|
||||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -144,3 +144,20 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
|
||||||
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
|
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.bernoulli_.float(
|
||||||
|
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
|
||||||
|
// CHECK: %[[GENERATOR:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
|
||||||
|
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
|
||||||
|
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||||
|
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||||
|
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||||
|
// CHECK: return %[[T]] : !torch.tensor
|
||||||
|
func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
|
||||||
|
%generator = torch.constant.none
|
||||||
|
%p = torch.constant.float 5.000000e-01
|
||||||
|
%ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor
|
||||||
|
return %ret : !torch.tensor
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue