mirror of https://github.com/llvm/torch-mlir
[Stablehlo] fix crashing on AtenEmbeddingBagSumExample_basic (#3389)
parent
27169dcda9
commit
28aeb047c1
|
@ -22,6 +22,10 @@ namespace hlo {
|
|||
|
||||
using mlir::ConversionPatternRewriter;
|
||||
|
||||
// Create chlo::ConstantLikeOp
|
||||
template <typename T>
|
||||
Value getConstantLike(OpBuilder &rewriter, Location loc, T constant, Value val);
|
||||
|
||||
// Create a 32-bit float constant operator from a float
|
||||
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val);
|
||||
|
|
|
@ -36,34 +36,6 @@ using namespace mlir::torch;
|
|||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
static Value getConstantLike(OpBuilder &b, Location loc, T constant,
|
||||
Value val) {
|
||||
Type ty = getElementTypeOrSelf(val.getType());
|
||||
auto getAttr = [&]() -> Attribute {
|
||||
if (isa<mlir::IntegerType>(ty))
|
||||
return b.getIntegerAttr(ty, constant);
|
||||
if (isa<mlir::FloatType>(ty))
|
||||
return b.getFloatAttr(ty, constant);
|
||||
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
|
||||
return complex::NumberAttr::get(complexTy, constant, 0);
|
||||
llvm_unreachable("unhandled element type");
|
||||
};
|
||||
return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
|
||||
val);
|
||||
}
|
||||
|
||||
Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
|
||||
Value val) {
|
||||
Type ty = getElementTypeOrSelf(val.getType());
|
||||
return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant),
|
||||
val);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
||||
mlir::Value &self, mlir::Value &other,
|
||||
size_t dimSizeIndexBits) {
|
||||
|
@ -928,7 +900,8 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
|||
"for AtenReciprocalOp");
|
||||
}
|
||||
|
||||
Value oneTensor = getConstantLike(rewriter, op->getLoc(), 1, input);
|
||||
Value oneTensor =
|
||||
hlo::getConstantLike<int64_t>(rewriter, op->getLoc(), 1, input);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
|
||||
return success();
|
||||
}
|
||||
|
@ -1070,12 +1043,8 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|||
return op->emitError("only float tensor in relu op is supported");
|
||||
}
|
||||
|
||||
Value zeroTensor;
|
||||
zeroTensor = getConstantLike(
|
||||
rewriter, op->getLoc(),
|
||||
APFloat::getZero(cast<mlir::FloatType>(lhsElemTy).getFloatSemantics(),
|
||||
false),
|
||||
lhs);
|
||||
Value zeroTensor =
|
||||
hlo::getConstantLike<int64_t>(rewriter, op->getLoc(), 0, lhs);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
|
||||
return success();
|
||||
}
|
||||
|
@ -1102,13 +1071,13 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
|||
return op.emitError("unsupported approximate: ") << approximate;
|
||||
}
|
||||
|
||||
Value one = getConstantLike(rewriter, loc, 1.0, input);
|
||||
Value two = getConstantLike(rewriter, loc, 2.0, input);
|
||||
Value three = getConstantLike(rewriter, loc, 3.0, input);
|
||||
Value half = getConstantLike(rewriter, loc, 0.5, input);
|
||||
Value one = hlo::getConstantLike(rewriter, loc, 1.0, input);
|
||||
Value two = hlo::getConstantLike(rewriter, loc, 2.0, input);
|
||||
Value three = hlo::getConstantLike(rewriter, loc, 3.0, input);
|
||||
Value half = hlo::getConstantLike(rewriter, loc, 0.5, input);
|
||||
// 2/pi
|
||||
Value twoDivPi = getConstantLike(rewriter, loc, M_2_PI, input);
|
||||
Value t = getConstantLike(rewriter, loc, 0.044715, input);
|
||||
Value twoDivPi = hlo::getConstantLike(rewriter, loc, M_2_PI, input);
|
||||
Value t = hlo::getConstantLike(rewriter, loc, 0.044715, input);
|
||||
|
||||
// x * 0.5
|
||||
auto inputMulHalf = rewriter.create<stablehlo::MulOp>(loc, input, half);
|
||||
|
@ -1147,7 +1116,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
|||
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||
|
||||
auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
||||
auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
||||
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
|
||||
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
|
||||
|
||||
|
@ -1169,7 +1138,7 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
|
|||
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||
|
||||
auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
||||
auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
||||
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
|
||||
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
|
||||
|
||||
|
@ -1764,12 +1733,13 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
|
||||
}
|
||||
// Create constant value
|
||||
Value kAlpha = getConstantLike(rewriter, loc, 0.70710678118654752440, input);
|
||||
Value kAlpha =
|
||||
hlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input);
|
||||
Value cstAlpha0 =
|
||||
getConstantLike(rewriter, loc, 1.12837916709551257390, input);
|
||||
Value half = getConstantLike(rewriter, loc, .5, input);
|
||||
Value one = getConstantLike(rewriter, loc, 1.0, input);
|
||||
Value negHalf = getConstantLike(rewriter, loc, -0.5, input);
|
||||
hlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input);
|
||||
Value half = hlo::getConstantLike(rewriter, loc, .5, input);
|
||||
Value one = hlo::getConstantLike(rewriter, loc, 1.0, input);
|
||||
Value negHalf = hlo::getConstantLike(rewriter, loc, -0.5, input);
|
||||
|
||||
// Compute
|
||||
Value kBeta0 =
|
||||
|
|
|
@ -32,6 +32,9 @@ namespace {
|
|||
static Value createInitialValueForGatherScatterOp(Operation *op,
|
||||
RankedTensorType constType,
|
||||
PatternRewriter &rewriter) {
|
||||
if (!constType.hasStaticShape()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto elementTy = constType.getElementType();
|
||||
if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) {
|
||||
if (isa<mlir::FloatType>(elementTy)) {
|
||||
|
|
|
@ -9,8 +9,10 @@
|
|||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -24,6 +26,31 @@ using namespace mlir::torch::Torch;
|
|||
namespace mlir {
|
||||
namespace hlo {
|
||||
|
||||
// Create chlo::ConstantLikeOp
|
||||
template <typename T>
|
||||
Value getConstantLike(OpBuilder &rewriter, Location loc, T constant,
|
||||
Value val) {
|
||||
Type ty = getElementTypeOrSelf(val.getType());
|
||||
auto getAttr = [&]() -> Attribute {
|
||||
if (isa<mlir::IntegerType>(ty))
|
||||
return rewriter.getIntegerAttr(ty, constant);
|
||||
if (isa<mlir::FloatType>(ty))
|
||||
return rewriter.getFloatAttr(ty, constant);
|
||||
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
|
||||
return mlir::complex::NumberAttr::get(complexTy, constant, 0);
|
||||
llvm_unreachable("unhandled element type");
|
||||
};
|
||||
return rewriter.create<mlir::chlo::ConstantLikeOp>(
|
||||
loc, cast<TypedAttr>(getAttr()), val);
|
||||
}
|
||||
|
||||
// Template instantiation
|
||||
template Value getConstantLike<int64_t>(OpBuilder &rewriter, Location loc,
|
||||
int64_t constant, Value val);
|
||||
|
||||
template Value getConstantLike<double>(OpBuilder &rewriter, Location loc,
|
||||
double constant, Value val);
|
||||
|
||||
// Create a 32-bit float constant operator from a float
|
||||
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val) {
|
||||
|
|
|
@ -1442,9 +1442,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseSoftshrinkStaticModule_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_CRASHING_SET = {
|
||||
"AtenEmbeddingBagSumExample_basic",
|
||||
}
|
||||
STABLEHLO_CRASHING_SET = set()
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
|
|
Loading…
Reference in New Issue