mirror of https://github.com/llvm/torch-mlir
add argmax lowering
Add argmax lowering from torch to linalgpull/364/head snapshot-20211013.20
parent
19e9fc4ef1
commit
7750d2173a
|
@ -0,0 +1,66 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ArgmaxModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.argmax(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ArgmaxModule())
|
||||||
|
def ArgmaxModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ArgmaxWithDimModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.argmax(a, dim=1)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
|
||||||
|
def ArgmaxModule_with_dim(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ArgmaxKeepDimsModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.argmax(a, 0, True)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
|
||||||
|
def ArgmaxModule_keepDim(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 6))
|
||||||
|
|
|
@ -34,6 +34,7 @@ from . import batchnorm
|
||||||
from . import quantized_models
|
from . import quantized_models
|
||||||
from . import elementwise
|
from . import elementwise
|
||||||
from . import reduction
|
from . import reduction
|
||||||
|
from . import argmax
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||||
|
|
|
@ -1349,6 +1349,22 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [
|
||||||
let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
|
let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
TorchOptionalIntType:$dim,
|
||||||
|
Torch_BoolType:$keepdim
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
|
def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
|
||||||
AllowsTypeRefinement
|
AllowsTypeRefinement
|
||||||
]> {
|
]> {
|
||||||
|
|
|
@ -800,7 +800,7 @@ public:
|
||||||
// of *internal* compiler invariants, and for a user manifests as a compiler
|
// of *internal* compiler invariants, and for a user manifests as a compiler
|
||||||
// crash in the worst case (such as we try to canonicalize/fold/print the
|
// crash in the worst case (such as we try to canonicalize/fold/print the
|
||||||
// invalid op before the verifier gets to see it -- also release builds of a
|
// invalid op before the verifier gets to see it -- also release builds of a
|
||||||
// mature copmiler usually have the verifier turned off for compile time
|
// mature compiler usually have the verifier turned off for compile time
|
||||||
// reasons).
|
// reasons).
|
||||||
//
|
//
|
||||||
// The compiler cannot crash even if the user wrote an erroneous program!
|
// The compiler cannot crash even if the user wrote an erroneous program!
|
||||||
|
@ -1141,12 +1141,161 @@ static Value createLinalgPayloadCalculationForReduceOp(
|
||||||
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
|
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
|
||||||
elementType.isa<mlir::FloatType>())
|
elementType.isa<mlir::FloatType>())
|
||||||
return b.create<AddFOp>(loc, payloadArgs);
|
return b.create<AddFOp>(loc, payloadArgs);
|
||||||
|
|
||||||
op->emitError("unimplemented lowering in "
|
op->emitError("unimplemented lowering in "
|
||||||
"createLinalgPayloadCalculationForReduceOp");
|
"createLinalgPayloadCalculationForReduceOp");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Aten argmax lowering represents the ArgMax op as an linalg.indexed_generic
|
||||||
|
// op, producing two output buffers.
|
||||||
|
//
|
||||||
|
// The first output buffer contains the index of the found maximum value. It is
|
||||||
|
// initialized to 0 and is resulting integer type.
|
||||||
|
//
|
||||||
|
// The second output buffer contains the maximum value found. It is initialized
|
||||||
|
// to the minimum representable value of the input element type. After being
|
||||||
|
// populated by indexed_generic, this buffer is disgarded as only the index is
|
||||||
|
// requested.
|
||||||
|
//
|
||||||
|
// The indexed_generic op updates both the maximum value and index if the
|
||||||
|
// current value exceeds the running max.
|
||||||
|
class ConvertAtenArgmaxOp : public OpConversionPattern<AtenArgmaxOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenArgmaxOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenArgmaxOp argmaxOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
Location loc = argmaxOp.getLoc();
|
||||||
|
AtenArgmaxOp::Adaptor adaptor(operands);
|
||||||
|
Value input = adaptor.self();
|
||||||
|
RankedTensorType resultType =
|
||||||
|
getTypeConverter()
|
||||||
|
->convertType(argmaxOp.getResult().getType())
|
||||||
|
.cast<RankedTensorType>();
|
||||||
|
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
Type outElementType = resultType.getElementType();
|
||||||
|
if (!outElementType.isa<IntegerType>())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
argmaxOp,
|
||||||
|
"aten.arg_max to linalg.* requires integer-like result type");
|
||||||
|
|
||||||
|
bool keepDim = false;
|
||||||
|
if (!matchPattern(argmaxOp.keepdim(), m_TorchConstantBool(&keepDim)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(argmaxOp.dim(), m_TorchConstantInt(&dim))) {
|
||||||
|
if (!argmaxOp.dim().getType().isa<Torch::NoneType>())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
argmaxOp,
|
||||||
|
"aten.arg_max to linalg.* requires int or NoneType value for Dim");
|
||||||
|
// For pytorch, if the value of Dim is None, argmax
|
||||||
|
// returns the index of the max value of the flattened input tensor,
|
||||||
|
// so here we flatten the input tensor.
|
||||||
|
SmallVector<ReassociationIndices> reassociation(1);
|
||||||
|
for (auto i : llvm::seq<int64_t>(0, inputType.getRank()))
|
||||||
|
reassociation[0].push_back(i);
|
||||||
|
input = rewriter.create<linalg::TensorCollapseShapeOp>(
|
||||||
|
argmaxOp->getLoc(), input, reassociation);
|
||||||
|
// Becomes 0 for flattened tensor.
|
||||||
|
dim = 0;
|
||||||
|
// Recast to fix shape.
|
||||||
|
inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
}
|
||||||
|
Type inElementType = inputType.getElementType();
|
||||||
|
if (!inElementType.isa<mlir::FloatType>()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
argmaxOp,
|
||||||
|
"aten.arg_max to linalg.* requires Float input element type");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constant op to account for the reduction along dim.
|
||||||
|
auto c1 = rewriter.create<mlir::ConstantIndexOp>(loc, /*value=*/1);
|
||||||
|
SmallVector<Value> resultShape;
|
||||||
|
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
||||||
|
if (dim != i) {
|
||||||
|
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
|
||||||
|
resultShape.push_back(currentDimSize);
|
||||||
|
} else if (keepDim)
|
||||||
|
resultShape.push_back(c1);
|
||||||
|
}
|
||||||
|
// First fill the output buffer for the index.
|
||||||
|
Value filledTensorIdx =
|
||||||
|
createZeroInitTensor(rewriter, loc, resultShape, outElementType);
|
||||||
|
|
||||||
|
// Second fill the output buffer for the running max.
|
||||||
|
Value initTensorMax =
|
||||||
|
rewriter.create<linalg::InitTensorOp>(loc, resultShape, inElementType)
|
||||||
|
.result();
|
||||||
|
|
||||||
|
FloatAttr fillValueMaxAttr = rewriter.getFloatAttr(
|
||||||
|
inElementType,
|
||||||
|
APFloat::getLargest(
|
||||||
|
inElementType.cast<mlir::FloatType>().getFloatSemantics(), true));
|
||||||
|
|
||||||
|
Value fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
|
||||||
|
Value filledTensorMax =
|
||||||
|
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
|
||||||
|
.result();
|
||||||
|
|
||||||
|
// Create the affine expressions that will be used to
|
||||||
|
// iterate over the input and output tensors.
|
||||||
|
// Here we also set the type of iterator: parallel or reduction.
|
||||||
|
SmallVector<AffineExpr> exprs;
|
||||||
|
SmallVector<StringRef> iteratorTypes;
|
||||||
|
SmallVector<AffineExpr> resultExprs;
|
||||||
|
for (auto size : llvm::enumerate(inputType.getShape())) {
|
||||||
|
exprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||||
|
|
||||||
|
if (unsigned(dim) == size.index()) {
|
||||||
|
iteratorTypes.push_back(getReductionIteratorTypeName());
|
||||||
|
// If `keepDim`, create affine map to the first element
|
||||||
|
// in the current dimension.
|
||||||
|
if (keepDim)
|
||||||
|
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||||
|
} else {
|
||||||
|
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||||
|
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
|
||||||
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||||
|
loc,
|
||||||
|
ArrayRef<Type>({filledTensorIdx.getType(), filledTensorMax.getType()}),
|
||||||
|
input, ValueRange({filledTensorIdx, filledTensorMax}), maps,
|
||||||
|
iteratorTypes,
|
||||||
|
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||||
|
ValueRange blockArgs) {
|
||||||
|
Value newValue = blockArgs[0];
|
||||||
|
Value oldIndex = blockArgs[1];
|
||||||
|
Value oldValue = blockArgs[2];
|
||||||
|
|
||||||
|
Value newIndex = rewriter.create<IndexCastOp>(
|
||||||
|
nestedLoc, oldIndex.getType(),
|
||||||
|
rewriter.create<linalg::IndexOp>(loc, dim));
|
||||||
|
|
||||||
|
Value predicate;
|
||||||
|
if (inElementType.isa<mlir::FloatType>())
|
||||||
|
predicate = rewriter.create<mlir::CmpFOp>(
|
||||||
|
nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
|
||||||
|
auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
|
||||||
|
newValue, oldValue);
|
||||||
|
auto resultIndex = rewriter.create<mlir::SelectOp>(
|
||||||
|
nestedLoc, predicate, newIndex, oldIndex);
|
||||||
|
nestedBuilder.create<linalg::YieldOp>(
|
||||||
|
nestedLoc, ValueRange({resultIndex, resultMax}));
|
||||||
|
});
|
||||||
|
|
||||||
|
// This cast is required to fix the shape in the case of keepDim=True
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(argmaxOp, resultType,
|
||||||
|
linalgOp.getResult(0));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Converts an elementwise op.
|
// Converts an elementwise op.
|
||||||
|
@ -1896,6 +2045,8 @@ public:
|
||||||
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
|
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenLayerNormOp>();
|
target.addIllegalOp<AtenLayerNormOp>();
|
||||||
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
|
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenArgmaxOp>();
|
||||||
|
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
|
|
@ -138,7 +138,6 @@ Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
|
||||||
sizes.push_back(-1);
|
sizes.push_back(-1);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t size;
|
int64_t size;
|
||||||
auto optionalInt = parser.parseOptionalInteger(size);
|
auto optionalInt = parser.parseOptionalInteger(size);
|
||||||
if (optionalInt.hasValue()) {
|
if (optionalInt.hasValue()) {
|
||||||
|
|
|
@ -270,6 +270,8 @@ public:
|
||||||
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
|
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
|
||||||
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
|
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
|
||||||
meanDim.keepdim(), operands);
|
meanDim.keepdim(), operands);
|
||||||
|
} else if (auto argmax = dyn_cast<AtenArgmaxOp>(op)) {
|
||||||
|
return visitAtenArgmaxOp(argmax, operands);
|
||||||
} else if (auto anyDim = dyn_cast<AtenAnyDimOp>(op)) {
|
} else if (auto anyDim = dyn_cast<AtenAnyDimOp>(op)) {
|
||||||
return visitAtenAnyDimOp(anyDim, operands);
|
return visitAtenAnyDimOp(anyDim, operands);
|
||||||
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
|
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
|
||||||
|
@ -397,6 +399,9 @@ private:
|
||||||
Operation *op, Value dim, Value keepdim,
|
Operation *op, Value dim, Value keepdim,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
ChangeResult
|
ChangeResult
|
||||||
|
visitAtenArgmaxOp(AtenArgmaxOp op,
|
||||||
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
ChangeResult
|
||||||
visitAtenAnyDimOp(AtenAnyDimOp op,
|
visitAtenAnyDimOp(AtenAnyDimOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
|
@ -733,8 +738,8 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = input.dtype;
|
knowledge.dtype = input.dtype;
|
||||||
llvm::SmallVector<int64_t> dimList;
|
llvm::SmallVector<int64_t> dimList;
|
||||||
bool keepdimBool;
|
bool keepDim;
|
||||||
if (matchPattern(keepdim, m_TorchConstantBool(&keepdimBool))) {
|
if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) {
|
||||||
knowledge.hasSizes = true;
|
knowledge.hasSizes = true;
|
||||||
int64_t inputRank = input.sizes.size();
|
int64_t inputRank = input.sizes.size();
|
||||||
// TODO: This is not safe. Need to check the list users and use aliasing
|
// TODO: This is not safe. Need to check the list users and use aliasing
|
||||||
|
@ -745,7 +750,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
|
||||||
DenseSet<int64_t> dimSet(dimList.begin(), dimList.end());
|
DenseSet<int64_t> dimSet(dimList.begin(), dimList.end());
|
||||||
for (auto en : llvm::enumerate(input.sizes)) {
|
for (auto en : llvm::enumerate(input.sizes)) {
|
||||||
if (dimSet.contains(en.index())) {
|
if (dimSet.contains(en.index())) {
|
||||||
if (keepdimBool)
|
if (keepDim)
|
||||||
knowledge.sizes.push_back(1);
|
knowledge.sizes.push_back(1);
|
||||||
} else {
|
} else {
|
||||||
knowledge.sizes.push_back(en.value());
|
knowledge.sizes.push_back(en.value());
|
||||||
|
@ -753,12 +758,39 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
|
||||||
}
|
}
|
||||||
} else if (auto listConstruct = dim.getDefiningOp<PrimListConstructOp>()) {
|
} else if (auto listConstruct = dim.getDefiningOp<PrimListConstructOp>()) {
|
||||||
auto sizes = listConstruct.elements();
|
auto sizes = listConstruct.elements();
|
||||||
knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - sizes.size(),
|
knowledge.sizes.resize(keepDim ? inputRank : inputRank - sizes.size(),
|
||||||
kUnknownSize);
|
kUnknownSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||||
}
|
}
|
||||||
|
ChangeResult TypeAnalyzer::visitAtenArgmaxOp(
|
||||||
|
AtenArgmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
auto input = operands[0]->getValue();
|
||||||
|
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||||
|
knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||||
|
int64_t dim;
|
||||||
|
bool keepDim;
|
||||||
|
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
|
||||||
|
int64_t inputRank = input.sizes.size();
|
||||||
|
knowledge.hasSizes = true;
|
||||||
|
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
||||||
|
knowledge.sizes = input.sizes;
|
||||||
|
dim = toPositiveDim(dim, inputRank);
|
||||||
|
if (isValidDim(dim, inputRank)) {
|
||||||
|
if (keepDim)
|
||||||
|
knowledge.sizes[dim] = 1;
|
||||||
|
else
|
||||||
|
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
|
||||||
|
}
|
||||||
|
} else if (op.dim().getType().isa<IntegerType>())
|
||||||
|
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1,
|
||||||
|
kUnknownSize);
|
||||||
|
}
|
||||||
|
// If dim is no kind of Integer, keepDim is ignored,
|
||||||
|
// and the result will bea rank-0 tensor
|
||||||
|
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
|
ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
|
||||||
AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
@ -767,22 +799,21 @@ ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = input.dtype;
|
knowledge.dtype = input.dtype;
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
bool keepdimBool;
|
bool keepDim;
|
||||||
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepdimBool))) {
|
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
|
||||||
int64_t inputRank = input.sizes.size();
|
int64_t inputRank = input.sizes.size();
|
||||||
knowledge.hasSizes = true;
|
knowledge.hasSizes = true;
|
||||||
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
||||||
knowledge.sizes = input.sizes;
|
knowledge.sizes = input.sizes;
|
||||||
dim = toPositiveDim(dim, inputRank);
|
dim = toPositiveDim(dim, inputRank);
|
||||||
if (isValidDim(dim, inputRank)) {
|
if (isValidDim(dim, inputRank)) {
|
||||||
if (keepdimBool)
|
if (keepDim)
|
||||||
knowledge.sizes[dim] = 1;
|
knowledge.sizes[dim] = 1;
|
||||||
else
|
else
|
||||||
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
|
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - 1,
|
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, kUnknownSize);
|
||||||
kUnknownSize);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||||
|
|
|
@ -510,6 +510,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
|
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
|
||||||
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
|
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
|
||||||
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
|
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
|
||||||
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
|
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
|
||||||
emit("aten::detach : (Tensor) -> (Tensor)")
|
emit("aten::detach : (Tensor) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue