Add aten.bucketize op and its decomposition (#1834)

pull/1851/head snapshot-20230203.738
Jiahao Li 2023-02-03 10:20:47 +08:00 committed by GitHub
parent 5d55390111
commit f58ba19448
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 446 additions and 89 deletions

View File

@ -110,6 +110,8 @@ STABLEHLO_PASS_SET = {
"BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",

View File

@ -59,6 +59,15 @@ static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
b, loc, elementalType, lhs, rhs);
}
static Value createGreaterThanOrEqual(OpBuilder &b, Location loc,
Type elementalType, Value lhs,
Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UGE,
arith::CmpIPredicate::uge,
arith::CmpIPredicate::sge>(
b, loc, elementalType, lhs, rhs);
}
static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::ULT,
@ -67,6 +76,14 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
b, loc, elementalType, lhs, rhs);
}
static Value createLessThanOrEqual(OpBuilder &b, Location loc,
Type elementalType, Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::ULE,
arith::CmpIPredicate::ule,
arith::CmpIPredicate::sle>(
b, loc, elementalType, lhs, rhs);
}
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
@ -117,6 +134,46 @@ static Value createCalculationForMathOpWithDtypeConversion(
return b.create<MathOpTy>(loc, arg);
}
template <typename OpTy>
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
std::is_same<OpTy, AtenLeTensorOp>() ||
std::is_same<OpTy, AtenGtTensorOp>() ||
std::is_same<OpTy, AtenGeTensorOp>() ||
std::is_same<OpTy, AtenEqTensorOp>(),
"unimplemented: op type not supported");
Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
op.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
op.getSelf().getType().template cast<BaseTensorType>().getDtype();
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
}
static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
@ -465,64 +522,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<math::Atan2Op>(loc, lhs, rhs);
}
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
AtenGtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
gtTensor.emitError("unimplemented: different lhs and rhs dtype");
return nullptr;
}
Type elementalType =
gtTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
return createGreaterThan(b, loc, elementalType, payloadArgs[0],
payloadArgs[1]);
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
AtenEqTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
eqTensor.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
eqTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
if (elementalType.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
payloadArgs[0], payloadArgs[1]);
if (elementalType.isa<mlir::IntegerType>()) {
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
payloadArgs[0], payloadArgs[1]);
}
eqTensor.emitError("unimplemented: dtype isn't supported.");
return nullptr;
}
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
AtenLtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
ltTensor.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
ltTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
return createLessThan(b, loc, elementalType, payloadArgs[0],
payloadArgs[1]);
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands);
@ -1084,10 +1102,10 @@ public:
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
@ -1563,12 +1581,12 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp,
AtenTriuOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp,
AtenFillScalarOp, AtenFillTensorOp>();
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -5903,6 +5903,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bucketize.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.contiguous\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"

View File

@ -171,6 +171,37 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
return sub;
}
// Helper function to unsqueeze the input tensor at given dim.
// Return the unsqueezed tensor or failure.
static FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter,
Operation *op, Value input, Value dim) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(op, "input tensor must have size");
}
SmallVector<int64_t> unsqueezedShape;
ArrayRef<int64_t> inputShape = inputType.getSizes();
// `input` has a reduced rank. Hence add 1.
int64_t unsqueezedRank = inputShape.size() + 1;
int64_t dimInt = 0;
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, unsqueezedRank);
if (!isValidDim(dimInt, unsqueezedRank)) {
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
}
unsqueezedShape.append(inputShape.begin(), inputShape.end());
unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1);
} else {
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
}
Type unsqueezedType =
inputType.getWithSizesAndDtype(unsqueezedShape, inputType.getDtype());
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
op->getLoc(), unsqueezedType, input, dim);
return unsqueezed;
}
namespace {
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
/// number of dimensions across which the max needs to be computed.
@ -606,6 +637,128 @@ public:
};
} // namespace
// Decompose `aten.bucketize` into the following op sequence:
//
// def aten_bucketize(input, boundaries, out_int32, right):
// unsqz_input = input.unsqueeze(-1)
// if not right:
// comparison = unsqz_input <= boundaries
// else:
// comparison = unsqz_input < boundaries
// indices = torch.argmax(comparison.float(), dim=-1)
// within_bound = comparison[..., -1]
// result = torch.where(within_bound, indices, boundaries.shape[0])
// if out_int32:
// result = result.int()
// return result
//
namespace {
class DecomposeAtenBucketizeTensorOp
: public OpRewritePattern<AtenBucketizeTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBucketizeTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: input must have known sizes");
}
ArrayRef<int64_t> inputShape = inputType.getSizes();
Value boundaries = op.getBoundaries();
auto boundariesType = boundaries.getType().cast<BaseTensorType>();
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
return rewriter.notifyMatchFailure(op,
"unimplemented: boundaries must have "
"known sizes and must be a 1D array");
}
int64_t boundariesSize = boundariesType.getSizes()[0];
bool outInt32;
if (!matchPattern(op.getOutInt32(), m_TorchConstantBool(&outInt32))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: out_int32 must be a constant bool");
}
bool right;
if (!matchPattern(op.getRight(), m_TorchConstantBool(&right))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: right must be a constant bool");
}
// unsqueeze input at the last dim to make it broadcastable with boundaries
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
auto unsqzTensorInfo =
unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne);
if (failed(unsqzTensorInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
Value unsqzInput = *unsqzTensorInfo;
// compare unsqueezed input with boundaries
SmallVector<int64_t> compareShape(inputShape);
compareShape.push_back(boundariesSize);
Type compareType =
inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type());
Value compare;
if (!right) {
compare = rewriter.create<AtenLeTensorOp>(loc, compareType, unsqzInput,
boundaries);
} else {
compare = rewriter.create<AtenLtTensorOp>(loc, compareType, unsqzInput,
boundaries);
}
// convert the comparison results to float32 as the argmax op input,
// which does not support integer dtype in LINALG backend
Value compareF32 =
convertTensorToDtype(rewriter, loc, compare, rewriter.getF32Type());
// get the first boundary index where the input element is less than (or
// equal to) the boundary value
Type indicesType = inputType.getWithSizesAndDtype(
inputShape, rewriter.getIntegerType(64, IntegerType::Signed));
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value indices = rewriter.create<AtenArgmaxOp>(loc, indicesType, compareF32,
/*dim=*/constMinusOne,
/*keepdim=*/constFalse);
// get the comparison results between each input element and the rightmost
// boundary value
Type withinUpperBoundType =
inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type());
Value withinUpperBound = rewriter.create<AtenSelectIntOp>(
loc, withinUpperBoundType, compare, /*dim=*/constMinusOne,
/*index=*/constMinusOne);
// If the input element is less than (or equal to) the rightmost boundary,
// take the max index as result. Otherwise, the element is beyond the
// rightmost boundary, so take the boundary size.
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value upperBound =
rewriter.create<AtenSizeIntOp>(loc, boundaries, /*dim=*/constZero);
Value result = rewriter.create<AtenWhereScalarOtherOp>(
loc, indicesType, withinUpperBound, indices, upperBound);
if (outInt32) {
result = convertTensorToDtype(
rewriter, loc, result,
rewriter.getIntegerType(32, IntegerType::Signed));
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
// To avoid overflow we use the following decomposition rule:
// x_max = aten.max(x, dim, keepdim=True)[0]
// shifted = x - x_max
@ -3193,29 +3346,13 @@ public:
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value startPlusOne =
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
SmallVector<int64_t> sizes;
if (!srcTensorType.hasSizes())
return rewriter.notifyMatchFailure(op, "src tensor must have size");
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
// `src` has a reduced rank. Hence add 1.
int64_t srcRank = srcShape.size() + 1;
int64_t dimInt = 0;
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, srcRank);
if (!isValidDim(dimInt, srcRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
sizes.append(srcShape.begin(), srcShape.end());
sizes.insert(sizes.begin() + dimInt, 1);
} else {
sizes.resize(srcShape.size() + 1, kUnknownSize);
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim);
if (failed(unsqueezedInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor op");
}
Type srcType = srcTensorType.getWithSizesAndDtype(
llvm::ArrayRef(sizes), srcTensorType.getOptionalDtype());
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
src = *unsqueezedInfo;
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
/*step=*/one);
@ -3786,6 +3923,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;

View File

@ -440,6 +440,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRandnLikeOp>();
target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>();
for (std::string opName : backendLegalOps) {
target.addLegalOp(OperationName(opName, context));
}

View File

@ -706,8 +706,9 @@ void TypeAnalysis::visitOperation(Operation *op,
// Dtype is always i1.
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp>(op)) {
AtenGtTensorOp, AtenGeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp,
AtenLogicalNotOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op->getContext(), 1);
@ -1191,6 +1192,22 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}
if (auto bucketize = dyn_cast<AtenBucketizeTensorOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
bool outInt32;
if (matchPattern(bucketize.getOutInt32(), m_TorchConstantBool(&outInt32)) &&
outInt32) {
knowledge.dtype =
IntegerType::get(op->getContext(), 32, IntegerType::Signed);
} else {
knowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
}
incorporateKnowledge(bucketize.getResult(), knowledge);
return;
}
// Otherwise, this is an unknown operation, so reset the state.
setAllToEntryStates(results);
return;

View File

@ -49,5 +49,12 @@ std::vector<torch::lazy::Shape> compute_shape_where(
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_bucketize(
const at::Tensor& self, const at::Tensor& boundaries, bool out_int32,
bool right) {
auto dtype = out_int32 ? at::kInt : at::kLong;
return {Shape(dtype, self.sizes().vec())};
}
} // namespace lazy
} // namespace torch

View File

@ -209,6 +209,9 @@ def atendropout〡shape(input: List[int], p: float, train: bool) -> List[int]
def atengelu〡shape(self: List[int], approximate: str = "none") -> List[int]:
return upstream_shape_functions.unary(self)
def atenbucketizeTensor〡shape(self: List[int], boundaries: List[int], out_int32: bool = False, right: bool = False) -> List[int]:
return upstream_shape_functions.unary(self)
def atencontiguous〡shape(self: List[int], memory_format: int = 0) -> List[int]:
return upstream_shape_functions.unary(self)

View File

@ -3161,4 +3161,91 @@ class SortIntListReverse(torch.nn.Module):
@register_test_case(module_factory=lambda: SortIntListReverse())
def SortIntListReverse_basic(module, tu: TestUtils):
module.forward()
module.forward()
# ==============================================================================
class BucketizeTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, input, boundaries):
return torch.bucketize(input, boundaries)
@register_test_case(module_factory=lambda: BucketizeTensorModule())
def BucketizeTensorModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6]))
class BucketizeTensorOutInt32RightModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, input, boundaries):
return torch.bucketize(input, boundaries, out_int32=True, right=True)
@register_test_case(module_factory=lambda: BucketizeTensorOutInt32RightModule())
def BucketizeTensorOutInt32RightModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6]))
class BucketizeTensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, input, boundaries):
return torch.bucketize(input, boundaries)
@register_test_case(module_factory=lambda: BucketizeTensorFloatModule())
def BucketizeTensorFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(15, 17), torch.sort(tu.rand(16)).values)
class BucketizeTensorStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 4], torch.int64, True),
([3], torch.int64, True),
])
def forward(self, input, boundaries):
return torch.bucketize(input, boundaries)
@register_test_case(module_factory=lambda: BucketizeTensorStaticModule())
def BucketizeTensorStaticModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6]))
class BucketizeTensorStaticFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([15, 17], torch.float32, True),
([16], torch.float32, True),
])
def forward(self, input, boundaries):
return torch.bucketize(input, boundaries)
@register_test_case(module_factory=lambda: BucketizeTensorStaticFloatModule())
def BucketizeTensorStaticFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(15, 17), torch.sort(tu.rand(16)).values)

View File

@ -144,6 +144,46 @@ def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseGeFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.ge(x, y)
@register_test_case(module_factory=lambda: ElementwiseGeFloatTensorModule())
def ElementwiseGeFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))
# ==============================================================================
class ElementwiseGeIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.ge(x, y)
@register_test_case(module_factory=lambda: ElementwiseGeIntTensorModule())
def ElementwiseGeIntTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
# ==============================================================================
class ElementwiseGtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -318,6 +358,46 @@ def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseLeFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.le(x, y)
@register_test_case(module_factory=lambda: ElementwiseLeFloatTensorModule())
def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))
# ==============================================================================
class ElementwiseLeIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.le(x, y)
@register_test_case(module_factory=lambda: ElementwiseLeIntTensorModule())
def ElementwiseLeIntTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
# ==============================================================================
class ElementwiseLtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()