mirror of https://github.com/llvm/torch-mlir
[stablehlo] Support aten.any and aten.all lowering (#3217)
parent
7be22bb260
commit
7030eacb76
|
@ -341,10 +341,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
|||
isa<AtenNormScalarOp>(op))
|
||||
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
|
||||
if (isa<AtenAllDimOp>(op)) {
|
||||
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
|
||||
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
|
||||
}
|
||||
|
||||
if (isa<AtenAnyOp>(op)) {
|
||||
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(false));
|
||||
}
|
||||
|
||||
op->emitError("unimplemented lowering in createInitElementForReduceOp");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -439,11 +443,16 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
|||
auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType);
|
||||
auto pow = b.create<math::PowFOp>(loc, abs, ord);
|
||||
return b.create<arith::AddFOp>(loc, pow, result);
|
||||
} else if (isa<AtenAllDimOp>(op)) {
|
||||
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
|
||||
Value elem = payloadArgs[0];
|
||||
Value result = payloadArgs[1];
|
||||
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
||||
return b.create<arith::MulIOp>(loc, self, result);
|
||||
return b.create<arith::AndIOp>(loc, self, result);
|
||||
} else if (isa<AtenAnyOp>(op)) {
|
||||
Value elem = payloadArgs[0];
|
||||
Value result = payloadArgs[1];
|
||||
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
||||
return b.create<arith::OrIOp>(loc, self, result);
|
||||
}
|
||||
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
|
||||
return nullptr;
|
||||
|
@ -510,13 +519,13 @@ private:
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
||||
|
||||
if (isa<AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp, AtenNormScalarOp>(
|
||||
op)) {
|
||||
if (isa<AtenAnyOp, AtenAllOp, AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp,
|
||||
AtenNormScalarOp>(op)) {
|
||||
opInfo.tensorOperand = operands[0];
|
||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
// `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and `AtenMinOp` each reduce
|
||||
// along all the dimensions of the input tensor.
|
||||
// `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and
|
||||
// `AtenMinOp` each reduce along all the dimensions of the input tensor.
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||
opInfo.dimSet.insert(i);
|
||||
|
||||
|
@ -715,6 +724,8 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
|
|||
target.addIllegalOp<AtenMinDimOp>();
|
||||
patterns.add<ConvertAtenMinMaxDimOp<AtenMinDimOp>>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSumOp>();
|
||||
target.addIllegalOp<AtenAnyOp>();
|
||||
target.addIllegalOp<AtenAllOp>();
|
||||
target.addIllegalOp<AtenSumDimIntListOp>();
|
||||
target.addIllegalOp<AtenProdOp>();
|
||||
target.addIllegalOp<AtenProdDimIntOp>();
|
||||
|
|
|
@ -104,6 +104,18 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
}
|
||||
}
|
||||
|
||||
if (isa<AtenAllOp>(op)) {
|
||||
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
|
||||
if (isa<AtenAnyOp>(op)) {
|
||||
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createInitialValueForReduceOp");
|
||||
return nullptr;
|
||||
|
@ -463,6 +475,150 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
}
|
||||
} // namespace
|
||||
|
||||
// AtenAllOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenReductionOp<AtenAllOp>::matchAndRewrite(
|
||||
AtenAllOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
||||
// Currently, (u)int8 dtype is not supported
|
||||
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenAllOp to StableHLO");
|
||||
}
|
||||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
|
||||
if (inputElemTy != outTy.getElementType()) {
|
||||
// Use output bool type as computation type.
|
||||
auto dstElemTy = outTy.getElementType();
|
||||
input =
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
inputElemTy = inputTy.getElementType();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
||||
dims.push_back(i);
|
||||
}
|
||||
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue)
|
||||
return failure();
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
|
||||
rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto *firstArgument = block.args_begin();
|
||||
auto secondArgument = block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value allResult = rewriter.create<stablehlo::AndOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), allResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
stablehloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenAnyOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenReductionOp<AtenAnyOp>::matchAndRewrite(
|
||||
AtenAnyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
||||
// Currently, (u)int8 dtype is not supported
|
||||
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenAllOp to StableHLO");
|
||||
}
|
||||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
|
||||
if (inputElemTy != outTy.getElementType()) {
|
||||
// Use output bool type as computation type.
|
||||
auto dstElemTy = outTy.getElementType();
|
||||
input =
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
inputElemTy = inputTy.getElementType();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
||||
dims.push_back(i);
|
||||
}
|
||||
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue)
|
||||
return failure();
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
|
||||
rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto *firstArgument = block.args_begin();
|
||||
auto secondArgument = block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value anyResult = rewriter.create<stablehlo::OrOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), anyResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
stablehloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenProdOp
|
||||
namespace {
|
||||
template <>
|
||||
|
@ -1052,6 +1208,8 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
|||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
|
||||
|
|
|
@ -1227,6 +1227,12 @@ STABLEHLO_PASS_SET = {
|
|||
"RandIntLowModule_basic",
|
||||
"RandIntModule_basic",
|
||||
"RandIntPinMemoryModule_basic",
|
||||
"ReduceAllFloatModule_basic",
|
||||
"ReduceAllIntModule_basic",
|
||||
"ReduceAllBoolModule_basic",
|
||||
"ReduceAnyFloatModule_basic",
|
||||
"ReduceAnyIntModule_basic",
|
||||
"ReduceAnyBoolModule_basic",
|
||||
"ReduceAmaxMultiDim_basic",
|
||||
"ReduceAmaxOutOfOrderDim_basic",
|
||||
"ReduceAmaxSingleDim_basic",
|
||||
|
@ -1813,6 +1819,8 @@ TOSA_PASS_SET = {
|
|||
"PrimsSqueezeModule_basic",
|
||||
"PrimsViewOfModule_basic",
|
||||
"PrimsViewOfZeroRankModule_basic",
|
||||
"ReduceAllBoolModule_basic",
|
||||
"ReduceAnyBoolModule_basic",
|
||||
"ReduceSumDimIntListFloatModule_basic",
|
||||
"ReduceSumDimIntListIntModule_basic",
|
||||
"ReduceSumDimIntListKeepDimFloatModule_basic",
|
||||
|
@ -2721,6 +2729,7 @@ ONNX_XFAIL_SET = {
|
|||
"MaskedFillTensorFloatValueModule_basic",
|
||||
"NativeDropoutTrainModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"ReduceAnyFloatModule_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
}
|
||||
|
|
|
@ -124,6 +124,120 @@ def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceAllFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.all(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAllFloatModule())
|
||||
def ReduceAllFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceAllIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.all(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAllIntModule())
|
||||
def ReduceAllIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceAllBoolModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.all(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAllBoolModule())
|
||||
def ReduceAllBoolModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, high=2).to(torch.bool))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceAnyFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.any(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAnyFloatModule())
|
||||
def ReduceAnyFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceAnyIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.any(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAnyIntModule())
|
||||
def ReduceAnyIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceAnyBoolModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.any(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAnyBoolModule())
|
||||
def ReduceAnyBoolModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, high=2).to(torch.bool))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumDimIntListFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue