[tosa] Implement Argmax support (#485)

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
pull/486/head
Suraj Sudhir 2021-12-15 11:01:01 -08:00 committed by GitHub
parent d13bb0e5c1
commit 829cf8afc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 1 deletions

View File

@ -12,6 +12,8 @@
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
@ -404,6 +406,93 @@ public:
}
};
template <>
LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
AtenArgmaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("Only ranked tensor types supported in TOSA argmax");
int64_t reduceDim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim))) {
// NoneType indicates reduce on all dims
reduceDim = -1;
}
bool keepDim = false;
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(
op, "non-const keepdim parameter unsupported");
auto resultTy = getTypeConverter()
->convertType(op.getResult().getType())
.cast<RankedTensorType>();
auto outputETy = resultTy.getElementType();
// Create a single instance of tosa.argmax.
// Multiple dims require chained construct.
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
auto inputTy = input.getType().cast<RankedTensorType>();
auto inputShape = inputTy.getShape();
SmallVector<int64_t> outputShapeArr = {};
int32_t i = 0;
for (auto &dim : inputShape) {
if (i++ != reduceDim) {
outputShapeArr.push_back(dim);
} else {
if (keepDim)
outputShapeArr.push_back(1);
}
}
// Tosa argmax output is i32, while Torch backend mandates i64.
auto outputReduceTy = RankedTensorType::get(
ArrayRef<int64_t>(outputShapeArr), rewriter.getI32Type());
auto reduceDimAttr =
rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim);
return rewriter
.create<tosa::ArgMaxOp>(op->getLoc(),
getTypeConverter()->convertType(outputReduceTy),
input, reduceDimAttr)
.getResult();
};
// Convert the final index to i64 for backend finalization, However, i64
// is not a defined type for tosa.cast, so using arith.extsi instead.
auto castToInt64 = [&](Value result) -> LogicalResult {
auto resTy = result.getType().cast<ShapedType>();
if (!resTy)
return op.emitError("Argmax: Result is not a shaped type");
auto resShape = resTy.getShape();
auto outTy =
RankedTensorType::get(resShape, outputETy); // rewriter.getI64Type());
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(
op, getTypeConverter()->convertType(outTy), result);
return success();
};
if (reduceDim == -1) { // reducing on all dims
Value input = self;
for (int dim = 0; dim < selfTy.getRank(); dim++) {
// progressively reduce each 0-th dim
input = buildArgmax(0, input);
}
return castToInt64(input);
} else {
return castToInt64(buildArgmax(reduceDim, self));
}
return success();
}
} // namespace
// -----------------------------------------------------------------------------
@ -415,13 +504,16 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tosa::TosaDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<tosa::TosaDialect>();
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
arith::ArithmeticDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
@ -491,6 +583,7 @@ public:
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
#undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,

View File

@ -42,6 +42,8 @@ class VerifyTosaBackendContractPass
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
// Basic scalar operations.
target.addLegalDialect<tosa::TosaDialect>();
target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes);
target.addDynamicallyLegalOp<arith::ExtSIOp>(opHasLegalTypes);
RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {