mirror of https://github.com/llvm/torch-mlir
parent
5d55390111
commit
f58ba19448
|
@ -110,6 +110,8 @@ STABLEHLO_PASS_SET = {
|
|||
"BroadcastToModule_basic",
|
||||
"BroadcastToSameRankStaticModule_basic",
|
||||
"BroadcastZeroRankInputStaticModule_basic",
|
||||
"BucketizeTensorStaticFloatModule_basic",
|
||||
"BucketizeTensorStaticModule_basic",
|
||||
"CumsumStaticModule_basic",
|
||||
"CumsumStaticNegativeDimModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||
|
|
|
@ -59,6 +59,15 @@ static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
|
|||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createGreaterThanOrEqual(OpBuilder &b, Location loc,
|
||||
Type elementalType, Value lhs,
|
||||
Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::UGE,
|
||||
arith::CmpIPredicate::uge,
|
||||
arith::CmpIPredicate::sge>(
|
||||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
|
||||
Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::ULT,
|
||||
|
@ -67,6 +76,14 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
|
|||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createLessThanOrEqual(OpBuilder &b, Location loc,
|
||||
Type elementalType, Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::ULE,
|
||||
arith::CmpIPredicate::ule,
|
||||
arith::CmpIPredicate::sle>(
|
||||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
|
||||
Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
|
||||
|
@ -117,6 +134,46 @@ static Value createCalculationForMathOpWithDtypeConversion(
|
|||
return b.create<MathOpTy>(loc, arg);
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
|
||||
Value lhs, Value rhs) {
|
||||
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
|
||||
std::is_same<OpTy, AtenLeTensorOp>() ||
|
||||
std::is_same<OpTy, AtenGtTensorOp>() ||
|
||||
std::is_same<OpTy, AtenGeTensorOp>() ||
|
||||
std::is_same<OpTy, AtenEqTensorOp>(),
|
||||
"unimplemented: op type not supported");
|
||||
|
||||
Type lhsDtype = lhs.getType();
|
||||
Type rhsDtype = rhs.getType();
|
||||
|
||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
||||
// to be handled.
|
||||
if (lhsDtype != rhsDtype) {
|
||||
op.emitError("unimplemented: lhs and rhs dtype must be same");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementalType =
|
||||
op.getSelf().getType().template cast<BaseTensorType>().getDtype();
|
||||
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
|
||||
return createLessThan(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
|
||||
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
|
||||
return createGreaterThan(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
|
||||
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
|
||||
return createEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
llvm_unreachable("unimplemented: op type not supported");
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||
OpBuilder &b, Location loc, TypeConverter *converter,
|
||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||
|
@ -465,64 +522,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<math::Atan2Op>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
|
||||
AtenGtTensorOp::Adaptor adaptor(operands);
|
||||
Type lhsDtype = payloadArgs[0].getType();
|
||||
Type rhsDtype = payloadArgs[1].getType();
|
||||
|
||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
||||
// to be handled.
|
||||
if (lhsDtype != rhsDtype) {
|
||||
gtTensor.emitError("unimplemented: different lhs and rhs dtype");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementalType =
|
||||
gtTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
return createGreaterThan(b, loc, elementalType, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
|
||||
AtenEqTensorOp::Adaptor adaptor(operands);
|
||||
Type lhsDtype = payloadArgs[0].getType();
|
||||
Type rhsDtype = payloadArgs[1].getType();
|
||||
|
||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
||||
// to be handled.
|
||||
if (lhsDtype != rhsDtype) {
|
||||
eqTensor.emitError("unimplemented: lhs and rhs dtype must be same");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementalType =
|
||||
eqTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
|
||||
if (elementalType.isa<mlir::FloatType>())
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
if (elementalType.isa<mlir::IntegerType>()) {
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
eqTensor.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
}
|
||||
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
|
||||
AtenLtTensorOp::Adaptor adaptor(operands);
|
||||
Type lhsDtype = payloadArgs[0].getType();
|
||||
Type rhsDtype = payloadArgs[1].getType();
|
||||
|
||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
||||
// to be handled.
|
||||
if (lhsDtype != rhsDtype) {
|
||||
ltTensor.emitError("unimplemented: lhs and rhs dtype must be same");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementalType =
|
||||
ltTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
return createLessThan(b, loc, elementalType, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
||||
AtenDivTensorOp::Adaptor adaptor(operands);
|
||||
|
@ -1084,10 +1102,10 @@ public:
|
|||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
||||
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
|
||||
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
|
||||
|
@ -1563,12 +1581,12 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp,
|
||||
AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp,
|
||||
AtenTriuOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp,
|
||||
AtenFillScalarOp, AtenFillTensorOp>();
|
||||
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -5903,6 +5903,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.bucketize.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.contiguous\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
|
|
@ -171,6 +171,37 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
|||
return sub;
|
||||
}
|
||||
|
||||
// Helper function to unsqueeze the input tensor at given dim.
|
||||
// Return the unsqueezed tensor or failure.
|
||||
static FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Value input, Value dim) {
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op, "input tensor must have size");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> unsqueezedShape;
|
||||
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||
// `input` has a reduced rank. Hence add 1.
|
||||
int64_t unsqueezedRank = inputShape.size() + 1;
|
||||
int64_t dimInt = 0;
|
||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, unsqueezedRank);
|
||||
if (!isValidDim(dimInt, unsqueezedRank)) {
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
}
|
||||
unsqueezedShape.append(inputShape.begin(), inputShape.end());
|
||||
unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1);
|
||||
} else {
|
||||
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
|
||||
}
|
||||
Type unsqueezedType =
|
||||
inputType.getWithSizesAndDtype(unsqueezedShape, inputType.getDtype());
|
||||
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
|
||||
op->getLoc(), unsqueezedType, input, dim);
|
||||
return unsqueezed;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
||||
/// number of dimensions across which the max needs to be computed.
|
||||
|
@ -606,6 +637,128 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose `aten.bucketize` into the following op sequence:
|
||||
//
|
||||
// def aten_bucketize(input, boundaries, out_int32, right):
|
||||
// unsqz_input = input.unsqueeze(-1)
|
||||
// if not right:
|
||||
// comparison = unsqz_input <= boundaries
|
||||
// else:
|
||||
// comparison = unsqz_input < boundaries
|
||||
// indices = torch.argmax(comparison.float(), dim=-1)
|
||||
// within_bound = comparison[..., -1]
|
||||
// result = torch.where(within_bound, indices, boundaries.shape[0])
|
||||
// if out_int32:
|
||||
// result = result.int()
|
||||
// return result
|
||||
//
|
||||
namespace {
|
||||
class DecomposeAtenBucketizeTensorOp
|
||||
: public OpRewritePattern<AtenBucketizeTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenBucketizeTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Value input = op.getSelf();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: input must have known sizes");
|
||||
}
|
||||
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||
|
||||
Value boundaries = op.getBoundaries();
|
||||
auto boundariesType = boundaries.getType().cast<BaseTensorType>();
|
||||
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: boundaries must have "
|
||||
"known sizes and must be a 1D array");
|
||||
}
|
||||
int64_t boundariesSize = boundariesType.getSizes()[0];
|
||||
|
||||
bool outInt32;
|
||||
if (!matchPattern(op.getOutInt32(), m_TorchConstantBool(&outInt32))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: out_int32 must be a constant bool");
|
||||
}
|
||||
|
||||
bool right;
|
||||
if (!matchPattern(op.getRight(), m_TorchConstantBool(&right))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: right must be a constant bool");
|
||||
}
|
||||
|
||||
// unsqueeze input at the last dim to make it broadcastable with boundaries
|
||||
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(-1));
|
||||
auto unsqzTensorInfo =
|
||||
unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne);
|
||||
if (failed(unsqzTensorInfo)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate unsqueeze tensor");
|
||||
}
|
||||
Value unsqzInput = *unsqzTensorInfo;
|
||||
|
||||
// compare unsqueezed input with boundaries
|
||||
SmallVector<int64_t> compareShape(inputShape);
|
||||
compareShape.push_back(boundariesSize);
|
||||
Type compareType =
|
||||
inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type());
|
||||
Value compare;
|
||||
if (!right) {
|
||||
compare = rewriter.create<AtenLeTensorOp>(loc, compareType, unsqzInput,
|
||||
boundaries);
|
||||
} else {
|
||||
compare = rewriter.create<AtenLtTensorOp>(loc, compareType, unsqzInput,
|
||||
boundaries);
|
||||
}
|
||||
|
||||
// convert the comparison results to float32 as the argmax op input,
|
||||
// which does not support integer dtype in LINALG backend
|
||||
Value compareF32 =
|
||||
convertTensorToDtype(rewriter, loc, compare, rewriter.getF32Type());
|
||||
|
||||
// get the first boundary index where the input element is less than (or
|
||||
// equal to) the boundary value
|
||||
Type indicesType = inputType.getWithSizesAndDtype(
|
||||
inputShape, rewriter.getIntegerType(64, IntegerType::Signed));
|
||||
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value indices = rewriter.create<AtenArgmaxOp>(loc, indicesType, compareF32,
|
||||
/*dim=*/constMinusOne,
|
||||
/*keepdim=*/constFalse);
|
||||
|
||||
// get the comparison results between each input element and the rightmost
|
||||
// boundary value
|
||||
Type withinUpperBoundType =
|
||||
inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type());
|
||||
Value withinUpperBound = rewriter.create<AtenSelectIntOp>(
|
||||
loc, withinUpperBoundType, compare, /*dim=*/constMinusOne,
|
||||
/*index=*/constMinusOne);
|
||||
|
||||
// If the input element is less than (or equal to) the rightmost boundary,
|
||||
// take the max index as result. Otherwise, the element is beyond the
|
||||
// rightmost boundary, so take the boundary size.
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value upperBound =
|
||||
rewriter.create<AtenSizeIntOp>(loc, boundaries, /*dim=*/constZero);
|
||||
Value result = rewriter.create<AtenWhereScalarOtherOp>(
|
||||
loc, indicesType, withinUpperBound, indices, upperBound);
|
||||
|
||||
if (outInt32) {
|
||||
result = convertTensorToDtype(
|
||||
rewriter, loc, result,
|
||||
rewriter.getIntegerType(32, IntegerType::Signed));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// To avoid overflow we use the following decomposition rule:
|
||||
// x_max = aten.max(x, dim, keepdim=True)[0]
|
||||
// shifted = x - x_max
|
||||
|
@ -3193,29 +3346,13 @@ public:
|
|||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value startPlusOne =
|
||||
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
||||
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
|
||||
SmallVector<int64_t> sizes;
|
||||
if (!srcTensorType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "src tensor must have size");
|
||||
|
||||
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
|
||||
// `src` has a reduced rank. Hence add 1.
|
||||
int64_t srcRank = srcShape.size() + 1;
|
||||
int64_t dimInt = 0;
|
||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, srcRank);
|
||||
if (!isValidDim(dimInt, srcRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
|
||||
sizes.append(srcShape.begin(), srcShape.end());
|
||||
sizes.insert(sizes.begin() + dimInt, 1);
|
||||
|
||||
} else {
|
||||
sizes.resize(srcShape.size() + 1, kUnknownSize);
|
||||
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim);
|
||||
if (failed(unsqueezedInfo)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate unsqueeze tensor op");
|
||||
}
|
||||
Type srcType = srcTensorType.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(sizes), srcTensorType.getOptionalDtype());
|
||||
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
|
||||
src = *unsqueezedInfo;
|
||||
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
||||
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
|
||||
/*step=*/one);
|
||||
|
@ -3786,6 +3923,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
|
|
@ -440,6 +440,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenRandnLikeOp>();
|
||||
target.addIllegalOp<AtenVarMeanOp>();
|
||||
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
||||
target.addIllegalOp<AtenBucketizeTensorOp>();
|
||||
for (std::string opName : backendLegalOps) {
|
||||
target.addLegalOp(OperationName(opName, context));
|
||||
}
|
||||
|
|
|
@ -706,8 +706,9 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
// Dtype is always i1.
|
||||
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
|
||||
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp>(op)) {
|
||||
AtenGtTensorOp, AtenGeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
|
||||
AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp,
|
||||
AtenLogicalNotOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = IntegerType::get(op->getContext(), 1);
|
||||
|
@ -1191,6 +1192,22 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
return;
|
||||
}
|
||||
|
||||
if (auto bucketize = dyn_cast<AtenBucketizeTensorOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
bool outInt32;
|
||||
if (matchPattern(bucketize.getOutInt32(), m_TorchConstantBool(&outInt32)) &&
|
||||
outInt32) {
|
||||
knowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 32, IntegerType::Signed);
|
||||
} else {
|
||||
knowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
}
|
||||
incorporateKnowledge(bucketize.getResult(), knowledge);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation, so reset the state.
|
||||
setAllToEntryStates(results);
|
||||
return;
|
||||
|
|
|
@ -49,5 +49,12 @@ std::vector<torch::lazy::Shape> compute_shape_where(
|
|||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_bucketize(
|
||||
const at::Tensor& self, const at::Tensor& boundaries, bool out_int32,
|
||||
bool right) {
|
||||
auto dtype = out_int32 ? at::kInt : at::kLong;
|
||||
return {Shape(dtype, self.sizes().vec())};
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -209,6 +209,9 @@ def aten〇dropout〡shape(input: List[int], p: float, train: bool) -> List[int]
|
|||
def aten〇gelu〡shape(self: List[int], approximate: str = "none") -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇bucketize〇Tensor〡shape(self: List[int], boundaries: List[int], out_int32: bool = False, right: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇contiguous〡shape(self: List[int], memory_format: int = 0) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
|
|
@ -3161,4 +3161,91 @@ class SortIntListReverse(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SortIntListReverse())
|
||||
def SortIntListReverse_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BucketizeTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, boundaries):
|
||||
return torch.bucketize(input, boundaries)
|
||||
|
||||
@register_test_case(module_factory=lambda: BucketizeTensorModule())
|
||||
def BucketizeTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6]))
|
||||
|
||||
class BucketizeTensorOutInt32RightModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, boundaries):
|
||||
return torch.bucketize(input, boundaries, out_int32=True, right=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: BucketizeTensorOutInt32RightModule())
|
||||
def BucketizeTensorOutInt32RightModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6]))
|
||||
|
||||
class BucketizeTensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, boundaries):
|
||||
return torch.bucketize(input, boundaries)
|
||||
|
||||
@register_test_case(module_factory=lambda: BucketizeTensorFloatModule())
|
||||
def BucketizeTensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(15, 17), torch.sort(tu.rand(16)).values)
|
||||
|
||||
class BucketizeTensorStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 4], torch.int64, True),
|
||||
([3], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, boundaries):
|
||||
return torch.bucketize(input, boundaries)
|
||||
|
||||
@register_test_case(module_factory=lambda: BucketizeTensorStaticModule())
|
||||
def BucketizeTensorStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6]))
|
||||
|
||||
class BucketizeTensorStaticFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([15, 17], torch.float32, True),
|
||||
([16], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, boundaries):
|
||||
return torch.bucketize(input, boundaries)
|
||||
|
||||
@register_test_case(module_factory=lambda: BucketizeTensorStaticFloatModule())
|
||||
def BucketizeTensorStaticFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(15, 17), torch.sort(tu.rand(16)).values)
|
||||
|
|
|
@ -144,6 +144,46 @@ def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGeFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ge(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGeFloatTensorModule())
|
||||
def ElementwiseGeFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGeIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ge(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGeIntTensorModule())
|
||||
def ElementwiseGeIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -318,6 +358,46 @@ def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLeFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.le(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLeFloatTensorModule())
|
||||
def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLeIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.le(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLeIntTensorModule())
|
||||
def ElementwiseLeIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue