[TOSA] Add aten.masked_fill.Tensor/Scalar support (#1735)

pull/1743/head
Chi_Liu 2022-12-21 08:56:07 -08:00 committed by GitHub
parent 810473cc03
commit b2cefc0b64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 150 additions and 0 deletions

View File

@ -564,6 +564,8 @@ TOSA_PASS_SET = {
"_LogSoftmaxModuleStable_basic",
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillTensorIntValueStaticModule_basic",
"ElementwiseAddScalarInt64Module_basic",
"TensorLiteralModule_basic",
"TensorOpaqueLiteralModule_basic",
@ -727,6 +729,7 @@ LTC_XFAIL_SET = {
"ElementwiseRemainderScalarModule_Bool_basic",
"AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic",
"MaskedFillTensorIntValueStaticModule_basic",
"Fill_TensorFloat32WithFloat32_basic",
"Fill_TensorFloat32WithFloat64_basic",
"Fill_TensorFloat32WithInt64_basic",

View File

@ -3747,6 +3747,67 @@ public:
}
};
template <typename AtenOpT>
class ConvertAtenMaskedFillOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
if (!outType || !outType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only Tensor types with static shapes are currently supported");
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");
}
// Not a tensor type.
auto selfType = adaptor.getSelf().getType().template dyn_cast<TensorType>();
if (!selfType || !outType.hasStaticShape())
return rewriter.notifyMatchFailure(
op,
"Only tensor types with static shapes input are currently supported");
auto maskType = adaptor.getMask().getType().template dyn_cast<TensorType>();
if (!maskType)
return rewriter.notifyMatchFailure(
op, "Only tensor types mask are currently supported");
Value rhs = adaptor.getValue();
auto rhsType = rhs.getType().template dyn_cast<TensorType>();
Value rhsAsTensor;
if (!rhsType) { // scalar
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(),
rhsAsTensor, rhs.getType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA operation");
} else { // tensor
rhsType = rhs.getType().dyn_cast<TensorType>();
}
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
auto rhsTensorType = rhsTensor.getType().template dyn_cast<TensorType>();
if (rhsTensorType.getElementType() != outElemTy)
rhsTensor = rewriter.create<tosa::CastOp>(
op.getLoc(),
RankedTensorType::get(rhsTensorType.getShape(), outElemTy),
rhsTensor);
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, adaptor.getMask(),
rhsTensor, adaptor.getSelf());
return success();
}
};
// Legalizes the torch.clone op.
template <typename AtenOpT>
class ConvertAtenCloneOp : public OpConversionPattern<AtenOpT> {
@ -3947,6 +4008,13 @@ public:
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
#undef INSERT_FILL_SCALAR_PATTERN
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMaskedFillOp<AtenOp>>(typeConverter, context);
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp);
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
#undef INSERT_MASKED_FILL_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);

View File

@ -1402,3 +1402,46 @@ class MaskedFillTensorFloatValueModule(torch.nn.Module):
def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 3, low=-10, high=10),
tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.rand())
class MaskedFillScalarIntValueStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3], torch.int64, True),
([2, 3], torch.bool, True),
])
def forward(self, x, mask):
return torch.ops.aten.masked_fill(x, mask, value=5)
@register_test_case(module_factory=lambda: MaskedFillScalarIntValueStaticModule())
def MaskedFillScalarIntValueStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 3),
tu.randint(2, 3, high=2).to(dtype=torch.bool))
class MaskedFillTensorIntValueStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3], torch.int64, True),
([2, 3], torch.bool, True),
([], torch.int64, True),
])
def forward(self, x, mask, value):
return torch.ops.aten.masked_fill(x, mask, value=value)
@register_test_case(module_factory=lambda: MaskedFillTensorIntValueStaticModule())
def MaskedFillTensorIntValueStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 3),
tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.randint())

View File

@ -961,6 +961,42 @@ func.func @torch.aten.clamp(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch
return %0 : !torch.vtensor<[1,1,128,128],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.masked_fill.Scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_5]]) : (tensor<i64>) -> tensor<f32>
// CHECK: %[[VAL_7:.*]] = "tosa.select"(%[[VAL_3]], %[[VAL_6]], %[[VAL_2]]) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32>
// CHECK: }
func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.masked_fill.Scalar %arg0, %arg1, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
return %0 : !torch.vtensor<[1,12,128,128],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.masked_fill.Tensor(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,128,128],i1>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_5]], %[[VAL_3]]) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32>
// CHECK: }
func.func @torch.aten.masked_fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
%0 = torch.aten.masked_fill.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
return %0 : !torch.vtensor<[1,12,128,128],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.where.self(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],i1>,