mirror of https://github.com/llvm/torch-mlir
[Torch] Decompose AtenMaskedScatterOp (#3353)
Co-authored-by: Yuanqiang Liu <liuyuanqiang.yqliu@bytedance.com>pull/3085/head
parent
a9edefb3cf
commit
7faba75696
|
@ -72,6 +72,9 @@ bool isBuiltInType(Type type);
|
||||||
// std::nullopt is returned if the tensorRank can't be determined.
|
// std::nullopt is returned if the tensorRank can't be determined.
|
||||||
std::optional<unsigned> getTensorRank(Value tensor);
|
std::optional<unsigned> getTensorRank(Value tensor);
|
||||||
|
|
||||||
|
// Helper function to get the number of elements in a tensor.
|
||||||
|
std::optional<int64_t> getTensorNumel(Value tensor);
|
||||||
|
|
||||||
bool isViewLikeOp(Operation *op);
|
bool isViewLikeOp(Operation *op);
|
||||||
|
|
||||||
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
|
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
|
||||||
|
|
|
@ -3371,6 +3371,104 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompose aten.masked_scatter:
|
||||||
|
// def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor:
|
||||||
|
// mask_int = mask + torch.zeros_like(self)
|
||||||
|
// prefix_sum = torch.cumsum(mask_int.flatten(), dim=0)
|
||||||
|
// mask_prefix = torch.clamp(prefix_sum - 1, min=0)
|
||||||
|
// mask = mask.to(torch.bool)
|
||||||
|
// source = source.flatten()[mask_prefix].reshape(mask.shape)
|
||||||
|
// return torch.where(mask, source, self)
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenMaskedScatterOp
|
||||||
|
: public OpRewritePattern<AtenMaskedScatterOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenMaskedScatterOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto context = op.getContext();
|
||||||
|
Value mask = op.getMask();
|
||||||
|
Value source = op.getSource();
|
||||||
|
Value self = op.getSelf();
|
||||||
|
|
||||||
|
auto selfTy = cast<BaseTensorType>(self.getType());
|
||||||
|
auto resTy = cast<BaseTensorType>(op.getType());
|
||||||
|
auto sourceTy = cast<BaseTensorType>(source.getType());
|
||||||
|
|
||||||
|
if (!resTy || !resTy.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
if (!selfTy || !selfTy.areAllSizesKnown())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Unimplemented: no implementation for rankless tensor");
|
||||||
|
if (!sourceTy || !sourceTy.areAllSizesKnown() || !sourceTy.hasDtype())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Unimplemented: no implementation for rankless tensor");
|
||||||
|
|
||||||
|
int64_t selfNumel = getTensorNumel(self).value(); // as selfTy has sizes
|
||||||
|
int64_t sourceNumel =
|
||||||
|
getTensorNumel(source).value(); // as sourceTy has sizes
|
||||||
|
int64_t selfRank = selfTy.getSizes().size();
|
||||||
|
int64_t sourceRank = sourceTy.getSizes().size();
|
||||||
|
|
||||||
|
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value selfLastDim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(selfRank - 1));
|
||||||
|
Value sourceLastDim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(sourceRank - 1));
|
||||||
|
|
||||||
|
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
||||||
|
auto int64Dtype = getDtypeIntValueForType(
|
||||||
|
rewriter, loc,
|
||||||
|
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
||||||
|
auto selfIntType = selfTy.getWithSizesAndDtype(selfTy.getSizes(), si64Type);
|
||||||
|
|
||||||
|
Value zerosLike = rewriter.create<Torch::AtenZerosLikeOp>(
|
||||||
|
loc, selfIntType, self, int64Dtype, constNone, constNone, constNone,
|
||||||
|
constNone);
|
||||||
|
Value maskInt = rewriter.create<Torch::AtenAddTensorOp>(
|
||||||
|
loc, selfIntType, mask, zerosLike, constOne);
|
||||||
|
|
||||||
|
auto flattenMaskedType = selfTy.getWithSizesAndDtype(
|
||||||
|
/*optionalSizes=*/{selfNumel}, si64Type);
|
||||||
|
Value maskIntFlatten = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||||
|
loc, flattenMaskedType, maskInt, constZero, selfLastDim);
|
||||||
|
Value prefixSum = rewriter.create<Torch::AtenCumsumOp>(
|
||||||
|
loc, flattenMaskedType, maskIntFlatten,
|
||||||
|
/*dim=*/constZero, constNone);
|
||||||
|
Value prefixSumMinusOne = rewriter.create<Torch::AtenSubScalarOp>(
|
||||||
|
loc, flattenMaskedType, prefixSum, constOne, constOne);
|
||||||
|
Value maskPrefix = rewriter.create<Torch::AtenClampOp>(
|
||||||
|
loc, flattenMaskedType, prefixSumMinusOne, /*min=*/constZero,
|
||||||
|
/*max=*/constNone);
|
||||||
|
|
||||||
|
auto sourceFlattenType = sourceTy.getWithSizesAndDtype(
|
||||||
|
/*optionalSizes=*/{sourceNumel}, sourceTy.getDtype());
|
||||||
|
Value sourceFlatten = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||||
|
loc, sourceFlattenType, source, constZero, sourceLastDim);
|
||||||
|
|
||||||
|
auto selectSourceType = sourceTy.getWithSizesAndDtype(
|
||||||
|
/*optionalSizes=*/{selfNumel}, sourceTy.getDtype());
|
||||||
|
Value selectSource = rewriter.create<Torch::AtenIndexSelectOp>(
|
||||||
|
loc, selectSourceType, sourceFlatten, constZero, maskPrefix);
|
||||||
|
|
||||||
|
// Reshape normalized output back to the original input shape
|
||||||
|
auto selfShape = rewriter.create<AtenSizeOp>(
|
||||||
|
loc, Torch::ListType::get(IntType::get(context)), self);
|
||||||
|
Value sourceReshape = rewriter.create<Torch::AtenViewOp>(
|
||||||
|
loc, selfTy, selectSource, selfShape);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(op, resTy, mask,
|
||||||
|
sourceReshape, self);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose aten._convolution-like to aten.convolution
|
// Decompose aten._convolution-like to aten.convolution
|
||||||
namespace {
|
namespace {
|
||||||
template <typename ConvolutionLikeOp>
|
template <typename ConvolutionLikeOp>
|
||||||
|
@ -7839,6 +7937,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedScatterOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
|
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
|
||||||
|
|
|
@ -390,6 +390,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
||||||
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
||||||
target.addIllegalOp<AtenMaskedFillScalarOp>();
|
target.addIllegalOp<AtenMaskedFillScalarOp>();
|
||||||
|
target.addIllegalOp<AtenMaskedScatterOp>();
|
||||||
target.addIllegalOp<AtenSizeOp>();
|
target.addIllegalOp<AtenSizeOp>();
|
||||||
target.addIllegalOp<AtenReshapeOp>();
|
target.addIllegalOp<AtenReshapeOp>();
|
||||||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||||
|
|
|
@ -209,6 +209,19 @@ std::optional<unsigned> Torch::getTensorRank(Value tensor) {
|
||||||
return tensorType.getSizes().size();
|
return tensorType.getSizes().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<int64_t> Torch::getTensorNumel(Value tensor) {
|
||||||
|
BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
|
||||||
|
if (!tensorType.hasSizes())
|
||||||
|
return std::nullopt;
|
||||||
|
int64_t numel = 1;
|
||||||
|
for (auto dim : tensorType.getSizes()) {
|
||||||
|
if (dim == ShapedType::kDynamic)
|
||||||
|
return ShapedType::kDynamic;
|
||||||
|
numel *= dim;
|
||||||
|
}
|
||||||
|
return numel;
|
||||||
|
}
|
||||||
|
|
||||||
bool Torch::isViewLikeOp(Operation *op) {
|
bool Torch::isViewLikeOp(Operation *op) {
|
||||||
// AtenContiguousOp might return a view, so this is conservatively
|
// AtenContiguousOp might return a view, so this is conservatively
|
||||||
// correct. We could potentially be more precise and identify the cases
|
// correct. We could potentially be more precise and identify the cases
|
||||||
|
|
|
@ -1072,6 +1072,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"LinspaceTwoSizeModule_basic",
|
"LinspaceTwoSizeModule_basic",
|
||||||
"MaskedFillScalarFloatValueStaticModule_basic",
|
"MaskedFillScalarFloatValueStaticModule_basic",
|
||||||
"MaskedFillScalarIntValueStaticModule_basic",
|
"MaskedFillScalarIntValueStaticModule_basic",
|
||||||
|
"MaskedScatterStaticBasic_basic",
|
||||||
"Matmul4dStatic_basic",
|
"Matmul4dStatic_basic",
|
||||||
"Matmul_2d",
|
"Matmul_2d",
|
||||||
"Matmul_dot",
|
"Matmul_dot",
|
||||||
|
@ -2366,6 +2367,7 @@ ONNX_XFAIL_SET = {
|
||||||
"LinalgNormKeepDimComplexModule_basic",
|
"LinalgNormKeepDimComplexModule_basic",
|
||||||
"LinalgVectorNormComplexModule_basic",
|
"LinalgVectorNormComplexModule_basic",
|
||||||
"LogSoftmaxBackwardModule_basic",
|
"LogSoftmaxBackwardModule_basic",
|
||||||
|
"MaskedScatterStaticBasic_basic",
|
||||||
"MaxPool1dCeilModeTrueModule_basic",
|
"MaxPool1dCeilModeTrueModule_basic",
|
||||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||||
"MaxPool1dModule_basic",
|
"MaxPool1dModule_basic",
|
||||||
|
|
|
@ -12,6 +12,31 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedScatterStaticBasic(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([4, 4], torch.float32, True),
|
||||||
|
([4, 4], torch.bool, True),
|
||||||
|
([8, 8], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x, mask, y):
|
||||||
|
return torch.masked_scatter(x, mask, y)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaskedScatterStaticBasic())
|
||||||
|
def MaskedScatterStaticBasic_basic(module, tu: TestUtils):
|
||||||
|
x = torch.rand(4, 4)
|
||||||
|
mask = torch.rand(4, 4) > 0.5
|
||||||
|
y = torch.rand(8, 8)
|
||||||
|
module.forward(x, mask, y)
|
||||||
|
|
||||||
|
|
||||||
class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module):
|
class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue