mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] address a dtype mismatch in `aten.multinomial` lowering (#3630)
Resolves <https://github.com/llvm/torch-mlir/issues/3628> Unblocks a compile failure for one of the MiGraphx models (`AgentModel`).pull/3657/head
parent
f72770a725
commit
f66908f190
|
@ -287,8 +287,16 @@ public:
|
||||||
|
|
||||||
Value initSum = rewriter.create<arith::ConstantOp>(
|
Value initSum = rewriter.create<arith::ConstantOp>(
|
||||||
loc, f64Ty, rewriter.getF64FloatAttr(0.0));
|
loc, f64Ty, rewriter.getF64FloatAttr(0.0));
|
||||||
|
int64_t srcWidth = cast<mlir::FloatType>(elemTy).getWidth();
|
||||||
|
if (srcWidth > 64)
|
||||||
|
op->emitWarning("Op bitwidth will be truncated from " +
|
||||||
|
std::to_string(srcWidth) + " bits to 64 bits.");
|
||||||
auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||||
Value input = payloadArgs[0];
|
Value input = payloadArgs[0];
|
||||||
|
if (srcWidth < 64)
|
||||||
|
input = b.create<arith::ExtFOp>(loc, f64Ty, input);
|
||||||
|
if (srcWidth > 64)
|
||||||
|
input = b.create<arith::TruncFOp>(loc, f64Ty, input);
|
||||||
Value result = payloadArgs[1];
|
Value result = payloadArgs[1];
|
||||||
Value nextSum = b.create<arith::AddFOp>(loc, input, result);
|
Value nextSum = b.create<arith::AddFOp>(loc, input, result);
|
||||||
b.create<linalg::YieldOp>(loc, nextSum);
|
b.create<linalg::YieldOp>(loc, nextSum);
|
||||||
|
@ -310,7 +318,7 @@ public:
|
||||||
|
|
||||||
// compute cdf in loop
|
// compute cdf in loop
|
||||||
Value initCdf = b.create<tensor::EmptyOp>(
|
Value initCdf = b.create<tensor::EmptyOp>(
|
||||||
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), elemTy);
|
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty);
|
||||||
Value cdf =
|
Value cdf =
|
||||||
b.create<scf::ForOp>(
|
b.create<scf::ForOp>(
|
||||||
loc, cstZero, numCategories, cstOne, ValueRange{initCdf},
|
loc, cstZero, numCategories, cstOne, ValueRange{initCdf},
|
||||||
|
@ -330,6 +338,11 @@ public:
|
||||||
ind = ValueRange{jIndex, iIndex};
|
ind = ValueRange{jIndex, iIndex};
|
||||||
}
|
}
|
||||||
Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind);
|
Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind);
|
||||||
|
if (srcWidth < 64)
|
||||||
|
currWeight = b.create<arith::ExtFOp>(loc, f64Ty, currWeight);
|
||||||
|
if (srcWidth > 64)
|
||||||
|
currWeight =
|
||||||
|
b.create<arith::TruncFOp>(loc, f64Ty, currWeight);
|
||||||
Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum);
|
Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum);
|
||||||
Value currCum =
|
Value currCum =
|
||||||
b.create<scf::IfOp>(
|
b.create<scf::IfOp>(
|
||||||
|
|
|
@ -2318,6 +2318,8 @@ ONNX_XFAIL_SET = {
|
||||||
"ElementwiseLog2IntModule_basic",
|
"ElementwiseLog2IntModule_basic",
|
||||||
"ElementwiseFminModule_basic",
|
"ElementwiseFminModule_basic",
|
||||||
"ElementwiseFmaxModule_basic",
|
"ElementwiseFmaxModule_basic",
|
||||||
|
"MultinomialModule2D_basic",
|
||||||
|
"MultinomialModule2D_F32",
|
||||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||||
"ReflectionPad1dModule2dInput_Right",
|
"ReflectionPad1dModule2dInput_Right",
|
||||||
"ReflectionPad1dModule2dInput_basic",
|
"ReflectionPad1dModule2dInput_basic",
|
||||||
|
@ -2346,6 +2348,8 @@ ONNX_XFAIL_SET = {
|
||||||
"MoveDimIntNegativeIndexModule_basic",
|
"MoveDimIntNegativeIndexModule_basic",
|
||||||
"ReduceL3NormKeepDimModule_basic",
|
"ReduceL3NormKeepDimModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
# incorrect shape generated by torch.onnx.export (needs an unsqueeze)
|
||||||
|
"MultinomialModule_basic",
|
||||||
# Failure - onnx_export
|
# Failure - onnx_export
|
||||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||||
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||||
|
@ -2849,8 +2853,6 @@ ONNX_XFAIL_SET = {
|
||||||
"ElementwiseUnaryIntModule_basic",
|
"ElementwiseUnaryIntModule_basic",
|
||||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||||
"MaskedFillTensorFloatValueModule_basic",
|
"MaskedFillTensorFloatValueModule_basic",
|
||||||
"MultinomialModule_basic",
|
|
||||||
"MultinomialModule2D_basic",
|
|
||||||
"NativeDropoutTrainModule_basic",
|
"NativeDropoutTrainModule_basic",
|
||||||
"NativeDropoutTrainStaticShapeModule_basic",
|
"NativeDropoutTrainStaticShapeModule_basic",
|
||||||
"ReduceAnyFloatModule_basic",
|
"ReduceAnyFloatModule_basic",
|
||||||
|
|
|
@ -377,10 +377,20 @@ def BernoulliPModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MultinomialModule(torch.nn.Module):
|
def generate_sample_distr(sizes: list[int], torchdtype, tu: TestUtils):
|
||||||
def __init__(self):
|
assert len(sizes) == 1 or len(sizes) == 2
|
||||||
super().__init__()
|
init = tu.rand(*sizes).to(dtype=torchdtype).abs()
|
||||||
|
normalized = init / (init.sum(-1, True, dtype=torchdtype))
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
class MultinomialBase(torch.nn.Module):
|
||||||
|
def _forward(self, x):
|
||||||
|
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
class MultinomialModule(MultinomialBase):
|
||||||
@export
|
@export
|
||||||
@annotate_args(
|
@annotate_args(
|
||||||
[
|
[
|
||||||
|
@ -389,20 +399,36 @@ class MultinomialModule(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
|
return self._forward(x).mean(dtype=torch.double)
|
||||||
return a.mean(dtype=torch.double)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: MultinomialModule())
|
@register_test_case(module_factory=lambda: MultinomialModule())
|
||||||
def MultinomialModule_basic(module, tu: TestUtils):
|
def MultinomialModule_basic(module, tu: TestUtils):
|
||||||
x = tu.rand(100).double()
|
x = generate_sample_distr([100], torch.float64, tu)
|
||||||
module.forward(x)
|
module.forward(x)
|
||||||
|
|
||||||
|
|
||||||
class MultinomialModule2D(torch.nn.Module):
|
class MultinomialModule2DF32(MultinomialBase):
|
||||||
def __init__(self):
|
@export
|
||||||
super().__init__()
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
# note: this should really call mean(-1)
|
||||||
|
# for some reason, doing this causes a torchscript numerics error?
|
||||||
|
return self._forward(x).mean(dtype=torch.double)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MultinomialModule2DF32())
|
||||||
|
def MultinomialModule2D_F32(module, tu: TestUtils):
|
||||||
|
x = generate_sample_distr([10, 100], torch.float32, tu)
|
||||||
|
module.forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MultinomialModule2D(MultinomialBase):
|
||||||
@export
|
@export
|
||||||
@annotate_args(
|
@annotate_args(
|
||||||
[
|
[
|
||||||
|
@ -411,13 +437,14 @@ class MultinomialModule2D(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
|
# note: this should really call mean(-1)
|
||||||
return a.mean(dtype=torch.double)
|
# for some reason, doing this causes a torchscript numerics error?
|
||||||
|
return self._forward(x).mean(dtype=torch.double)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: MultinomialModule2D())
|
@register_test_case(module_factory=lambda: MultinomialModule2D())
|
||||||
def MultinomialModule2D_basic(module, tu: TestUtils):
|
def MultinomialModule2D_basic(module, tu: TestUtils):
|
||||||
x = tu.rand(10, 100).double()
|
x = generate_sample_distr([10, 100], torch.float64, tu)
|
||||||
module.forward(x)
|
module.forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue