[stablehlo] Support aten.any and aten.all lowering (#3217)

pull/3231/head
Xinyu Yang 2024-04-25 11:15:52 +08:00 committed by GitHub
parent 7be22bb260
commit 7030eacb76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 299 additions and 7 deletions

View File

@ -341,10 +341,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
isa<AtenNormScalarOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
if (isa<AtenAllDimOp>(op)) {
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
}
if (isa<AtenAnyOp>(op)) {
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(false));
}
op->emitError("unimplemented lowering in createInitElementForReduceOp");
return nullptr;
}
@ -439,11 +443,16 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, ord);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (isa<AtenAllDimOp>(op)) {
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
return b.create<arith::MulIOp>(loc, self, result);
return b.create<arith::AndIOp>(loc, self, result);
} else if (isa<AtenAnyOp>(op)) {
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
return b.create<arith::OrIOp>(loc, self, result);
}
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
return nullptr;
@ -510,13 +519,13 @@ private:
ConversionPatternRewriter &rewriter) const {
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
if (isa<AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp, AtenNormScalarOp>(
op)) {
if (isa<AtenAnyOp, AtenAllOp, AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp,
AtenNormScalarOp>(op)) {
opInfo.tensorOperand = operands[0];
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
// `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and `AtenMinOp` each reduce
// along all the dimensions of the input tensor.
// `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and
// `AtenMinOp` each reduce along all the dimensions of the input tensor.
for (int64_t i = 0; i < inputType.getRank(); i++)
opInfo.dimSet.insert(i);
@ -715,6 +724,8 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenMinDimOp>();
patterns.add<ConvertAtenMinMaxDimOp<AtenMinDimOp>>(typeConverter, context);
target.addIllegalOp<AtenSumOp>();
target.addIllegalOp<AtenAnyOp>();
target.addIllegalOp<AtenAllOp>();
target.addIllegalOp<AtenSumDimIntListOp>();
target.addIllegalOp<AtenProdOp>();
target.addIllegalOp<AtenProdDimIntOp>();

View File

@ -104,6 +104,18 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}
if (isa<AtenAllOp>(op)) {
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
if (isa<AtenAnyOp>(op)) {
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
op->emitError("unimplemented lowering in "
"createInitialValueForReduceOp");
return nullptr;
@ -463,6 +475,150 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
}
} // namespace
// AtenAllOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAllOp>::matchAndRewrite(
AtenAllOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenAllOp to StableHLO");
}
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
if (inputElemTy != outTy.getElementType()) {
// Use output bool type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
inputElemTy = inputTy.getElementType();
}
SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value allResult = rewriter.create<stablehlo::AndOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), allResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace
// AtenAnyOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAnyOp>::matchAndRewrite(
AtenAnyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenAllOp to StableHLO");
}
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
if (inputElemTy != outTy.getElementType()) {
// Use output bool type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
inputElemTy = inputTy.getElementType();
}
SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value anyResult = rewriter.create<stablehlo::OrOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), anyResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace
// AtenProdOp
namespace {
template <>
@ -1052,6 +1208,8 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);

View File

@ -1227,6 +1227,12 @@ STABLEHLO_PASS_SET = {
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"ReduceAllFloatModule_basic",
"ReduceAllIntModule_basic",
"ReduceAllBoolModule_basic",
"ReduceAnyFloatModule_basic",
"ReduceAnyIntModule_basic",
"ReduceAnyBoolModule_basic",
"ReduceAmaxMultiDim_basic",
"ReduceAmaxOutOfOrderDim_basic",
"ReduceAmaxSingleDim_basic",
@ -1813,6 +1819,8 @@ TOSA_PASS_SET = {
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"ReduceAllBoolModule_basic",
"ReduceAnyBoolModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumDimIntListKeepDimFloatModule_basic",
@ -2721,6 +2729,7 @@ ONNX_XFAIL_SET = {
"MaskedFillTensorFloatValueModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceAnyFloatModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
}

View File

@ -124,6 +124,120 @@ def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils):
# ==============================================================================
class ReduceAllFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.all(a)
@register_test_case(module_factory=lambda: ReduceAllFloatModule())
def ReduceAllFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ReduceAllIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.all(a)
@register_test_case(module_factory=lambda: ReduceAllIntModule())
def ReduceAllIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32))
# ==============================================================================
class ReduceAllBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
])
def forward(self, a):
return torch.ops.aten.all(a)
@register_test_case(module_factory=lambda: ReduceAllBoolModule())
def ReduceAllBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=2).to(torch.bool))
# ==============================================================================
class ReduceAnyFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.any(a)
@register_test_case(module_factory=lambda: ReduceAnyFloatModule())
def ReduceAnyFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ReduceAnyIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.any(a)
@register_test_case(module_factory=lambda: ReduceAnyIntModule())
def ReduceAnyIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32))
# ==============================================================================
class ReduceAnyBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
])
def forward(self, a):
return torch.ops.aten.any(a)
@register_test_case(module_factory=lambda: ReduceAnyBoolModule())
def ReduceAnyBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=2).to(torch.bool))
# ==============================================================================
class ReduceSumDimIntListFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()