* [tosa] Support for Maximum and Minimum

Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>
pull/487/head
Anup Gangwar 2021-12-15 11:19:25 -08:00 committed by Sean Silva
parent 707c113463
commit a6c3050dd0
2 changed files with 72 additions and 0 deletions

View File

@ -77,6 +77,39 @@ public:
}
};
// These binary op legalizations are identical for floating-point
// or quantized types
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenBinaryOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().cast<TensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
auto rhsElemTy = rhsTy.getElementType();
if (lhsElemTy != rhsElemTy)
return op.emitError("Add: input datatypes mismatched");
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rhs);
return success();
}
};
// These binary op legalizations are specific to add/sub which have an
// alpha multiplier.
template <typename AtenOpT, typename TosaOpT>
@ -538,6 +571,13 @@ public:
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
#undef INSERT_UNARY_PATTERN
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
#undef INSERT_BINARY_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);

View File

@ -298,3 +298,35 @@ func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.maximum$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.minimum$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = "tosa.minimum"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}