mirror of https://github.com/llvm/torch-mlir
* [tosa] Support for Maximum and Minimum
Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>pull/487/head
parent
707c113463
commit
a6c3050dd0
|
@ -77,6 +77,39 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// These binary op legalizations are identical for floating-point
|
||||||
|
// or quantized types
|
||||||
|
template <typename AtenOpT, typename TosaOpT>
|
||||||
|
class ConvertAtenBinaryOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Value lhs = adaptor.self();
|
||||||
|
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||||
|
Value rhs = adaptor.other();
|
||||||
|
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||||
|
|
||||||
|
if (!lhsTy || !rhsTy)
|
||||||
|
return op.emitError("Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
|
auto lhsElemTy = lhsTy.getElementType();
|
||||||
|
auto rhsElemTy = rhsTy.getElementType();
|
||||||
|
|
||||||
|
if (lhsElemTy != rhsElemTy)
|
||||||
|
return op.emitError("Add: input datatypes mismatched");
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||||
|
op,
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
op.getType()),
|
||||||
|
lhs, rhs);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// These binary op legalizations are specific to add/sub which have an
|
// These binary op legalizations are specific to add/sub which have an
|
||||||
// alpha multiplier.
|
// alpha multiplier.
|
||||||
template <typename AtenOpT, typename TosaOpT>
|
template <typename AtenOpT, typename TosaOpT>
|
||||||
|
@ -538,6 +571,13 @@ public:
|
||||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
||||||
#undef INSERT_UNARY_PATTERN
|
#undef INSERT_UNARY_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||||
|
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
|
||||||
|
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
|
||||||
|
#undef INSERT_BINARY_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
|
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
|
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||||
|
|
|
@ -298,3 +298,35 @@ func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
|
||||||
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
return %0 : !torch.vtensor<[?,?],f32>
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.maximum$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.minimum$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tosa.minimum"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue