From 4279b750da10b4ded10ca6ccb1c120d7a4187a51 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Tue, 17 Oct 2023 14:49:47 -0700 Subject: [PATCH] update AtenClampOp in torch-to-tosa to handle fp inputs (#2516) As titled. --------- Co-authored-by: Ze Zhang --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 40 ++++++++++++++++------ test/Conversion/TorchToTosa/basic.mlir | 17 +++++++++ 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 970ef15d8..c2c73708d 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3984,19 +3984,37 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only tensor types input are currently supported"); - int64_t int_min, int_max; - if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_min` should be a torch constant int"); + IntegerAttr min_int, max_int; + FloatAttr min_fp, max_fp; + if (selfType.getElementType().isa()) { + double fp_min, fp_max; + if (!matchPattern(op.getMin(), m_TorchConstantFloat(&fp_min))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `fp_min` should be a torch constant float"); - if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_max` should be a torch constant int"); + if (!matchPattern(op.getMax(), m_TorchConstantFloat(&fp_max))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `fp_max` should be a torch constant float"); - IntegerAttr min_int = rewriter.getI64IntegerAttr(int_min); - IntegerAttr max_int = rewriter.getI64IntegerAttr(int_max); - FloatAttr min_fp = rewriter.getF32FloatAttr(float(int_min)); - FloatAttr max_fp = rewriter.getF32FloatAttr(float(int_max)); + min_int = rewriter.getI64IntegerAttr(static_cast(fp_min)); + max_int = rewriter.getI64IntegerAttr(static_cast(fp_max)); + min_fp = rewriter.getF32FloatAttr(static_cast(fp_min)); + max_fp = rewriter.getF32FloatAttr(static_cast(fp_max)); + } else { + int64_t int_min, int_max; + if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `int_min` should be a torch constant int"); + + if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `int_max` should be a torch constant int"); + + min_int = rewriter.getI64IntegerAttr(int_min); + max_int = rewriter.getI64IntegerAttr(int_max); + min_fp = rewriter.getF32FloatAttr(static_cast(int_min)); + max_fp = rewriter.getF32FloatAttr(static_cast(int_max)); + } auto outType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf(), diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index f04109873..180f48bce 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1072,6 +1072,23 @@ func.func @torch.aten.clamp(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch return %0 : !torch.vtensor<[1,1,128,128],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.clamp.float( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],f32>) -> !torch.vtensor<[1,1,128,128],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],f32> -> tensor<1x1x128x128xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 6.432100e+00 : f32, max_int = 6 : i64, min_fp = 3.123400e+00 : f32, min_int = 3 : i64} : (tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xf32> -> !torch.vtensor<[1,1,128,128],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],f32> +// CHECK: } +func.func @torch.aten.clamp.float(%arg0: !torch.vtensor<[1,1,128,128],f32>) -> !torch.vtensor<[1,1,128,128],f32> { + %fp_min = torch.constant.float 3.123400e+00 + %fp_max = torch.constant.float 6.432100e+00 + %0 = torch.aten.clamp %arg0, %fp_min, %fp_max : !torch.vtensor<[1,1,128,128],f32>, !torch.float, !torch.float -> !torch.vtensor<[1,1,128,128],f32> + return %0 : !torch.vtensor<[1,1,128,128],f32> +} + // ----- // CHECK-LABEL: func.func @torch.aten.masked_fill.Scalar( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,