[LINALG] Fix `aten.bernoulli` op lowering

- This commit adds E2E support for `aten.rand_like` and
  `aten.bernoulli_.Tensor` ops.
- The `aten.bernoulli(x)` was implemented as:
  `aten.bernoulli(x) = rand_like(x) < 0.5`, assuming 0.5 as default
  probability, whereas according to the pytorch documentation:
  https://pytorch.org/docs/stable/generated/torch.bernoulli.html#torch.bernoulli
  the input x in `aten.bernoulli(x)` is itself a tensor containing
  probabilities to be used for drawing the binary random number.
- So this commit fixes the `aten.bernoulli(x)` implementation as:
  `aten.bernoulli(x) = rand_like(x) < x`.
- It also fixes the case where the input to `aten.bernoulli_.float` is
  an integer tensor. In this case the input must be casted to float type
  before passing it as operand to `aten.rand_like` op.
  `aten.bernoulli_.float(x, p) = rand_like(float(x)) < p`.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/644/head snapshot-20220305.305
Gaurav Shukla 2022-02-25 22:05:04 +05:30
parent af551bd9cd
commit e57d3f9774
8 changed files with 420 additions and 103 deletions

View File

@ -4,7 +4,6 @@ from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class UniformModule(torch.nn.Module):
@ -90,17 +89,73 @@ class BernoulliModule(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
x = torch.bernoulli(a)
mean = torch.mean(x)
std = torch.std(x)
def forward(self, x, y, z):
a = torch.bernoulli(x)
b = torch.bernoulli(y)
c = torch.bernoulli(z)
mean = torch.cat([
torch.flatten(torch.mean(a)),
torch.flatten(torch.mean(b)),
torch.flatten(torch.mean(c))
])
std = torch.cat([
torch.flatten(torch.std(a)),
torch.flatten(torch.std(b)),
torch.flatten(torch.std(c))
])
return mean, std
@register_test_case(module_factory=lambda: BernoulliModule())
def BernoulliModule_basic(module, tu: TestUtils):
module.forward(tu.rand(256, 512, 64))
module.forward(
tu.rand(256, 512, 8).double(),
tu.rand(512, 1024, 4).double(),
tu.rand(512, 256, 4).double())
# ==============================================================================
class BernoulliZerosModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float64, True),
])
def forward(self, x):
return torch.bernoulli(x)
@register_test_case(module_factory=lambda: BernoulliZerosModule())
def BernoulliZerosModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(4, 8).double())
# ==============================================================================
class BernoulliOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float64, True),
])
def forward(self, x):
return torch.bernoulli(x)
@register_test_case(module_factory=lambda: BernoulliOnesModule())
def BernoulliOnesModule_basic(module, tu: TestUtils):
module.forward(torch.ones(4, 8).double())
# ==============================================================================
class BernoulliFloatModule(torch.nn.Module):
def __init__(self):
@ -113,25 +168,69 @@ class BernoulliFloatModule(torch.nn.Module):
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
])
def forward(self, a, b, c):
x = torch.ops.aten.bernoulli_(a, 0.4)
y = torch.ops.aten.bernoulli_(b, 0.7)
z = torch.ops.aten.bernoulli_(c, 0.5)
def forward(self, x, y, z):
a = torch.ops.aten.bernoulli_(x, 0.4)
b = torch.ops.aten.bernoulli_(y, 0.7)
c = torch.ops.aten.bernoulli_(z, 0.5)
mean = torch.cat([
torch.flatten(torch.mean(x)),
torch.flatten(torch.mean(y)),
torch.flatten(torch.mean(z))
torch.flatten(torch.mean(a)),
torch.flatten(torch.mean(b)),
torch.flatten(torch.mean(c))
])
std = torch.cat([
torch.flatten(torch.std(x)),
torch.flatten(torch.std(y)),
torch.flatten(torch.std(z))
torch.flatten(torch.std(a)),
torch.flatten(torch.std(b)),
torch.flatten(torch.std(c))
])
return mean, std
@register_test_case(module_factory=lambda: BernoulliFloatModule())
def BernoulliFloatModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(256, 512, 8).double(),
tu.rand(512, 1024, 4).double(),
tu.rand(512, 256, 4).double())
# ==============================================================================
class BernoulliTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
])
def forward(self, x, px, y, py, z, pz):
a = torch.ops.aten.bernoulli_(x, px)
b = torch.ops.aten.bernoulli_(y, py)
c = torch.ops.aten.bernoulli_(z, pz)
mean = torch.cat([
torch.flatten(torch.mean(a)),
torch.flatten(torch.mean(b)),
torch.flatten(torch.mean(c))
])
std = torch.cat([
torch.flatten(torch.std(a)),
torch.flatten(torch.std(b)),
torch.flatten(torch.std(c))
])
return mean, std
@register_test_case(module_factory=lambda: BernoulliTensorModule())
def BernoulliTensorModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(512, 512, 8).double(),
tu.rand(512, 512, 8).double(),
tu.rand(512, 1024, 4).double(),
tu.rand(512, 1024, 4).double(),
tu.rand(512, 256, 4).double(),
tu.rand(512, 256, 4).double())

