[TORCH] Add decomposition of `aten.linear` op

This commit adds decomposition of `aten.linear` op. Due to limited
support at tosa backend in case of dynamic dimensions, this
decomposition is currently disabled for tosa backend.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/1350/head
Gaurav Shukla 2022-05-18 21:59:04 +05:30
parent cc86cc0f02
commit 99093d0623
3 changed files with 67 additions and 173 deletions

View File

@ -296,6 +296,21 @@ public:
op, "unable to perform broadcast operation");
}
if (maxRank == 3) {
Value zeroTensor = createZeroInitTensor(
rewriter, loc,
ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1},
elementType);
Value matmul =
rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// Check if the result of the matrix multiplication has more than one
// dynamic batch dimensions.
ArrayRef<int64_t> batchDimsInt = resultType.getShape().drop_back(2);
@ -454,176 +469,6 @@ public:
};
} // namespace
namespace {
// See comments at in convertMmOp and the heading for this section for general
// considerations. This function needs to be auto-generated.
class ConvertAtenLinearOp : public OpConversionPattern<AtenLinearOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenLinearOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
if (inputType.getRank() != 2 && inputType.getRank() != 3) {
return rewriter.notifyMatchFailure(
op, "expected input to be rank 2 or rank 3");
}
if (!bias.getType().isa<Torch::NoneType>()) {
auto biasType = bias.getType().cast<RankedTensorType>();
// Only handle the case of rank 2 `weight` for now.
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expected weight to be rank 2 and bias to be rank 1");
}
// TODO: Handle type promotion. What are ATen's promotion rules?
if (inputType.getElementType() != weightType.getElementType() ||
inputType.getElementType() != biasType.getElementType()) {
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
}
// TODO: We can handle a static size 1 here at some complexity cost, but the
// dynamic case is not representable in linalg. We don't handle either for
// now. Biases are generally statically shaped for most models (since for
// inference they are constants, and for training they don't change shape
// typically), so this is not too constraining.
auto biasSize = bias.getType().cast<RankedTensorType>().getShape()[0];
if (biasSize == 1 || biasSize == ShapedType::kDynamicSize)
return rewriter.notifyMatchFailure(
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
}
Value batchDim = nullptr;
int restDim = 0;
if (inputType.getRank() == 3) {
batchDim = getDimOp(rewriter, loc, input, 0);
restDim = 1;
}
Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0);
Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1);
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, inputDim1, weightDim1);
rewriter.create<cf::AssertOp>(
loc, contractingDimEqual,
rewriter.getStringAttr(
"mismatching contracting dimension for aten.linear"));
if (!bias.getType().isa<Torch::NoneType>()) {
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
// Here we take advantage of ruling out the size-1 case above.
// In the static-size-1 case, we will not emit this check at all.
Value biasSizeCorrect = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, weightDim0, biasDim0);
rewriter.create<cf::AssertOp>(
loc, biasSizeCorrect,
rewriter.getStringAttr("mismatching bias size for aten.linear"));
}
Value initTensor;
SmallVector<AffineMap> broadcastIndexingMaps;
Value transposedWeightInitTensor;
if (inputType.getRank() > 2) {
initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{batchDim, inputDim0, weightDim0},
inputType.getElementType());
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{batchDim, weightDim1, weightDim0},
weightType.getElementType());
broadcastIndexingMaps = {
AffineMap::get(
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1 + restDim)}, context),
rewriter.getMultiDimIdentityMap(inputType.getRank())};
} else {
initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{inputDim0, weightDim0},
inputType.getElementType());
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
broadcastIndexingMaps = {
AffineMap::get(
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1)}, context),
rewriter.getMultiDimIdentityMap(inputType.getRank())};
}
SmallVector<StringRef> iteratorTypes(inputType.getRank(), "parallel");
Value broadcasted;
if (!bias.getType().isa<Torch::NoneType>()) {
broadcasted =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
/*indexingMaps=*/broadcastIndexingMaps,
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
} else {
Type elementType =
input.getType().cast<RankedTensorType>().getElementType();
Value c0float = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
broadcasted = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
.getResult(0);
}
// We need a matmul with dimension ordering (N, K) * (M, K), so transpose
// the weights to fit into linalg::MatmulOp which is (N, K) * (K, M).
// TODO: This whole aten.linear lowering should eventually be generated from
// a single linalg ODS generator statement. Both the bias and matmul part.
SmallVector<AffineMap> transposeIndexingMaps = {
AffineMap::get(
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1 + restDim),
rewriter.getAffineDimExpr(0 + restDim)},
context),
rewriter.getMultiDimIdentityMap(inputType.getRank())};
Value transposedWeights =
rewriter
.create<linalg::GenericOp>(
loc, transposedWeightInitTensor.getType(), weight,
transposedWeightInitTensor,
/*indexingMaps=*/transposeIndexingMaps,
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
Value matmul;
if (batchDim)
matmul = rewriter
.create<linalg::BatchMatmulOp>(
loc, broadcasted.getType(),
ValueRange{input, transposedWeights}, broadcasted)
.getResult(0);
else
matmul = rewriter
.create<linalg::MatmulOp>(
loc, broadcasted.getType(),
ValueRange{input, transposedWeights}, broadcasted)
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
};
} // namespace
namespace {
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
public:
@ -996,8 +841,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
target.addIllegalOp<AtenBmmOp>();
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenLinearOp>();
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
target.addIllegalOp<AtenConvolutionOp>();
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
}

View File

@ -2060,6 +2060,55 @@ public:
};
} // namespace
namespace {
// Decompose `aten.linear` op into `aten.matmul` and `aten.add` ops.
class DecomposeAtenLinearOp : public OpRewritePattern<AtenLinearOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLinearOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.input();
Value weight = op.weight();
Value bias = op.bias();
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes() || inputType.getSizes().size() < 2)
return rewriter.notifyMatchFailure(
op, "expected input to be rank 2 or greater");
BaseTensorType weightType = weight.getType().cast<BaseTensorType>();
// `weight` must be a rank 2 matrix.
if (!weightType.hasSizes() || weightType.getSizes().size() != 2)
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");
SmallVector<int64_t> transposeShape =
llvm::to_vector(llvm::reverse(weightType.getSizes()));
Type transposeType = weightType.getWithSizesAndDtype(
llvm::makeArrayRef(transposeShape), weightType.getDtype());
Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight);
Value matmul = rewriter.create<AtenMatmulOp>(loc, op.getType(), input,
transposeWeight);
if (bias.getType().isa<Torch::NoneType>()) {
rewriter.replaceOp(op, matmul);
return success();
}
BaseTensorType biasType = bias.getType().cast<BaseTensorType>();
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), matmul,
op.bias(), alpha);
return success();
}
};
} // namespace
namespace {
// Decompose `aten.full_like` op into `aten.empty_like` and `aten.fill` ops.
class DecomposeAtenFullLikeOp : public OpRewritePattern<AtenFullLikeOp> {
@ -2837,6 +2886,8 @@ public:
target.addIllegalOp<AtenHardtanhOp>();
patterns.add<DecomposeAtenFullOp>(context);
target.addIllegalOp<AtenFullOp>();
patterns.add<DecomposeAtenLinearOp>(context);
target.addIllegalOp<AtenLinearOp>();
patterns.add<DecomposeAtenFullLikeOp>(context);
target.addIllegalOp<AtenFullLikeOp>();
patterns.add<DecomposeAtenIndexPutOp>(context);

View File

@ -126,7 +126,7 @@ class TensorPlaceholder:
# ops in the backend contract, and move these lists somewhere deeper in the
# compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm'],
OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm','torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',],
OutputType.MHLO: [],
}