mirror of https://github.com/llvm/torch-mlir
Fix handling of non-int tensors in `getScalarValue` (#1914)
The current implementation of `getScalarValue` does not check that the input to a `ValueTensorLiteralOp` is an i64 before extracting the value, and it does not check that the result type of the `PrimNumToTensorScalarOp` is also an i64. This leads to crashes or invalid IR generated when the `input` is something other than an i64 tensor or `!torch.int`. This commit addresses those issues. In addition, the function `getScalarValue` is renamed to `getScalarIntValue` to make it clear that it *only* extracts scalar integers.pull/1918/head
parent
62250dabbb
commit
671be048fe
|
@ -89,6 +89,8 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
#ERROR: value (-56) is not equal to golden value (200)
|
#ERROR: value (-56) is not equal to golden value (200)
|
||||||
"AtenIntTensorByteDtypeModule_basic",
|
"AtenIntTensorByteDtypeModule_basic",
|
||||||
|
# ERROR: assert isinstance(e, FakeTensor)
|
||||||
|
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLEHLO_PASS_SET = {
|
STABLEHLO_PASS_SET = {
|
||||||
|
@ -155,6 +157,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseAddScalarFloatModule_basic",
|
"ElementwiseAddScalarFloatModule_basic",
|
||||||
"ElementwiseAddScalarInt64Module_basic",
|
"ElementwiseAddScalarInt64Module_basic",
|
||||||
"ElementwiseAddScalarIntModule_basic",
|
"ElementwiseAddScalarIntModule_basic",
|
||||||
|
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||||
|
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||||
"ElementwiseDivScalarModule_basic",
|
"ElementwiseDivScalarModule_basic",
|
||||||
"ElementwiseEqDiffWidthScalarModule_basic",
|
"ElementwiseEqDiffWidthScalarModule_basic",
|
||||||
"ElementwiseEqFloatScalarModule_basic",
|
"ElementwiseEqFloatScalarModule_basic",
|
||||||
|
@ -537,6 +541,7 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseDivScalarModule_basic",
|
"ElementwiseDivScalarModule_basic",
|
||||||
"ElementwiseSubScalarFloatModule_basic",
|
"ElementwiseSubScalarFloatModule_basic",
|
||||||
"ElementwiseAddScalarFloatModule_basic",
|
"ElementwiseAddScalarFloatModule_basic",
|
||||||
|
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||||
"ElementwiseMulScalarModule_float",
|
"ElementwiseMulScalarModule_float",
|
||||||
"ElementwiseCeilModule_basic",
|
"ElementwiseCeilModule_basic",
|
||||||
"ElementwiseReciprocalModule_basic",
|
"ElementwiseReciprocalModule_basic",
|
||||||
|
|
|
@ -128,32 +128,36 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
||||||
return FloatAttr::get(Float64Type::get(context), value);
|
return FloatAttr::get(Float64Type::get(context), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value getScalarValue(Value input, Location loc,
|
static Value getScalarIntValue(Value input, Location loc,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
auto inputType = input.getType();
|
auto inputType = input.getType();
|
||||||
if (inputType.isa<Torch::IntType>()) {
|
if (inputType.isa<Torch::IntType>()) {
|
||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
Value scalar = nullptr;
|
|
||||||
|
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
|
||||||
|
if (!inputTensorType)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
Type inputDtype = inputTensorType.getOptionalDtype();
|
||||||
|
if (!inputDtype || !inputDtype.isInteger(64))
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
std::optional<unsigned> inputRank = getTensorRank(input);
|
||||||
|
if (!inputRank || *inputRank != 0)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||||
std::optional<unsigned> tensorRank =
|
auto val = valueTensorLiteralOp.getValue()
|
||||||
getTensorRank(valueTensorLiteralOp.getResult());
|
.cast<DenseElementsAttr>()
|
||||||
if (valueTensorLiteralOp && tensorRank && *tensorRank == 0) {
|
.getSplatValue<int64_t>();
|
||||||
auto tensorType =
|
return rewriter.create<Torch::ConstantIntOp>(
|
||||||
valueTensorLiteralOp.getValue().getType().cast<RankedTensorType>();
|
loc, rewriter.getI64IntegerAttr(val));
|
||||||
if (tensorType.getElementType().isa<mlir::IntegerType>()) {
|
|
||||||
auto val = valueTensorLiteralOp.getValue()
|
|
||||||
.cast<DenseElementsAttr>()
|
|
||||||
.getSplatValue<int64_t>();
|
|
||||||
scalar = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(val));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (auto primNumToTensorScalarOp =
|
} else if (auto primNumToTensorScalarOp =
|
||||||
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
||||||
scalar = primNumToTensorScalarOp.getA();
|
return primNumToTensorScalarOp.getA();
|
||||||
}
|
}
|
||||||
return scalar;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -869,8 +873,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
||||||
if (op->getNumOperands() < 2) {
|
if (op->getNumOperands() < 2) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
auto lhs = getScalarValue(op->getOperand(0), loc, rewriter);
|
auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter);
|
||||||
auto rhs = getScalarValue(op->getOperand(1), loc, rewriter);
|
auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter);
|
||||||
auto outType = op->getResult(0).getType();
|
auto outType = op->getResult(0).getType();
|
||||||
|
|
||||||
if (!lhs || !rhs) {
|
if (!lhs || !rhs) {
|
||||||
|
@ -879,7 +883,7 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
||||||
}
|
}
|
||||||
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
|
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
|
||||||
op)) {
|
op)) {
|
||||||
Value alpha = getScalarValue(op->getOperand(2), loc, rewriter);
|
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
|
||||||
if (!alpha) {
|
if (!alpha) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only int scalar alpha is supported");
|
"only int scalar alpha is supported");
|
||||||
|
|
|
@ -1919,6 +1919,52 @@ def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAddScalar_NumToTensorFloat_Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
x = torch.ops.prim.NumToTensor(5.0)
|
||||||
|
return torch.add(x, 3)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseAddScalar_NumToTensorFloat_Module())
|
||||||
|
def ElementwiseAddScalar_NumToTensorFloat_Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAddScalar_TensorLiteralInt32_Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.x = torch.tensor(2, dtype=torch.int32)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.add(self.x, 3)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseAddScalar_TensorLiteralInt32_Module())
|
||||||
|
def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseCloneModule(torch.nn.Module):
|
class ElementwiseCloneModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue