mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] add E2E support for aten.min op (#2422)
* impl aten.min op * remove extraneous testpull/2430/head
parent
5282324c68
commit
c42d2beb6e
|
@ -792,6 +792,10 @@ STABLEHLO_PASS_SET = {
|
||||||
"ReduceMaxFloatModule_basic",
|
"ReduceMaxFloatModule_basic",
|
||||||
"ReduceMaxSignedIntModule_basic",
|
"ReduceMaxSignedIntModule_basic",
|
||||||
"ReduceMaxUnsignedIntModule_basic",
|
"ReduceMaxUnsignedIntModule_basic",
|
||||||
|
"ReduceMinAllDims_basic",
|
||||||
|
"ReduceMinFloatModule_basic",
|
||||||
|
"ReduceMinSignedIntModule_basic",
|
||||||
|
"ReduceMinUnsignedIntModule_basic",
|
||||||
"ReduceSumDimIntListFloatModule_basic",
|
"ReduceSumDimIntListFloatModule_basic",
|
||||||
"ReduceSumDimIntListIntModule_basic",
|
"ReduceSumDimIntListIntModule_basic",
|
||||||
"ReduceSumFloatModule_basic",
|
"ReduceSumFloatModule_basic",
|
||||||
|
|
|
@ -224,6 +224,22 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
||||||
elementType.getIntOrFloatBitWidth())));
|
elementType.getIntOrFloatBitWidth())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isa<AtenMinOp>(op)) {
|
||||||
|
if (elementType.isa<mlir::FloatType>())
|
||||||
|
return b.create<arith::ConstantOp>(
|
||||||
|
loc, b.getFloatAttr(
|
||||||
|
elementType,
|
||||||
|
APFloat::getInf(
|
||||||
|
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
|
/*Negative=*/false)));
|
||||||
|
else if (elementType.isa<mlir::IntegerType>() &&
|
||||||
|
elementType.getIntOrFloatBitWidth() != 8)
|
||||||
|
return b.create<arith::ConstantOp>(
|
||||||
|
loc, b.getIntegerAttr(elementType,
|
||||||
|
APSInt::getSignedMaxValue(
|
||||||
|
elementType.getIntOrFloatBitWidth())));
|
||||||
|
}
|
||||||
|
|
||||||
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
|
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
|
||||||
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||||
|
|
||||||
|
@ -261,6 +277,23 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
||||||
if (intType.isSigned())
|
if (intType.isSigned())
|
||||||
return b.create<arith::MaxSIOp>(loc, self, result);
|
return b.create<arith::MaxSIOp>(loc, self, result);
|
||||||
}
|
}
|
||||||
|
} else if (auto min = dyn_cast<AtenMinOp>(op)) {
|
||||||
|
Value self =
|
||||||
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||||
|
Value result = payloadArgs[1];
|
||||||
|
if (resultElementType.isa<mlir::FloatType>())
|
||||||
|
return b.create<arith::MinFOp>(loc, self, result);
|
||||||
|
else if (resultElementType.isa<mlir::IntegerType>()) {
|
||||||
|
IntegerType intType = min.getSelf()
|
||||||
|
.getType()
|
||||||
|
.cast<BaseTensorType>()
|
||||||
|
.getDtype()
|
||||||
|
.dyn_cast<mlir::IntegerType>();
|
||||||
|
if (intType.isUnsigned())
|
||||||
|
return b.create<arith::MinUIOp>(loc, self, result);
|
||||||
|
if (intType.isSigned())
|
||||||
|
return b.create<arith::MinSIOp>(loc, self, result);
|
||||||
|
}
|
||||||
} else if (isa<AtenLinalgVectorNormOp>(op)) {
|
} else if (isa<AtenLinalgVectorNormOp>(op)) {
|
||||||
// This creates payload for only the first of the two linalg.generic ops.
|
// This creates payload for only the first of the two linalg.generic ops.
|
||||||
// TODO: Short-circuit operations if `ord` is zero or one.
|
// TODO: Short-circuit operations if `ord` is zero or one.
|
||||||
|
@ -340,11 +373,11 @@ private:
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
||||||
|
|
||||||
if (isa<AtenMaxOp, AtenSumOp>(op)) {
|
if (isa<AtenMaxOp, AtenMinOp, AtenSumOp>(op)) {
|
||||||
opInfo.tensorOperand = operands[0];
|
opInfo.tensorOperand = operands[0];
|
||||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
|
// `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the dimensions of the
|
||||||
// input tensor.
|
// input tensor.
|
||||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||||
opInfo.dimSet.insert(i);
|
opInfo.dimSet.insert(i);
|
||||||
|
@ -520,6 +553,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
|
||||||
target.addIllegalOp<AtenSumOp>();
|
target.addIllegalOp<AtenSumOp>();
|
||||||
target.addIllegalOp<AtenSumDimIntListOp>();
|
target.addIllegalOp<AtenSumDimIntListOp>();
|
||||||
target.addIllegalOp<AtenMaxOp>();
|
target.addIllegalOp<AtenMaxOp>();
|
||||||
|
target.addIllegalOp<AtenMinOp>();
|
||||||
target.addIllegalOp<AtenLinalgVectorNormOp>();
|
target.addIllegalOp<AtenLinalgVectorNormOp>();
|
||||||
target.addIllegalOp<AtenFrobeniusNormDimOp>();
|
target.addIllegalOp<AtenFrobeniusNormDimOp>();
|
||||||
patterns.add<ConvertReductionOp>(typeConverter, context);
|
patterns.add<ConvertReductionOp>(typeConverter, context);
|
||||||
|
|
|
@ -68,6 +68,24 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isa<AtenMinOp>(op)) {
|
||||||
|
if (elementTy.isa<mlir::FloatType>()) {
|
||||||
|
auto constAttr = DenseElementsAttr::get(
|
||||||
|
constType, {APFloat::getInf(
|
||||||
|
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
|
/*negative=*/false)});
|
||||||
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
|
constAttr);
|
||||||
|
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||||
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
|
auto constAttr = DenseElementsAttr::get(
|
||||||
|
constType,
|
||||||
|
{APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())});
|
||||||
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
|
constAttr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
op->emitError("unimplemented lowering in "
|
op->emitError("unimplemented lowering in "
|
||||||
"createInitialValueForReduceOp");
|
"createInitialValueForReduceOp");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -481,6 +499,68 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// AtenMinOp
|
||||||
|
namespace {
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
|
||||||
|
AtenMinOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!inputTy) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only Tensor types supported in StableHLO");
|
||||||
|
}
|
||||||
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
return op.emitError(
|
||||||
|
"only floating-point or integer datatype legalization supported");
|
||||||
|
}
|
||||||
|
// Currently, (u)int8 dtype is not supported
|
||||||
|
if (inputElemTy.isa<mlir::IntegerType>() &&
|
||||||
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
"AtenMinOp to StableHLO");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> dims;
|
||||||
|
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
||||||
|
dims.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value initValue =
|
||||||
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||||
|
if (!initValue)
|
||||||
|
return failure();
|
||||||
|
llvm::sort(dims.begin(), dims.end());
|
||||||
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||||
|
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||||
|
|
||||||
|
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||||
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||||
|
|
||||||
|
block.addArgument(blockArgumentTy, op->getLoc());
|
||||||
|
block.addArgument(blockArgumentTy, op->getLoc());
|
||||||
|
|
||||||
|
auto *firstArgument = block.args_begin();
|
||||||
|
auto secondArgument = block.args_rbegin();
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
Value minResult = rewriter.create<stablehlo::MinOp>(
|
||||||
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), minResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
|
op, getTypeConverter()->convertType(op.getType()),
|
||||||
|
stablehloReduceOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// AtenSumDimIntListOp
|
// AtenSumDimIntListOp
|
||||||
namespace {
|
namespace {
|
||||||
template <>
|
template <>
|
||||||
|
@ -838,6 +918,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
||||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
|
||||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
|
||||||
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp);
|
||||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
|
||||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp);
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp);
|
||||||
#undef INSERT_ATEN_REDUCTION_OP_PATTERN
|
#undef INSERT_ATEN_REDUCTION_OP_PATTERN
|
||||||
|
|
|
@ -6562,6 +6562,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.min\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
|
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.max\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.max\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -10160,6 +10164,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %2 : !torch.int\n"
|
" return %2 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.min\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.max\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.max\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
|
|
|
@ -300,6 +300,9 @@ def aten〇any〡shape(self: List[int]) -> List[int]:
|
||||||
def aten〇all〡shape(self: List[int]) -> List[int]:
|
def aten〇all〡shape(self: List[int]) -> List[int]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def aten〇min〡shape(self: List[int]) -> List[int]:
|
||||||
|
return []
|
||||||
|
|
||||||
def aten〇max〡shape(self: List[int]) -> List[int]:
|
def aten〇max〡shape(self: List[int]) -> List[int]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -2981,6 +2984,11 @@ def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim
|
||||||
return self_dtype
|
return self_dtype
|
||||||
return torch.bool
|
return torch.bool
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
|
def aten〇min〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
def aten〇max〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇max〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
|
|
@ -572,6 +572,58 @@ def ReduceAmaxKeepDim_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceMinFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.min(a)
|
||||||
|
@register_test_case(module_factory=lambda: ReduceMinFloatModule())
|
||||||
|
def ReduceMinFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceMinSignedIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.min(a)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceMinSignedIntModule())
|
||||||
|
def ReduceMinSignedIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, 5, low=-100, high=100))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceMinUnsignedIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.min(a)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceMinUnsignedIntModule())
|
||||||
|
def ReduceMinUnsignedIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, 5, high=100))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
class ReduceL1NormModule(torch.nn.Module):
|
class ReduceL1NormModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue