mirror of https://github.com/llvm/torch-mlir
[TOSA] Add aten.masked_fill.Tensor/Scalar support (#1735)
parent
810473cc03
commit
b2cefc0b64
|
@ -564,6 +564,8 @@ TOSA_PASS_SET = {
|
||||||
"_LogSoftmaxModuleStable_basic",
|
"_LogSoftmaxModuleStable_basic",
|
||||||
"ElementwiseAtenWhereSelfModule_basic",
|
"ElementwiseAtenWhereSelfModule_basic",
|
||||||
"ElementwiseUnsqueezeBroadcastModule_basic",
|
"ElementwiseUnsqueezeBroadcastModule_basic",
|
||||||
|
"MaskedFillScalarIntValueStaticModule_basic",
|
||||||
|
"MaskedFillTensorIntValueStaticModule_basic",
|
||||||
"ElementwiseAddScalarInt64Module_basic",
|
"ElementwiseAddScalarInt64Module_basic",
|
||||||
"TensorLiteralModule_basic",
|
"TensorLiteralModule_basic",
|
||||||
"TensorOpaqueLiteralModule_basic",
|
"TensorOpaqueLiteralModule_basic",
|
||||||
|
@ -727,6 +729,7 @@ LTC_XFAIL_SET = {
|
||||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||||
"AtenIntTensorByteDtypeModule_basic",
|
"AtenIntTensorByteDtypeModule_basic",
|
||||||
"AtenIntTensorCharDtypeModule_basic",
|
"AtenIntTensorCharDtypeModule_basic",
|
||||||
|
"MaskedFillTensorIntValueStaticModule_basic",
|
||||||
"Fill_TensorFloat32WithFloat32_basic",
|
"Fill_TensorFloat32WithFloat32_basic",
|
||||||
"Fill_TensorFloat32WithFloat64_basic",
|
"Fill_TensorFloat32WithFloat64_basic",
|
||||||
"Fill_TensorFloat32WithInt64_basic",
|
"Fill_TensorFloat32WithInt64_basic",
|
||||||
|
|
|
@ -3747,6 +3747,67 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename AtenOpT>
|
||||||
|
class ConvertAtenMaskedFillOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template dyn_cast<TensorType>();
|
||||||
|
|
||||||
|
if (!outType || !outType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only Tensor types with static shapes are currently supported");
|
||||||
|
|
||||||
|
Type outElemTy = outType.getElementType();
|
||||||
|
if (!outElemTy.isIntOrFloat()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only floating-point or integer datatype legalization supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not a tensor type.
|
||||||
|
auto selfType = adaptor.getSelf().getType().template dyn_cast<TensorType>();
|
||||||
|
if (!selfType || !outType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"Only tensor types with static shapes input are currently supported");
|
||||||
|
|
||||||
|
auto maskType = adaptor.getMask().getType().template dyn_cast<TensorType>();
|
||||||
|
if (!maskType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only tensor types mask are currently supported");
|
||||||
|
|
||||||
|
Value rhs = adaptor.getValue();
|
||||||
|
auto rhsType = rhs.getType().template dyn_cast<TensorType>();
|
||||||
|
Value rhsAsTensor;
|
||||||
|
if (!rhsType) { // scalar
|
||||||
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(),
|
||||||
|
rhsAsTensor, rhs.getType(), {})))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Currently only scalar constants are supported for "
|
||||||
|
"conversion in TOSA operation");
|
||||||
|
} else { // tensor
|
||||||
|
rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||||
|
auto rhsTensorType = rhsTensor.getType().template dyn_cast<TensorType>();
|
||||||
|
if (rhsTensorType.getElementType() != outElemTy)
|
||||||
|
rhsTensor = rewriter.create<tosa::CastOp>(
|
||||||
|
op.getLoc(),
|
||||||
|
RankedTensorType::get(rhsTensorType.getShape(), outElemTy),
|
||||||
|
rhsTensor);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, adaptor.getMask(),
|
||||||
|
rhsTensor, adaptor.getSelf());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Legalizes the torch.clone op.
|
// Legalizes the torch.clone op.
|
||||||
template <typename AtenOpT>
|
template <typename AtenOpT>
|
||||||
class ConvertAtenCloneOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenCloneOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
@ -3947,6 +4008,13 @@ public:
|
||||||
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
|
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
|
||||||
#undef INSERT_FILL_SCALAR_PATTERN
|
#undef INSERT_FILL_SCALAR_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenMaskedFillOp<AtenOp>>(typeConverter, context);
|
||||||
|
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp);
|
||||||
|
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
|
||||||
|
#undef INSERT_MASKED_FILL_PATTERN
|
||||||
|
|
||||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||||
|
|
|
@ -1402,3 +1402,46 @@ class MaskedFillTensorFloatValueModule(torch.nn.Module):
|
||||||
def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils):
|
def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(2, 3, low=-10, high=10),
|
module.forward(tu.randint(2, 3, low=-10, high=10),
|
||||||
tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.rand())
|
tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.rand())
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedFillScalarIntValueStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 3], torch.int64, True),
|
||||||
|
([2, 3], torch.bool, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, mask):
|
||||||
|
return torch.ops.aten.masked_fill(x, mask, value=5)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaskedFillScalarIntValueStaticModule())
|
||||||
|
def MaskedFillScalarIntValueStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(2, 3),
|
||||||
|
tu.randint(2, 3, high=2).to(dtype=torch.bool))
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedFillTensorIntValueStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 3], torch.int64, True),
|
||||||
|
([2, 3], torch.bool, True),
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, mask, value):
|
||||||
|
return torch.ops.aten.masked_fill(x, mask, value=value)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaskedFillTensorIntValueStaticModule())
|
||||||
|
def MaskedFillTensorIntValueStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(2, 3),
|
||||||
|
tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.randint())
|
||||||
|
|
|
@ -961,6 +961,42 @@ func.func @torch.aten.clamp(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch
|
||||||
return %0 : !torch.vtensor<[1,1,128,128],si64>
|
return %0 : !torch.vtensor<[1,1,128,128],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.masked_fill.Scalar(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_5]]) : (tensor<i64>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = "tosa.select"(%[[VAL_3]], %[[VAL_6]], %[[VAL_2]]) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.masked_fill.Scalar %arg0, %arg1, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.masked_fill.Tensor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,128,128],i1>,
|
||||||
|
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_5]], %[[VAL_3]]) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.masked_fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
%0 = torch.aten.masked_fill.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func.func @torch.aten.where.self(
|
// CHECK-LABEL: func.func @torch.aten.where.self(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],i1>,
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],i1>,
|
||||||
|
|
Loading…
Reference in New Issue