mirror of https://github.com/llvm/torch-mlir
[TOSA] Add aten.add/sub.Scalar/Tensor si64 type support (#1604)
parent
73bd32d06c
commit
163d19cce6
|
@ -611,6 +611,12 @@ TOSA_PASS_SET = {
|
|||
"_LogSoftmaxModuleStable_basic",
|
||||
"ElementwiseAtenWhereSelfModule_basic",
|
||||
"ElementwiseUnsqueezeBroadcastModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
"TensorLiteralModule_basic",
|
||||
"TensorOpaqueLiteralModule_basic",
|
||||
"TypePromotionDifferentCategoryModule_basic",
|
||||
"TypePromotionSameCategoryDifferentWidthModule_basic",
|
||||
"TypePromotionZeroRankHigherCategoryModule_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
"ReduceSumDimIntListFloatModule_basic",
|
||||
|
|
|
@ -219,6 +219,11 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// left : tensor: tensor<i32/i64/f32>
|
||||
// right : scalar: i32/i64/f32
|
||||
// tensor: tensor<i32/i64/f32>
|
||||
// alpha : scalar: i32/i64/f32
|
||||
// output: tensor: tensor<i32/i64/f32>
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||
Value rhs = adaptor.getOther();
|
||||
|
@ -229,11 +234,12 @@ public:
|
|||
"Only Tensor types supported in TOSA");
|
||||
|
||||
if (auto lhsElemTy = lhsType.getElementType().dyn_cast<IntegerType>()) {
|
||||
if (lhsElemTy.getWidth() > 32)
|
||||
if (lhsElemTy.getWidth() > 64)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Integers with widths greater than 32 are not supported");
|
||||
op, "Integers with widths greater than 64 are not supported");
|
||||
}
|
||||
|
||||
// Get output type: tensor<i32/i64/f32>
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
@ -244,40 +250,91 @@ public:
|
|||
op, "Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
Type rhsAlphaMulElemType;
|
||||
if (outElemTy.isa<mlir::FloatType>()) {
|
||||
rhsAlphaMulElemType = outElemTy;
|
||||
} else {
|
||||
// if output type is 64, input type should also be 32
|
||||
rhsAlphaMulElemType = rewriter.getIntegerType(32);
|
||||
}
|
||||
|
||||
// if right is scalar, rhgType==None, which need to be manually cast to
|
||||
// TensorType else right is tensor, rhsType==tensor<i32/i64/f32>
|
||||
Value rhsAsTensor;
|
||||
if (!rhsType) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), rhsAsTensor,
|
||||
outElemTy, {})))
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
|
||||
rhsAsTensor, rhsAlphaMulElemType, {})))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
} else if (rhsType.getElementType() != rhsAlphaMulElemType) {
|
||||
// right is tensor, rhsType == tensor<i32/i64/f32>
|
||||
// right must be cast to same type as the alpha, so MulOp success
|
||||
rhs = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs);
|
||||
// reinitialize right value type to tensor<i32/f32>
|
||||
rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
}
|
||||
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||
|
||||
// Handle alpha.
|
||||
// Handle scalar value alpha.
|
||||
// It should be either f32/i32
|
||||
Value alphaTensor;
|
||||
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), op.getAlpha(),
|
||||
alphaTensor, outElemTy,
|
||||
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(),
|
||||
op.getAlpha(), alphaTensor,
|
||||
rhsAlphaMulElemType,
|
||||
/*checkForUnity=*/false))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only scalar constants are supported for "
|
||||
"alpha in conversion to TOSA operation");
|
||||
}
|
||||
|
||||
// make sure input of MulOp is same datetype, otherwise the lowering to
|
||||
// arith dialect will bug
|
||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
||||
op.getLoc(), rhsType ? rhsType : RankedTensorType::get({}, outElemTy),
|
||||
op.getLoc(),
|
||||
rhsType ? rhsType : RankedTensorType::get({}, rhsAlphaMulElemType),
|
||||
rhsTensor, alphaTensor, /*shift=*/0);
|
||||
|
||||
if (outElemTy.isa<mlir::FloatType>()) {
|
||||
if (lhsType.getElementType() != outElemTy)
|
||||
lhs = rewriter.create<tosa::CastOp>(op.getLoc(), outType, lhs);
|
||||
if (outElemTy.isa<mlir::FloatType>() || outElemTy.isInteger(32)) {
|
||||
// if outElemTy tensor<f32>, mulTensor must be tensor<f32>,
|
||||
// left value could be tensor<f32/i32/i64>, cast left value to
|
||||
// tensor<f32> type
|
||||
// if outElemTy tensor<i32>, mulTensor must be tensor<i32>,
|
||||
// left value could be tensor<f32/i32/i64>, cast left value to
|
||||
// tensor<i32> type
|
||||
if (lhsType.getElementType() != rhsAlphaMulElemType)
|
||||
lhs = rewriter.create<tosa::CastOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(lhsType.getShape(), rhsAlphaMulElemType),
|
||||
lhs);
|
||||
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, lhs, multTensor);
|
||||
|
||||
return success();
|
||||
} else if (outElemTy.isInteger(64)) {
|
||||
// if outElemTy tensor<i64>, mulTensor must be tensor<i32>,
|
||||
// left value could be tensor<f32/i32/i64> type, cast left value to
|
||||
// tensor<i32> type
|
||||
if (lhsType.getElementType() != rhsAlphaMulElemType)
|
||||
lhs = rewriter.create<tosa::CastOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(lhsType.getShape(), rhsAlphaMulElemType),
|
||||
lhs);
|
||||
|
||||
auto tosaOpTOutputTensor = rewriter.create<TosaOpT>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(outType.getShape(), rhsAlphaMulElemType), lhs,
|
||||
multTensor);
|
||||
// cast tensor<i32> back to tensor<i64>
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
|
||||
tosaOpTOutputTensor);
|
||||
|
||||
return success();
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization supported");
|
||||
op, "Only floating-point, i32, i64 datatype legalization supported");
|
||||
}
|
||||
}
|
||||
}; // namespace
|
||||
|
|
|
@ -902,6 +902,48 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten
|
|||
return %0 : !torch.vtensor<[3,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.add$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2],si32>) -> !torch.vtensor<[2,2],si64> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<2x2xi32>, tensor<i32>) -> tensor<2x2xi32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_6]]) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = "tosa.cast"(%[[VAL_7]]) : (tensor<2x2xi32>) -> tensor<2x2xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64>
|
||||
// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,2],si64>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torch.vtensor<[2, 2],si32>) -> !torch.vtensor<[2, 2],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[2, 2],si32>, !torch.vtensor<[2, 2],si32>, !torch.int -> !torch.vtensor<[2, 2],si64>
|
||||
return %0 : !torch.vtensor<[2, 2],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.Scalar$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 256
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<256> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_4]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = "tosa.add"(%[[VAL_7]], %[[VAL_6]]) : (tensor<1x1x128x128xi32>, tensor<i32>) -> tensor<1x1x128x128xi32>
|
||||
// CHECK: %[[VAL_9:.*]] = "tosa.cast"(%[[VAL_8]]) : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64>
|
||||
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64>
|
||||
// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,1,128,128],si64>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int256 = torch.constant.int 256
|
||||
%0 = torch.aten.add.Scalar %arg0, %int256, %int1 : !torch.vtensor<[1,1,128,128],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,1,128,128],si64>
|
||||
return %0 : !torch.vtensor<[1,1,128,128],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.clamp(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
|
||||
|
|
Loading…
Reference in New Issue