mirror of https://github.com/llvm/torch-mlir
Fix handling of non-int tensors in `getScalarValue` (#1914)
The current implementation of `getScalarValue` does not check that the input to a `ValueTensorLiteralOp` is an i64 before extracting the value, and it does not check that the result type of the `PrimNumToTensorScalarOp` is also an i64. This leads to crashes or invalid IR generated when the `input` is something other than an i64 tensor or `!torch.int`. This commit addresses those issues. In addition, the function `getScalarValue` is renamed to `getScalarIntValue` to make it clear that it *only* extracts scalar integers.pull/1918/head
parent
62250dabbb
commit
671be048fe
|
@ -89,6 +89,8 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
#ERROR: value (-56) is not equal to golden value (200)
|
||||
"AtenIntTensorByteDtypeModule_basic",
|
||||
# ERROR: assert isinstance(e, FakeTensor)
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
|
@ -155,6 +157,8 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
"ElementwiseAddScalarIntModule_basic",
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||
"ElementwiseDivScalarModule_basic",
|
||||
"ElementwiseEqDiffWidthScalarModule_basic",
|
||||
"ElementwiseEqFloatScalarModule_basic",
|
||||
|
@ -537,6 +541,7 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseDivScalarModule_basic",
|
||||
"ElementwiseSubScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||
"ElementwiseMulScalarModule_float",
|
||||
"ElementwiseCeilModule_basic",
|
||||
"ElementwiseReciprocalModule_basic",
|
||||
|
|
|
@ -128,32 +128,36 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
|||
return FloatAttr::get(Float64Type::get(context), value);
|
||||
}
|
||||
|
||||
static Value getScalarValue(Value input, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
static Value getScalarIntValue(Value input, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
auto inputType = input.getType();
|
||||
if (inputType.isa<Torch::IntType>()) {
|
||||
return input;
|
||||
}
|
||||
Value scalar = nullptr;
|
||||
|
||||
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
|
||||
if (!inputTensorType)
|
||||
return nullptr;
|
||||
|
||||
Type inputDtype = inputTensorType.getOptionalDtype();
|
||||
if (!inputDtype || !inputDtype.isInteger(64))
|
||||
return nullptr;
|
||||
|
||||
std::optional<unsigned> inputRank = getTensorRank(input);
|
||||
if (!inputRank || *inputRank != 0)
|
||||
return nullptr;
|
||||
|
||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||
std::optional<unsigned> tensorRank =
|
||||
getTensorRank(valueTensorLiteralOp.getResult());
|
||||
if (valueTensorLiteralOp && tensorRank && *tensorRank == 0) {
|
||||
auto tensorType =
|
||||
valueTensorLiteralOp.getValue().getType().cast<RankedTensorType>();
|
||||
if (tensorType.getElementType().isa<mlir::IntegerType>()) {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseElementsAttr>()
|
||||
.getSplatValue<int64_t>();
|
||||
scalar = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
}
|
||||
}
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseElementsAttr>()
|
||||
.getSplatValue<int64_t>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
} else if (auto primNumToTensorScalarOp =
|
||||
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
||||
scalar = primNumToTensorScalarOp.getA();
|
||||
return primNumToTensorScalarOp.getA();
|
||||
}
|
||||
return scalar;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -869,8 +873,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|||
if (op->getNumOperands() < 2) {
|
||||
return failure();
|
||||
}
|
||||
auto lhs = getScalarValue(op->getOperand(0), loc, rewriter);
|
||||
auto rhs = getScalarValue(op->getOperand(1), loc, rewriter);
|
||||
auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter);
|
||||
auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter);
|
||||
auto outType = op->getResult(0).getType();
|
||||
|
||||
if (!lhs || !rhs) {
|
||||
|
@ -879,7 +883,7 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|||
}
|
||||
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
|
||||
op)) {
|
||||
Value alpha = getScalarValue(op->getOperand(2), loc, rewriter);
|
||||
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
|
||||
if (!alpha) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only int scalar alpha is supported");
|
||||
|
|
|
@ -1919,6 +1919,52 @@ def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAddScalar_NumToTensorFloat_Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
x = torch.ops.prim.NumToTensor(5.0)
|
||||
return torch.add(x, 3)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseAddScalar_NumToTensorFloat_Module())
|
||||
def ElementwiseAddScalar_NumToTensorFloat_Module_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAddScalar_TensorLiteralInt32_Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.x = torch.tensor(2, dtype=torch.int32)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.add(self.x, 3)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseAddScalar_TensorLiteralInt32_Module())
|
||||
def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCloneModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue