Fix empty tensor when select -1 (#1787)

pull/1804/head
Jiahao Li 2023-01-18 02:14:14 +08:00 committed by GitHub
parent 19bb8aebdf
commit e2698433db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 122 additions and 2 deletions

View File

@ -234,6 +234,7 @@ MHLO_PASS_SET = {
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"SelectIntModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SliceSingleIdxModule_basic",
"SqueezeDimModule_dynamic",
"SqueezeDimModule_negDim",
@ -454,6 +455,7 @@ TOSA_PASS_SET = {
"BoolTensorReturnMixedModule_basic",
"BoolTensorHandleSignless_basic",
"ElementwiseRsqrtModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SqueezeModule_static",
"SqueezeModule_noUnitDim",
"SqueezeModule_allUnitDim",
@ -662,6 +664,7 @@ LTC_XFAIL_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddIntModule_basic",
"AtenIntBoolOpModule_basic",
"BernoulliFloatModule_basic",
"BernoulliTensorModule_basic",
"BincountMinlengthModule_basic",

View File

@ -9111,6 +9111,30 @@ def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
let hasFolder = 1;
}
def Torch_AtenIntBoolOp : Torch_Op<"aten.Int.bool", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::Int.bool : (bool) -> (int)`";
let arguments = (ins
Torch_BoolType:$a
);
let results = (outs
Torch_IntType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIntBoolOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIntBoolOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -1347,6 +1347,18 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenIntBoolOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntBoolOp::fold(ArrayRef<Attribute> operands) {
bool b;
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
return getI64IntegerAttr(getContext(), static_cast<long>(b));
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenSortIntOp
//===----------------------------------------------------------------------===//

View File

@ -258,6 +258,15 @@ public:
Value dim = op.getDim();
Value self = op.getSelf();
// convert `start` to non-negative: start += int(start < 0) * dimSize
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value isNegative = rewriter.create<AtenLtIntOp>(loc, start, zero);
isNegative = rewriter.create<AtenIntBoolOp>(loc, isNegative);
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
Value indexOffset = rewriter.create<AtenMulIntOp>(loc, isNegative, dimSize);
start = rewriter.create<AtenAddIntOp>(loc, start, indexOffset);
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value startPlusOne =

View File

@ -581,6 +581,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::Float.str : (str) -> (float)")
emit("aten::Int.float : (float) -> (int)")
emit("aten::Int.Scalar : (Scalar) -> (int)", has_folder=True)
emit("aten::Int.bool : (bool) -> (int)", has_folder=True)
# Primitive ops
emit("aten::__range_length : (int, int, int) -> (int)", has_folder=True)

View File

@ -339,6 +339,59 @@ def BoolIntConstantModule_basic(module, tu: TestUtils):
# ==============================================================================
class AtenIntBoolOpModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.bool, True),
])
def forward(self, x):
return int(torch.ops.aten.Int(x))
@register_test_case(module_factory=lambda: AtenIntBoolOpModule())
def AtenIntBoolOpModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=0, high=2).bool())
class AtenIntBoolOpConstTrueModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return int(torch.ops.aten.Int(True))
@register_test_case(module_factory=lambda: AtenIntBoolOpConstTrueModule())
def AtenIntBoolOpConstTrueModule_basic(module, tu: TestUtils):
module.forward()
class AtenIntBoolOpConstFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return int(torch.ops.aten.Int(False))
@register_test_case(module_factory=lambda: AtenIntBoolOpConstFalseModule())
def AtenIntBoolOpConstFalseModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class AtenIntTensorByteDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -243,12 +243,30 @@ class SelectIntModule(torch.nn.Module):
([-1, -1], torch.int64, True),
])
def forward(self, x):
return x.select(0,0)
return torch.select(x, dim=0, index=0)
@register_test_case(module_factory=lambda: SelectIntModule())
def SelectIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(5,5, high=10))
module.forward(tu.randint(5, 5, high=10))
class SelectIntNegativeDimAndIndexStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([5, 5], torch.int64, True),
])
def forward(self, x):
return torch.select(x, dim=-1, index=-1)
@register_test_case(module_factory=lambda: SelectIntNegativeDimAndIndexStaticModule())
def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(5, 5, high=10))
# ==============================================================================