[linalg] Added `aten.clamp` support with integers to `torch-to-linalg` (#2718)

The lowering for `aten.clamp` did not support integer types. Added
support for integer types including a signed integer test.
pull/2736/head snapshot-20240106.1075
Rob Suderman 2024-01-05 15:16:49 -08:00 committed by GitHub
parent 6096fcb347
commit 985e7796a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 19 deletions

View File

@ -1007,13 +1007,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
}
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
Type dtype = converter->convertType(clamp.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
clamp.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
AtenClampOp::Adaptor adaptor(operands);
auto min = adaptor.getMin();
auto max = adaptor.getMax();
@ -1022,19 +1015,45 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
clamp.emitError("unimplemented: runtime optional type");
return nullptr;
}
Type dtype = converter->convertType(clamp.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType, mlir::IntegerType>()) {
clamp.emitError("unimplement type for clamp");
return nullptr;
}
Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype();
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
if (auto intTy = dstOriginalDtype.dyn_cast<IntegerType>()) {
isUnsigned = intTy.isUnsigned();
}
auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
clamp = convertScalarToDtype(b, loc, clamp, dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/dstOriginalDtype);
Value pred;
if (dtype.isa<mlir::FloatType>()) {
auto cmp =
getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
pred = b.create<arith::CmpFOp>(loc, cmp, input, clamp);
} else if (dtype.isa<mlir::IntegerType>()) {
auto cmp =
isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
if (getMax)
cmp = arith::invertPredicate(cmp);
pred = b.create<arith::CmpIOp>(loc, cmp, input, clamp);
}
return b.create<arith::SelectOp>(loc, pred, clamp, input);
};
auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>()) {
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
result, minPromoted);
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
}
if (!max.getType().isa<Torch::NoneType>()) {
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
result, maxPromoted);
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);
}
if (!min.getType().isa<Torch::NoneType>())
result = cmpSelect(result, min, /*getMax=*/false);
if (!max.getType().isa<Torch::NoneType>())
result = cmpSelect(result, max, /*getMax=*/true);
return result;
}
if (auto clampTensor = dyn_cast<AtenClampTensorOp>(op)) {

View File

@ -988,6 +988,34 @@ def ElementwiseClampTensorIntModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseClampTensorInt8Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int8, True)
])
def forward(self, x):
min = -5
max = 5
min_clamp = torch.clamp(x, min)
max_clamp = torch.clamp(x, max=max)
both_clamp = torch.clamp(x, min=min, max=max)
return min_clamp, max_clamp, both_clamp
@register_test_case(module_factory=lambda: ElementwiseClampTensorInt8Module())
def ElementwiseClampTensorInt8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, low=-10, high=10, dtype=torch.int8))
# ==============================================================================
class ElementwiseClampMinTensorFloatModule(torch.nn.Module):
def __init__(self):