mirror of https://github.com/llvm/torch-mlir
parent
5d55390111
commit
f58ba19448
|
@ -110,6 +110,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"BroadcastToModule_basic",
|
"BroadcastToModule_basic",
|
||||||
"BroadcastToSameRankStaticModule_basic",
|
"BroadcastToSameRankStaticModule_basic",
|
||||||
"BroadcastZeroRankInputStaticModule_basic",
|
"BroadcastZeroRankInputStaticModule_basic",
|
||||||
|
"BucketizeTensorStaticFloatModule_basic",
|
||||||
|
"BucketizeTensorStaticModule_basic",
|
||||||
"CumsumStaticModule_basic",
|
"CumsumStaticModule_basic",
|
||||||
"CumsumStaticNegativeDimModule_basic",
|
"CumsumStaticNegativeDimModule_basic",
|
||||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||||
|
|
|
@ -59,6 +59,15 @@ static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
|
||||||
b, loc, elementalType, lhs, rhs);
|
b, loc, elementalType, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createGreaterThanOrEqual(OpBuilder &b, Location loc,
|
||||||
|
Type elementalType, Value lhs,
|
||||||
|
Value rhs) {
|
||||||
|
return createComparisonTemplate<arith::CmpFPredicate::UGE,
|
||||||
|
arith::CmpIPredicate::uge,
|
||||||
|
arith::CmpIPredicate::sge>(
|
||||||
|
b, loc, elementalType, lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
|
static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
|
||||||
Value lhs, Value rhs) {
|
Value lhs, Value rhs) {
|
||||||
return createComparisonTemplate<arith::CmpFPredicate::ULT,
|
return createComparisonTemplate<arith::CmpFPredicate::ULT,
|
||||||
|
@ -67,6 +76,14 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
|
||||||
b, loc, elementalType, lhs, rhs);
|
b, loc, elementalType, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createLessThanOrEqual(OpBuilder &b, Location loc,
|
||||||
|
Type elementalType, Value lhs, Value rhs) {
|
||||||
|
return createComparisonTemplate<arith::CmpFPredicate::ULE,
|
||||||
|
arith::CmpIPredicate::ule,
|
||||||
|
arith::CmpIPredicate::sle>(
|
||||||
|
b, loc, elementalType, lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
|
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
|
||||||
Value lhs, Value rhs) {
|
Value lhs, Value rhs) {
|
||||||
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
|
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
|
||||||
|
@ -117,6 +134,46 @@ static Value createCalculationForMathOpWithDtypeConversion(
|
||||||
return b.create<MathOpTy>(loc, arg);
|
return b.create<MathOpTy>(loc, arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
|
||||||
|
Value lhs, Value rhs) {
|
||||||
|
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
|
||||||
|
std::is_same<OpTy, AtenLeTensorOp>() ||
|
||||||
|
std::is_same<OpTy, AtenGtTensorOp>() ||
|
||||||
|
std::is_same<OpTy, AtenGeTensorOp>() ||
|
||||||
|
std::is_same<OpTy, AtenEqTensorOp>(),
|
||||||
|
"unimplemented: op type not supported");
|
||||||
|
|
||||||
|
Type lhsDtype = lhs.getType();
|
||||||
|
Type rhsDtype = rhs.getType();
|
||||||
|
|
||||||
|
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
||||||
|
// to be handled.
|
||||||
|
if (lhsDtype != rhsDtype) {
|
||||||
|
op.emitError("unimplemented: lhs and rhs dtype must be same");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type elementalType =
|
||||||
|
op.getSelf().getType().template cast<BaseTensorType>().getDtype();
|
||||||
|
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
|
||||||
|
return createLessThan(b, loc, elementalType, lhs, rhs);
|
||||||
|
}
|
||||||
|
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
|
||||||
|
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
|
||||||
|
}
|
||||||
|
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
|
||||||
|
return createGreaterThan(b, loc, elementalType, lhs, rhs);
|
||||||
|
}
|
||||||
|
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
|
||||||
|
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
|
||||||
|
}
|
||||||
|
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
|
||||||
|
return createEqual(b, loc, elementalType, lhs, rhs);
|
||||||
|
}
|
||||||
|
llvm_unreachable("unimplemented: op type not supported");
|
||||||
|
}
|
||||||
|
|
||||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
OpBuilder &b, Location loc, TypeConverter *converter,
|
OpBuilder &b, Location loc, TypeConverter *converter,
|
||||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||||
|
@ -465,64 +522,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
return b.create<math::Atan2Op>(loc, lhs, rhs);
|
return b.create<math::Atan2Op>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
|
||||||
|
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
|
||||||
|
payloadArgs[1]);
|
||||||
|
}
|
||||||
|
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
|
||||||
|
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
|
||||||
|
payloadArgs[1]);
|
||||||
|
}
|
||||||
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
|
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
|
||||||
AtenGtTensorOp::Adaptor adaptor(operands);
|
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
|
||||||
Type lhsDtype = payloadArgs[0].getType();
|
payloadArgs[1]);
|
||||||
Type rhsDtype = payloadArgs[1].getType();
|
}
|
||||||
|
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
|
||||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
|
||||||
// to be handled.
|
payloadArgs[1]);
|
||||||
if (lhsDtype != rhsDtype) {
|
|
||||||
gtTensor.emitError("unimplemented: different lhs and rhs dtype");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
Type elementalType =
|
|
||||||
gtTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
|
|
||||||
return createGreaterThan(b, loc, elementalType, payloadArgs[0],
|
|
||||||
payloadArgs[1]);
|
|
||||||
}
|
}
|
||||||
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
|
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
|
||||||
AtenEqTensorOp::Adaptor adaptor(operands);
|
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
|
||||||
Type lhsDtype = payloadArgs[0].getType();
|
payloadArgs[1]);
|
||||||
Type rhsDtype = payloadArgs[1].getType();
|
|
||||||
|
|
||||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
|
||||||
// to be handled.
|
|
||||||
if (lhsDtype != rhsDtype) {
|
|
||||||
eqTensor.emitError("unimplemented: lhs and rhs dtype must be same");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
Type elementalType =
|
|
||||||
eqTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
|
|
||||||
|
|
||||||
if (elementalType.isa<mlir::FloatType>())
|
|
||||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
|
|
||||||
payloadArgs[0], payloadArgs[1]);
|
|
||||||
if (elementalType.isa<mlir::IntegerType>()) {
|
|
||||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
|
||||||
payloadArgs[0], payloadArgs[1]);
|
|
||||||
}
|
|
||||||
eqTensor.emitError("unimplemented: dtype isn't supported.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
|
|
||||||
AtenLtTensorOp::Adaptor adaptor(operands);
|
|
||||||
Type lhsDtype = payloadArgs[0].getType();
|
|
||||||
Type rhsDtype = payloadArgs[1].getType();
|
|
||||||
|
|
||||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
|
||||||
// to be handled.
|
|
||||||
if (lhsDtype != rhsDtype) {
|
|
||||||
ltTensor.emitError("unimplemented: lhs and rhs dtype must be same");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
Type elementalType =
|
|
||||||
ltTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
|
|
||||||
return createLessThan(b, loc, elementalType, payloadArgs[0],
|
|
||||||
payloadArgs[1]);
|
|
||||||
}
|
}
|
||||||
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
||||||
AtenDivTensorOp::Adaptor adaptor(operands);
|
AtenDivTensorOp::Adaptor adaptor(operands);
|
||||||
|
@ -1084,10 +1102,10 @@ public:
|
||||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
|
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
|
||||||
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||||
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||||
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
|
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
|
||||||
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
|
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
|
||||||
|
@ -1563,12 +1581,12 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||||
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
|
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
|
||||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||||
AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||||
AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp,
|
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||||
AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp,
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||||
AtenTriuOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp,
|
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
|
||||||
AtenFillScalarOp, AtenFillTensorOp>();
|
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||||
|
|
|
@ -5903,6 +5903,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.bucketize.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.contiguous\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.contiguous\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
|
|
@ -171,6 +171,37 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
||||||
return sub;
|
return sub;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to unsqueeze the input tensor at given dim.
|
||||||
|
// Return the unsqueezed tensor or failure.
|
||||||
|
static FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter,
|
||||||
|
Operation *op, Value input, Value dim) {
|
||||||
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||||
|
if (!inputType.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "input tensor must have size");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> unsqueezedShape;
|
||||||
|
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||||
|
// `input` has a reduced rank. Hence add 1.
|
||||||
|
int64_t unsqueezedRank = inputShape.size() + 1;
|
||||||
|
int64_t dimInt = 0;
|
||||||
|
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||||
|
dimInt = toPositiveDim(dimInt, unsqueezedRank);
|
||||||
|
if (!isValidDim(dimInt, unsqueezedRank)) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||||
|
}
|
||||||
|
unsqueezedShape.append(inputShape.begin(), inputShape.end());
|
||||||
|
unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1);
|
||||||
|
} else {
|
||||||
|
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
|
||||||
|
}
|
||||||
|
Type unsqueezedType =
|
||||||
|
inputType.getWithSizesAndDtype(unsqueezedShape, inputType.getDtype());
|
||||||
|
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
|
||||||
|
op->getLoc(), unsqueezedType, input, dim);
|
||||||
|
return unsqueezed;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
||||||
/// number of dimensions across which the max needs to be computed.
|
/// number of dimensions across which the max needs to be computed.
|
||||||
|
@ -606,6 +637,128 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompose `aten.bucketize` into the following op sequence:
|
||||||
|
//
|
||||||
|
// def aten_bucketize(input, boundaries, out_int32, right):
|
||||||
|
// unsqz_input = input.unsqueeze(-1)
|
||||||
|
// if not right:
|
||||||
|
// comparison = unsqz_input <= boundaries
|
||||||
|
// else:
|
||||||
|
// comparison = unsqz_input < boundaries
|
||||||
|
// indices = torch.argmax(comparison.float(), dim=-1)
|
||||||
|
// within_bound = comparison[..., -1]
|
||||||
|
// result = torch.where(within_bound, indices, boundaries.shape[0])
|
||||||
|
// if out_int32:
|
||||||
|
// result = result.int()
|
||||||
|
// return result
|
||||||
|
//
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenBucketizeTensorOp
|
||||||
|
: public OpRewritePattern<AtenBucketizeTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenBucketizeTensorOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
Value input = op.getSelf();
|
||||||
|
auto inputType = input.getType().cast<BaseTensorType>();
|
||||||
|
if (!inputType.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: input must have known sizes");
|
||||||
|
}
|
||||||
|
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||||
|
|
||||||
|
Value boundaries = op.getBoundaries();
|
||||||
|
auto boundariesType = boundaries.getType().cast<BaseTensorType>();
|
||||||
|
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"unimplemented: boundaries must have "
|
||||||
|
"known sizes and must be a 1D array");
|
||||||
|
}
|
||||||
|
int64_t boundariesSize = boundariesType.getSizes()[0];
|
||||||
|
|
||||||
|
bool outInt32;
|
||||||
|
if (!matchPattern(op.getOutInt32(), m_TorchConstantBool(&outInt32))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: out_int32 must be a constant bool");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool right;
|
||||||
|
if (!matchPattern(op.getRight(), m_TorchConstantBool(&right))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: right must be a constant bool");
|
||||||
|
}
|
||||||
|
|
||||||
|
// unsqueeze input at the last dim to make it broadcastable with boundaries
|
||||||
|
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(-1));
|
||||||
|
auto unsqzTensorInfo =
|
||||||
|
unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne);
|
||||||
|
if (failed(unsqzTensorInfo)) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"cannot generate unsqueeze tensor");
|
||||||
|
}
|
||||||
|
Value unsqzInput = *unsqzTensorInfo;
|
||||||
|
|
||||||
|
// compare unsqueezed input with boundaries
|
||||||
|
SmallVector<int64_t> compareShape(inputShape);
|
||||||
|
compareShape.push_back(boundariesSize);
|
||||||
|
Type compareType =
|
||||||
|
inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type());
|
||||||
|
Value compare;
|
||||||
|
if (!right) {
|
||||||
|
compare = rewriter.create<AtenLeTensorOp>(loc, compareType, unsqzInput,
|
||||||
|
boundaries);
|
||||||
|
} else {
|
||||||
|
compare = rewriter.create<AtenLtTensorOp>(loc, compareType, unsqzInput,
|
||||||
|
boundaries);
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert the comparison results to float32 as the argmax op input,
|
||||||
|
// which does not support integer dtype in LINALG backend
|
||||||
|
Value compareF32 =
|
||||||
|
convertTensorToDtype(rewriter, loc, compare, rewriter.getF32Type());
|
||||||
|
|
||||||
|
// get the first boundary index where the input element is less than (or
|
||||||
|
// equal to) the boundary value
|
||||||
|
Type indicesType = inputType.getWithSizesAndDtype(
|
||||||
|
inputShape, rewriter.getIntegerType(64, IntegerType::Signed));
|
||||||
|
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
|
Value indices = rewriter.create<AtenArgmaxOp>(loc, indicesType, compareF32,
|
||||||
|
/*dim=*/constMinusOne,
|
||||||
|
/*keepdim=*/constFalse);
|
||||||
|
|
||||||
|
// get the comparison results between each input element and the rightmost
|
||||||
|
// boundary value
|
||||||
|
Type withinUpperBoundType =
|
||||||
|
inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type());
|
||||||
|
Value withinUpperBound = rewriter.create<AtenSelectIntOp>(
|
||||||
|
loc, withinUpperBoundType, compare, /*dim=*/constMinusOne,
|
||||||
|
/*index=*/constMinusOne);
|
||||||
|
|
||||||
|
// If the input element is less than (or equal to) the rightmost boundary,
|
||||||
|
// take the max index as result. Otherwise, the element is beyond the
|
||||||
|
// rightmost boundary, so take the boundary size.
|
||||||
|
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value upperBound =
|
||||||
|
rewriter.create<AtenSizeIntOp>(loc, boundaries, /*dim=*/constZero);
|
||||||
|
Value result = rewriter.create<AtenWhereScalarOtherOp>(
|
||||||
|
loc, indicesType, withinUpperBound, indices, upperBound);
|
||||||
|
|
||||||
|
if (outInt32) {
|
||||||
|
result = convertTensorToDtype(
|
||||||
|
rewriter, loc, result,
|
||||||
|
rewriter.getIntegerType(32, IntegerType::Signed));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// To avoid overflow we use the following decomposition rule:
|
// To avoid overflow we use the following decomposition rule:
|
||||||
// x_max = aten.max(x, dim, keepdim=True)[0]
|
// x_max = aten.max(x, dim, keepdim=True)[0]
|
||||||
// shifted = x - x_max
|
// shifted = x - x_max
|
||||||
|
@ -3193,29 +3346,13 @@ public:
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
Value startPlusOne =
|
Value startPlusOne =
|
||||||
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
||||||
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
|
|
||||||
SmallVector<int64_t> sizes;
|
|
||||||
if (!srcTensorType.hasSizes())
|
|
||||||
return rewriter.notifyMatchFailure(op, "src tensor must have size");
|
|
||||||
|
|
||||||
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
|
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim);
|
||||||
// `src` has a reduced rank. Hence add 1.
|
if (failed(unsqueezedInfo)) {
|
||||||
int64_t srcRank = srcShape.size() + 1;
|
return rewriter.notifyMatchFailure(op,
|
||||||
int64_t dimInt = 0;
|
"cannot generate unsqueeze tensor op");
|
||||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
|
||||||
dimInt = toPositiveDim(dimInt, srcRank);
|
|
||||||
if (!isValidDim(dimInt, srcRank))
|
|
||||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
|
||||||
|
|
||||||
sizes.append(srcShape.begin(), srcShape.end());
|
|
||||||
sizes.insert(sizes.begin() + dimInt, 1);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
sizes.resize(srcShape.size() + 1, kUnknownSize);
|
|
||||||
}
|
}
|
||||||
Type srcType = srcTensorType.getWithSizesAndDtype(
|
src = *unsqueezedInfo;
|
||||||
llvm::ArrayRef(sizes), srcTensorType.getOptionalDtype());
|
|
||||||
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
|
|
||||||
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
||||||
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
|
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
|
||||||
/*step=*/one);
|
/*step=*/one);
|
||||||
|
@ -3786,6 +3923,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
|
||||||
|
|
||||||
GreedyRewriteConfig config;
|
GreedyRewriteConfig config;
|
||||||
config.useTopDownTraversal = true;
|
config.useTopDownTraversal = true;
|
||||||
|
|
|
@ -440,6 +440,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenRandnLikeOp>();
|
target.addIllegalOp<AtenRandnLikeOp>();
|
||||||
target.addIllegalOp<AtenVarMeanOp>();
|
target.addIllegalOp<AtenVarMeanOp>();
|
||||||
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
||||||
|
target.addIllegalOp<AtenBucketizeTensorOp>();
|
||||||
for (std::string opName : backendLegalOps) {
|
for (std::string opName : backendLegalOps) {
|
||||||
target.addLegalOp(OperationName(opName, context));
|
target.addLegalOp(OperationName(opName, context));
|
||||||
}
|
}
|
||||||
|
|
|
@ -706,8 +706,9 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
// Dtype is always i1.
|
// Dtype is always i1.
|
||||||
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
|
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
|
||||||
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
|
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
|
||||||
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
AtenGtTensorOp, AtenGeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
|
||||||
AtenLogicalXorOp, AtenLogicalNotOp>(op)) {
|
AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp,
|
||||||
|
AtenLogicalNotOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = IntegerType::get(op->getContext(), 1);
|
knowledge.dtype = IntegerType::get(op->getContext(), 1);
|
||||||
|
@ -1191,6 +1192,22 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto bucketize = dyn_cast<AtenBucketizeTensorOp>(op)) {
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
|
bool outInt32;
|
||||||
|
if (matchPattern(bucketize.getOutInt32(), m_TorchConstantBool(&outInt32)) &&
|
||||||
|
outInt32) {
|
||||||
|
knowledge.dtype =
|
||||||
|
IntegerType::get(op->getContext(), 32, IntegerType::Signed);
|
||||||
|
} else {
|
||||||
|
knowledge.dtype =
|
||||||
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||||
|
}
|
||||||
|
incorporateKnowledge(bucketize.getResult(), knowledge);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Otherwise, this is an unknown operation, so reset the state.
|
// Otherwise, this is an unknown operation, so reset the state.
|
||||||
setAllToEntryStates(results);
|
setAllToEntryStates(results);
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -49,5 +49,12 @@ std::vector<torch::lazy::Shape> compute_shape_where(
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> compute_shape_bucketize(
|
||||||
|
const at::Tensor& self, const at::Tensor& boundaries, bool out_int32,
|
||||||
|
bool right) {
|
||||||
|
auto dtype = out_int32 ? at::kInt : at::kLong;
|
||||||
|
return {Shape(dtype, self.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -209,6 +209,9 @@ def aten〇dropout〡shape(input: List[int], p: float, train: bool) -> List[int]
|
||||||
def aten〇gelu〡shape(self: List[int], approximate: str = "none") -> List[int]:
|
def aten〇gelu〡shape(self: List[int], approximate: str = "none") -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
def aten〇bucketize〇Tensor〡shape(self: List[int], boundaries: List[int], out_int32: bool = False, right: bool = False) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇contiguous〡shape(self: List[int], memory_format: int = 0) -> List[int]:
|
def aten〇contiguous〡shape(self: List[int], memory_format: int = 0) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
|
|
@ -3161,4 +3161,91 @@ class SortIntListReverse(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: SortIntListReverse())
|
@register_test_case(module_factory=lambda: SortIntListReverse())
|
||||||
def SortIntListReverse_basic(module, tu: TestUtils):
|
def SortIntListReverse_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class BucketizeTensorModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, boundaries):
|
||||||
|
return torch.bucketize(input, boundaries)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BucketizeTensorModule())
|
||||||
|
def BucketizeTensorModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6]))
|
||||||
|
|
||||||
|
class BucketizeTensorOutInt32RightModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, boundaries):
|
||||||
|
return torch.bucketize(input, boundaries, out_int32=True, right=True)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BucketizeTensorOutInt32RightModule())
|
||||||
|
def BucketizeTensorOutInt32RightModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6]))
|
||||||
|
|
||||||
|
class BucketizeTensorFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, boundaries):
|
||||||
|
return torch.bucketize(input, boundaries)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BucketizeTensorFloatModule())
|
||||||
|
def BucketizeTensorFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(15, 17), torch.sort(tu.rand(16)).values)
|
||||||
|
|
||||||
|
class BucketizeTensorStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 4], torch.int64, True),
|
||||||
|
([3], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, boundaries):
|
||||||
|
return torch.bucketize(input, boundaries)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BucketizeTensorStaticModule())
|
||||||
|
def BucketizeTensorStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6]))
|
||||||
|
|
||||||
|
class BucketizeTensorStaticFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([15, 17], torch.float32, True),
|
||||||
|
([16], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, boundaries):
|
||||||
|
return torch.bucketize(input, boundaries)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BucketizeTensorStaticFloatModule())
|
||||||
|
def BucketizeTensorStaticFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(15, 17), torch.sort(tu.rand(16)).values)
|
||||||
|
|
|
@ -144,6 +144,46 @@ def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseGeFloatTensorModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ge(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseGeFloatTensorModule())
|
||||||
|
def ElementwiseGeFloatTensorModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseGeIntTensorModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ge(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseGeIntTensorModule())
|
||||||
|
def ElementwiseGeIntTensorModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -318,6 +358,46 @@ def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseLeFloatTensorModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.le(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseLeFloatTensorModule())
|
||||||
|
def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseLeIntTensorModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.le(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseLeIntTensorModule())
|
||||||
|
def ElementwiseLeIntTensorModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue