Fix handling of non-int tensors in `getScalarValue` (#1914)

The current implementation of `getScalarValue` does not check that the
input to a `ValueTensorLiteralOp` is an i64 before extracting the
value, and it does not check that the result type of the
`PrimNumToTensorScalarOp` is also an i64. This leads to crashes or
invalid IR generated when the `input` is something other than an i64
tensor or `!torch.int`.

This commit addresses those issues. In addition, the function
`getScalarValue` is renamed to `getScalarIntValue` to make it clear
that it *only* extracts scalar integers.
pull/1918/head
Ramiro Leal-Cavazos 2023-03-06 10:12:58 -08:00 committed by GitHub
parent 62250dabbb
commit 671be048fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 76 additions and 21 deletions

View File

@ -89,6 +89,8 @@ TORCHDYNAMO_XFAIL_SET = {
"ReduceMaxAlongDimUnsignedInt_basic",
#ERROR: value (-56) is not equal to golden value (200)
"AtenIntTensorByteDtypeModule_basic",
# ERROR: assert isinstance(e, FakeTensor)
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
}
STABLEHLO_PASS_SET = {
@ -155,6 +157,8 @@ STABLEHLO_PASS_SET = {
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseDivScalarModule_basic",
"ElementwiseEqDiffWidthScalarModule_basic",
"ElementwiseEqFloatScalarModule_basic",
@ -537,6 +541,7 @@ TOSA_PASS_SET = {
"ElementwiseDivScalarModule_basic",
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseMulScalarModule_float",
"ElementwiseCeilModule_basic",
"ElementwiseReciprocalModule_basic",

View File

@ -128,32 +128,36 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
return FloatAttr::get(Float64Type::get(context), value);
}
static Value getScalarValue(Value input, Location loc,
PatternRewriter &rewriter) {
static Value getScalarIntValue(Value input, Location loc,
PatternRewriter &rewriter) {
auto inputType = input.getType();
if (inputType.isa<Torch::IntType>()) {
return input;
}
Value scalar = nullptr;
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
if (!inputTensorType)
return nullptr;
Type inputDtype = inputTensorType.getOptionalDtype();
if (!inputDtype || !inputDtype.isInteger(64))
return nullptr;
std::optional<unsigned> inputRank = getTensorRank(input);
if (!inputRank || *inputRank != 0)
return nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
std::optional<unsigned> tensorRank =
getTensorRank(valueTensorLiteralOp.getResult());
if (valueTensorLiteralOp && tensorRank && *tensorRank == 0) {
auto tensorType =
valueTensorLiteralOp.getValue().getType().cast<RankedTensorType>();
if (tensorType.getElementType().isa<mlir::IntegerType>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>()
.getSplatValue<int64_t>();
scalar = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
}
}
auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>()
.getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
} else if (auto primNumToTensorScalarOp =
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
scalar = primNumToTensorScalarOp.getA();
return primNumToTensorScalarOp.getA();
}
return scalar;
return nullptr;
}
//===----------------------------------------------------------------------===//
@ -869,8 +873,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
if (op->getNumOperands() < 2) {
return failure();
}
auto lhs = getScalarValue(op->getOperand(0), loc, rewriter);
auto rhs = getScalarValue(op->getOperand(1), loc, rewriter);
auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter);
auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter);
auto outType = op->getResult(0).getType();
if (!lhs || !rhs) {
@ -879,7 +883,7 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
}
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
op)) {
Value alpha = getScalarValue(op->getOperand(2), loc, rewriter);
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
if (!alpha) {
return rewriter.notifyMatchFailure(op,
"only int scalar alpha is supported");

View File

@ -1919,6 +1919,52 @@ def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseAddScalar_NumToTensorFloat_Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
x = torch.ops.prim.NumToTensor(5.0)
return torch.add(x, 3)
@register_test_case(
module_factory=lambda: ElementwiseAddScalar_NumToTensorFloat_Module())
def ElementwiseAddScalar_NumToTensorFloat_Module_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ElementwiseAddScalar_TensorLiteralInt32_Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.x = torch.tensor(2, dtype=torch.int32)
@export
@annotate_args([
None,
])
def forward(self):
return torch.add(self.x, 3)
@register_test_case(
module_factory=lambda: ElementwiseAddScalar_TensorLiteralInt32_Module())
def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ElementwiseCloneModule(torch.nn.Module):
def __init__(self):