mirror of https://github.com/llvm/torch-mlir
update AtenClampOp in torch-to-tosa to handle fp inputs (#2516)
As titled. --------- Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>pull/2517/head
parent
14a4da923b
commit
4279b750da
|
@ -3984,19 +3984,37 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only tensor types input are currently supported");
|
op, "only tensor types input are currently supported");
|
||||||
|
|
||||||
int64_t int_min, int_max;
|
IntegerAttr min_int, max_int;
|
||||||
if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min)))
|
FloatAttr min_fp, max_fp;
|
||||||
return rewriter.notifyMatchFailure(
|
if (selfType.getElementType().isa<mlir::FloatType>()) {
|
||||||
op, "unimplemented: value `int_min` should be a torch constant int");
|
double fp_min, fp_max;
|
||||||
|
if (!matchPattern(op.getMin(), m_TorchConstantFloat(&fp_min)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: value `fp_min` should be a torch constant float");
|
||||||
|
|
||||||
if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max)))
|
if (!matchPattern(op.getMax(), m_TorchConstantFloat(&fp_max)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: value `int_max` should be a torch constant int");
|
op, "unimplemented: value `fp_max` should be a torch constant float");
|
||||||
|
|
||||||
IntegerAttr min_int = rewriter.getI64IntegerAttr(int_min);
|
min_int = rewriter.getI64IntegerAttr(static_cast<int64_t>(fp_min));
|
||||||
IntegerAttr max_int = rewriter.getI64IntegerAttr(int_max);
|
max_int = rewriter.getI64IntegerAttr(static_cast<int64_t>(fp_max));
|
||||||
FloatAttr min_fp = rewriter.getF32FloatAttr(float(int_min));
|
min_fp = rewriter.getF32FloatAttr(static_cast<float>(fp_min));
|
||||||
FloatAttr max_fp = rewriter.getF32FloatAttr(float(int_max));
|
max_fp = rewriter.getF32FloatAttr(static_cast<float>(fp_max));
|
||||||
|
} else {
|
||||||
|
int64_t int_min, int_max;
|
||||||
|
if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: value `int_min` should be a torch constant int");
|
||||||
|
|
||||||
|
if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: value `int_max` should be a torch constant int");
|
||||||
|
|
||||||
|
min_int = rewriter.getI64IntegerAttr(int_min);
|
||||||
|
max_int = rewriter.getI64IntegerAttr(int_max);
|
||||||
|
min_fp = rewriter.getF32FloatAttr(static_cast<float>(int_min));
|
||||||
|
max_fp = rewriter.getF32FloatAttr(static_cast<float>(int_max));
|
||||||
|
}
|
||||||
|
|
||||||
auto outType = getTypeConverter()->convertType(op.getType());
|
auto outType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, outType, adaptor.getSelf(),
|
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, outType, adaptor.getSelf(),
|
||||||
|
|
|
@ -1072,6 +1072,23 @@ func.func @torch.aten.clamp(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch
|
||||||
return %0 : !torch.vtensor<[1,1,128,128],si64>
|
return %0 : !torch.vtensor<[1,1,128,128],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.clamp.float(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],f32>) -> !torch.vtensor<[1,1,128,128],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],f32> -> tensor<1x1x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 6.432100e+00 : f32, max_int = 6 : i64, min_fp = 3.123400e+00 : f32, min_int = 3 : i64} : (tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xf32> -> !torch.vtensor<[1,1,128,128],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.clamp.float(%arg0: !torch.vtensor<[1,1,128,128],f32>) -> !torch.vtensor<[1,1,128,128],f32> {
|
||||||
|
%fp_min = torch.constant.float 3.123400e+00
|
||||||
|
%fp_max = torch.constant.float 6.432100e+00
|
||||||
|
%0 = torch.aten.clamp %arg0, %fp_min, %fp_max : !torch.vtensor<[1,1,128,128],f32>, !torch.float, !torch.float -> !torch.vtensor<[1,1,128,128],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,128,128],f32>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func.func @torch.aten.masked_fill.Scalar(
|
// CHECK-LABEL: func.func @torch.aten.masked_fill.Scalar(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
|
||||||
|
|
Loading…
Reference in New Issue