mirror of https://github.com/llvm/torch-mlir
[TORCH] Add decomposition of `aten.linear` op
This commit adds decomposition of `aten.linear` op. Due to limited support at tosa backend in case of dynamic dimensions, this decomposition is currently disabled for tosa backend. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/1350/head
parent
cc86cc0f02
commit
99093d0623
|
@ -296,6 +296,21 @@ public:
|
|||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
||||
if (maxRank == 3) {
|
||||
Value zeroTensor = createZeroInitTensor(
|
||||
rewriter, loc,
|
||||
ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1},
|
||||
elementType);
|
||||
Value matmul =
|
||||
rewriter
|
||||
.create<linalg::BatchMatmulOp>(
|
||||
loc, zeroTensor.getType(),
|
||||
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Check if the result of the matrix multiplication has more than one
|
||||
// dynamic batch dimensions.
|
||||
ArrayRef<int64_t> batchDimsInt = resultType.getShape().drop_back(2);
|
||||
|
@ -454,176 +469,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// See comments at in convertMmOp and the heading for this section for general
|
||||
// considerations. This function needs to be auto-generated.
|
||||
class ConvertAtenLinearOp : public OpConversionPattern<AtenLinearOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenLinearOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MLIRContext *context = op->getContext();
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.input();
|
||||
Value weight = adaptor.weight();
|
||||
Value bias = adaptor.bias();
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||
|
||||
if (inputType.getRank() != 2 && inputType.getRank() != 3) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input to be rank 2 or rank 3");
|
||||
}
|
||||
|
||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
// Only handle the case of rank 2 `weight` for now.
|
||||
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
|
||||
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected weight to be rank 2 and bias to be rank 1");
|
||||
}
|
||||
// TODO: Handle type promotion. What are ATen's promotion rules?
|
||||
if (inputType.getElementType() != weightType.getElementType() ||
|
||||
inputType.getElementType() != biasType.getElementType()) {
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
|
||||
}
|
||||
// TODO: We can handle a static size 1 here at some complexity cost, but the
|
||||
// dynamic case is not representable in linalg. We don't handle either for
|
||||
// now. Biases are generally statically shaped for most models (since for
|
||||
// inference they are constants, and for training they don't change shape
|
||||
// typically), so this is not too constraining.
|
||||
auto biasSize = bias.getType().cast<RankedTensorType>().getShape()[0];
|
||||
if (biasSize == 1 || biasSize == ShapedType::kDynamicSize)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
|
||||
}
|
||||
|
||||
|
||||
Value batchDim = nullptr;
|
||||
int restDim = 0;
|
||||
if (inputType.getRank() == 3) {
|
||||
batchDim = getDimOp(rewriter, loc, input, 0);
|
||||
restDim = 1;
|
||||
}
|
||||
|
||||
Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0);
|
||||
Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1);
|
||||
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
|
||||
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
|
||||
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, inputDim1, weightDim1);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, contractingDimEqual,
|
||||
rewriter.getStringAttr(
|
||||
"mismatching contracting dimension for aten.linear"));
|
||||
|
||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
|
||||
// Here we take advantage of ruling out the size-1 case above.
|
||||
// In the static-size-1 case, we will not emit this check at all.
|
||||
Value biasSizeCorrect = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, weightDim0, biasDim0);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, biasSizeCorrect,
|
||||
rewriter.getStringAttr("mismatching bias size for aten.linear"));
|
||||
}
|
||||
|
||||
Value initTensor;
|
||||
SmallVector<AffineMap> broadcastIndexingMaps;
|
||||
Value transposedWeightInitTensor;
|
||||
if (inputType.getRank() > 2) {
|
||||
initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{batchDim, inputDim0, weightDim0},
|
||||
inputType.getElementType());
|
||||
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{batchDim, weightDim1, weightDim0},
|
||||
weightType.getElementType());
|
||||
broadcastIndexingMaps = {
|
||||
AffineMap::get(
|
||||
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
|
||||
{rewriter.getAffineDimExpr(1 + restDim)}, context),
|
||||
rewriter.getMultiDimIdentityMap(inputType.getRank())};
|
||||
} else {
|
||||
initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{inputDim0, weightDim0},
|
||||
inputType.getElementType());
|
||||
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
|
||||
broadcastIndexingMaps = {
|
||||
AffineMap::get(
|
||||
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
|
||||
{rewriter.getAffineDimExpr(1)}, context),
|
||||
rewriter.getMultiDimIdentityMap(inputType.getRank())};
|
||||
}
|
||||
|
||||
SmallVector<StringRef> iteratorTypes(inputType.getRank(), "parallel");
|
||||
Value broadcasted;
|
||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||
broadcasted =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor.getType(), bias, initTensor,
|
||||
/*indexingMaps=*/broadcastIndexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
} else {
|
||||
Type elementType =
|
||||
input.getType().cast<RankedTensorType>().getElementType();
|
||||
Value c0float = rewriter.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 0.0));
|
||||
broadcasted = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
|
||||
.getResult(0);
|
||||
}
|
||||
// We need a matmul with dimension ordering (N, K) * (M, K), so transpose
|
||||
// the weights to fit into linalg::MatmulOp which is (N, K) * (K, M).
|
||||
// TODO: This whole aten.linear lowering should eventually be generated from
|
||||
// a single linalg ODS generator statement. Both the bias and matmul part.
|
||||
SmallVector<AffineMap> transposeIndexingMaps = {
|
||||
AffineMap::get(
|
||||
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
|
||||
{rewriter.getAffineDimExpr(1 + restDim),
|
||||
rewriter.getAffineDimExpr(0 + restDim)},
|
||||
context),
|
||||
rewriter.getMultiDimIdentityMap(inputType.getRank())};
|
||||
Value transposedWeights =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, transposedWeightInitTensor.getType(), weight,
|
||||
transposedWeightInitTensor,
|
||||
/*indexingMaps=*/transposeIndexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
Value matmul;
|
||||
if (batchDim)
|
||||
matmul = rewriter
|
||||
.create<linalg::BatchMatmulOp>(
|
||||
loc, broadcasted.getType(),
|
||||
ValueRange{input, transposedWeights}, broadcasted)
|
||||
.getResult(0);
|
||||
else
|
||||
matmul = rewriter
|
||||
.create<linalg::MatmulOp>(
|
||||
loc, broadcasted.getType(),
|
||||
ValueRange{input, transposedWeights}, broadcasted)
|
||||
.getResult(0);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
|
||||
public:
|
||||
|
@ -996,8 +841,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
|
|||
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBmmOp>();
|
||||
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenLinearOp>();
|
||||
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenConvolutionOp>();
|
||||
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -2060,6 +2060,55 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.linear` op into `aten.matmul` and `aten.add` ops.
|
||||
class DecomposeAtenLinearOp : public OpRewritePattern<AtenLinearOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenLinearOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
Value bias = op.bias();
|
||||
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasSizes() || inputType.getSizes().size() < 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input to be rank 2 or greater");
|
||||
|
||||
BaseTensorType weightType = weight.getType().cast<BaseTensorType>();
|
||||
// `weight` must be a rank 2 matrix.
|
||||
if (!weightType.hasSizes() || weightType.getSizes().size() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");
|
||||
|
||||
SmallVector<int64_t> transposeShape =
|
||||
llvm::to_vector(llvm::reverse(weightType.getSizes()));
|
||||
Type transposeType = weightType.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(transposeShape), weightType.getDtype());
|
||||
Value transposeWeight =
|
||||
rewriter.create<AtenTOp>(loc, transposeType, weight);
|
||||
|
||||
Value matmul = rewriter.create<AtenMatmulOp>(loc, op.getType(), input,
|
||||
transposeWeight);
|
||||
if (bias.getType().isa<Torch::NoneType>()) {
|
||||
rewriter.replaceOp(op, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
BaseTensorType biasType = bias.getType().cast<BaseTensorType>();
|
||||
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
||||
|
||||
Value alpha =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||||
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), matmul,
|
||||
op.bias(), alpha);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.full_like` op into `aten.empty_like` and `aten.fill` ops.
|
||||
class DecomposeAtenFullLikeOp : public OpRewritePattern<AtenFullLikeOp> {
|
||||
|
@ -2837,6 +2886,8 @@ public:
|
|||
target.addIllegalOp<AtenHardtanhOp>();
|
||||
patterns.add<DecomposeAtenFullOp>(context);
|
||||
target.addIllegalOp<AtenFullOp>();
|
||||
patterns.add<DecomposeAtenLinearOp>(context);
|
||||
target.addIllegalOp<AtenLinearOp>();
|
||||
patterns.add<DecomposeAtenFullLikeOp>(context);
|
||||
target.addIllegalOp<AtenFullLikeOp>();
|
||||
patterns.add<DecomposeAtenIndexPutOp>(context);
|
||||
|
|
|
@ -126,7 +126,7 @@ class TensorPlaceholder:
|
|||
# ops in the backend contract, and move these lists somewhere deeper in the
|
||||
# compiler where each backend can "own" its set of legal ops.
|
||||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm'],
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm','torch.aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',],
|
||||
OutputType.MHLO: [],
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue