mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] address a dtype mismatch in `aten.multinomial` lowering (#3630)
Resolves <https://github.com/llvm/torch-mlir/issues/3628> Unblocks a compile failure for one of the MiGraphx models (`AgentModel`).pull/3657/head
parent
f72770a725
commit
f66908f190
|
@ -287,8 +287,16 @@ public:
|
|||
|
||||
Value initSum = rewriter.create<arith::ConstantOp>(
|
||||
loc, f64Ty, rewriter.getF64FloatAttr(0.0));
|
||||
int64_t srcWidth = cast<mlir::FloatType>(elemTy).getWidth();
|
||||
if (srcWidth > 64)
|
||||
op->emitWarning("Op bitwidth will be truncated from " +
|
||||
std::to_string(srcWidth) + " bits to 64 bits.");
|
||||
auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
Value input = payloadArgs[0];
|
||||
if (srcWidth < 64)
|
||||
input = b.create<arith::ExtFOp>(loc, f64Ty, input);
|
||||
if (srcWidth > 64)
|
||||
input = b.create<arith::TruncFOp>(loc, f64Ty, input);
|
||||
Value result = payloadArgs[1];
|
||||
Value nextSum = b.create<arith::AddFOp>(loc, input, result);
|
||||
b.create<linalg::YieldOp>(loc, nextSum);
|
||||
|
@ -310,7 +318,7 @@ public:
|
|||
|
||||
// compute cdf in loop
|
||||
Value initCdf = b.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), elemTy);
|
||||
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty);
|
||||
Value cdf =
|
||||
b.create<scf::ForOp>(
|
||||
loc, cstZero, numCategories, cstOne, ValueRange{initCdf},
|
||||
|
@ -330,6 +338,11 @@ public:
|
|||
ind = ValueRange{jIndex, iIndex};
|
||||
}
|
||||
Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind);
|
||||
if (srcWidth < 64)
|
||||
currWeight = b.create<arith::ExtFOp>(loc, f64Ty, currWeight);
|
||||
if (srcWidth > 64)
|
||||
currWeight =
|
||||
b.create<arith::TruncFOp>(loc, f64Ty, currWeight);
|
||||
Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum);
|
||||
Value currCum =
|
||||
b.create<scf::IfOp>(
|
||||
|
|
|
@ -2318,6 +2318,8 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseLog2IntModule_basic",
|
||||
"ElementwiseFminModule_basic",
|
||||
"ElementwiseFmaxModule_basic",
|
||||
"MultinomialModule2D_basic",
|
||||
"MultinomialModule2D_F32",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"ReflectionPad1dModule2dInput_Right",
|
||||
"ReflectionPad1dModule2dInput_basic",
|
||||
|
@ -2346,6 +2348,8 @@ ONNX_XFAIL_SET = {
|
|||
"MoveDimIntNegativeIndexModule_basic",
|
||||
"ReduceL3NormKeepDimModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
# incorrect shape generated by torch.onnx.export (needs an unsqueeze)
|
||||
"MultinomialModule_basic",
|
||||
# Failure - onnx_export
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||
|
@ -2849,8 +2853,6 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseUnaryIntModule_basic",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
"MultinomialModule_basic",
|
||||
"MultinomialModule2D_basic",
|
||||
"NativeDropoutTrainModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"ReduceAnyFloatModule_basic",
|
||||
|
|
|
@ -377,10 +377,20 @@ def BernoulliPModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class MultinomialModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def generate_sample_distr(sizes: list[int], torchdtype, tu: TestUtils):
|
||||
assert len(sizes) == 1 or len(sizes) == 2
|
||||
init = tu.rand(*sizes).to(dtype=torchdtype).abs()
|
||||
normalized = init / (init.sum(-1, True, dtype=torchdtype))
|
||||
return normalized
|
||||
|
||||
|
||||
class MultinomialBase(torch.nn.Module):
|
||||
def _forward(self, x):
|
||||
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
|
||||
return a
|
||||
|
||||
|
||||
class MultinomialModule(MultinomialBase):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
|
@ -389,20 +399,36 @@ class MultinomialModule(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
|
||||
return a.mean(dtype=torch.double)
|
||||
return self._forward(x).mean(dtype=torch.double)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MultinomialModule())
|
||||
def MultinomialModule_basic(module, tu: TestUtils):
|
||||
x = tu.rand(100).double()
|
||||
x = generate_sample_distr([100], torch.float64, tu)
|
||||
module.forward(x)
|
||||
|
||||
|
||||
class MultinomialModule2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
class MultinomialModule2DF32(MultinomialBase):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
# note: this should really call mean(-1)
|
||||
# for some reason, doing this causes a torchscript numerics error?
|
||||
return self._forward(x).mean(dtype=torch.double)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MultinomialModule2DF32())
|
||||
def MultinomialModule2D_F32(module, tu: TestUtils):
|
||||
x = generate_sample_distr([10, 100], torch.float32, tu)
|
||||
module.forward(x)
|
||||
|
||||
|
||||
class MultinomialModule2D(MultinomialBase):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
|
@ -411,13 +437,14 @@ class MultinomialModule2D(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
|
||||
return a.mean(dtype=torch.double)
|
||||
# note: this should really call mean(-1)
|
||||
# for some reason, doing this causes a torchscript numerics error?
|
||||
return self._forward(x).mean(dtype=torch.double)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MultinomialModule2D())
|
||||
def MultinomialModule2D_basic(module, tu: TestUtils):
|
||||
x = tu.rand(10, 100).double()
|
||||
x = generate_sample_distr([10, 100], torch.float64, tu)
|
||||
module.forward(x)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue