[E2E][ONNX] torch.multinomial (#3404)

This PR adds a conversion in the TorchOnnxToTorch pass for the ONNX
Multinomial operation. It also adds a TorchToLinalg lowering for the
`aten.Multinomial` op and does a light refactor of some repeated code
that generates random floating point numbers in
`TorchToLinalg/Random.cpp`.
pull/3550/head
Arham Khan 2024-07-16 12:39:39 -05:00 committed by GitHub
parent 0791a8860c
commit 574143448b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 526 additions and 15 deletions

View File

@ -591,6 +591,72 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"Multinomial", 7,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value self;
int64_t onnxDtype, sampleSize;
if (binder.tensorOperand(self) ||
binder.s64IntegerAttr(onnxDtype, "dtype", 6) ||
binder.s64IntegerAttr(sampleSize, "sample_size", 1) ||
binder.tensorResultType(resultType)) {
return failure();
}
if (binder.op->hasAttr("torch.onnx.seed")) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for seed attribute");
}
if (sampleSize <= 0) {
return rewriter.notifyMatchFailure(binder.op,
"unsupported: sample_size <= 0");
}
std::optional<int64_t> torchDtype =
onnxDtypeIntToTorchDtypeInt(onnxDtype);
if (!torchDtype.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value torchDtypeIntValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(torchDtype.value()));
Value numSamples = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(sampleSize));
// PRG is seeded globally by default
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
// Sample with replacement by default (no onnx equivalent in arguments)
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(true));
// Torch Multinomial always produces a LongTensor
Torch::ValueTensorType selfType =
cast<Torch::ValueTensorType>(self.getType());
Type int64Dtype =
IntegerType::get(selfType.getContext(), 64, IntegerType::Signed);
int64_t batchSize = selfType.getSizes()[0];
SmallVector<int64_t> outShapes({batchSize, sampleSize});
Torch::ValueTensorType multinomialOutputType =
Torch::ValueTensorType::get(selfType.getContext(), outShapes,
int64Dtype);
Value multinomialTensor = rewriter.create<Torch::AtenMultinomialOp>(
binder.getLoc(), multinomialOutputType, self, numSamples, cstTrue,
none);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(false));
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, multinomialTensor, torchDtypeIntValue,
cstFalse, cstFalse, none);
return success();
});
patterns.onOp(
"NegativeLogLikelihoodLoss", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -12,6 +12,7 @@
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
@ -107,6 +108,25 @@ static Value randomUniformUInt(OpBuilder &b, Location loc, Value ctr,
return bitwiseXOr(t, shiftRight32(add(mul(x, x), y)));
}
// generate uniform random Float64
static Value randomUniformF64(OpBuilder &b, Location loc, Value ctr, Value key,
Value min, Value max) {
Value randomVal = randomUniformUInt(b, loc, ctr, key);
// scale = (max - min) * const(F64, 5.4210108E-20)
// which is derived from rand(min,max) =
// rand()/(RAND_MAX/(max-min)) where RAND_MAX = 2^64 - 1
Value epsilon = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(b.getF64Type(), 5.4210108E-20));
Value range = b.create<arith::SubFOp>(loc, max, min);
Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
// res = cast(F64, tempN) * scale + min
Value updateFloat = b.create<arith::UIToFPOp>(loc, b.getF64Type(), randomVal);
Value updateScaled = b.create<arith::MulFOp>(loc, updateFloat, scale);
Value uniformSample = b.create<arith::AddFOp>(loc, updateScaled, min);
return uniformSample;
}
namespace {
class ConvertAtenUniformOp : public OpConversionPattern<AtenUniformOp> {
public:
@ -162,22 +182,9 @@ public:
Value linearIndex =
toLinearIndex(b, loc, indicesIntValues, sizesIntValues);
Value randomVal = randomUniformUInt(b, loc, linearIndex, key);
// scale = (max - min) * const(F64, 5.4210108E-20)
// which is derived from rand(min,max) =
// rand()/(RAND_MAX/(max-min)) where RAND_MAX = 2^64 - 1
Value epsilon = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(min.getType(), 5.4210108E-20));
Value range = b.create<arith::SubFOp>(loc, max, min);
Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
// res = cast(F64, tempN) * scale + min
Value updateFloat =
b.create<arith::UIToFPOp>(loc, f64Ty, randomVal);
Value updateScaled =
b.create<arith::MulFOp>(loc, updateFloat, scale);
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
Value res =
randomUniformF64(b, loc, linearIndex, key, min, max);
Value truncRes = res;
if (isa<Float16Type, Float32Type>(elemTy))
truncRes = b.create<arith::TruncFOp>(loc, elemTy, res);
@ -192,6 +199,310 @@ public:
};
} // namespace
namespace {
class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMultinomialOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value self = adaptor.getSelf();
Value numSamples = adaptor.getNumSamples();
Value generator = adaptor.getGenerator();
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
Type elemTy = selfType.getElementType();
Type f64Ty = rewriter.getF64Type();
Type i64Ty = rewriter.getI64Type();
Type indexTy = rewriter.getIndexType();
int64_t inputRank = selfType.getRank();
bool bReplacement;
if (!isa<mlir::FloatType>(elemTy))
return rewriter.notifyMatchFailure(op, "This op only support float type");
if (!mlir::isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
if (!matchPattern(op.getReplacement(), m_TorchConstantBool(&bReplacement)))
return rewriter.notifyMatchFailure(
op, "Unsupported: replacement must be a boolean value");
if (!bReplacement)
return rewriter.notifyMatchFailure(op,
"Unimplemented: replacement = False");
if (!mlir::isa<mlir::IntegerType>(numSamples.getType())) {
return rewriter.notifyMatchFailure(
op, "Unsupported: num_samples must be an integer value");
}
if (!(inputRank == 1 || inputRank == 2)) {
return rewriter.notifyMatchFailure(
op, "torch.multinomial accepts only rank 1 or 2 tensors as weights");
}
Value cstZero = rewriter.create<arith::ConstantOp>(
loc, i64Ty, rewriter.getI64IntegerAttr(0));
Value cstOne = rewriter.create<arith::ConstantOp>(
loc, i64Ty, rewriter.getI64IntegerAttr(1));
Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value oneIndex = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value numSamplesIndex =
rewriter.create<arith::IndexCastOp>(loc, indexTy, numSamples);
Value numDistributions;
Value numCategoriesIndex;
ValueRange resultShape;
if (inputRank == 1) {
numDistributions = cstOne;
numCategoriesIndex =
rewriter.create<tensor::DimOp>(loc, indexTy, self, zeroIndex);
resultShape = ValueRange{numSamplesIndex};
} else {
Value numDistIndex =
rewriter.create<tensor::DimOp>(loc, indexTy, self, zeroIndex);
numCategoriesIndex =
rewriter.create<tensor::DimOp>(loc, indexTy, self, oneIndex);
numDistributions =
rewriter.create<arith::IndexCastOp>(loc, i64Ty, numDistIndex);
resultShape = ValueRange{numDistIndex, numSamplesIndex};
}
Value numCategories =
rewriter.create<arith::IndexCastOp>(loc, i64Ty, numCategoriesIndex);
Value resultTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(resultShape), i64Ty);
// sum weights for normalization
torch_to_linalg::ReductionOpInfo opInfo;
if (inputRank == 1)
opInfo = {false, self, {0}};
else
opInfo = {false, self, {1}};
Value initSum = rewriter.create<arith::ConstantOp>(
loc, f64Ty, rewriter.getF64FloatAttr(0.0));
auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value input = payloadArgs[0];
Value result = payloadArgs[1];
Value nextSum = b.create<arith::AddFOp>(loc, input, result);
b.create<linalg::YieldOp>(loc, nextSum);
};
Value sumWeights = torch_to_linalg::createReductionLinalgGeneric(
rewriter, loc, opInfo, initSum, sumBody);
// Get multinomial samples for each weight vector
auto multinomialComputation = [&](OpBuilder &b, Location loc, Value j,
ValueRange args) {
Value jIndex = b.create<arith::IndexCastOp>(loc, indexTy, j);
Value sum;
if (inputRank == 1) {
sum = b.create<tensor::ExtractOp>(loc, sumWeights, ValueRange{});
} else {
sum = b.create<tensor::ExtractOp>(loc, sumWeights, ValueRange{jIndex});
}
// compute cdf in loop
Value initCdf = b.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), elemTy);
Value cdf =
b.create<scf::ForOp>(
loc, cstZero, numCategories, cstOne, ValueRange{initCdf},
[&](OpBuilder &b, Location loc, Value i, ValueRange vals) {
Value distribution = vals[0];
// if (i > 0)
auto comparisonPredicate = arith::CmpIPredicateAttr::get(
b.getContext(), arith::CmpIPredicate::sgt);
Value condition = b.create<arith::CmpIOp>(
loc, comparisonPredicate, i, cstZero);
Value iIndex = b.create<arith::IndexCastOp>(loc, indexTy, i);
// curr_cum = i > 0 ? prob[i] + prob[i-1] : prob[i]
ValueRange ind;
if (inputRank == 1) {
ind = ValueRange{iIndex};
} else {
ind = ValueRange{jIndex, iIndex};
}
Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind);
Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum);
Value currCum =
b.create<scf::IfOp>(
loc, condition,
[&](OpBuilder &b, Location loc) {
Value prevI =
b.create<arith::SubIOp>(loc, i, cstOne);
Value prevIndex = b.create<arith::IndexCastOp>(
loc, indexTy, prevI);
Value prevMass = b.create<tensor::ExtractOp>(
loc, distribution, ValueRange{prevIndex});
Value currSum = b.create<arith::AddFOp>(
loc, currMass, prevMass);
b.create<scf::YieldOp>(loc, ValueRange(currSum));
},
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, ValueRange{currMass});
})
.getResult(0);
Value updatedCdf = b.create<tensor::InsertOp>(
loc, currCum, distribution, ValueRange(iIndex));
b.create<scf::YieldOp>(loc, ValueRange(updatedCdf));
})
.getResult(0);
/*
* Above we've computed the CDF for the unnormalized distribution given to
* us by the user. In order to actually sample from this distribution we
* do the following below: 1) Sample a random floating point value, r in
* [0,1), from a uniform distribution. 2) Perform a binary search in the
* cdf to find the first bin in the CDF where cdf[i] < r. This guarantees
* a random sample from the provided distribution with the appropriate
* probabilities.
*
* This logic is pulled straight from PyTorch's Multinomial Kernel:
* https://github.com/pytorch/pytorch/blob/e4623de4cf6097ff399aa9eb0cef44b44ca76da4/aten/src/ATen/native/cpu/MultinomialKernel.cpp#L23
* */
// Get key, min and max used by RNG.
Value key = b.create<TorchConversion::GetNextSeedOp>(loc);
Value min = b.create<arith::ConstantOp>(loc, f64Ty,
rewriter.getF64FloatAttr(0.0));
Value max = b.create<arith::ConstantOp>(loc, f64Ty,
rewriter.getF64FloatAttr(1.0));
// iterate and sample class indices
Value result = args[0];
Value finalResult =
rewriter
.create<scf::ForOp>(
loc, cstZero, numSamples, cstOne, ValueRange{result},
[&](OpBuilder &b, Location loc, Value i, ValueRange args) {
// Sample random float
Value uniformSample =
randomUniformF64(b, loc, i, key, min, max);
// binary search in cdf to find our sample
Value left = b.create<arith::ConstantOp>(
loc, i64Ty, b.getI64IntegerAttr(0));
Value right = numCategories;
auto checkCondition = [&](OpBuilder &b, Location loc,
ValueRange vals) {
Value left = vals[0];
Value right = vals[1];
// while (right > left)
auto comparisonPredicate = arith::CmpIPredicateAttr::get(
b.getContext(), arith::CmpIPredicate::sgt);
Value loopCondition = b.create<arith::CmpIOp>(
loc, comparisonPredicate, right, left);
b.create<scf::ConditionOp>(loc, loopCondition, vals);
};
ValueRange whileResults =
b.create<scf::WhileOp>(
loc, TypeRange{i64Ty, i64Ty},
ValueRange{left, right}, checkCondition,
[&](OpBuilder &b, Location loc, ValueRange vals) {
Value left = vals[0];
Value right = vals[1];
Value two = b.create<arith::ConstantOp>(
loc, i64Ty, b.getI64IntegerAttr(2));
Value diff =
b.create<arith::SubIOp>(loc, right, left);
Value diffMid =
b.create<arith::DivSIOp>(loc, diff, two);
Value midPointer =
b.create<arith::AddIOp>(loc, left, diffMid);
Type indexTy = b.getIndexType();
Value midIndex = b.create<arith::IndexCastOp>(
loc, indexTy, midPointer);
// branch and update search indices
auto thenBlock = [&](OpBuilder &b,
Location loc) {
// left = mid + 1
Value newLeft = b.create<arith::AddIOp>(
loc, midPointer, cstOne);
b.create<scf::YieldOp>(
loc, ValueRange{newLeft, right});
};
auto elseBlock = [&](OpBuilder &b,
Location loc) {
// right = mid
b.create<scf::YieldOp>(
loc, ValueRange{left, midPointer});
};
Value cumProb = b.create<tensor::ExtractOp>(
loc, cdf, ValueRange{midIndex});
auto cmpPredicate =
arith::CmpFPredicateAttr::get(
b.getContext(),
arith::CmpFPredicate::OLT);
Value branchCondition = b.create<arith::CmpFOp>(
loc, cmpPredicate, cumProb, uniformSample);
ValueRange branchResults =
b.create<scf::IfOp>(loc, branchCondition,
thenBlock, elseBlock)
.getResults();
Value newLeft = branchResults[0];
Value newRight = branchResults[1];
b.create<scf::YieldOp>(
loc, ValueRange{newLeft, newRight});
})
.getResults();
// sample_idx = left_pointer
Value samplePointer = whileResults[0];
Value iIndex =
b.create<arith::IndexCastOp>(loc, indexTy, i);
Value prevResult = args[0];
Value newResult;
if (inputRank == 1) {
// result[i] = sample_idx
newResult = b.create<tensor::InsertOp>(
loc, samplePointer, prevResult, ValueRange{iIndex});
} else {
// result[j][i] = sample_idx
newResult = b.create<tensor::InsertOp>(
loc, samplePointer, prevResult,
ValueRange{jIndex, iIndex});
}
b.create<scf::YieldOp>(loc, ValueRange{newResult});
})
.getResult(0);
b.create<scf::YieldOp>(loc, ValueRange{finalResult});
};
Value finalResultTensor =
rewriter
.create<scf::ForOp>(loc, cstZero, numDistributions, cstOne,
ValueRange{resultTensor},
multinomialComputation)
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
finalResultTensor);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
@ -200,4 +511,6 @@ void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality(
patterns.add<ConvertAtenDropoutOp>(typeConverter, context);
target.addIllegalOp<AtenUniformOp>();
patterns.add<ConvertAtenUniformOp>(typeConverter, context);
target.addIllegalOp<AtenMultinomialOp>();
patterns.add<ConvertAtenMultinomialOp>(typeConverter, context);
}

View File

@ -8874,6 +8874,40 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.bernoulli\"(%arg0: !torch.list<int>, %arg1: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.multinomial\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %true = torch.constant.bool true\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %7 = torch.aten.eq.int %6, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
" %6 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %6 : !torch.list<int>\n"
" } else {\n"
" %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.prim.ListConstruct %6, %arg1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %7 : !torch.list<int>\n"
" }\n"
" return %5 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
@ -11001,6 +11035,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.multinomial\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" return %int4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_not\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -2725,6 +2725,8 @@ ONNX_XFAIL_SET = {
"ElementwiseUnaryIntModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"MultinomialModule_basic",
"MultinomialModule2D_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceAnyFloatModule_basic",

View File

@ -1356,6 +1356,17 @@ def aten_index_put_impl〡shape(self: List[int], indices: List[Optional[List[
def atenbernoulli〡shape(self: List[int], generator: Any = None) -> List[int]:
return self
@check_shape_function([
Invocation(TensorOfShape(5), num_samples=3), # Vector
Invocation(TensorOfShape(4, 5), num_samples=3), # Matrix
])
def atenmultinomial〡shape(self: List[int], num_samples: int, replacement: bool = False, generator: Any = None) -> List[int]:
assert len(self) == 1 or len(self) == 2
if len(self) == 1:
return [num_samples]
num_rows = self[0]
return [num_rows, num_samples]
def atencumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
return self
@ -2574,6 +2585,10 @@ def atenbernoulliTensor〡dtype(self_rank_dtype: Tuple[int, int], p_rank_d
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function([Invocation(TensorOfShape(5, dtype=dtype), 3) for dtype in _SORTED_TORCH_TYPES])
def atenmultinomial〡dtype(self_rank_dtype: Tuple[int, int], num_samples: int, replacement: bool = False, generator: Any = None) -> int:
return torch.int64
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenbitwise_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -377,6 +377,53 @@ def BernoulliPModule_basic(module, tu: TestUtils):
# ==============================================================================
class MultinomialModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.float64, True),
]
)
def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a.mean(dtype=torch.double)
@register_test_case(module_factory=lambda: MultinomialModule())
def MultinomialModule_basic(module, tu: TestUtils):
x = tu.rand(100).double()
module.forward(x)
class MultinomialModule2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float64, True),
]
)
def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a.mean(dtype=torch.double)
@register_test_case(module_factory=lambda: MultinomialModule2D())
def MultinomialModule2D_basic(module, tu: TestUtils):
x = tu.rand(10, 100).double()
module.forward(x)
# ==============================================================================
class RandLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -562,6 +562,36 @@ func.func @test_matmulinteger_batched(%arg0: !torch.vtensor<[7,4,3],ui8>, %arg1:
// -----
// CHECK-LABEL: func.func @test_multinomial_default
func.func @test_multinomial_default(%arg0: !torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3, 1],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = torch.constant.bool true
// CHECK: %[[VAL_5:.*]] = torch.aten.multinomial %arg0, %[[VAL_2]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[3,5],f64>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,1],si64>
// CHECK: %[[VAL_6:.*]] = torch.constant.bool false
// CHECK: %[[VAL_7:.*]] = torch.aten.to.dtype %[[VAL_5]], %[[VAL_1]], %[[VAL_6]], %[[VAL_6]], %[[VAL_3]] : !torch.vtensor<[3,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,1],si32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,1],si32>
%0 = torch.operator "onnx.Multinomial"(%arg0) : (!torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3,1],si32>
return %0 : !torch.vtensor<[3,1],si32>
}
// CHECK-LABEL: func.func @test_multinomial_dtype_double_samplenum_4
func.func @test_multinomial_dtype_double_samplenum_4(%arg0: !torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3, 4],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 7
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = torch.constant.bool true
// CHECK: %[[VAL_5:.*]] = torch.aten.multinomial %arg0, %[[VAL_2]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[3,5],f64>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,4],si64>
// CHECK: %[[VAL_6:.*]] = torch.constant.bool false
// CHECK: %[[VAL_7:.*]] = torch.aten.to.dtype %[[VAL_5]], %[[VAL_1]], %[[VAL_6]], %[[VAL_6]], %[[VAL_3]] : !torch.vtensor<[3,4],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f64>
%0 = torch.operator "onnx.Multinomial"(%arg0) {torch.onnx.dtype = 11 : si64, torch.onnx.sample_size = 4 : si64} : (!torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3,4],f64>
return %0 : !torch.vtensor<[3,4],f64>
}
// -----
// CHECK-LABEL: func.func @test_maxpool_2d_default
func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
// CHECK: %[[I2:.*]] = torch.constant.int 2