mirror of https://github.com/llvm/torch-mlir
Add lowering of `aten.max.dim` op.
Lowering of `aten.max.dim` op has been added.pull/553/head
parent
454fa9d123
commit
e58b66bc3b
|
@ -122,3 +122,97 @@ class ReduceMeanDtypeModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ReduceMeanDtypeModule())
|
||||
def ReduceMeanDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceMaxAlongDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxAlongDim())
|
||||
def ReduceMaxAlongDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceMaxAlongDimNegative(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxAlongDimNegative())
|
||||
def ReduceMaxAlongDimNegative_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, low=-10, high=10).to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceMaxKeepDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1, keepdim=True)[1]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxKeepDim())
|
||||
def ReduceMaxKeepDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceMaxKeepDimReturnBoth(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1, keepdim=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxKeepDimReturnBoth())
|
||||
def ReduceMaxKeepDimReturnBoth_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceMaxAllDims(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxAllDims())
|
||||
def ReduceMaxAllDims_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
|
||||
|
|
|
@ -2506,6 +2506,37 @@ def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [
|
|||
let assemblyFormat = "$self `,` $dim `,` $keepdim `,` $dtype attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($keepdim)) `,` qualified(type($dtype)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenMaxOp : Torch_Op<"aten.max", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$values,
|
||||
AnyTorchTensorType:$indices
|
||||
);
|
||||
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($keepdim)) `->` qualified(type($values)) `,` qualified(type($indices))";
|
||||
}
|
||||
|
||||
def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
|
|
@ -2126,6 +2126,13 @@ static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
|
|||
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
|
||||
elementType.isa<mlir::FloatType>())
|
||||
return b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
|
||||
if (isa<AtenMaxOp>(op) && elementType.isa<mlir::FloatType>())
|
||||
return b.create<arith::ConstantOp>(
|
||||
loc, b.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getLargest(
|
||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true)));
|
||||
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgNeutralElementForReduceOp");
|
||||
|
@ -2141,6 +2148,11 @@ static Value createLinalgPayloadCalculationForReduceOp(
|
|||
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||
Value result = payloadArgs[1];
|
||||
return b.create<arith::AddFOp>(loc, self, result);
|
||||
} else if (isa<AtenMaxOp>(op) && resultElementType.isa<mlir::FloatType>()) {
|
||||
Value self =
|
||||
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||
Value result = payloadArgs[1];
|
||||
return b.create<arith::MaxFOp>(loc, self, result);
|
||||
}
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgPayloadCalculationForReduceOp");
|
||||
|
@ -2148,68 +2160,56 @@ static Value createLinalgPayloadCalculationForReduceOp(
|
|||
}
|
||||
|
||||
namespace {
|
||||
// Aten argmax lowering represents the ArgMax op as an linalg.indexed_generic
|
||||
// Aten maxdim lowering represents the MaxDim op as an linalg.indexed_generic
|
||||
// op, producing two output buffers.
|
||||
//
|
||||
// The first output buffer contains the index of the found maximum value. It is
|
||||
// initialized to 0 and is resulting integer type.
|
||||
// The first output buffer contains the maximum value found. It is initialized
|
||||
// to the minimum representable value of the input element type.
|
||||
//
|
||||
// The second output buffer contains the maximum value found. It is initialized
|
||||
// to the minimum representable value of the input element type. After being
|
||||
// populated by indexed_generic, this buffer is disgarded as only the index is
|
||||
// requested.
|
||||
// The second output buffer contains the index of the found maximum value. It is
|
||||
// initialized to 0 and is resulting integer type.
|
||||
//
|
||||
// The indexed_generic op updates both the maximum value and index if the
|
||||
// current value exceeds the running max.
|
||||
class ConvertAtenArgmaxOp : public OpConversionPattern<AtenArgmaxOp> {
|
||||
class ConvertAtenMaxDimOp : public OpConversionPattern<AtenMaxDimOp> {
|
||||
public:
|
||||
using OpConversionPattern<AtenArgmaxOp>::OpConversionPattern;
|
||||
using OpConversionPattern<AtenMaxDimOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenArgmaxOp argmaxOp, OpAdaptor adaptor,
|
||||
matchAndRewrite(AtenMaxDimOp maxDimOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = argmaxOp.getLoc();
|
||||
Location loc = maxDimOp.getLoc();
|
||||
Value input = adaptor.self();
|
||||
RankedTensorType resultType =
|
||||
RankedTensorType valResultType =
|
||||
getTypeConverter()
|
||||
->convertType(argmaxOp.getResult().getType())
|
||||
->convertType(maxDimOp.getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType idxResultType =
|
||||
getTypeConverter()
|
||||
->convertType(maxDimOp.getResult(1).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
Type outElementType = resultType.getElementType();
|
||||
if (!outElementType.isa<IntegerType>())
|
||||
Type idxElementType = idxResultType.getElementType();
|
||||
if (!idxElementType.isa<IntegerType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
argmaxOp,
|
||||
"aten.arg_max to linalg.* requires integer-like result type");
|
||||
maxDimOp,
|
||||
"aten.max_dim to linalg.* requires integer-like result type");
|
||||
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(argmaxOp.keepdim(), m_TorchConstantBool(&keepDim)))
|
||||
if (!matchPattern(maxDimOp.keepdim(), m_TorchConstantBool(&keepDim)))
|
||||
return failure();
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(argmaxOp.dim(), m_TorchConstantInt(&dim))) {
|
||||
if (!argmaxOp.dim().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
argmaxOp,
|
||||
"aten.arg_max to linalg.* requires int or NoneType value for Dim");
|
||||
// For pytorch, if the value of Dim is None, argmax
|
||||
// returns the index of the max value of the flattened input tensor,
|
||||
// so here we flatten the input tensor.
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
for (auto i : llvm::seq<int64_t>(0, inputType.getRank()))
|
||||
reassociation[0].push_back(i);
|
||||
input = rewriter.create<tensor::CollapseShapeOp>(argmaxOp->getLoc(),
|
||||
input, reassociation);
|
||||
// Becomes 0 for flattened tensor.
|
||||
dim = 0;
|
||||
// Recast to fix shape.
|
||||
inputType = input.getType().cast<RankedTensorType>();
|
||||
}
|
||||
if (!matchPattern(maxDimOp.dim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
maxDimOp, "aten.max_dim to linalg.* requires int value for Dim");
|
||||
|
||||
Type inElementType = inputType.getElementType();
|
||||
if (!inElementType.isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
argmaxOp,
|
||||
"aten.arg_max to linalg.* requires Float input element type");
|
||||
maxDimOp,
|
||||
"aten.max_dim to linalg.* requires Float input element type");
|
||||
}
|
||||
|
||||
// Constant op to account for the reduction along dim.
|
||||
|
@ -2224,7 +2224,7 @@ public:
|
|||
}
|
||||
// First fill the output buffer for the index.
|
||||
Value filledTensorIdx =
|
||||
createZeroInitTensor(rewriter, loc, resultShape, outElementType);
|
||||
createZeroInitTensor(rewriter, loc, resultShape, idxElementType);
|
||||
|
||||
// Second fill the output buffer for the running max.
|
||||
Value initTensorMax =
|
||||
|
@ -2265,14 +2265,14 @@ public:
|
|||
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc,
|
||||
ArrayRef<Type>({filledTensorIdx.getType(), filledTensorMax.getType()}),
|
||||
input, ValueRange({filledTensorIdx, filledTensorMax}), maps,
|
||||
ArrayRef<Type>({filledTensorMax.getType(), filledTensorIdx.getType()}),
|
||||
input, ValueRange({filledTensorMax, filledTensorIdx}), maps,
|
||||
iteratorTypes,
|
||||
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||
ValueRange blockArgs) {
|
||||
Value newValue = blockArgs[0];
|
||||
Value oldIndex = blockArgs[1];
|
||||
Value oldValue = blockArgs[2];
|
||||
Value oldValue = blockArgs[1];
|
||||
Value oldIndex = blockArgs[2];
|
||||
|
||||
Value newIndex = rewriter.create<arith::IndexCastOp>(
|
||||
nestedLoc, oldIndex.getType(),
|
||||
|
@ -2287,12 +2287,15 @@ public:
|
|||
auto resultIndex = rewriter.create<mlir::SelectOp>(
|
||||
nestedLoc, predicate, newIndex, oldIndex);
|
||||
nestedBuilder.create<linalg::YieldOp>(
|
||||
nestedLoc, ValueRange({resultIndex, resultMax}));
|
||||
nestedLoc, ValueRange({resultMax, resultIndex}));
|
||||
});
|
||||
|
||||
// This cast is required to fix the shape in the case of keepDim=True
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(argmaxOp, resultType,
|
||||
linalgOp.getResult(0));
|
||||
Value maxValuesCast = rewriter.create<tensor::CastOp>(
|
||||
loc, valResultType, linalgOp.getResult(0));
|
||||
Value maxIdxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
|
||||
linalgOp.getResult(1));
|
||||
rewriter.replaceOp(maxDimOp, {maxValuesCast, maxIdxCast});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2558,11 +2561,12 @@ struct ConvertReductionOp : ConversionPattern {
|
|||
// `keepDim` in accordance with their specification.
|
||||
DenseSet<int64_t> dimSet;
|
||||
bool keepDim = false;
|
||||
if (isa<AtenSumOp>(op)) {
|
||||
if (isa<AtenSumOp>(op) || isa<AtenMaxOp>(op)) {
|
||||
auto tensorOperand = operands[0];
|
||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
// `AtenSumOp` reduces along all the dimensiosn of the input tensor.
|
||||
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
|
||||
// input tensor.
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||
dimSet.insert(i);
|
||||
} else if (auto sumDimIntListOp = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||
|
@ -2587,7 +2591,6 @@ struct ConvertReductionOp : ConversionPattern {
|
|||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
|
||||
}
|
||||
|
||||
return createReductionLinalgGeneric(op, operands, dimSet, keepDim,
|
||||
rewriter);
|
||||
}
|
||||
|
@ -4378,6 +4381,8 @@ public:
|
|||
target.addIllegalOp<AtenConstantPadNdOp>();
|
||||
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSumOp>();
|
||||
target.addIllegalOp<AtenSumDimIntListOp>();
|
||||
target.addIllegalOp<AtenMaxOp>();
|
||||
patterns.add<ConvertReductionOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenTransposeIntOp>();
|
||||
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
|
||||
|
@ -4391,8 +4396,8 @@ public:
|
|||
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBroadcastToOp>();
|
||||
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMaxDimOp>();
|
||||
patterns.add<ConvertAtenMaxDimOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSizeIntOp>();
|
||||
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenEmbeddingOp>();
|
||||
|
|
|
@ -325,6 +325,56 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`.
|
||||
namespace {
|
||||
class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenArgmaxOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
Value dim = op.dim();
|
||||
Value keepDim = op.keepdim();
|
||||
Value result = op.result();
|
||||
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
|
||||
|
||||
if (!indicesTensorType.hasSizes())
|
||||
return failure();
|
||||
BaseTensorType valueTensorType =
|
||||
inputType
|
||||
.getWithSizesAndDtype(indicesTensorType.getSizes(),
|
||||
inputType.getDtype())
|
||||
.cast<BaseTensorType>();
|
||||
|
||||
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
||||
// `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input
|
||||
// tensor is flattened to 1d tensor and then the reduction happens on the
|
||||
// 0th dimension.
|
||||
if (dim.getType().isa<Torch::NoneType>()) {
|
||||
BaseTensorType flattenType =
|
||||
inputType.getWithSizesAndDtype({kUnknownSize}, inputType.getDtype())
|
||||
.cast<BaseTensorType>();
|
||||
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value end = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(getTensorRank(input) - 1));
|
||||
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
||||
dim, end);
|
||||
}
|
||||
Value maxResult =
|
||||
rewriter
|
||||
.create<AtenMaxDimOp>(loc, valueTensorType, indicesTensorType,
|
||||
input, dim, keepDim)
|
||||
.indices();
|
||||
|
||||
rewriter.replaceOp(op, maxResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.log_softmax op into: log(softmax(x))
|
||||
namespace {
|
||||
class DecomposeAtenLogSoftmaxIntOp
|
||||
|
@ -678,6 +728,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenArangeOp>();
|
||||
patterns.add<DecomposeAtenArangeStartOp>(context);
|
||||
target.addIllegalOp<AtenArangeStartOp>();
|
||||
patterns.add<DecomposeAtenArgMaxOp>(context);
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -354,6 +354,15 @@ public:
|
|||
Type dtype = operands[0]->getValue().dtype;
|
||||
return visitReductionAlongDimIntOp(anyDim, anyDim.dim(), anyDim.keepdim(),
|
||||
dtype, operands);
|
||||
} else if (auto maxDim = dyn_cast<AtenMaxDimOp>(op)) {
|
||||
Type firstResDtype = operands[0]->getValue().dtype;
|
||||
Type secondResDtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
ChangeResult firstRes = visitReductionAlongDimIntOp(
|
||||
maxDim, maxDim.dim(), maxDim.keepdim(), firstResDtype, operands);
|
||||
return firstRes | visitReductionAlongDimIntOp(
|
||||
maxDim, maxDim.dim(), maxDim.keepdim(),
|
||||
secondResDtype, operands, /*resNum=*/1);
|
||||
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
|
||||
return visitReshapeLikeOp(view, operands);
|
||||
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) {
|
||||
|
@ -457,6 +466,9 @@ public:
|
|||
Type dtype =
|
||||
getDtypeOrDefault(mean.getContext(), mean.dtype(), defaultDtype);
|
||||
return visitReductionAlongAllDimsOp(mean, dtype, operands);
|
||||
} else if (auto max = dyn_cast<AtenMaxOp>(op)) {
|
||||
Type dtype = operands[0]->getValue().dtype;
|
||||
return visitReductionAlongAllDimsOp(max, dtype, operands);
|
||||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
|
||||
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
|
||||
|
@ -547,7 +559,7 @@ private:
|
|||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult visitReductionAlongDimIntOp(
|
||||
Operation *op, Value dim, Value keepdim, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands, int resNum = 0);
|
||||
template <typename OpTy>
|
||||
ChangeResult
|
||||
visitReshapeLikeOp(OpTy op,
|
||||
|
@ -1307,7 +1319,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
|
|||
|
||||
ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
|
||||
Operation *op, Value dim, Value keepdim, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands, int resNum) {
|
||||
assert(dim.getType().isa<Torch::IntType>() && "dim must be int type");
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
|
@ -1331,7 +1343,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
|
|||
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, kUnknownSize);
|
||||
}
|
||||
}
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
return getLatticeElement(op->getResult(resNum)).join(knowledge);
|
||||
}
|
||||
|
||||
// Reshape like ops are given a size list which specify the shape of the
|
||||
|
|
|
@ -587,6 +587,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||
emit("aten::sum : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
|
||||
emit("aten::max : (Tensor) -> (Tensor)")
|
||||
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
|
||||
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
|
||||
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
||||
|
|
|
@ -154,3 +154,34 @@ func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
|
|||
%0 = torch.aten.arange.start %int0, %int10, %none, %none, %none, %none : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
|
||||
return %0 : !torch.vtensor<[?],si64>
|
||||
}
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @torch.aten.argmax(
|
||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[CST0]], %[[TRUE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],f32>, !torch.vtensor<[1,?],si64>
|
||||
// CHECK: return %[[IND]] : !torch.vtensor<[1,?],si64>
|
||||
func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
|
||||
%int0 = torch.constant.int 0
|
||||
%true = torch.constant.bool true
|
||||
%0 = torch.aten.argmax %arg0, %int0, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],si64>
|
||||
return %0 : !torch.vtensor<[1,?],si64>
|
||||
}
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @torch.aten.argmax$reduceall(
|
||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[INP]], %[[CST0]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] : !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[IND]] : !torch.vtensor<[],si64>
|
||||
func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
|
||||
return %0 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue