[TorchToLinalg] address a dtype mismatch in `aten.multinomial` lowering (#3630)

Resolves <https://github.com/llvm/torch-mlir/issues/3628>
Unblocks a compile failure for one of the MiGraphx models
(`AgentModel`).
pull/3657/head
zjgarvey 2024-08-20 13:14:48 -07:00 committed by GitHub
parent f72770a725
commit f66908f190
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 15 deletions

View File

@ -287,8 +287,16 @@ public:
Value initSum = rewriter.create<arith::ConstantOp>(
loc, f64Ty, rewriter.getF64FloatAttr(0.0));
int64_t srcWidth = cast<mlir::FloatType>(elemTy).getWidth();
if (srcWidth > 64)
op->emitWarning("Op bitwidth will be truncated from " +
std::to_string(srcWidth) + " bits to 64 bits.");
auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value input = payloadArgs[0];
if (srcWidth < 64)
input = b.create<arith::ExtFOp>(loc, f64Ty, input);
if (srcWidth > 64)
input = b.create<arith::TruncFOp>(loc, f64Ty, input);
Value result = payloadArgs[1];
Value nextSum = b.create<arith::AddFOp>(loc, input, result);
b.create<linalg::YieldOp>(loc, nextSum);
@ -310,7 +318,7 @@ public:
// compute cdf in loop
Value initCdf = b.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), elemTy);
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty);
Value cdf =
b.create<scf::ForOp>(
loc, cstZero, numCategories, cstOne, ValueRange{initCdf},
@ -330,6 +338,11 @@ public:
ind = ValueRange{jIndex, iIndex};
}
Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind);
if (srcWidth < 64)
currWeight = b.create<arith::ExtFOp>(loc, f64Ty, currWeight);
if (srcWidth > 64)
currWeight =
b.create<arith::TruncFOp>(loc, f64Ty, currWeight);
Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum);
Value currCum =
b.create<scf::IfOp>(

View File

@ -2318,6 +2318,8 @@ ONNX_XFAIL_SET = {
"ElementwiseLog2IntModule_basic",
"ElementwiseFminModule_basic",
"ElementwiseFmaxModule_basic",
"MultinomialModule2D_basic",
"MultinomialModule2D_F32",
"PixelShuffleModuleStaticRank4Float32_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
@ -2346,6 +2348,8 @@ ONNX_XFAIL_SET = {
"MoveDimIntNegativeIndexModule_basic",
"ReduceL3NormKeepDimModule_basic",
"ViewSizeFromOtherTensor_basic",
# incorrect shape generated by torch.onnx.export (needs an unsqueeze)
"MultinomialModule_basic",
# Failure - onnx_export
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
@ -2849,8 +2853,6 @@ ONNX_XFAIL_SET = {
"ElementwiseUnaryIntModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"MultinomialModule_basic",
"MultinomialModule2D_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceAnyFloatModule_basic",

View File

@ -377,10 +377,20 @@ def BernoulliPModule_basic(module, tu: TestUtils):
# ==============================================================================
class MultinomialModule(torch.nn.Module):
def __init__(self):
super().__init__()
def generate_sample_distr(sizes: list[int], torchdtype, tu: TestUtils):
assert len(sizes) == 1 or len(sizes) == 2
init = tu.rand(*sizes).to(dtype=torchdtype).abs()
normalized = init / (init.sum(-1, True, dtype=torchdtype))
return normalized
class MultinomialBase(torch.nn.Module):
def _forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a
class MultinomialModule(MultinomialBase):
@export
@annotate_args(
[
@ -389,20 +399,36 @@ class MultinomialModule(torch.nn.Module):
]
)
def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a.mean(dtype=torch.double)
return self._forward(x).mean(dtype=torch.double)
@register_test_case(module_factory=lambda: MultinomialModule())
def MultinomialModule_basic(module, tu: TestUtils):
x = tu.rand(100).double()
x = generate_sample_distr([100], torch.float64, tu)
module.forward(x)
class MultinomialModule2D(torch.nn.Module):
def __init__(self):
super().__init__()
class MultinomialModule2DF32(MultinomialBase):
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
# note: this should really call mean(-1)
# for some reason, doing this causes a torchscript numerics error?
return self._forward(x).mean(dtype=torch.double)
@register_test_case(module_factory=lambda: MultinomialModule2DF32())
def MultinomialModule2D_F32(module, tu: TestUtils):
x = generate_sample_distr([10, 100], torch.float32, tu)
module.forward(x)
class MultinomialModule2D(MultinomialBase):
@export
@annotate_args(
[
@ -411,13 +437,14 @@ class MultinomialModule2D(torch.nn.Module):
]
)
def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a.mean(dtype=torch.double)
# note: this should really call mean(-1)
# for some reason, doing this causes a torchscript numerics error?
return self._forward(x).mean(dtype=torch.double)
@register_test_case(module_factory=lambda: MultinomialModule2D())
def MultinomialModule2D_basic(module, tu: TestUtils):
x = tu.rand(10, 100).double()
x = generate_sample_distr([10, 100], torch.float64, tu)
module.forward(x)