Add lowering of `aten.max.dim` op.

Lowering of `aten.max.dim` op has been added.
pull/553/head
Prashant Kumar 2022-01-25 14:23:55 +05:30 committed by Prashant Kumar
parent 454fa9d123
commit e58b66bc3b
7 changed files with 281 additions and 54 deletions

View File

@ -122,3 +122,97 @@ class ReduceMeanDtypeModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ReduceMeanDtypeModule())
def ReduceMeanDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceMaxAlongDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.max(a, 1)[0]
@register_test_case(module_factory=lambda: ReduceMaxAlongDim())
def ReduceMaxAlongDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceMaxAlongDimNegative(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.max(a, 1)[0]
@register_test_case(module_factory=lambda: ReduceMaxAlongDimNegative())
def ReduceMaxAlongDimNegative_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-10, high=10).to(torch.float64))
# ==============================================================================
class ReduceMaxKeepDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.max(a, 1, keepdim=True)[1]
@register_test_case(module_factory=lambda: ReduceMaxKeepDim())
def ReduceMaxKeepDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceMaxKeepDimReturnBoth(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.max(a, 1, keepdim=True)
@register_test_case(module_factory=lambda: ReduceMaxKeepDimReturnBoth())
def ReduceMaxKeepDimReturnBoth_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
# ==============================================================================
class ReduceMaxAllDims(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.max(a)
@register_test_case(module_factory=lambda: ReduceMaxAllDims())
def ReduceMaxAllDims_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))

View File

@ -2506,6 +2506,37 @@ def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [
let assemblyFormat = "$self `,` $dim `,` $keepdim `,` $dtype attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($keepdim)) `,` qualified(type($dtype)) `->` qualified(type($result))";
}
def Torch_AtenMaxOp : Torch_Op<"aten.max", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::max : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$values,
AnyTorchTensorType:$indices
);
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($keepdim)) `->` qualified(type($values)) `,` qualified(type($indices))";
}
def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
AllowsTypeRefinement
]> {

View File

@ -2126,6 +2126,13 @@ static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
elementType.isa<mlir::FloatType>())
return b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
if (isa<AtenMaxOp>(op) && elementType.isa<mlir::FloatType>())
return b.create<arith::ConstantOp>(
loc, b.getFloatAttr(
elementType,
APFloat::getLargest(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true)));
op->emitError("unimplemented lowering in "
"createLinalgNeutralElementForReduceOp");
@ -2141,6 +2148,11 @@ static Value createLinalgPayloadCalculationForReduceOp(
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1];
return b.create<arith::AddFOp>(loc, self, result);
} else if (isa<AtenMaxOp>(op) && resultElementType.isa<mlir::FloatType>()) {
Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1];
return b.create<arith::MaxFOp>(loc, self, result);
}
op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForReduceOp");
@ -2148,68 +2160,56 @@ static Value createLinalgPayloadCalculationForReduceOp(
}
namespace {
// Aten argmax lowering represents the ArgMax op as an linalg.indexed_generic
// Aten maxdim lowering represents the MaxDim op as an linalg.indexed_generic
// op, producing two output buffers.
//
// The first output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
// The first output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type.
//
// The second output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type. After being
// populated by indexed_generic, this buffer is disgarded as only the index is
// requested.
// The second output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
//
// The indexed_generic op updates both the maximum value and index if the
// current value exceeds the running max.
class ConvertAtenArgmaxOp : public OpConversionPattern<AtenArgmaxOp> {
class ConvertAtenMaxDimOp : public OpConversionPattern<AtenMaxDimOp> {
public:
using OpConversionPattern<AtenArgmaxOp>::OpConversionPattern;
using OpConversionPattern<AtenMaxDimOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenArgmaxOp argmaxOp, OpAdaptor adaptor,
matchAndRewrite(AtenMaxDimOp maxDimOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = argmaxOp.getLoc();
Location loc = maxDimOp.getLoc();
Value input = adaptor.self();
RankedTensorType resultType =
RankedTensorType valResultType =
getTypeConverter()
->convertType(argmaxOp.getResult().getType())
->convertType(maxDimOp.getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType idxResultType =
getTypeConverter()
->convertType(maxDimOp.getResult(1).getType())
.cast<RankedTensorType>();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type outElementType = resultType.getElementType();
if (!outElementType.isa<IntegerType>())
Type idxElementType = idxResultType.getElementType();
if (!idxElementType.isa<IntegerType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires integer-like result type");
maxDimOp,
"aten.max_dim to linalg.* requires integer-like result type");
bool keepDim = false;
if (!matchPattern(argmaxOp.keepdim(), m_TorchConstantBool(&keepDim)))
if (!matchPattern(maxDimOp.keepdim(), m_TorchConstantBool(&keepDim)))
return failure();
int64_t dim;
if (!matchPattern(argmaxOp.dim(), m_TorchConstantInt(&dim))) {
if (!argmaxOp.dim().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires int or NoneType value for Dim");
// For pytorch, if the value of Dim is None, argmax
// returns the index of the max value of the flattened input tensor,
// so here we flatten the input tensor.
SmallVector<ReassociationIndices> reassociation(1);
for (auto i : llvm::seq<int64_t>(0, inputType.getRank()))
reassociation[0].push_back(i);
input = rewriter.create<tensor::CollapseShapeOp>(argmaxOp->getLoc(),
input, reassociation);
// Becomes 0 for flattened tensor.
dim = 0;
// Recast to fix shape.
inputType = input.getType().cast<RankedTensorType>();
}
if (!matchPattern(maxDimOp.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires int value for Dim");
Type inElementType = inputType.getElementType();
if (!inElementType.isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires Float input element type");
maxDimOp,
"aten.max_dim to linalg.* requires Float input element type");
}
// Constant op to account for the reduction along dim.
@ -2224,7 +2224,7 @@ public:
}
// First fill the output buffer for the index.
Value filledTensorIdx =
createZeroInitTensor(rewriter, loc, resultShape, outElementType);
createZeroInitTensor(rewriter, loc, resultShape, idxElementType);
// Second fill the output buffer for the running max.
Value initTensorMax =
@ -2265,14 +2265,14 @@ public:
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc,
ArrayRef<Type>({filledTensorIdx.getType(), filledTensorMax.getType()}),
input, ValueRange({filledTensorIdx, filledTensorMax}), maps,
ArrayRef<Type>({filledTensorMax.getType(), filledTensorIdx.getType()}),
input, ValueRange({filledTensorMax, filledTensorIdx}), maps,
iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value newValue = blockArgs[0];
Value oldIndex = blockArgs[1];
Value oldValue = blockArgs[2];
Value oldValue = blockArgs[1];
Value oldIndex = blockArgs[2];
Value newIndex = rewriter.create<arith::IndexCastOp>(
nestedLoc, oldIndex.getType(),
@ -2287,12 +2287,15 @@ public:
auto resultIndex = rewriter.create<mlir::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultIndex, resultMax}));
nestedLoc, ValueRange({resultMax, resultIndex}));
});
// This cast is required to fix the shape in the case of keepDim=True
rewriter.replaceOpWithNewOp<tensor::CastOp>(argmaxOp, resultType,
linalgOp.getResult(0));
Value maxValuesCast = rewriter.create<tensor::CastOp>(
loc, valResultType, linalgOp.getResult(0));
Value maxIdxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
linalgOp.getResult(1));
rewriter.replaceOp(maxDimOp, {maxValuesCast, maxIdxCast});
return success();
}
};
@ -2558,11 +2561,12 @@ struct ConvertReductionOp : ConversionPattern {
// `keepDim` in accordance with their specification.
DenseSet<int64_t> dimSet;
bool keepDim = false;
if (isa<AtenSumOp>(op)) {
if (isa<AtenSumOp>(op) || isa<AtenMaxOp>(op)) {
auto tensorOperand = operands[0];
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
// `AtenSumOp` reduces along all the dimensiosn of the input tensor.
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
// input tensor.
for (int64_t i = 0; i < inputType.getRank(); i++)
dimSet.insert(i);
} else if (auto sumDimIntListOp = dyn_cast<AtenSumDimIntListOp>(op)) {
@ -2587,7 +2591,6 @@ struct ConvertReductionOp : ConversionPattern {
} else {
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
}
return createReductionLinalgGeneric(op, operands, dimSet, keepDim,
rewriter);
}
@ -4378,6 +4381,8 @@ public:
target.addIllegalOp<AtenConstantPadNdOp>();
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
target.addIllegalOp<AtenSumOp>();
target.addIllegalOp<AtenSumDimIntListOp>();
target.addIllegalOp<AtenMaxOp>();
patterns.add<ConvertReductionOp>(typeConverter, context);
target.addIllegalOp<AtenTransposeIntOp>();
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
@ -4391,8 +4396,8 @@ public:
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenBroadcastToOp>();
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
target.addIllegalOp<AtenArgmaxOp>();
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);
target.addIllegalOp<AtenMaxDimOp>();
patterns.add<ConvertAtenMaxDimOp>(typeConverter, context);
target.addIllegalOp<AtenSizeIntOp>();
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
target.addIllegalOp<AtenEmbeddingOp>();

View File

@ -325,6 +325,56 @@ public:
};
} // namespace
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`.
namespace {
class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenArgmaxOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Value dim = op.dim();
Value keepDim = op.keepdim();
Value result = op.result();
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
if (!indicesTensorType.hasSizes())
return failure();
BaseTensorType valueTensorType =
inputType
.getWithSizesAndDtype(indicesTensorType.getSizes(),
inputType.getDtype())
.cast<BaseTensorType>();
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
// `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input
// tensor is flattened to 1d tensor and then the reduction happens on the
// 0th dimension.
if (dim.getType().isa<Torch::NoneType>()) {
BaseTensorType flattenType =
inputType.getWithSizesAndDtype({kUnknownSize}, inputType.getDtype())
.cast<BaseTensorType>();
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(getTensorRank(input) - 1));
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
dim, end);
}
Value maxResult =
rewriter
.create<AtenMaxDimOp>(loc, valueTensorType, indicesTensorType,
input, dim, keepDim)
.indices();
rewriter.replaceOp(op, maxResult);
return success();
}
};
} // namespace
// Decompose aten.log_softmax op into: log(softmax(x))
namespace {
class DecomposeAtenLogSoftmaxIntOp
@ -678,6 +728,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenArangeOp>();
patterns.add<DecomposeAtenArangeStartOp>(context);
target.addIllegalOp<AtenArangeStartOp>();
patterns.add<DecomposeAtenArgMaxOp>(context);
target.addIllegalOp<AtenArgmaxOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -354,6 +354,15 @@ public:
Type dtype = operands[0]->getValue().dtype;
return visitReductionAlongDimIntOp(anyDim, anyDim.dim(), anyDim.keepdim(),
dtype, operands);
} else if (auto maxDim = dyn_cast<AtenMaxDimOp>(op)) {
Type firstResDtype = operands[0]->getValue().dtype;
Type secondResDtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
ChangeResult firstRes = visitReductionAlongDimIntOp(
maxDim, maxDim.dim(), maxDim.keepdim(), firstResDtype, operands);
return firstRes | visitReductionAlongDimIntOp(
maxDim, maxDim.dim(), maxDim.keepdim(),
secondResDtype, operands, /*resNum=*/1);
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
return visitReshapeLikeOp(view, operands);
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) {
@ -457,6 +466,9 @@ public:
Type dtype =
getDtypeOrDefault(mean.getContext(), mean.dtype(), defaultDtype);
return visitReductionAlongAllDimsOp(mean, dtype, operands);
} else if (auto max = dyn_cast<AtenMaxOp>(op)) {
Type dtype = operands[0]->getValue().dtype;
return visitReductionAlongAllDimsOp(max, dtype, operands);
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
@ -547,7 +559,7 @@ private:
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitReductionAlongDimIntOp(
Operation *op, Value dim, Value keepdim, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ArrayRef<LatticeElement<ValueKnowledge> *> operands, int resNum = 0);
template <typename OpTy>
ChangeResult
visitReshapeLikeOp(OpTy op,
@ -1307,7 +1319,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
Operation *op, Value dim, Value keepdim, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
ArrayRef<LatticeElement<ValueKnowledge> *> operands, int resNum) {
assert(dim.getType().isa<Torch::IntType>() && "dim must be int type");
auto input = operands[0]->getValue();
auto knowledge =
@ -1331,7 +1343,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, kUnknownSize);
}
}
return getLatticeElement(op->getResult(0)).join(knowledge);
return getLatticeElement(op->getResult(resNum)).join(knowledge);
}
// Reshape like ops are given a size list which specify the shape of the

View File

@ -587,6 +587,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::sum : (Tensor, int?) -> (Tensor)")
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
emit("aten::max : (Tensor) -> (Tensor)")
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")

View File

@ -154,3 +154,34 @@ func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
%0 = torch.aten.arange.start %int0, %int10, %none, %none, %none, %none : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
return %0 : !torch.vtensor<[?],si64>
}
// ----
// CHECK-LABEL: func @torch.aten.argmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[CST0]], %[[TRUE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],f32>, !torch.vtensor<[1,?],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[1,?],si64>
func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%0 = torch.aten.argmax %arg0, %int0, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],si64>
return %0 : !torch.vtensor<[1,?],si64>
}
// ----
// CHECK-LABEL: func @torch.aten.argmax$reduceall(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[INP]], %[[CST0]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] : !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[],si64>
func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
%none = torch.constant.none
%false = torch.constant.bool false
%0 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
return %0 : !torch.vtensor<[],si64>
}