mirror of https://github.com/llvm/torch-mlir
[E2E][ONNX] torch.multinomial (#3404)
This PR adds a conversion in the TorchOnnxToTorch pass for the ONNX Multinomial operation. It also adds a TorchToLinalg lowering for the `aten.Multinomial` op and does a light refactor of some repeated code that generates random floating point numbers in `TorchToLinalg/Random.cpp`.pull/3550/head
parent
0791a8860c
commit
574143448b
|
@ -591,6 +591,72 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, resultType, lhs, rhs);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Multinomial", 7,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value self;
|
||||
int64_t onnxDtype, sampleSize;
|
||||
|
||||
if (binder.tensorOperand(self) ||
|
||||
binder.s64IntegerAttr(onnxDtype, "dtype", 6) ||
|
||||
binder.s64IntegerAttr(sampleSize, "sample_size", 1) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (binder.op->hasAttr("torch.onnx.seed")) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented: support not present for seed attribute");
|
||||
}
|
||||
|
||||
if (sampleSize <= 0) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"unsupported: sample_size <= 0");
|
||||
}
|
||||
|
||||
std::optional<int64_t> torchDtype =
|
||||
onnxDtypeIntToTorchDtypeInt(onnxDtype);
|
||||
if (!torchDtype.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented support for the given dtype conversion");
|
||||
}
|
||||
|
||||
Value torchDtypeIntValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(torchDtype.value()));
|
||||
Value numSamples = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(sampleSize));
|
||||
|
||||
// PRG is seeded globally by default
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
// Sample with replacement by default (no onnx equivalent in arguments)
|
||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(true));
|
||||
|
||||
// Torch Multinomial always produces a LongTensor
|
||||
Torch::ValueTensorType selfType =
|
||||
cast<Torch::ValueTensorType>(self.getType());
|
||||
Type int64Dtype =
|
||||
IntegerType::get(selfType.getContext(), 64, IntegerType::Signed);
|
||||
int64_t batchSize = selfType.getSizes()[0];
|
||||
SmallVector<int64_t> outShapes({batchSize, sampleSize});
|
||||
Torch::ValueTensorType multinomialOutputType =
|
||||
Torch::ValueTensorType::get(selfType.getContext(), outShapes,
|
||||
int64Dtype);
|
||||
Value multinomialTensor = rewriter.create<Torch::AtenMultinomialOp>(
|
||||
binder.getLoc(), multinomialOutputType, self, numSamples, cstTrue,
|
||||
none);
|
||||
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
||||
binder.op, resultType, multinomialTensor, torchDtypeIntValue,
|
||||
cstFalse, cstFalse, none);
|
||||
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"NegativeLogLikelihoodLoss", 13,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "PopulatePatterns.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
@ -107,6 +108,25 @@ static Value randomUniformUInt(OpBuilder &b, Location loc, Value ctr,
|
|||
return bitwiseXOr(t, shiftRight32(add(mul(x, x), y)));
|
||||
}
|
||||
|
||||
// generate uniform random Float64
|
||||
static Value randomUniformF64(OpBuilder &b, Location loc, Value ctr, Value key,
|
||||
Value min, Value max) {
|
||||
Value randomVal = randomUniformUInt(b, loc, ctr, key);
|
||||
// scale = (max - min) * const(F64, 5.4210108E-20)
|
||||
// which is derived from rand(min,max) =
|
||||
// rand()/(RAND_MAX/(max-min)) where RAND_MAX = 2^64 - 1
|
||||
Value epsilon = b.create<arith::ConstantOp>(
|
||||
loc, b.getFloatAttr(b.getF64Type(), 5.4210108E-20));
|
||||
Value range = b.create<arith::SubFOp>(loc, max, min);
|
||||
Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
|
||||
// res = cast(F64, tempN) * scale + min
|
||||
Value updateFloat = b.create<arith::UIToFPOp>(loc, b.getF64Type(), randomVal);
|
||||
Value updateScaled = b.create<arith::MulFOp>(loc, updateFloat, scale);
|
||||
Value uniformSample = b.create<arith::AddFOp>(loc, updateScaled, min);
|
||||
|
||||
return uniformSample;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenUniformOp : public OpConversionPattern<AtenUniformOp> {
|
||||
public:
|
||||
|
@ -162,22 +182,9 @@ public:
|
|||
|
||||
Value linearIndex =
|
||||
toLinearIndex(b, loc, indicesIntValues, sizesIntValues);
|
||||
Value randomVal = randomUniformUInt(b, loc, linearIndex, key);
|
||||
|
||||
// scale = (max - min) * const(F64, 5.4210108E-20)
|
||||
// which is derived from rand(min,max) =
|
||||
// rand()/(RAND_MAX/(max-min)) where RAND_MAX = 2^64 - 1
|
||||
Value epsilon = b.create<arith::ConstantOp>(
|
||||
loc, b.getFloatAttr(min.getType(), 5.4210108E-20));
|
||||
Value range = b.create<arith::SubFOp>(loc, max, min);
|
||||
Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
|
||||
|
||||
// res = cast(F64, tempN) * scale + min
|
||||
Value updateFloat =
|
||||
b.create<arith::UIToFPOp>(loc, f64Ty, randomVal);
|
||||
Value updateScaled =
|
||||
b.create<arith::MulFOp>(loc, updateFloat, scale);
|
||||
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
|
||||
Value res =
|
||||
randomUniformF64(b, loc, linearIndex, key, min, max);
|
||||
Value truncRes = res;
|
||||
if (isa<Float16Type, Float32Type>(elemTy))
|
||||
truncRes = b.create<arith::TruncFOp>(loc, elemTy, res);
|
||||
|
@ -192,6 +199,310 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenMultinomialOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
Value self = adaptor.getSelf();
|
||||
Value numSamples = adaptor.getNumSamples();
|
||||
Value generator = adaptor.getGenerator();
|
||||
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||
Type elemTy = selfType.getElementType();
|
||||
Type f64Ty = rewriter.getF64Type();
|
||||
Type i64Ty = rewriter.getI64Type();
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
int64_t inputRank = selfType.getRank();
|
||||
bool bReplacement;
|
||||
|
||||
if (!isa<mlir::FloatType>(elemTy))
|
||||
return rewriter.notifyMatchFailure(op, "This op only support float type");
|
||||
|
||||
if (!mlir::isa<Torch::NoneType>(generator.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
||||
if (!matchPattern(op.getReplacement(), m_TorchConstantBool(&bReplacement)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unsupported: replacement must be a boolean value");
|
||||
|
||||
if (!bReplacement)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unimplemented: replacement = False");
|
||||
|
||||
if (!mlir::isa<mlir::IntegerType>(numSamples.getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unsupported: num_samples must be an integer value");
|
||||
}
|
||||
|
||||
if (!(inputRank == 1 || inputRank == 2)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "torch.multinomial accepts only rank 1 or 2 tensors as weights");
|
||||
}
|
||||
|
||||
Value cstZero = rewriter.create<arith::ConstantOp>(
|
||||
loc, i64Ty, rewriter.getI64IntegerAttr(0));
|
||||
Value cstOne = rewriter.create<arith::ConstantOp>(
|
||||
loc, i64Ty, rewriter.getI64IntegerAttr(1));
|
||||
Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value oneIndex = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value numSamplesIndex =
|
||||
rewriter.create<arith::IndexCastOp>(loc, indexTy, numSamples);
|
||||
|
||||
Value numDistributions;
|
||||
Value numCategoriesIndex;
|
||||
ValueRange resultShape;
|
||||
if (inputRank == 1) {
|
||||
numDistributions = cstOne;
|
||||
numCategoriesIndex =
|
||||
rewriter.create<tensor::DimOp>(loc, indexTy, self, zeroIndex);
|
||||
resultShape = ValueRange{numSamplesIndex};
|
||||
} else {
|
||||
Value numDistIndex =
|
||||
rewriter.create<tensor::DimOp>(loc, indexTy, self, zeroIndex);
|
||||
numCategoriesIndex =
|
||||
rewriter.create<tensor::DimOp>(loc, indexTy, self, oneIndex);
|
||||
numDistributions =
|
||||
rewriter.create<arith::IndexCastOp>(loc, i64Ty, numDistIndex);
|
||||
resultShape = ValueRange{numDistIndex, numSamplesIndex};
|
||||
}
|
||||
|
||||
Value numCategories =
|
||||
rewriter.create<arith::IndexCastOp>(loc, i64Ty, numCategoriesIndex);
|
||||
Value resultTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(resultShape), i64Ty);
|
||||
|
||||
// sum weights for normalization
|
||||
torch_to_linalg::ReductionOpInfo opInfo;
|
||||
if (inputRank == 1)
|
||||
opInfo = {false, self, {0}};
|
||||
else
|
||||
opInfo = {false, self, {1}};
|
||||
|
||||
Value initSum = rewriter.create<arith::ConstantOp>(
|
||||
loc, f64Ty, rewriter.getF64FloatAttr(0.0));
|
||||
auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
Value input = payloadArgs[0];
|
||||
Value result = payloadArgs[1];
|
||||
Value nextSum = b.create<arith::AddFOp>(loc, input, result);
|
||||
b.create<linalg::YieldOp>(loc, nextSum);
|
||||
};
|
||||
Value sumWeights = torch_to_linalg::createReductionLinalgGeneric(
|
||||
rewriter, loc, opInfo, initSum, sumBody);
|
||||
|
||||
// Get multinomial samples for each weight vector
|
||||
auto multinomialComputation = [&](OpBuilder &b, Location loc, Value j,
|
||||
ValueRange args) {
|
||||
Value jIndex = b.create<arith::IndexCastOp>(loc, indexTy, j);
|
||||
|
||||
Value sum;
|
||||
if (inputRank == 1) {
|
||||
sum = b.create<tensor::ExtractOp>(loc, sumWeights, ValueRange{});
|
||||
} else {
|
||||
sum = b.create<tensor::ExtractOp>(loc, sumWeights, ValueRange{jIndex});
|
||||
}
|
||||
|
||||
// compute cdf in loop
|
||||
Value initCdf = b.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), elemTy);
|
||||
Value cdf =
|
||||
b.create<scf::ForOp>(
|
||||
loc, cstZero, numCategories, cstOne, ValueRange{initCdf},
|
||||
[&](OpBuilder &b, Location loc, Value i, ValueRange vals) {
|
||||
Value distribution = vals[0];
|
||||
// if (i > 0)
|
||||
auto comparisonPredicate = arith::CmpIPredicateAttr::get(
|
||||
b.getContext(), arith::CmpIPredicate::sgt);
|
||||
Value condition = b.create<arith::CmpIOp>(
|
||||
loc, comparisonPredicate, i, cstZero);
|
||||
Value iIndex = b.create<arith::IndexCastOp>(loc, indexTy, i);
|
||||
// curr_cum = i > 0 ? prob[i] + prob[i-1] : prob[i]
|
||||
ValueRange ind;
|
||||
if (inputRank == 1) {
|
||||
ind = ValueRange{iIndex};
|
||||
} else {
|
||||
ind = ValueRange{jIndex, iIndex};
|
||||
}
|
||||
Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind);
|
||||
Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum);
|
||||
Value currCum =
|
||||
b.create<scf::IfOp>(
|
||||
loc, condition,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
Value prevI =
|
||||
b.create<arith::SubIOp>(loc, i, cstOne);
|
||||
Value prevIndex = b.create<arith::IndexCastOp>(
|
||||
loc, indexTy, prevI);
|
||||
Value prevMass = b.create<tensor::ExtractOp>(
|
||||
loc, distribution, ValueRange{prevIndex});
|
||||
Value currSum = b.create<arith::AddFOp>(
|
||||
loc, currMass, prevMass);
|
||||
b.create<scf::YieldOp>(loc, ValueRange(currSum));
|
||||
},
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(loc, ValueRange{currMass});
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
Value updatedCdf = b.create<tensor::InsertOp>(
|
||||
loc, currCum, distribution, ValueRange(iIndex));
|
||||
b.create<scf::YieldOp>(loc, ValueRange(updatedCdf));
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
/*
|
||||
* Above we've computed the CDF for the unnormalized distribution given to
|
||||
* us by the user. In order to actually sample from this distribution we
|
||||
* do the following below: 1) Sample a random floating point value, r in
|
||||
* [0,1), from a uniform distribution. 2) Perform a binary search in the
|
||||
* cdf to find the first bin in the CDF where cdf[i] < r. This guarantees
|
||||
* a random sample from the provided distribution with the appropriate
|
||||
* probabilities.
|
||||
*
|
||||
* This logic is pulled straight from PyTorch's Multinomial Kernel:
|
||||
* https://github.com/pytorch/pytorch/blob/e4623de4cf6097ff399aa9eb0cef44b44ca76da4/aten/src/ATen/native/cpu/MultinomialKernel.cpp#L23
|
||||
* */
|
||||
|
||||
// Get key, min and max used by RNG.
|
||||
Value key = b.create<TorchConversion::GetNextSeedOp>(loc);
|
||||
Value min = b.create<arith::ConstantOp>(loc, f64Ty,
|
||||
rewriter.getF64FloatAttr(0.0));
|
||||
Value max = b.create<arith::ConstantOp>(loc, f64Ty,
|
||||
rewriter.getF64FloatAttr(1.0));
|
||||
|
||||
// iterate and sample class indices
|
||||
Value result = args[0];
|
||||
Value finalResult =
|
||||
rewriter
|
||||
.create<scf::ForOp>(
|
||||
loc, cstZero, numSamples, cstOne, ValueRange{result},
|
||||
[&](OpBuilder &b, Location loc, Value i, ValueRange args) {
|
||||
// Sample random float
|
||||
Value uniformSample =
|
||||
randomUniformF64(b, loc, i, key, min, max);
|
||||
|
||||
// binary search in cdf to find our sample
|
||||
Value left = b.create<arith::ConstantOp>(
|
||||
loc, i64Ty, b.getI64IntegerAttr(0));
|
||||
Value right = numCategories;
|
||||
|
||||
auto checkCondition = [&](OpBuilder &b, Location loc,
|
||||
ValueRange vals) {
|
||||
Value left = vals[0];
|
||||
Value right = vals[1];
|
||||
|
||||
// while (right > left)
|
||||
auto comparisonPredicate = arith::CmpIPredicateAttr::get(
|
||||
b.getContext(), arith::CmpIPredicate::sgt);
|
||||
Value loopCondition = b.create<arith::CmpIOp>(
|
||||
loc, comparisonPredicate, right, left);
|
||||
b.create<scf::ConditionOp>(loc, loopCondition, vals);
|
||||
};
|
||||
|
||||
ValueRange whileResults =
|
||||
b.create<scf::WhileOp>(
|
||||
loc, TypeRange{i64Ty, i64Ty},
|
||||
ValueRange{left, right}, checkCondition,
|
||||
[&](OpBuilder &b, Location loc, ValueRange vals) {
|
||||
Value left = vals[0];
|
||||
Value right = vals[1];
|
||||
|
||||
Value two = b.create<arith::ConstantOp>(
|
||||
loc, i64Ty, b.getI64IntegerAttr(2));
|
||||
Value diff =
|
||||
b.create<arith::SubIOp>(loc, right, left);
|
||||
Value diffMid =
|
||||
b.create<arith::DivSIOp>(loc, diff, two);
|
||||
Value midPointer =
|
||||
b.create<arith::AddIOp>(loc, left, diffMid);
|
||||
Type indexTy = b.getIndexType();
|
||||
Value midIndex = b.create<arith::IndexCastOp>(
|
||||
loc, indexTy, midPointer);
|
||||
|
||||
// branch and update search indices
|
||||
auto thenBlock = [&](OpBuilder &b,
|
||||
Location loc) {
|
||||
// left = mid + 1
|
||||
Value newLeft = b.create<arith::AddIOp>(
|
||||
loc, midPointer, cstOne);
|
||||
|
||||
b.create<scf::YieldOp>(
|
||||
loc, ValueRange{newLeft, right});
|
||||
};
|
||||
auto elseBlock = [&](OpBuilder &b,
|
||||
Location loc) {
|
||||
// right = mid
|
||||
b.create<scf::YieldOp>(
|
||||
loc, ValueRange{left, midPointer});
|
||||
};
|
||||
|
||||
Value cumProb = b.create<tensor::ExtractOp>(
|
||||
loc, cdf, ValueRange{midIndex});
|
||||
auto cmpPredicate =
|
||||
arith::CmpFPredicateAttr::get(
|
||||
b.getContext(),
|
||||
arith::CmpFPredicate::OLT);
|
||||
Value branchCondition = b.create<arith::CmpFOp>(
|
||||
loc, cmpPredicate, cumProb, uniformSample);
|
||||
ValueRange branchResults =
|
||||
b.create<scf::IfOp>(loc, branchCondition,
|
||||
thenBlock, elseBlock)
|
||||
.getResults();
|
||||
Value newLeft = branchResults[0];
|
||||
Value newRight = branchResults[1];
|
||||
|
||||
b.create<scf::YieldOp>(
|
||||
loc, ValueRange{newLeft, newRight});
|
||||
})
|
||||
.getResults();
|
||||
|
||||
// sample_idx = left_pointer
|
||||
Value samplePointer = whileResults[0];
|
||||
Value iIndex =
|
||||
b.create<arith::IndexCastOp>(loc, indexTy, i);
|
||||
|
||||
Value prevResult = args[0];
|
||||
Value newResult;
|
||||
if (inputRank == 1) {
|
||||
// result[i] = sample_idx
|
||||
newResult = b.create<tensor::InsertOp>(
|
||||
loc, samplePointer, prevResult, ValueRange{iIndex});
|
||||
} else {
|
||||
// result[j][i] = sample_idx
|
||||
newResult = b.create<tensor::InsertOp>(
|
||||
loc, samplePointer, prevResult,
|
||||
ValueRange{jIndex, iIndex});
|
||||
}
|
||||
|
||||
b.create<scf::YieldOp>(loc, ValueRange{newResult});
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
b.create<scf::YieldOp>(loc, ValueRange{finalResult});
|
||||
};
|
||||
|
||||
Value finalResultTensor =
|
||||
rewriter
|
||||
.create<scf::ForOp>(loc, cstZero, numDistributions, cstOne,
|
||||
ValueRange{resultTensor},
|
||||
multinomialComputation)
|
||||
.getResult(0);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
|
||||
finalResultTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
|
@ -200,4 +511,6 @@ void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality(
|
|||
patterns.add<ConvertAtenDropoutOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUniformOp>();
|
||||
patterns.add<ConvertAtenUniformOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMultinomialOp>();
|
||||
patterns.add<ConvertAtenMultinomialOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -8874,6 +8874,40 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.bernoulli\"(%arg0: !torch.list<int>, %arg1: !torch.any) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.multinomial\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %7 = torch.aten.eq.int %6, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %7 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
|
||||
" %6 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %6 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %7 = torch.prim.ListConstruct %6, %arg1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %7 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %5 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -11001,6 +11035,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.multinomial\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" return %int4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_not\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -2725,6 +2725,8 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseUnaryIntModule_basic",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
"MultinomialModule_basic",
|
||||
"MultinomialModule2D_basic",
|
||||
"NativeDropoutTrainModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"ReduceAnyFloatModule_basic",
|
||||
|
|
|
@ -1356,6 +1356,17 @@ def aten〇_index_put_impl〡shape(self: List[int], indices: List[Optional[List[
|
|||
def aten〇bernoulli〡shape(self: List[int], generator: Any = None) -> List[int]:
|
||||
return self
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(5), num_samples=3), # Vector
|
||||
Invocation(TensorOfShape(4, 5), num_samples=3), # Matrix
|
||||
])
|
||||
def aten〇multinomial〡shape(self: List[int], num_samples: int, replacement: bool = False, generator: Any = None) -> List[int]:
|
||||
assert len(self) == 1 or len(self) == 2
|
||||
if len(self) == 1:
|
||||
return [num_samples]
|
||||
num_rows = self[0]
|
||||
return [num_rows, num_samples]
|
||||
|
||||
def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
||||
return self
|
||||
|
||||
|
@ -2574,6 +2585,10 @@ def aten〇bernoulli〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], p_rank_d
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function([Invocation(TensorOfShape(5, dtype=dtype), 3) for dtype in _SORTED_TORCH_TYPES])
|
||||
def aten〇multinomial〡dtype(self_rank_dtype: Tuple[int, int], num_samples: int, replacement: bool = False, generator: Any = None) -> int:
|
||||
return torch.int64
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇bitwise_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -377,6 +377,53 @@ def BernoulliPModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class MultinomialModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
|
||||
return a.mean(dtype=torch.double)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MultinomialModule())
|
||||
def MultinomialModule_basic(module, tu: TestUtils):
|
||||
x = tu.rand(100).double()
|
||||
module.forward(x)
|
||||
|
||||
|
||||
class MultinomialModule2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
|
||||
return a.mean(dtype=torch.double)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MultinomialModule2D())
|
||||
def MultinomialModule2D_basic(module, tu: TestUtils):
|
||||
x = tu.rand(10, 100).double()
|
||||
module.forward(x)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandLikeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -562,6 +562,36 @@ func.func @test_matmulinteger_batched(%arg0: !torch.vtensor<[7,4,3],ui8>, %arg1:
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_multinomial_default
|
||||
func.func @test_multinomial_default(%arg0: !torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3, 1],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_5:.*]] = torch.aten.multinomial %arg0, %[[VAL_2]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[3,5],f64>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,1],si64>
|
||||
// CHECK: %[[VAL_6:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_7:.*]] = torch.aten.to.dtype %[[VAL_5]], %[[VAL_1]], %[[VAL_6]], %[[VAL_6]], %[[VAL_3]] : !torch.vtensor<[3,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,1],si32>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,1],si32>
|
||||
%0 = torch.operator "onnx.Multinomial"(%arg0) : (!torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3,1],si32>
|
||||
return %0 : !torch.vtensor<[3,1],si32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_multinomial_dtype_double_samplenum_4
|
||||
func.func @test_multinomial_dtype_double_samplenum_4(%arg0: !torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3, 4],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[VAL_1:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_5:.*]] = torch.aten.multinomial %arg0, %[[VAL_2]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[3,5],f64>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,4],si64>
|
||||
// CHECK: %[[VAL_6:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_7:.*]] = torch.aten.to.dtype %[[VAL_5]], %[[VAL_1]], %[[VAL_6]], %[[VAL_6]], %[[VAL_3]] : !torch.vtensor<[3,4],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f64>
|
||||
%0 = torch.operator "onnx.Multinomial"(%arg0) {torch.onnx.dtype = 11 : si64, torch.onnx.sample_size = 4 : si64} : (!torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3,4],f64>
|
||||
return %0 : !torch.vtensor<[3,4],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxpool_2d_default
|
||||
func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
|
||||
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||
|
|
Loading…
Reference in New Issue