mirror of https://github.com/llvm/torch-mlir
461 lines
20 KiB
C++
461 lines
20 KiB
C++
//===----------------------------------------------------------------------===//
|
||
//
|
||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
// See https://llvm.org/LICENSE.txt for license information.
|
||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
//
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||
|
||
#include "../PassDetail.h"
|
||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||
#include "mlir/Dialect/Math/IR/Math.h"
|
||
#include "mlir/Dialect/MemRef/IR/MemRef.h" // TODO: For `memref.dim`.
|
||
#include "mlir/Dialect/Traits.h"
|
||
#include "mlir/Transforms/DialectConversion.h"
|
||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h"
|
||
|
||
using namespace mlir;
|
||
using namespace mlir::NPCOMP;
|
||
using namespace mlir::NPCOMP::Torch;
|
||
|
||
// -----------------------------------------------------------------------------
|
||
// Patterns (as this grows, it should be organized into multiple files)
|
||
// -----------------------------------------------------------------------------
|
||
// This is going to eventually be O(#aten ops), which is in the 100s.
|
||
//
|
||
// Most of these patterns consist of:
|
||
// 1. Checking that the operand/result types and other static properties are
|
||
// good-enough to create a valid linalg op (such as operands being of
|
||
// ranks/dtypes acceptable to the linalg op).
|
||
// 2. Creating dynamic error guards, usually checking a predicate on the
|
||
// compatibility of operand shapes.
|
||
// 3. Creating init tensors for the computation op. Usually this involves
|
||
// reifying IR for a shape transfer function based on the operand shapes.
|
||
// 4. Creating a named linalg op to replace the original op.
|
||
//
|
||
// TODO: Use linalg OpDSL to autogenerate at least 1)/2)/3) such
|
||
// that these patterns become mostly mechanical associations of
|
||
// "aten.foo -> linalg.foo".
|
||
|
||
static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
||
PatternRewriter &rewriter) {
|
||
// Check the value tensor is ranked as expected by Linalg.
|
||
// TODO: Remove this check but use a separate verification pass to verify the
|
||
// invariants expected by later passes.
|
||
auto isValidLinalgType = [](Type type) {
|
||
auto tensor = type.dyn_cast<ValueTensorType>();
|
||
return !tensor ||
|
||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
|
||
};
|
||
|
||
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
|
||
llvm::all_of(op->getResultTypes(), isValidLinalgType);
|
||
if (!valid)
|
||
return rewriter.notifyMatchFailure(op, "type cannot be lowered to linalg");
|
||
return success();
|
||
}
|
||
|
||
namespace {
|
||
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
|
||
public:
|
||
using OpConversionPattern::OpConversionPattern;
|
||
LogicalResult
|
||
matchAndRewrite(AtenBatchNormOp op, ArrayRef<Value> operands,
|
||
ConversionPatternRewriter &rewriter) const override {
|
||
AtenBatchNormOp::Adaptor adaptor(operands);
|
||
MLIRContext *context = op->getContext();
|
||
Location loc = op->getLoc();
|
||
Value input = adaptor.input();
|
||
Value weight = adaptor.weight();
|
||
Value bias = adaptor.bias();
|
||
Value runningMean = adaptor.running_mean();
|
||
Value runningVar = adaptor.running_var();
|
||
Value training = adaptor.training();
|
||
Value eps = adaptor.eps();
|
||
|
||
// TODO: Handle the None cases for the optional parameters:
|
||
// weight, bias, running_mean, running_var.
|
||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||
return failure();
|
||
|
||
auto inputType = input.getType().cast<RankedTensorType>();
|
||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||
auto runningMeanType = runningMean.getType().cast<RankedTensorType>();
|
||
auto runningVarType = runningVar.getType().cast<RankedTensorType>();
|
||
|
||
auto inputRank = inputType.getRank();
|
||
if (inputRank <= 2)
|
||
return rewriter.notifyMatchFailure(
|
||
op, "input should have rank larger than 2");
|
||
|
||
if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
|
||
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "expect weight, bias, running_mean and running_var to be rank 1");
|
||
}
|
||
|
||
// TODO: Add support for training.
|
||
auto constFalse = rewriter.create<ConstantOp>(
|
||
loc, IntegerAttr::get(IntegerType::get(context, 1), 0));
|
||
auto trainingFalse =
|
||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, training, constFalse);
|
||
rewriter.create<AssertOp>(
|
||
loc, trainingFalse,
|
||
rewriter.getStringAttr("training is not supported for now"));
|
||
|
||
// num_features – C from an expected input of size (N,C,D,H,W ...)
|
||
Value numFeatures = rewriter.create<memref::DimOp>(loc, input, 1);
|
||
auto contractingDim0EqualsNumFeatures = [&](Value v) {
|
||
auto dim0 = rewriter.create<memref::DimOp>(loc, v, 0);
|
||
auto dim0Equal =
|
||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, numFeatures, dim0);
|
||
rewriter.create<AssertOp>(
|
||
loc, dim0Equal,
|
||
rewriter.getStringAttr(
|
||
"expect the size of dim 0 equal to the number of features"));
|
||
};
|
||
contractingDim0EqualsNumFeatures(weight);
|
||
contractingDim0EqualsNumFeatures(bias);
|
||
contractingDim0EqualsNumFeatures(runningMean);
|
||
contractingDim0EqualsNumFeatures(runningVar);
|
||
|
||
auto indexingMap = AffineMap::get(
|
||
/*dimCount=*/inputRank,
|
||
/*symbolCount=*/0, rewriter.getAffineDimExpr(1), context);
|
||
SmallVector<AffineMap> indexingMaps = {
|
||
rewriter.getMultiDimIdentityMap(inputRank), // input
|
||
indexingMap, // weight
|
||
indexingMap, // bias
|
||
indexingMap, // runningMean
|
||
indexingMap, // runningVar
|
||
rewriter.getMultiDimIdentityMap(inputRank), // output
|
||
};
|
||
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
|
||
Value batchNorm =
|
||
rewriter
|
||
.create<linalg::GenericOp>(
|
||
loc, input.getType(),
|
||
ValueRange{input, weight, bias, runningMean, runningVar}, input,
|
||
/*indexingMaps=*/indexingMaps,
|
||
/*iteratorTypes=*/iteratorTypes,
|
||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||
Value input = args[0], weight = args[1], bias = args[2],
|
||
mean = args[3], var = args[4];
|
||
// ((input - mean) / sqrt(var + eps)) * weight + bias
|
||
Value inputSubMean = b.create<SubFOp>(loc, input, mean);
|
||
// The eps is always f64.
|
||
Value truncatedEps =
|
||
b.create<FPTruncOp>(loc, var.getType(), eps);
|
||
Value varPlusEps = b.create<AddFOp>(loc, var, truncatedEps);
|
||
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
|
||
Value temp = b.create<MulFOp>(loc, inputSubMean, rSTD);
|
||
Value timesWeight = b.create<MulFOp>(loc, temp, weight);
|
||
Value plusBias = b.create<AddFOp>(loc, timesWeight, bias);
|
||
b.create<linalg::YieldOp>(loc, plusBias);
|
||
})
|
||
.getResult(0);
|
||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, batchNorm);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
||
public:
|
||
using OpConversionPattern::OpConversionPattern;
|
||
LogicalResult
|
||
matchAndRewrite(AtenMmOp op, ArrayRef<Value> operands,
|
||
ConversionPatternRewriter &rewriter) const override {
|
||
Location loc = op->getLoc();
|
||
Value lhs = operands[0];
|
||
Value rhs = operands[1];
|
||
|
||
// A user can write an errorneous program where `aten.mm` is in fact called
|
||
// with operands of invalid rank or dtype. We cannot convert to linalg in
|
||
// this case or we will get a verifier error, which corresponds to breaking
|
||
// of *internal* compiler invariants, and for a user manifests as a compiler
|
||
// crash in the worst case (such as we try to canonicalize/fold/print the
|
||
// invalid op before the verifier gets to see it -- also release builds of a
|
||
// mature copmiler usually have the verifier turned off for compile time
|
||
// reasons).
|
||
//
|
||
// The compiler cannot crash even if the user wrote an erroneous program!
|
||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||
return failure();
|
||
if (lhs.getType().cast<RankedTensorType>().getRank() != 2 ||
|
||
rhs.getType().cast<RankedTensorType>().getRank() != 2) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "expected both operands to aten.mm to be rank 2");
|
||
}
|
||
|
||
Value lhsDim0 = rewriter.create<memref::DimOp>(loc, lhs, 0);
|
||
Value lhsDim1 = rewriter.create<memref::DimOp>(loc, lhs, 1);
|
||
Value rhsDim0 = rewriter.create<memref::DimOp>(loc, rhs, 0);
|
||
Value rhsDim1 = rewriter.create<memref::DimOp>(loc, rhs, 1);
|
||
Value contractingDimEqual =
|
||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsDim1, rhsDim0);
|
||
rewriter.create<AssertOp>(
|
||
loc, contractingDimEqual,
|
||
rewriter.getStringAttr(
|
||
"mismatching contracting dimension for torch.aten.mm"));
|
||
|
||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||
Type elementType = newResultType.cast<TensorType>().getElementType();
|
||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||
loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||
Value c0 =
|
||
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||
Value zeroFill =
|
||
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||
Value matmul = rewriter
|
||
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
|
||
ValueRange{lhs, rhs}, zeroFill)
|
||
.getResult(0);
|
||
// When constructed with just dynamic sizes, InitTensorOp will have a result
|
||
// type which has all `?`'s for dimensions, which might not be the result
|
||
// type of `op`. The constraints on later linalg ops means that the result
|
||
// of the MatmulOp will have this type too. So cast it to the desired type
|
||
// so that in the end we have the original result type.
|
||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
// See comments at in convertMmOp and the heading for this section for general
|
||
// considerations. This function needs to be auto-generated.
|
||
class ConvertAtenLinearOp : public OpConversionPattern<AtenLinearOp> {
|
||
public:
|
||
using OpConversionPattern::OpConversionPattern;
|
||
LogicalResult
|
||
matchAndRewrite(AtenLinearOp op, ArrayRef<Value> operands,
|
||
ConversionPatternRewriter &rewriter) const override {
|
||
AtenLinearOp::Adaptor adaptor(operands);
|
||
MLIRContext *context = op->getContext();
|
||
Location loc = op->getLoc();
|
||
Value input = adaptor.input();
|
||
Value weight = adaptor.weight();
|
||
Value bias = adaptor.bias();
|
||
// TODO: Handle the case of bias being None (bias is optional).
|
||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||
return failure();
|
||
auto inputType = input.getType().cast<RankedTensorType>();
|
||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||
// Only handle the case of rank 2 `input` for now.
|
||
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
|
||
if (inputType.getRank() != 2 || weightType.getRank() != 2 ||
|
||
biasType.getRank() != 1) {
|
||
return rewriter.notifyMatchFailure(
|
||
op,
|
||
"expected both input and weight to be rank 2 and bias to be rank 1");
|
||
}
|
||
// TODO: Handle type promotion. What are ATen's promotion rules?
|
||
if (inputType.getElementType() != weightType.getElementType() ||
|
||
inputType.getElementType() != biasType.getElementType()) {
|
||
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
|
||
}
|
||
|
||
// TODO: We can handle a static size 1 here at some complexity cost, but the
|
||
// dynamic case is not representable in linalg. We don't handle either for
|
||
// now. Biases are generally statically shaped for most models (since for
|
||
// inference they are constants, and for training they don't change shape
|
||
// typically), so this is not too constraining.
|
||
auto biasSize = bias.getType().cast<RankedTensorType>().getShape()[0];
|
||
if (biasSize == 1 || biasSize == ShapedType::kDynamicSize)
|
||
return rewriter.notifyMatchFailure(
|
||
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
|
||
|
||
auto getDimOp = [&](Value v, int dimension) {
|
||
return rewriter.create<memref::DimOp>(loc, v, dimension);
|
||
};
|
||
Value inputDim0 = getDimOp(input, 0);
|
||
Value inputDim1 = getDimOp(input, 1);
|
||
Value weightDim0 = getDimOp(weight, 0);
|
||
Value weightDim1 = getDimOp(weight, 1);
|
||
Value biasDim0 = getDimOp(bias, 0);
|
||
Value contractingDimEqual =
|
||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, inputDim1, weightDim1);
|
||
rewriter.create<AssertOp>(
|
||
loc, contractingDimEqual,
|
||
rewriter.getStringAttr(
|
||
"mismatching contracting dimension for aten.linear"));
|
||
// Here we take advantage of ruling out the size-1 case above.
|
||
// In the static-size-1 case, we will not emit this check at all.
|
||
Value biasSizeCorrect =
|
||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, weightDim0, biasDim0);
|
||
rewriter.create<AssertOp>(
|
||
loc, biasSizeCorrect,
|
||
rewriter.getStringAttr("mismatching bias size for aten.linear"));
|
||
|
||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||
loc, ValueRange{inputDim0, weightDim0}, inputType.getElementType());
|
||
SmallVector<AffineMap> broadcastIndexingMaps = {
|
||
AffineMap::get(
|
||
/*dimCount=*/2, /*symbolCount=*/0, rewriter.getAffineDimExpr(1)),
|
||
rewriter.getMultiDimIdentityMap(2)};
|
||
SmallVector<StringRef> iteratorTypes(2, "parallel");
|
||
Value broadcasted =
|
||
rewriter
|
||
.create<linalg::GenericOp>(
|
||
loc, initTensor.getType(), bias, initTensor,
|
||
/*indexingMaps=*/broadcastIndexingMaps,
|
||
/*iteratorTypes=*/iteratorTypes,
|
||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||
b.create<linalg::YieldOp>(loc, args[0]);
|
||
})
|
||
.getResult(0);
|
||
// We need a matmul with dimension ordering (N, K) * (M, K), so transpose
|
||
// the weights to fit into linalg::MatmulOp which is (N, K) * (K, M).
|
||
// TODO: This whole aten.linear lowering should eventually be generated from
|
||
// a single linalg ODS generator statement. Both the bias and matmul part.
|
||
SmallVector<AffineMap> transposeIndexingMaps = {
|
||
AffineMap::get(
|
||
/*dimCount=*/2, /*symbolCount=*/0,
|
||
{rewriter.getAffineDimExpr(1), rewriter.getAffineDimExpr(0)},
|
||
context),
|
||
rewriter.getMultiDimIdentityMap(2)};
|
||
Value transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
|
||
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
|
||
Value transposedWeights =
|
||
rewriter
|
||
.create<linalg::GenericOp>(
|
||
loc, transposedWeightInitTensor.getType(), weight,
|
||
transposedWeightInitTensor,
|
||
/*indexingMaps=*/transposeIndexingMaps,
|
||
/*iteratorTypes=*/iteratorTypes,
|
||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||
b.create<linalg::YieldOp>(loc, args[0]);
|
||
})
|
||
.getResult(0);
|
||
Value matmul = rewriter
|
||
.create<linalg::MatmulOp>(
|
||
loc, broadcasted.getType(),
|
||
ValueRange{input, transposedWeights}, broadcasted)
|
||
.getResult(0);
|
||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
static Value createScalarRelu(OpBuilder &b, Location loc, ValueRange args) {
|
||
Type elementType = args[0].getType();
|
||
// TODO: Add support for integer types.
|
||
assert(elementType.isa<::mlir::FloatType>() &&
|
||
"Only support float case for relu");
|
||
|
||
Value constZero = b.create<ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||
Value pred = b.create<CmpFOp>(loc, CmpFPredicate::UGT, args[0], constZero);
|
||
return b.create<SelectOp>(loc, pred, args[0], constZero);
|
||
}
|
||
|
||
namespace {
|
||
|
||
// Converts a unary op. There is no implicit broadcasting behavior, so these can
|
||
// be trivially lowered to linalg.
|
||
// TODO: For binary ops, we will need a "linalg.generic-like" op that models
|
||
// N-ary broadcasting and allows us to do multiversioning techniques for
|
||
// lowering to linalg. We can trivially handle this as through that
|
||
// abstraction instead.
|
||
struct ConvertUnaryOp : ConversionPattern {
|
||
ConvertUnaryOp(TypeConverter &typeConverter, MLIRContext *context)
|
||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||
context) {}
|
||
|
||
LogicalResult
|
||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||
ConversionPatternRewriter &rewriter) const override {
|
||
if (!isa<AtenTanhOp>(op) && !isa<AtenReluOp>(op))
|
||
return rewriter.notifyMatchFailure(op, "not a unary op");
|
||
|
||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||
return failure();
|
||
|
||
Value operand = operands[0];
|
||
auto type = getTypeConverter()
|
||
->convertType(op->getResult(0).getType())
|
||
.cast<RankedTensorType>();
|
||
auto rank = type.getRank();
|
||
|
||
SmallVector<StringRef> iteratorTypes(rank, "parallel");
|
||
SmallVector<AffineMap> indexingMaps = {
|
||
rewriter.getMultiDimIdentityMap(rank),
|
||
rewriter.getMultiDimIdentityMap(rank)};
|
||
|
||
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
|
||
op, type, operand, operand,
|
||
/*indexingMaps=*/indexingMaps,
|
||
/*iteratorTypes=*/iteratorTypes,
|
||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||
Value result;
|
||
if (isa<AtenTanhOp>(op))
|
||
result = b.create<math::TanhOp>(loc, args[0]);
|
||
else if (isa<AtenReluOp>(op))
|
||
result = createScalarRelu(b, loc, args);
|
||
|
||
b.create<linalg::YieldOp>(loc, result);
|
||
});
|
||
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// -----------------------------------------------------------------------------
|
||
// The pass
|
||
// -----------------------------------------------------------------------------
|
||
|
||
namespace {
|
||
class ConvertTorchToLinalg
|
||
: public ConvertTorchToLinalgBase<ConvertTorchToLinalg> {
|
||
public:
|
||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||
registry.insert<linalg::LinalgDialect>();
|
||
registry.insert<memref::MemRefDialect>();
|
||
registry.insert<math::MathDialect>();
|
||
registry.insert<StandardOpsDialect>();
|
||
registry.insert<tensor::TensorDialect>();
|
||
}
|
||
|
||
void runOnOperation() override {
|
||
MLIRContext *context = &getContext();
|
||
ConversionTarget target(*context);
|
||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||
memref::MemRefDialect, math::MathDialect,
|
||
tensor::TensorDialect>();
|
||
|
||
TypeConverter typeConverter;
|
||
typeConverter.addConversion([](Type type) { return type; });
|
||
setupBackendTypeConversion(target, typeConverter);
|
||
|
||
RewritePatternSet patterns(context);
|
||
target.addIllegalOp<AtenMmOp>();
|
||
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
||
target.addIllegalOp<AtenLinearOp>();
|
||
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
|
||
target.addIllegalOp<AtenBatchNormOp>();
|
||
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||
target.addIllegalOp<AtenTanhOp>();
|
||
patterns.add<ConvertUnaryOp>(typeConverter, context);
|
||
if (failed(applyPartialConversion(getOperation(), target,
|
||
std::move(patterns))))
|
||
return signalPassFailure();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
std::unique_ptr<OperationPass<FuncOp>>
|
||
mlir::NPCOMP::createConvertTorchToLinalgPass() {
|
||
return std::make_unique<ConvertTorchToLinalg>();
|
||
}
|