mirror of https://github.com/llvm/torch-mlir
[tosa] Implement Argmax support (#485)
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/486/head
parent
d13bb0e5c1
commit
829cf8afc3
|
@ -12,6 +12,8 @@
|
|||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
@ -404,6 +406,93 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
||||
AtenArgmaxOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value self = adaptor.self();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
|
||||
if (!selfTy)
|
||||
return op.emitError("Only ranked tensor types supported in TOSA argmax");
|
||||
|
||||
int64_t reduceDim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim))) {
|
||||
// NoneType indicates reduce on all dims
|
||||
reduceDim = -1;
|
||||
}
|
||||
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const keepdim parameter unsupported");
|
||||
|
||||
auto resultTy = getTypeConverter()
|
||||
->convertType(op.getResult().getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto outputETy = resultTy.getElementType();
|
||||
|
||||
// Create a single instance of tosa.argmax.
|
||||
// Multiple dims require chained construct.
|
||||
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputShape = inputTy.getShape();
|
||||
SmallVector<int64_t> outputShapeArr = {};
|
||||
int32_t i = 0;
|
||||
|
||||
for (auto &dim : inputShape) {
|
||||
if (i++ != reduceDim) {
|
||||
outputShapeArr.push_back(dim);
|
||||
} else {
|
||||
if (keepDim)
|
||||
outputShapeArr.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Tosa argmax output is i32, while Torch backend mandates i64.
|
||||
auto outputReduceTy = RankedTensorType::get(
|
||||
ArrayRef<int64_t>(outputShapeArr), rewriter.getI32Type());
|
||||
auto reduceDimAttr =
|
||||
rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim);
|
||||
return rewriter
|
||||
.create<tosa::ArgMaxOp>(op->getLoc(),
|
||||
getTypeConverter()->convertType(outputReduceTy),
|
||||
input, reduceDimAttr)
|
||||
.getResult();
|
||||
};
|
||||
|
||||
// Convert the final index to i64 for backend finalization, However, i64
|
||||
// is not a defined type for tosa.cast, so using arith.extsi instead.
|
||||
auto castToInt64 = [&](Value result) -> LogicalResult {
|
||||
auto resTy = result.getType().cast<ShapedType>();
|
||||
if (!resTy)
|
||||
return op.emitError("Argmax: Result is not a shaped type");
|
||||
|
||||
auto resShape = resTy.getShape();
|
||||
auto outTy =
|
||||
RankedTensorType::get(resShape, outputETy); // rewriter.getI64Type());
|
||||
|
||||
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(
|
||||
op, getTypeConverter()->convertType(outTy), result);
|
||||
|
||||
return success();
|
||||
};
|
||||
|
||||
if (reduceDim == -1) { // reducing on all dims
|
||||
Value input = self;
|
||||
for (int dim = 0; dim < selfTy.getRank(); dim++) {
|
||||
// progressively reduce each 0-th dim
|
||||
input = buildArgmax(0, input);
|
||||
}
|
||||
return castToInt64(input);
|
||||
} else {
|
||||
return castToInt64(buildArgmax(reduceDim, self));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -415,13 +504,16 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
|
|||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<tosa::TosaDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<tosa::TosaDialect>();
|
||||
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
|
||||
arith::ArithmeticDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
@ -491,6 +583,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
|
|
@ -42,6 +42,8 @@ class VerifyTosaBackendContractPass
|
|||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
|
||||
// Basic scalar operations.
|
||||
target.addLegalDialect<tosa::TosaDialect>();
|
||||
target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalOp<arith::ExtSIOp>(opHasLegalTypes);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
|
|
Loading…
Reference in New Issue