mirror of https://github.com/llvm/torch-mlir
Fix empty tensor when select -1 (#1787)
parent
19bb8aebdf
commit
e2698433db
|
@ -234,6 +234,7 @@ MHLO_PASS_SET = {
|
||||||
"ReduceSumDtypeFloatModule_basic",
|
"ReduceSumDtypeFloatModule_basic",
|
||||||
"ReduceSumDtypeIntModule_basic",
|
"ReduceSumDtypeIntModule_basic",
|
||||||
"SelectIntModule_basic",
|
"SelectIntModule_basic",
|
||||||
|
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||||
"SliceSingleIdxModule_basic",
|
"SliceSingleIdxModule_basic",
|
||||||
"SqueezeDimModule_dynamic",
|
"SqueezeDimModule_dynamic",
|
||||||
"SqueezeDimModule_negDim",
|
"SqueezeDimModule_negDim",
|
||||||
|
@ -454,6 +455,7 @@ TOSA_PASS_SET = {
|
||||||
"BoolTensorReturnMixedModule_basic",
|
"BoolTensorReturnMixedModule_basic",
|
||||||
"BoolTensorHandleSignless_basic",
|
"BoolTensorHandleSignless_basic",
|
||||||
"ElementwiseRsqrtModule_basic",
|
"ElementwiseRsqrtModule_basic",
|
||||||
|
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||||
"SqueezeModule_static",
|
"SqueezeModule_static",
|
||||||
"SqueezeModule_noUnitDim",
|
"SqueezeModule_noUnitDim",
|
||||||
"SqueezeModule_allUnitDim",
|
"SqueezeModule_allUnitDim",
|
||||||
|
@ -662,6 +664,7 @@ LTC_XFAIL_SET = {
|
||||||
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||||
"AddIntModule_basic",
|
"AddIntModule_basic",
|
||||||
|
"AtenIntBoolOpModule_basic",
|
||||||
"BernoulliFloatModule_basic",
|
"BernoulliFloatModule_basic",
|
||||||
"BernoulliTensorModule_basic",
|
"BernoulliTensorModule_basic",
|
||||||
"BincountMinlengthModule_basic",
|
"BincountMinlengthModule_basic",
|
||||||
|
|
|
@ -9111,6 +9111,30 @@ def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenIntBoolOp : Torch_Op<"aten.Int.bool", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::Int.bool : (bool) -> (int)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_BoolType:$a
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_IntType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenIntBoolOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenIntBoolOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [
|
def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -1347,6 +1347,18 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenIntBoolOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenIntBoolOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
bool b;
|
||||||
|
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
|
||||||
|
return getI64IntegerAttr(getContext(), static_cast<long>(b));
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenSortIntOp
|
// AtenSortIntOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -258,6 +258,15 @@ public:
|
||||||
Value dim = op.getDim();
|
Value dim = op.getDim();
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
|
|
||||||
|
// convert `start` to non-negative: start += int(start < 0) * dimSize
|
||||||
|
Value zero =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value isNegative = rewriter.create<AtenLtIntOp>(loc, start, zero);
|
||||||
|
isNegative = rewriter.create<AtenIntBoolOp>(loc, isNegative);
|
||||||
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
||||||
|
Value indexOffset = rewriter.create<AtenMulIntOp>(loc, isNegative, dimSize);
|
||||||
|
start = rewriter.create<AtenAddIntOp>(loc, start, indexOffset);
|
||||||
|
|
||||||
Value one =
|
Value one =
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
Value startPlusOne =
|
Value startPlusOne =
|
||||||
|
|
|
@ -581,6 +581,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::Float.str : (str) -> (float)")
|
emit("aten::Float.str : (str) -> (float)")
|
||||||
emit("aten::Int.float : (float) -> (int)")
|
emit("aten::Int.float : (float) -> (int)")
|
||||||
emit("aten::Int.Scalar : (Scalar) -> (int)", has_folder=True)
|
emit("aten::Int.Scalar : (Scalar) -> (int)", has_folder=True)
|
||||||
|
emit("aten::Int.bool : (bool) -> (int)", has_folder=True)
|
||||||
|
|
||||||
# Primitive ops
|
# Primitive ops
|
||||||
emit("aten::__range_length : (int, int, int) -> (int)", has_folder=True)
|
emit("aten::__range_length : (int, int, int) -> (int)", has_folder=True)
|
||||||
|
|
|
@ -339,6 +339,59 @@ def BoolIntConstantModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenIntBoolOpModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.bool, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return int(torch.ops.aten.Int(x))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenIntBoolOpModule())
|
||||||
|
def AtenIntBoolOpModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(low=0, high=2).bool())
|
||||||
|
|
||||||
|
|
||||||
|
class AtenIntBoolOpConstTrueModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return int(torch.ops.aten.Int(True))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenIntBoolOpConstTrueModule())
|
||||||
|
def AtenIntBoolOpConstTrueModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class AtenIntBoolOpConstFalseModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return int(torch.ops.aten.Int(False))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenIntBoolOpConstFalseModule())
|
||||||
|
def AtenIntBoolOpConstFalseModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenIntTensorByteDtypeModule(torch.nn.Module):
|
class AtenIntTensorByteDtypeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -243,12 +243,30 @@ class SelectIntModule(torch.nn.Module):
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x.select(0,0)
|
return torch.select(x, dim=0, index=0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: SelectIntModule())
|
@register_test_case(module_factory=lambda: SelectIntModule())
|
||||||
def SelectIntModule_basic(module, tu: TestUtils):
|
def SelectIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(5,5, high=10))
|
module.forward(tu.randint(5, 5, high=10))
|
||||||
|
|
||||||
|
|
||||||
|
class SelectIntNegativeDimAndIndexStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([5, 5], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.select(x, dim=-1, index=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SelectIntNegativeDimAndIndexStaticModule())
|
||||||
|
def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(5, 5, high=10))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue