[TorchToLinalg] address a dtype mismatch in `aten.multinomial` lowering (#3630)

Resolves <https://github.com/llvm/torch-mlir/issues/3628>
Unblocks a compile failure for one of the MiGraphx models
(`AgentModel`).
pull/3657/head
zjgarvey 2024-08-20 13:14:48 -07:00 committed by GitHub
parent f72770a725
commit f66908f190
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 15 deletions

View File

@ -287,8 +287,16 @@ public:
Value initSum = rewriter.create<arith::ConstantOp>( Value initSum = rewriter.create<arith::ConstantOp>(
loc, f64Ty, rewriter.getF64FloatAttr(0.0)); loc, f64Ty, rewriter.getF64FloatAttr(0.0));
int64_t srcWidth = cast<mlir::FloatType>(elemTy).getWidth();
if (srcWidth > 64)
op->emitWarning("Op bitwidth will be truncated from " +
std::to_string(srcWidth) + " bits to 64 bits.");
auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value input = payloadArgs[0]; Value input = payloadArgs[0];
if (srcWidth < 64)
input = b.create<arith::ExtFOp>(loc, f64Ty, input);
if (srcWidth > 64)
input = b.create<arith::TruncFOp>(loc, f64Ty, input);
Value result = payloadArgs[1]; Value result = payloadArgs[1];
Value nextSum = b.create<arith::AddFOp>(loc, input, result); Value nextSum = b.create<arith::AddFOp>(loc, input, result);
b.create<linalg::YieldOp>(loc, nextSum); b.create<linalg::YieldOp>(loc, nextSum);
@ -310,7 +318,7 @@ public:
// compute cdf in loop // compute cdf in loop
Value initCdf = b.create<tensor::EmptyOp>( Value initCdf = b.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), elemTy); loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty);
Value cdf = Value cdf =
b.create<scf::ForOp>( b.create<scf::ForOp>(
loc, cstZero, numCategories, cstOne, ValueRange{initCdf}, loc, cstZero, numCategories, cstOne, ValueRange{initCdf},
@ -330,6 +338,11 @@ public:
ind = ValueRange{jIndex, iIndex}; ind = ValueRange{jIndex, iIndex};
} }
Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind); Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind);
if (srcWidth < 64)
currWeight = b.create<arith::ExtFOp>(loc, f64Ty, currWeight);
if (srcWidth > 64)
currWeight =
b.create<arith::TruncFOp>(loc, f64Ty, currWeight);
Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum); Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum);
Value currCum = Value currCum =
b.create<scf::IfOp>( b.create<scf::IfOp>(

View File

@ -2318,6 +2318,8 @@ ONNX_XFAIL_SET = {
"ElementwiseLog2IntModule_basic", "ElementwiseLog2IntModule_basic",
"ElementwiseFminModule_basic", "ElementwiseFminModule_basic",
"ElementwiseFmaxModule_basic", "ElementwiseFmaxModule_basic",
"MultinomialModule2D_basic",
"MultinomialModule2D_F32",
"PixelShuffleModuleStaticRank4Float32_basic", "PixelShuffleModuleStaticRank4Float32_basic",
"ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic", "ReflectionPad1dModule2dInput_basic",
@ -2346,6 +2348,8 @@ ONNX_XFAIL_SET = {
"MoveDimIntNegativeIndexModule_basic", "MoveDimIntNegativeIndexModule_basic",
"ReduceL3NormKeepDimModule_basic", "ReduceL3NormKeepDimModule_basic",
"ViewSizeFromOtherTensor_basic", "ViewSizeFromOtherTensor_basic",
# incorrect shape generated by torch.onnx.export (needs an unsqueeze)
"MultinomialModule_basic",
# Failure - onnx_export # Failure - onnx_export
"AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
@ -2849,8 +2853,6 @@ ONNX_XFAIL_SET = {
"ElementwiseUnaryIntModule_basic", "ElementwiseUnaryIntModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic",
"MaskedFillTensorFloatValueModule_basic", "MaskedFillTensorFloatValueModule_basic",
"MultinomialModule_basic",
"MultinomialModule2D_basic",
"NativeDropoutTrainModule_basic", "NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic", "NativeDropoutTrainStaticShapeModule_basic",
"ReduceAnyFloatModule_basic", "ReduceAnyFloatModule_basic",

View File

@ -377,10 +377,20 @@ def BernoulliPModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class MultinomialModule(torch.nn.Module): def generate_sample_distr(sizes: list[int], torchdtype, tu: TestUtils):
def __init__(self): assert len(sizes) == 1 or len(sizes) == 2
super().__init__() init = tu.rand(*sizes).to(dtype=torchdtype).abs()
normalized = init / (init.sum(-1, True, dtype=torchdtype))
return normalized
class MultinomialBase(torch.nn.Module):
def _forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a
class MultinomialModule(MultinomialBase):
@export @export
@annotate_args( @annotate_args(
[ [
@ -389,20 +399,36 @@ class MultinomialModule(torch.nn.Module):
] ]
) )
def forward(self, x): def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True) return self._forward(x).mean(dtype=torch.double)
return a.mean(dtype=torch.double)
@register_test_case(module_factory=lambda: MultinomialModule()) @register_test_case(module_factory=lambda: MultinomialModule())
def MultinomialModule_basic(module, tu: TestUtils): def MultinomialModule_basic(module, tu: TestUtils):
x = tu.rand(100).double() x = generate_sample_distr([100], torch.float64, tu)
module.forward(x) module.forward(x)
class MultinomialModule2D(torch.nn.Module): class MultinomialModule2DF32(MultinomialBase):
def __init__(self): @export
super().__init__() @annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
# note: this should really call mean(-1)
# for some reason, doing this causes a torchscript numerics error?
return self._forward(x).mean(dtype=torch.double)
@register_test_case(module_factory=lambda: MultinomialModule2DF32())
def MultinomialModule2D_F32(module, tu: TestUtils):
x = generate_sample_distr([10, 100], torch.float32, tu)
module.forward(x)
class MultinomialModule2D(MultinomialBase):
@export @export
@annotate_args( @annotate_args(
[ [
@ -411,13 +437,14 @@ class MultinomialModule2D(torch.nn.Module):
] ]
) )
def forward(self, x): def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True) # note: this should really call mean(-1)
return a.mean(dtype=torch.double) # for some reason, doing this causes a torchscript numerics error?
return self._forward(x).mean(dtype=torch.double)
@register_test_case(module_factory=lambda: MultinomialModule2D()) @register_test_case(module_factory=lambda: MultinomialModule2D())
def MultinomialModule2D_basic(module, tu: TestUtils): def MultinomialModule2D_basic(module, tu: TestUtils):
x = tu.rand(10, 100).double() x = generate_sample_distr([10, 100], torch.float64, tu)
module.forward(x) module.forward(x)