mirror of https://github.com/llvm/torch-mlir
[linalg] Added `aten.clamp` support with integers to `torch-to-linalg` (#2718)
The lowering for `aten.clamp` did not support integer types. Added support for integer types including a signed integer test.pull/2736/head snapshot-20240106.1075
parent
6096fcb347
commit
985e7796a4
|
@ -1007,13 +1007,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
|
||||
}
|
||||
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
|
||||
Type dtype = converter->convertType(clamp.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
clamp.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
AtenClampOp::Adaptor adaptor(operands);
|
||||
auto min = adaptor.getMin();
|
||||
auto max = adaptor.getMax();
|
||||
|
@ -1022,19 +1015,45 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
clamp.emitError("unimplemented: runtime optional type");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type dtype = converter->convertType(clamp.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType, mlir::IntegerType>()) {
|
||||
clamp.emitError("unimplement type for clamp");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype();
|
||||
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
|
||||
if (auto intTy = dstOriginalDtype.dyn_cast<IntegerType>()) {
|
||||
isUnsigned = intTy.isUnsigned();
|
||||
}
|
||||
auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
|
||||
clamp = convertScalarToDtype(b, loc, clamp, dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/dstOriginalDtype);
|
||||
|
||||
Value pred;
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
auto cmp =
|
||||
getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
|
||||
pred = b.create<arith::CmpFOp>(loc, cmp, input, clamp);
|
||||
} else if (dtype.isa<mlir::IntegerType>()) {
|
||||
auto cmp =
|
||||
isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
|
||||
if (getMax)
|
||||
cmp = arith::invertPredicate(cmp);
|
||||
pred = b.create<arith::CmpIOp>(loc, cmp, input, clamp);
|
||||
}
|
||||
return b.create<arith::SelectOp>(loc, pred, clamp, input);
|
||||
};
|
||||
|
||||
auto result = payloadArgs[0];
|
||||
if (!min.getType().isa<Torch::NoneType>()) {
|
||||
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
result, minPromoted);
|
||||
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
|
||||
}
|
||||
if (!max.getType().isa<Torch::NoneType>()) {
|
||||
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
result, maxPromoted);
|
||||
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);
|
||||
}
|
||||
if (!min.getType().isa<Torch::NoneType>())
|
||||
result = cmpSelect(result, min, /*getMax=*/false);
|
||||
if (!max.getType().isa<Torch::NoneType>())
|
||||
result = cmpSelect(result, max, /*getMax=*/true);
|
||||
return result;
|
||||
}
|
||||
if (auto clampTensor = dyn_cast<AtenClampTensorOp>(op)) {
|
||||
|
|
|
@ -988,6 +988,34 @@ def ElementwiseClampTensorIntModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseClampTensorInt8Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int8, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
min = -5
|
||||
max = 5
|
||||
min_clamp = torch.clamp(x, min)
|
||||
max_clamp = torch.clamp(x, max=max)
|
||||
both_clamp = torch.clamp(x, min=min, max=max)
|
||||
return min_clamp, max_clamp, both_clamp
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseClampTensorInt8Module())
|
||||
def ElementwiseClampTensorInt8Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5, low=-10, high=10, dtype=torch.int8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
||||
class ElementwiseClampMinTensorFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue