mirror of https://github.com/llvm/torch-mlir
[LINALG] Fix `aten.bernoulli` op lowering
- This commit adds E2E support for `aten.rand_like` and `aten.bernoulli_.Tensor` ops. - The `aten.bernoulli(x)` was implemented as: `aten.bernoulli(x) = rand_like(x) < 0.5`, assuming 0.5 as default probability, whereas according to the pytorch documentation: https://pytorch.org/docs/stable/generated/torch.bernoulli.html#torch.bernoulli the input x in `aten.bernoulli(x)` is itself a tensor containing probabilities to be used for drawing the binary random number. - So this commit fixes the `aten.bernoulli(x)` implementation as: `aten.bernoulli(x) = rand_like(x) < x`. - It also fixes the case where the input to `aten.bernoulli_.float` is an integer tensor. In this case the input must be casted to float type before passing it as operand to `aten.rand_like` op. `aten.bernoulli_.float(x, p) = rand_like(float(x)) < p`. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/644/head snapshot-20220305.305
parent
af551bd9cd
commit
e57d3f9774
|
@ -4,7 +4,6 @@ from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
|||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class UniformModule(torch.nn.Module):
|
||||
|
@ -90,17 +89,73 @@ class BernoulliModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
x = torch.bernoulli(a)
|
||||
mean = torch.mean(x)
|
||||
std = torch.std(x)
|
||||
def forward(self, x, y, z):
|
||||
a = torch.bernoulli(x)
|
||||
b = torch.bernoulli(y)
|
||||
c = torch.bernoulli(z)
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
std = torch.cat([
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliModule())
|
||||
def BernoulliModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(256, 512, 64))
|
||||
module.forward(
|
||||
tu.rand(256, 512, 8).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BernoulliZerosModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.bernoulli(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliZerosModule())
|
||||
def BernoulliZerosModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(4, 8).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BernoulliOnesModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.bernoulli(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliOnesModule())
|
||||
def BernoulliOnesModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.ones(4, 8).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BernoulliFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -113,25 +168,69 @@ class BernoulliFloatModule(torch.nn.Module):
|
|||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b, c):
|
||||
x = torch.ops.aten.bernoulli_(a, 0.4)
|
||||
y = torch.ops.aten.bernoulli_(b, 0.7)
|
||||
z = torch.ops.aten.bernoulli_(c, 0.5)
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.bernoulli_(x, 0.4)
|
||||
b = torch.ops.aten.bernoulli_(y, 0.7)
|
||||
c = torch.ops.aten.bernoulli_(z, 0.5)
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.mean(x)),
|
||||
torch.flatten(torch.mean(y)),
|
||||
torch.flatten(torch.mean(z))
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
std = torch.cat([
|
||||
torch.flatten(torch.std(x)),
|
||||
torch.flatten(torch.std(y)),
|
||||
torch.flatten(torch.std(z))
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliFloatModule())
|
||||
def BernoulliFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(256, 512, 8).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BernoulliTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, px, y, py, z, pz):
|
||||
a = torch.ops.aten.bernoulli_(x, px)
|
||||
b = torch.ops.aten.bernoulli_(y, py)
|
||||
c = torch.ops.aten.bernoulli_(z, pz)
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
std = torch.cat([
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
||||
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 512, 8).double(),
|
||||
tu.rand(512, 512, 8).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
|
|
|
@ -1502,6 +1502,25 @@ def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [
|
|||
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($from)) `,` qualified(type($to)) `,` qualified(type($generator)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenRandLikeOp : Torch_Op<"aten.rand_like", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
TorchOptionalDeviceType:$device,
|
||||
TorchOptionalBoolType:$pin_memory,
|
||||
TorchOptionalIntType:$memory_format
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` qualified(type($self)) `,` qualified(type($dtype)) `,` qualified(type($layout)) `,` qualified(type($device)) `,` qualified(type($pin_memory)) `,` qualified(type($memory_format)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenBernoulliOp : Torch_Op<"aten.bernoulli", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -1532,6 +1551,21 @@ def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [
|
|||
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($p)) `,` qualified(type($generator)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenBernoulli_TensorOp : Torch_Op<"aten.bernoulli_.Tensor", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$p,
|
||||
TorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($p)) `,` qualified(type($generator)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -953,6 +953,24 @@ def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [
|
|||
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)";
|
||||
}
|
||||
|
||||
// The corresponding without underscore variant for `torch.aten.bernoulli_.Tensor`
|
||||
// doesn't exist in the pytorch ops registry. Add it here.
|
||||
def Torch_PseudoAtenBernoulliTensorOp: Torch_Op<"pseudo.aten.bernoulli.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
]> {
|
||||
let summary = "Generated op for `aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$p,
|
||||
TorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($p)) `,` qualified(type($generator)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
// The corresponding without underscore variant for `torch.aten.fill_.Scalar`
|
||||
// doesn't exist in the pytorch ops registry. Add it here.
|
||||
def Torch_PseudoAtenFillScalarOp: Torch_Op<"pseudo.aten.fill.Scalar", [
|
||||
|
|
|
@ -131,8 +131,8 @@ static Value convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
|||
Value input, Type dtype) {
|
||||
BaseTensorType origType = input.getType().cast<BaseTensorType>();
|
||||
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
|
||||
// `convertIntVal` contains the corresponding integer for the dtype which is used
|
||||
// by the aten.to.dtype op.
|
||||
// `convertIntVal` contains the corresponding integer for the dtype which is
|
||||
// used by the aten.to.dtype op.
|
||||
Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype);
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
@ -141,22 +141,21 @@ static Value convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
|||
return converted;
|
||||
}
|
||||
|
||||
// Helper to create a tensor filled with the given scalar. Scalar would be
|
||||
// converted the to the element type of the given tensor type.
|
||||
// Helper to create a tensor filled with the given `scalar`. `scalar` would be
|
||||
// converted to the element type of the given `resultType`.
|
||||
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
||||
Type resultType, Value scalar, Value sizeList) {
|
||||
BaseTensorType tensorType = resultType.cast<BaseTensorType>();
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||
loc, tensorType, sizeList, /*dtype=*/noneVal, /*layout=*/noneVal,
|
||||
/*device=*/noneVal,
|
||||
/*pin_memory=*/noneVal, /*memory_format=*/noneVal);
|
||||
/*device=*/noneVal, /*pin_memory=*/noneVal, /*memory_format=*/noneVal);
|
||||
return rewriter.create<PseudoAtenFillScalarOp>(loc, resultType, emptyTensor,
|
||||
scalar);
|
||||
}
|
||||
|
||||
// Helper to create a rank0 tensor filled with the given scalar. Scalar would be
|
||||
// converted the to the element type of the given tensor type.
|
||||
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
|
||||
// would be converted to the element type of the given `inputType`.
|
||||
static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
||||
BaseTensorType inputType, Value scalar) {
|
||||
SmallVector<int64_t> sizes;
|
||||
|
@ -930,66 +929,125 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Returns a tensor with bernoulli(p) distribution.
|
||||
// Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p).
|
||||
static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
|
||||
Operation *op, Location loc,
|
||||
Value input, double p,
|
||||
Value &result) {
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasSizes() || !inputType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Can't decomposeBernoulliLikeOp without sizes or dtype");
|
||||
}
|
||||
BaseTensorType boolType =
|
||||
inputType
|
||||
.getWithSizesAndDtype(
|
||||
inputType.getSizes(),
|
||||
IntegerType::get(op->getContext(), 1, IntegerType::Signless))
|
||||
.cast<BaseTensorType>();
|
||||
Value prob =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(p));
|
||||
Value lb =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
Value ub =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
// Create a uniform random op with low and high set to lb and ub respectively.
|
||||
Value uniformRandom = rewriter.create<PseudoAtenUniformOp>(
|
||||
loc, inputType, input, lb, ub, noneVal);
|
||||
Value gtValue =
|
||||
rewriter.create<AtenLtScalarOp>(loc, boolType, uniformRandom, prob);
|
||||
// Since `gtValue` will be a boolean tensor convert it back to the original
|
||||
// type.
|
||||
result = convertTensorToDtype(rewriter, loc, gtValue, inputType.getDtype());
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenBernoulliOp : public OpRewritePattern<AtenBernoulliOp> {
|
||||
class DecomposeAtenRandLikeOp : public OpRewritePattern<AtenRandLikeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenBernoulliOp op,
|
||||
LogicalResult matchAndRewrite(AtenRandLikeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.self();
|
||||
Value generator = op.generator();
|
||||
if (!generator.getType().isa<Torch::NoneType>())
|
||||
Value input = op.self();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support floating-point type");
|
||||
|
||||
// TODO: Add support for layout, pin_memory and memory_format features.
|
||||
// Only `none` layout is supported.
|
||||
if (!op.layout().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to ben None because only global default "
|
||||
"generator is supported");
|
||||
Value result;
|
||||
if (failed(decomposeBernoulliLikeOp(rewriter, op, loc, self, /*p=*/0.5,
|
||||
result)))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, result);
|
||||
op, "unimplemented: only default layout is supported");
|
||||
|
||||
// The pin_memory should be either `none` or constant `False`.
|
||||
if (!op.pin_memory().getType().isa<Torch::NoneType>()) {
|
||||
bool pinMemory;
|
||||
if (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: pin_memory must be a constant");
|
||||
else if (pinMemory)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: pin_memory is expected to be false");
|
||||
}
|
||||
|
||||
// Only `none` memory_format is supported.
|
||||
if (!op.memory_format().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default memory format is supported");
|
||||
|
||||
// Create a uniform random op with low and high set to 0.0 and 1.0
|
||||
// respectively.
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value lb =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
Value ub =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
rewriter.replaceOpWithNewOp<PseudoAtenUniformOp>(
|
||||
op, op.getType(), input, lb, ub, /*generator=*/none);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Bernoulli(x, p) = (rand_like(float(x)) < p).cast(type(x)). Here,
|
||||
// 1. p must be a float tensor.
|
||||
// 2. The shape of p should be broadcastable to the shape of x.
|
||||
// 3. Bernoulli(x, p) returns a tensor of the same type as that of x.
|
||||
static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
|
||||
Operation *op, Location loc,
|
||||
Value input, Value prob,
|
||||
Value &output) {
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto probType = prob.getType().cast<BaseTensorType>();
|
||||
// Both the `input` and `prob` must be ranked tensors.
|
||||
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
|
||||
!probType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "can't decompose bernoulli like ops without sizes or dtype");
|
||||
}
|
||||
// The `prob` is expected to be a float type tensor.
|
||||
if (!probType.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "probabilities must be a float type tensor");
|
||||
}
|
||||
|
||||
// Since the `aten.rand_like` op expects float-type operand, create a
|
||||
// float-type tensor with the same shape as that of the `input`.
|
||||
Value floatTensor =
|
||||
convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type());
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value randomVal = rewriter.create<AtenRandLikeOp>(
|
||||
loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
|
||||
|
||||
// Bernoulli(x, p) = rand_like(float(x)) < p.
|
||||
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
|
||||
rewriter.getI1Type());
|
||||
Value lessThanP =
|
||||
rewriter.create<AtenLtTensorOp>(loc, boolResType, randomVal, prob);
|
||||
|
||||
// As the `output` is expected to be of the `input` type, convert the boolean
|
||||
// tensor `lessThanP` to a `input` type tensor.
|
||||
output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype());
|
||||
return success();
|
||||
}
|
||||
|
||||
// aten.bernoulli(x) = rand_like(x) < x. Here, the input x is a tensor
|
||||
// containing probabilities to be used for drawing the binary random number.
|
||||
class DecomposeAtenBernoulliOp : public OpRewritePattern<AtenBernoulliOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenBernoulliOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
if (!op.generator().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to ben None because only global default "
|
||||
"generator is supported");
|
||||
Value output;
|
||||
if (failed(
|
||||
decomposeBernoulliLikeOp(rewriter, op, loc, input, input, output)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// aten.bernoulli.float(x, p) = (rand_like(float(x)) < tensor(p)).cast(type(x)).
|
||||
// Since the input x can be an integer tensor, it's important to cast it to
|
||||
// float type before passing it to the `aten.rand_like` op.
|
||||
class DecomposePseudoAtenBernoulliFloatOp
|
||||
: public OpRewritePattern<PseudoAtenBernoulliFloatOp> {
|
||||
public:
|
||||
|
@ -997,20 +1055,50 @@ public:
|
|||
LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.self();
|
||||
Value generator = op.generator();
|
||||
double p;
|
||||
if (!matchPattern(op.p(), m_TorchConstantFloat(&p)))
|
||||
return rewriter.notifyMatchFailure(op, "p should be constant float");
|
||||
|
||||
if (!generator.getType().isa<Torch::NoneType>())
|
||||
Value input = op.self();
|
||||
Value p = op.p();
|
||||
if (!op.generator().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to ben None because only global default "
|
||||
"generator is supported");
|
||||
Value result;
|
||||
if (failed(decomposeBernoulliLikeOp(rewriter, op, loc, self, p, result)))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
SmallVector<int64_t> empty;
|
||||
Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty),
|
||||
rewriter.getF64Type());
|
||||
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
|
||||
Value output;
|
||||
if (failed(
|
||||
decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// aten.bernoulli.Tensor(x, p) = (rand_like(float(x)) < p).cast(type(x)).
|
||||
// Since the input x can be an integer tensor, it's important to cast it to
|
||||
// float type before passing it to the `aten.rand_like` op.
|
||||
class DecomposePseudoAtenBernoulliTensorOp
|
||||
: public OpRewritePattern<PseudoAtenBernoulliTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PseudoAtenBernoulliTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
Value prob = op.p();
|
||||
if (!op.generator().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to ben None because only global default "
|
||||
"generator is supported");
|
||||
Value output;
|
||||
if (failed(
|
||||
decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1425,6 +1513,10 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenBernoulliOp>();
|
||||
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
|
||||
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
|
||||
patterns.add<DecomposePseudoAtenBernoulliTensorOp>(context);
|
||||
target.addIllegalOp<PseudoAtenBernoulliTensorOp>();
|
||||
patterns.add<DecomposeAtenRandLikeOp>(context);
|
||||
target.addIllegalOp<AtenRandLikeOp>();
|
||||
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
||||
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||
patterns.add<DecomposeAtenHardswishOp>(context);
|
||||
|
|
|
@ -152,6 +152,9 @@ public:
|
|||
} else if (isa<AtenBernoulli_FloatOp>(op)) {
|
||||
newOp = rewriter.create<PseudoAtenBernoulliFloatOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else if (isa<AtenBernoulli_TensorOp>(op)) {
|
||||
newOp = rewriter.create<PseudoAtenBernoulliTensorOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else if (isa<AtenFill_ScalarOp>(op)) {
|
||||
newOp = rewriter.create<PseudoAtenFillScalarOp>(loc, op->getResultTypes(),
|
||||
op->getOperands());
|
||||
|
@ -232,6 +235,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||
target.addIllegalOp<AtenUniform_Op>();
|
||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||
target.addIllegalOp<AtenBernoulli_TensorOp>();
|
||||
target.addIllegalOp<AtenFill_ScalarOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||
|
|
|
@ -229,10 +229,10 @@ public:
|
|||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
|
||||
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
|
||||
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp,
|
||||
AtenHardsigmoidOp, AtenHardswishOp, AtenSiluOp, AtenHardtanhOp>(
|
||||
op)) {
|
||||
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
|
||||
PseudoAtenBernoulliFloatOp, PseudoAtenBernoulliTensorOp,
|
||||
PseudoAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
||||
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
@ -399,6 +399,8 @@ public:
|
|||
return visitConstantTensorNewLikeOp<AtenNewZerosOp>(newZeros, operands);
|
||||
} else if (auto newOnes = dyn_cast<AtenNewOnesOp>(op)) {
|
||||
return visitConstantTensorNewLikeOp<AtenNewOnesOp>(newOnes, operands);
|
||||
} else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) {
|
||||
return visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands);
|
||||
} else if (auto toDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||
return visitAtenToDtypeOp(toDtype, operands);
|
||||
} else if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {
|
||||
|
|
|
@ -508,8 +508,10 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
# underscore variant doesn't exist.
|
||||
emit("aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)")
|
||||
emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
|
||||
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
|
||||
emit("aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
|
||||
|
||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
||||
|
|
|
@ -377,26 +377,92 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
|
|||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.bernoulli
|
||||
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor {
|
||||
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FLOAT0_5:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK: %[[INT7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[NONE0:.*]] = torch.constant.none
|
||||
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[INP]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE0]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[GT:.*]] = torch.aten.lt.Scalar %[[UNF]], %[[FLOAT0_5]] : !torch.vtensor<[?,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],i1>
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE1:.*]] = torch.constant.none
|
||||
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[GT]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE1]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
|
||||
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
|
||||
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_3:.*]] = torch.constant.none
|
||||
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor {
|
||||
func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
|
||||
%0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.float
|
||||
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01
|
||||
// CHECK: %[[PROB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[PROB]] : !torch.float -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[INT7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1>
|
||||
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_3:.*]] = torch.constant.none
|
||||
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
%none = torch.constant.none
|
||||
%prob = torch.constant.float 4.000000e-01
|
||||
%0 = torch.pseudo.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.Tensor(
|
||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>,
|
||||
// CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
|
||||
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_3:.*]] = torch.constant.none
|
||||
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func @torch.pseudo.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.pseudo.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue