Handle torch.none type in tosa.clamp op (#2739)

This PR updates the torch-to-tosa conversion with following changes:

- Support torch.none as min/max input argument for tosa.clamp op
- Support negative value as start index for tosa.slice op
- Add tosa.logical_or lowering support

e2e test:
python -m e2e_testing.main --config=tosa

LIT tests:
cmake --build build --target tools/torch-mlir/all

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
pull/2748/head snapshot-20240112.1081
Ze Zhang 2024-01-11 10:36:48 -08:00 committed by GitHub
parent 47ffc90db4
commit 670a99ae19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 128 additions and 31 deletions

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -3336,9 +3337,11 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); return rewriter.notifyMatchFailure(op, "start must be a Scalar constant");
if (start < 0) if (start < 0) {
return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0"); start = toPositiveDim(start, selfType.getShape()[dim]);
if (!isValidDim(start, selfType.getShape()[dim]))
return rewriter.notifyMatchFailure(op, "start is not a valid index");
}
start = std::min(selfType.getShape()[dim], start); start = std::min(selfType.getShape()[dim], start);
int64_t end; int64_t end;
@ -3984,36 +3987,46 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only tensor types input are currently supported"); op, "only tensor types input are currently supported");
IntegerAttr min_int, max_int; IntegerAttr min_int =
FloatAttr min_fp, max_fp; rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::min());
if (op.getMin().getType().isa<Torch::FloatType>()) { IntegerAttr max_int =
double fp_min, fp_max; rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::max());
if (!matchPattern(op.getMin(), m_TorchConstantFloat(&fp_min))) FloatAttr min_fp =
return rewriter.notifyMatchFailure( rewriter.getF32FloatAttr(std::numeric_limits<float>::lowest());
op, "unimplemented: value `fp_min` should be a torch constant float"); FloatAttr max_fp =
rewriter.getF32FloatAttr(std::numeric_limits<float>::max());
if (!matchPattern(op.getMax(), m_TorchConstantFloat(&fp_max))) auto getValAttr = [&](Value operand, IntegerAttr &intAttr,
return rewriter.notifyMatchFailure( FloatAttr &fpAttr) -> LogicalResult {
op, "unimplemented: value `fp_max` should be a torch constant float"); double valFloat;
int64_t valInt;
min_int = rewriter.getI64IntegerAttr(static_cast<int64_t>(fp_min)); if (matchPattern(operand, m_TorchConstantFloat(&valFloat))) {
max_int = rewriter.getI64IntegerAttr(static_cast<int64_t>(fp_max)); intAttr = rewriter.getI64IntegerAttr(static_cast<int64_t>(valFloat));
min_fp = rewriter.getF32FloatAttr(static_cast<float>(fp_min)); fpAttr = rewriter.getF32FloatAttr(static_cast<float>(valFloat));
max_fp = rewriter.getF32FloatAttr(static_cast<float>(fp_max)); } else if (matchPattern(operand, m_TorchConstantInt(&valInt))) {
intAttr = rewriter.getI64IntegerAttr(valInt);
fpAttr = rewriter.getF32FloatAttr(static_cast<float>(valInt));
} else { } else {
int64_t int_min, int_max; return failure();
if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min))) }
return rewriter.notifyMatchFailure( return success();
op, "unimplemented: value `int_min` should be a torch constant int"); };
if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max))) LogicalResult minAttrResult = getValAttr(op.getMin(), min_int, min_fp);
LogicalResult maxAttrResult = getValAttr(op.getMax(), max_int, max_fp);
if (failed(minAttrResult) && failed(maxAttrResult)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: value `int_max` should be a torch constant int"); op, "either `min` or `max` should be a torch constant");
}
min_int = rewriter.getI64IntegerAttr(int_min); if (failed(minAttrResult) &&
max_int = rewriter.getI64IntegerAttr(int_max); succeeded(checkNotNone(rewriter, op, op.getMin()))) {
min_fp = rewriter.getF32FloatAttr(static_cast<float>(int_min)); return rewriter.notifyMatchFailure(op,
max_fp = rewriter.getF32FloatAttr(static_cast<float>(int_max)); "min attr should be a torch constant");
}
if (failed(maxAttrResult) &&
succeeded(checkNotNone(rewriter, op, op.getMax()))) {
return rewriter.notifyMatchFailure(op,
"max attr should be a torch constant");
} }
auto outType = getTypeConverter()->convertType(op.getType()); auto outType = getTypeConverter()->convertType(op.getType());
@ -5025,6 +5038,7 @@ public:
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context); patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
#undef INSERT_BINARY_PATTERN #undef INSERT_BINARY_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \

View File

@ -1035,6 +1035,15 @@ TOSA_PASS_SET = {
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseAtenDivIntScalarModule_basic", "ElementwiseAtenDivIntScalarModule_basic",
"ElementwiseAtenIsinfOpModule_basic", "ElementwiseAtenIsinfOpModule_basic",
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
"ElementwiseAtenLogicalOrOpDiffArgs2Module_basic",
"ElementwiseAtenLogicalOrOpDiffArgs3Module_basic",
"ElementwiseAtenLogicalOrOpModule_basic",
"ElementwiseAtenLogicalOrOpNegativeModule_basic",
"ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalOrOpRandomFloatModule_basic",
"ElementwiseAtenLogicalOrOpRandomModule_basic",
"ElementwiseAtenWhereSelfModule_basic", "ElementwiseAtenWhereSelfModule_basic",
"ElementwiseBinaryModule_basic", "ElementwiseBinaryModule_basic",
"ElementwiseBinaryStaticShapeModule_basic", "ElementwiseBinaryStaticShapeModule_basic",
@ -1047,6 +1056,9 @@ TOSA_PASS_SET = {
"ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseCeilModule_basic", "ElementwiseCeilModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneContiguousModule_basic", "ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic", "ElementwiseCloneModule_basic",

View File

@ -645,6 +645,22 @@ func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// ----- // -----
// CHECK-LABEL: func.func @torch.aten.logical_or$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[VAL_4:.*]] = tosa.logical_or %[[VAL_2]], %[[VAL_3]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @forward( // CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32>
@ -1055,6 +1071,61 @@ func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) ->
return %0 : !torch.vtensor<[1,1,128,128],si64> return %0 : !torch.vtensor<[1,1,128,128],si64>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.slice.negative_start(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.constant.int 100
// CHECK: %[[VAL_5:.*]] = torch.constant.int -16
// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 4, 16, 256>, start = array<i64: 0, 49, 0>} : (tensor<4x65x256xf32>) -> tensor<4x16x256xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,16,256],f32>
// CHECK: }
func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int100 = torch.constant.int 100
%int-16 = torch.constant.int -16
%0 = torch.aten.slice.Tensor %arg0, %int1, %int-16, %int100, %int1 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,16,256],f32>
return %0 : !torch.vtensor<[4,16,256],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.clamp.min_none(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0.000000e+00 : f32, max_int = 0 : i64, min_fp = -3.40282347E+38 : f32, min_int = -9223372036854775808 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64>
// CHECK: }
func.func @torch.aten.clamp.min_none(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
%int0 = torch.constant.int 0
%none = torch.constant.none
%0 = torch.aten.clamp %arg0, %none, %int0 : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.int -> !torch.vtensor<[1,1,128,128],si64>
return %0 : !torch.vtensor<[1,1,128,128],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.clamp.max_none(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64>
// CHECK: }
func.func @torch.aten.clamp.max_none(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
%int0 = torch.constant.int 0
%none = torch.constant.none
%0 = torch.aten.clamp %arg0, %int0, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,1,128,128],si64>
return %0 : !torch.vtensor<[1,1,128,128],si64>
}
// ----- // -----
// CHECK-LABEL: func.func @torch.aten.clamp( // CHECK-LABEL: func.func @torch.aten.clamp(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {