[Stablehlo] fix crashing on AtenEmbeddingBagSumExample_basic (#3389)

pull/3391/head
Yuanqiang Liu 2024-05-26 12:34:56 +08:00 committed by GitHub
parent 27169dcda9
commit 28aeb047c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 53 additions and 51 deletions

View File

@ -22,6 +22,10 @@ namespace hlo {
using mlir::ConversionPatternRewriter;
// Create chlo::ConstantLikeOp
template <typename T>
Value getConstantLike(OpBuilder &rewriter, Location loc, T constant, Value val);
// Create a 32-bit float constant operator from a float
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);

View File

@ -36,34 +36,6 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo;
namespace {
template <typename T>
static Value getConstantLike(OpBuilder &b, Location loc, T constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
auto getAttr = [&]() -> Attribute {
if (isa<mlir::IntegerType>(ty))
return b.getIntegerAttr(ty, constant);
if (isa<mlir::FloatType>(ty))
return b.getFloatAttr(ty, constant);
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
return complex::NumberAttr::get(complexTy, constant, 0);
llvm_unreachable("unhandled element type");
};
return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
val);
}
Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant),
val);
}
} // namespace
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other,
size_t dimSizeIndexBits) {
@ -928,7 +900,8 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
"for AtenReciprocalOp");
}
Value oneTensor = getConstantLike(rewriter, op->getLoc(), 1, input);
Value oneTensor =
hlo::getConstantLike<int64_t>(rewriter, op->getLoc(), 1, input);
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
return success();
}
@ -1070,12 +1043,8 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
return op->emitError("only float tensor in relu op is supported");
}
Value zeroTensor;
zeroTensor = getConstantLike(
rewriter, op->getLoc(),
APFloat::getZero(cast<mlir::FloatType>(lhsElemTy).getFloatSemantics(),
false),
lhs);
Value zeroTensor =
hlo::getConstantLike<int64_t>(rewriter, op->getLoc(), 0, lhs);
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
return success();
}
@ -1102,13 +1071,13 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
return op.emitError("unsupported approximate: ") << approximate;
}
Value one = getConstantLike(rewriter, loc, 1.0, input);
Value two = getConstantLike(rewriter, loc, 2.0, input);
Value three = getConstantLike(rewriter, loc, 3.0, input);
Value half = getConstantLike(rewriter, loc, 0.5, input);
Value one = hlo::getConstantLike(rewriter, loc, 1.0, input);
Value two = hlo::getConstantLike(rewriter, loc, 2.0, input);
Value three = hlo::getConstantLike(rewriter, loc, 3.0, input);
Value half = hlo::getConstantLike(rewriter, loc, 0.5, input);
// 2/pi
Value twoDivPi = getConstantLike(rewriter, loc, M_2_PI, input);
Value t = getConstantLike(rewriter, loc, 0.044715, input);
Value twoDivPi = hlo::getConstantLike(rewriter, loc, M_2_PI, input);
Value t = hlo::getConstantLike(rewriter, loc, 0.044715, input);
// x * 0.5
auto inputMulHalf = rewriter.create<stablehlo::MulOp>(loc, input, half);
@ -1147,7 +1116,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input);
auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input);
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
@ -1169,7 +1138,7 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input);
auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input);
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
@ -1764,12 +1733,13 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
}
// Create constant value
Value kAlpha = getConstantLike(rewriter, loc, 0.70710678118654752440, input);
Value kAlpha =
hlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input);
Value cstAlpha0 =
getConstantLike(rewriter, loc, 1.12837916709551257390, input);
Value half = getConstantLike(rewriter, loc, .5, input);
Value one = getConstantLike(rewriter, loc, 1.0, input);
Value negHalf = getConstantLike(rewriter, loc, -0.5, input);
hlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input);
Value half = hlo::getConstantLike(rewriter, loc, .5, input);
Value one = hlo::getConstantLike(rewriter, loc, 1.0, input);
Value negHalf = hlo::getConstantLike(rewriter, loc, -0.5, input);
// Compute
Value kBeta0 =

View File

@ -32,6 +32,9 @@ namespace {
static Value createInitialValueForGatherScatterOp(Operation *op,
RankedTensorType constType,
PatternRewriter &rewriter) {
if (!constType.hasStaticShape()) {
return nullptr;
}
auto elementTy = constType.getElementType();
if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {

View File

@ -9,8 +9,10 @@
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -24,6 +26,31 @@ using namespace mlir::torch::Torch;
namespace mlir {
namespace hlo {
// Create chlo::ConstantLikeOp
template <typename T>
Value getConstantLike(OpBuilder &rewriter, Location loc, T constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
auto getAttr = [&]() -> Attribute {
if (isa<mlir::IntegerType>(ty))
return rewriter.getIntegerAttr(ty, constant);
if (isa<mlir::FloatType>(ty))
return rewriter.getFloatAttr(ty, constant);
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
return mlir::complex::NumberAttr::get(complexTy, constant, 0);
llvm_unreachable("unhandled element type");
};
return rewriter.create<mlir::chlo::ConstantLikeOp>(
loc, cast<TypedAttr>(getAttr()), val);
}
// Template instantiation
template Value getConstantLike<int64_t>(OpBuilder &rewriter, Location loc,
int64_t constant, Value val);
template Value getConstantLike<double>(OpBuilder &rewriter, Location loc,
double constant, Value val);
// Create a 32-bit float constant operator from a float
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val) {

View File

@ -1442,9 +1442,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseSoftshrinkStaticModule_basic",
}
STABLEHLO_CRASHING_SET = {
"AtenEmbeddingBagSumExample_basic",
}
STABLEHLO_CRASHING_SET = set()
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.