View File

@ -1502,6 +1502,25 @@ def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($from)) `,` qualified(type($to)) `,` qualified(type($generator)) `->` qualified(type($result))";
}
def Torch_AtenRandLikeOp : Torch_Op<"aten.rand_like", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchOptionalIntType:$dtype,
TorchOptionalIntType:$layout,
TorchOptionalDeviceType:$device,
TorchOptionalBoolType:$pin_memory,
TorchOptionalIntType:$memory_format
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` qualified(type($self)) `,` qualified(type($dtype)) `,` qualified(type($layout)) `,` qualified(type($device)) `,` qualified(type($pin_memory)) `,` qualified(type($memory_format)) `->` qualified(type($result))";
}
def Torch_AtenBernoulliOp : Torch_Op<"aten.bernoulli", [
AllowsTypeRefinement,
HasValueSemantics
@ -1532,6 +1551,21 @@ def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($p)) `,` qualified(type($generator)) `->` qualified(type($result))";
}
def Torch_AtenBernoulli_TensorOp : Torch_Op<"aten.bernoulli_.Tensor", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$p,
TorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($p)) `,` qualified(type($generator)) `->` qualified(type($result))";
}
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -953,6 +953,24 @@ def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)";
}
// The corresponding without underscore variant for `torch.aten.bernoulli_.Tensor`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenBernoulliTensorOp: Torch_Op<"pseudo.aten.bernoulli.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
]> {
let summary = "Generated op for `aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$p,
TorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($p)) `,` qualified(type($generator)) `->` qualified(type($result))";
}
// The corresponding without underscore variant for `torch.aten.fill_.Scalar`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenFillScalarOp: Torch_Op<"pseudo.aten.fill.Scalar", [

View File

@ -131,8 +131,8 @@ static Value convertTensorToDtype(PatternRewriter &rewriter, Location loc,
Value input, Type dtype) {
BaseTensorType origType = input.getType().cast<BaseTensorType>();
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
// `convertIntVal` contains the corresponding integer for the dtype which is used
// by the aten.to.dtype op.
// `convertIntVal` contains the corresponding integer for the dtype which is
// used by the aten.to.dtype op.
Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype);
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
@ -141,22 +141,21 @@ static Value convertTensorToDtype(PatternRewriter &rewriter, Location loc,
return converted;
}
// Helper to create a tensor filled with the given scalar. Scalar would be
// converted the to the element type of the given tensor type.
// Helper to create a tensor filled with the given `scalar`. `scalar` would be
// converted to the element type of the given `resultType`.
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
Type resultType, Value scalar, Value sizeList) {
BaseTensorType tensorType = resultType.cast<BaseTensorType>();
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, tensorType, sizeList, /*dtype=*/noneVal, /*layout=*/noneVal,
/*device=*/noneVal,
/*pin_memory=*/noneVal, /*memory_format=*/noneVal);
/*device=*/noneVal, /*pin_memory=*/noneVal, /*memory_format=*/noneVal);
return rewriter.create<PseudoAtenFillScalarOp>(loc, resultType, emptyTensor,
scalar);
}
// Helper to create a rank0 tensor filled with the given scalar. Scalar would be
// converted the to the element type of the given tensor type.
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
// would be converted to the element type of the given `inputType`.
static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar) {
SmallVector<int64_t> sizes;
@ -930,66 +929,125 @@ public:
};
} // namespace
// Returns a tensor with bernoulli(p) distribution.
// Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p).
static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
Operation *op, Location loc,
Value input, double p,
Value &result) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes() || !inputType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "Can't decomposeBernoulliLikeOp without sizes or dtype");
}
BaseTensorType boolType =
inputType
.getWithSizesAndDtype(
inputType.getSizes(),
IntegerType::get(op->getContext(), 1, IntegerType::Signless))
.cast<BaseTensorType>();
Value prob =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(p));
Value lb =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value ub =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
// Create a uniform random op with low and high set to lb and ub respectively.
Value uniformRandom = rewriter.create<PseudoAtenUniformOp>(
loc, inputType, input, lb, ub, noneVal);
Value gtValue =
rewriter.create<AtenLtScalarOp>(loc, boolType, uniformRandom, prob);
// Since `gtValue` will be a boolean tensor convert it back to the original
// type.
result = convertTensorToDtype(rewriter, loc, gtValue, inputType.getDtype());
return success();
}
namespace {
class DecomposeAtenBernoulliOp : public OpRewritePattern<AtenBernoulliOp> {
class DecomposeAtenRandLikeOp : public OpRewritePattern<AtenRandLikeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBernoulliOp op,
LogicalResult matchAndRewrite(AtenRandLikeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
Value generator = op.generator();
if (!generator.getType().isa<Torch::NoneType>())
Value input = op.self();
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op,
"only support floating-point type");
// TODO: Add support for layout, pin_memory and memory_format features.
// Only `none` layout is supported.
if (!op.layout().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to ben None because only global default "
"generator is supported");
Value result;
if (failed(decomposeBernoulliLikeOp(rewriter, op, loc, self, /*p=*/0.5,
result)))
return failure();
rewriter.replaceOp(op, result);
op, "unimplemented: only default layout is supported");
// The pin_memory should be either `none` or constant `False`.
if (!op.pin_memory().getType().isa<Torch::NoneType>()) {
bool pinMemory;
if (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)))
return rewriter.notifyMatchFailure(
op, "unimplemented: pin_memory must be a constant");
else if (pinMemory)
return rewriter.notifyMatchFailure(
op, "unimplemented: pin_memory is expected to be false");
}
// Only `none` memory_format is supported.
if (!op.memory_format().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default memory format is supported");
// Create a uniform random op with low and high set to 0.0 and 1.0
// respectively.
Value none = rewriter.create<ConstantNoneOp>(loc);
Value lb =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value ub =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
rewriter.replaceOpWithNewOp<PseudoAtenUniformOp>(
op, op.getType(), input, lb, ub, /*generator=*/none);
return success();
}
};
} // namespace
namespace {
// Bernoulli(x, p) = (rand_like(float(x)) < p).cast(type(x)). Here,
// 1. p must be a float tensor.
// 2. The shape of p should be broadcastable to the shape of x.
// 3. Bernoulli(x, p) returns a tensor of the same type as that of x.
static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
Operation *op, Location loc,
Value input, Value prob,
Value &output) {
auto inputType = input.getType().cast<BaseTensorType>();
auto probType = prob.getType().cast<BaseTensorType>();
// Both the `input` and `prob` must be ranked tensors.
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
!probType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "can't decompose bernoulli like ops without sizes or dtype");
}
// The `prob` is expected to be a float type tensor.
if (!probType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "probabilities must be a float type tensor");
}
// Since the `aten.rand_like` op expects float-type operand, create a
// float-type tensor with the same shape as that of the `input`.
Value floatTensor =
convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type());
Value none = rewriter.create<ConstantNoneOp>(loc);
Value randomVal = rewriter.create<AtenRandLikeOp>(
loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
// Bernoulli(x, p) = rand_like(float(x)) < p.
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
rewriter.getI1Type());
Value lessThanP =
rewriter.create<AtenLtTensorOp>(loc, boolResType, randomVal, prob);
// As the `output` is expected to be of the `input` type, convert the boolean
// tensor `lessThanP` to a `input` type tensor.
output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype());
return success();
}
// aten.bernoulli(x) = rand_like(x) < x. Here, the input x is a tensor
// containing probabilities to be used for drawing the binary random number.
class DecomposeAtenBernoulliOp : public OpRewritePattern<AtenBernoulliOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBernoulliOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
if (!op.generator().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to ben None because only global default "
"generator is supported");
Value output;
if (failed(
decomposeBernoulliLikeOp(rewriter, op, loc, input, input, output)))
return rewriter.notifyMatchFailure(
op, "decomposeBernoulliLikeOp failed to decompose the op");
rewriter.replaceOp(op, output);
return success();
}
};
// aten.bernoulli.float(x, p) = (rand_like(float(x)) < tensor(p)).cast(type(x)).
// Since the input x can be an integer tensor, it's important to cast it to
// float type before passing it to the `aten.rand_like` op.
class DecomposePseudoAtenBernoulliFloatOp
: public OpRewritePattern<PseudoAtenBernoulliFloatOp> {
public:
@ -997,20 +1055,50 @@ public:
LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
Value generator = op.generator();
double p;
if (!matchPattern(op.p(), m_TorchConstantFloat(&p)))
return rewriter.notifyMatchFailure(op, "p should be constant float");
if (!generator.getType().isa<Torch::NoneType>())
Value input = op.self();
Value p = op.p();
if (!op.generator().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to ben None because only global default "
"generator is supported");
Value result;
if (failed(decomposeBernoulliLikeOp(rewriter, op, loc, self, p, result)))
return failure();
rewriter.replaceOp(op, result);
auto inputType = input.getType().cast<BaseTensorType>();
SmallVector<int64_t> empty;
Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty),
rewriter.getF64Type());
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
Value output;
if (failed(
decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output)))
return rewriter.notifyMatchFailure(
op, "decomposeBernoulliLikeOp failed to decompose the op");
rewriter.replaceOp(op, output);
return success();
}
};
// aten.bernoulli.Tensor(x, p) = (rand_like(float(x)) < p).cast(type(x)).
// Since the input x can be an integer tensor, it's important to cast it to
// float type before passing it to the `aten.rand_like` op.
class DecomposePseudoAtenBernoulliTensorOp
: public OpRewritePattern<PseudoAtenBernoulliTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PseudoAtenBernoulliTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Value prob = op.p();
if (!op.generator().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to ben None because only global default "
"generator is supported");
Value output;
if (failed(
decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output)))
return rewriter.notifyMatchFailure(
op, "decomposeBernoulliLikeOp failed to decompose the op");
rewriter.replaceOp(op, output);
return success();
}
};
@ -1425,6 +1513,10 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenBernoulliOp>();
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
patterns.add<DecomposePseudoAtenBernoulliTensorOp>(context);
target.addIllegalOp<PseudoAtenBernoulliTensorOp>();
patterns.add<DecomposeAtenRandLikeOp>(context);
target.addIllegalOp<AtenRandLikeOp>();
patterns.add<DecomposeAtenHardsigmoidOp>(context);
target.addIllegalOp<AtenHardsigmoidOp>();
patterns.add<DecomposeAtenHardswishOp>(context);

View File

@ -152,6 +152,9 @@ public:
} else if (isa<AtenBernoulli_FloatOp>(op)) {
newOp = rewriter.create<PseudoAtenBernoulliFloatOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenBernoulli_TensorOp>(op)) {
newOp = rewriter.create<PseudoAtenBernoulliTensorOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenFill_ScalarOp>(op)) {
newOp = rewriter.create<PseudoAtenFillScalarOp>(loc, op->getResultTypes(),
op->getOperands());
@ -232,6 +235,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenUniform_Op>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.addIllegalOp<AtenBernoulli_TensorOp>();
target.addIllegalOp<AtenFill_ScalarOp>();
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {

View File

@ -229,10 +229,10 @@ public:
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp,
AtenHardsigmoidOp, AtenHardswishOp, AtenSiluOp, AtenHardtanhOp>(
op)) {
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
PseudoAtenBernoulliFloatOp, PseudoAtenBernoulliTensorOp,
PseudoAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}
@ -399,6 +399,8 @@ public:
return visitConstantTensorNewLikeOp<AtenNewZerosOp>(newZeros, operands);
} else if (auto newOnes = dyn_cast<AtenNewOnesOp>(op)) {
return visitConstantTensorNewLikeOp<AtenNewOnesOp>(newOnes, operands);
} else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) {
return visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands);
} else if (auto toDtype = dyn_cast<AtenToDtypeOp>(op)) {
return visitAtenToDtypeOp(toDtype, operands);
} else if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {

View File

@ -508,8 +508,10 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
# underscore variant doesn't exist.
emit("aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)")
emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")

View File

@ -377,26 +377,92 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
// -----
// CHECK-LABEL: func @torch.aten.bernoulli
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor {
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FLOAT0_5:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[INT7:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[NONE0:.*]] = torch.constant.none
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[INP]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE0]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[GT:.*]] = torch.aten.lt.Scalar %[[UNF]], %[[FLOAT0_5]] : !torch.vtensor<[?,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE1:.*]] = torch.constant.none
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[GT]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE1]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
// CHECK: %[[NONE_3:.*]] = torch.constant.none
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor {
func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none
%0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[?,?,?],f32>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
%0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.float
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01
// CHECK: %[[PROB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[PROB]] : !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[INT7:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
// CHECK: %[[NONE_3:.*]] = torch.constant.none
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none
%prob = torch.constant.float 4.000000e-01
%0 = torch.pseudo.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.Tensor(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>,
// CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT7:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
// CHECK: %[[NONE_3:.*]] = torch.constant.none
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @torch.pseudo.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none
%0 = torch.pseudo.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
return %1 : !torch.vtensor
}