mirror of https://github.com/llvm/torch-mlir
E2e support for layernorm.
parent
b01f579687
commit
98ba255288
|
@ -9,6 +9,7 @@ from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
|||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class BatchNorm1DModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -17,8 +18,10 @@ class BatchNorm1DModule(torch.nn.Module):
|
|||
self.bn1d.eval()
|
||||
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
||||
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
||||
self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||
self.bn1d.weight = torch.nn.Parameter(
|
||||
torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -27,10 +30,12 @@ class BatchNorm1DModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return self.bn1d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BatchNorm1DModule())
|
||||
def BatchNorm1DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class BatchNorm2DModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -41,6 +46,7 @@ class BatchNorm2DModule(torch.nn.Module):
|
|||
self.bn2d.running_var = torch.tensor([3.0, 2.0])
|
||||
self.bn2d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0]))
|
||||
self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -49,10 +55,12 @@ class BatchNorm2DModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return self.bn2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BatchNorm2DModule())
|
||||
def BatchNorm2DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 2, 3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class BatchNorm3DModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -61,8 +69,11 @@ class BatchNorm3DModule(torch.nn.Module):
|
|||
self.bn3d.eval()
|
||||
self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])
|
||||
self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])
|
||||
self.bn3d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
|
||||
self.bn3d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
|
||||
self.bn3d.weight = torch.nn.Parameter(
|
||||
torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
|
||||
self.bn3d.bias = torch.nn.Parameter(
|
||||
torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -71,6 +82,83 @@ class BatchNorm3DModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return self.bn3d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BatchNorm3DModule())
|
||||
def BatchNorm3DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 3, 6, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class LayerNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
||||
self.ly.eval()
|
||||
self.ly.weight = torch.nn.Parameter(
|
||||
torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]],
|
||||
[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]))
|
||||
self.ly.bias = torch.nn.Parameter(
|
||||
torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]],
|
||||
[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 5, 2, 2, 3], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ly(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LayerNormModule())
|
||||
def LayerNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class LayerNormLastDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ly = torch.nn.LayerNorm([3])
|
||||
self.ly.eval()
|
||||
self.ly.weight = torch.nn.Parameter(torch.tensor([2.0, 3.0, 2.0]))
|
||||
self.ly.bias = torch.nn.Parameter(torch.tensor([0.2, 0.4, 0.3]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 5, 2, 2, 3], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ly(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LayerNormLastDimModule())
|
||||
def LayerNormLastDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||
|
||||
# ==============================================================================
|
||||
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
||||
self.ly.eval()
|
||||
self.ly.weight = torch.nn.Parameter(
|
||||
torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]],
|
||||
[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]))
|
||||
self.ly.bias = torch.nn.Parameter(
|
||||
torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]],
|
||||
[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 2, 3], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ly(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
|
||||
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2, 3))
|
||||
|
|
|
@ -899,6 +899,25 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
|
|||
let assemblyFormat = "$input `,` $weight `,` $bias `,` $running_mean `,` $running_var `,` $training `,` $momentum `,` $eps `,` $cudnn_enabled attr-dict `:` type($input) `,` type($weight) `,` type($bias) `,` type($running_mean) `,` type($running_var) `,` type($training) `,` type($momentum) `,` type($eps) `,` type($cudnn_enabled) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
TorchIntListType:$normalized_shape,
|
||||
AnyTorchOptionalTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
Torch_FloatType:$eps,
|
||||
Torch_BoolType:$cudnn_enable
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps `,` $cudnn_enable attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `,` type($cudnn_enable) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -62,6 +62,15 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op,
|
||||
Value v) {
|
||||
Type type = v.getType();
|
||||
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() ||
|
||||
type.isa<mlir::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
|
||||
return success();
|
||||
}
|
||||
|
||||
// Hack to deal with the Torch list type arguments which is not supported end
|
||||
// to end. Constant values can be be extracted directly and non constant
|
||||
// list values are not supported.
|
||||
|
@ -96,23 +105,40 @@ static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) {
|
|||
return b.create<tensor::DimOp>(loc, v, dimension);
|
||||
}
|
||||
|
||||
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDimIndex,
|
||||
Value rhsDimIndex) {
|
||||
Value lhsDimInt = castIndexToInt(b, loc, lhsDimIndex);
|
||||
Value rhsDimInt = castIndexToInt(b, loc, rhsDimIndex);
|
||||
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
|
||||
Value rhsDim) {
|
||||
Type lhsType = lhsDim.getType();
|
||||
Type rhsType = rhsDim.getType();
|
||||
auto checkIntOrIndex = [](Type type) {
|
||||
assert(type.isa<IntegerType>() ||
|
||||
type.isa<IndexType>() && "must be either integer or index type");
|
||||
};
|
||||
checkIntOrIndex(lhsType);
|
||||
checkIntOrIndex(rhsType);
|
||||
Value lhsDimInt = lhsType.isIndex() ? castIndexToInt(b, loc, lhsDim) : lhsDim;
|
||||
Value rhsDimInt = rhsType.isIndex() ? castIndexToInt(b, loc, rhsDim) : rhsDim;
|
||||
Value contractingDimEqual =
|
||||
b.create<CmpIOp>(loc, CmpIPredicate::eq, lhsDimInt, rhsDimInt);
|
||||
b.create<AssertOp>(loc, contractingDimEqual,
|
||||
b.getStringAttr("mismatching contracting dimension"));
|
||||
}
|
||||
|
||||
static SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
||||
Value tensor, int dim) {
|
||||
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
||||
assert(dim < type.getRank() &&
|
||||
"The given dim must be smaller than tensor rank");
|
||||
(void)type;
|
||||
SmallVector<Value> sizes;
|
||||
for (int i = 0; i <= dim; i++)
|
||||
sizes.push_back(getDimOp(b, loc, tensor, i));
|
||||
return sizes;
|
||||
}
|
||||
|
||||
static SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc,
|
||||
Value tensor) {
|
||||
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
||||
SmallVector<Value> sizes;
|
||||
for (int i = 0; i < type.getRank(); i++)
|
||||
sizes.push_back(getDimOp(b, loc, tensor, i));
|
||||
return sizes;
|
||||
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
|
||||
}
|
||||
|
||||
static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||
|
@ -173,6 +199,19 @@ getAsOpFoldResult(OpBuilder &b, Location loc, SmallVectorImpl<int64_t> &ints) {
|
|||
ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); }));
|
||||
}
|
||||
|
||||
// This is a temporary solution to deal with types that are not fully supported
|
||||
// like list, dict. For those container tyes, this helper can be used to
|
||||
// convert their elements to valid target type.
|
||||
// TODO: remove this when list gets full support.
|
||||
static SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
|
||||
TypeConverter *converter,
|
||||
SmallVectorImpl<Value> &vs) {
|
||||
return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) {
|
||||
return converter->materializeTargetConversion(
|
||||
b, loc, converter->convertType(v.getType()), v);
|
||||
}));
|
||||
}
|
||||
|
||||
// Helper function to get the padding tensor given the padding int values.
|
||||
// It's assumed that the padding on the low end and high end are the same.
|
||||
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
||||
|
@ -192,6 +231,14 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
|||
return paddedInput;
|
||||
}
|
||||
|
||||
static bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
|
||||
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
|
||||
if (!listConstruct)
|
||||
return false;
|
||||
elems = llvm::to_vector<4>(listConstruct.elements());
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAdaptiveAvgPool2dOp
|
||||
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
|
||||
|
@ -393,6 +440,22 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Normalization formula:
|
||||
// ((input - mean) / sqrt(var + eps)) * weight + bias
|
||||
static Value createLinalgPayloadCalculationForNormOps(
|
||||
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
|
||||
Value eps, Value weight, Value bias) {
|
||||
Value inputSubMean = b.create<SubFOp>(loc, input, mean);
|
||||
// The eps is always f64.
|
||||
Value truncatedEps = b.create<FPTruncOp>(loc, elemTy, eps);
|
||||
Value varPlusEps = b.create<AddFOp>(loc, var, truncatedEps);
|
||||
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
|
||||
Value temp = b.create<MulFOp>(loc, inputSubMean, rSTD);
|
||||
Value timesWeight = b.create<MulFOp>(loc, temp, weight);
|
||||
Value plusBias = b.create<AddFOp>(loc, timesWeight, bias);
|
||||
return plusBias;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
|
||||
public:
|
||||
|
@ -411,11 +474,17 @@ public:
|
|||
Value training = adaptor.training();
|
||||
Value eps = adaptor.eps();
|
||||
|
||||
// TODO: Handle the None cases for the optional parameters:
|
||||
// weight, bias, running_mean, running_var.
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
// TODO: Handle the None cases for the optional parameters:
|
||||
// weight, bias.
|
||||
if (failed(checkNotNone(rewriter, op, weight)) ||
|
||||
failed(checkNotNone(rewriter, op, bias)) ||
|
||||
failed(checkNotNone(rewriter, op, runningMean)) ||
|
||||
failed(checkNotNone(rewriter, op, runningVar)))
|
||||
return failure();
|
||||
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
|
@ -480,17 +549,10 @@ public:
|
|||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value input = args[0], weight = args[1], bias = args[2],
|
||||
mean = args[3], var = args[4];
|
||||
// ((input - mean) / sqrt(var + eps)) * weight + bias
|
||||
Value inputSubMean = b.create<SubFOp>(loc, input, mean);
|
||||
// The eps is always f64.
|
||||
Value truncatedEps =
|
||||
b.create<FPTruncOp>(loc, var.getType(), eps);
|
||||
Value varPlusEps = b.create<AddFOp>(loc, var, truncatedEps);
|
||||
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
|
||||
Value temp = b.create<MulFOp>(loc, inputSubMean, rSTD);
|
||||
Value timesWeight = b.create<MulFOp>(loc, temp, weight);
|
||||
Value plusBias = b.create<AddFOp>(loc, timesWeight, bias);
|
||||
b.create<linalg::YieldOp>(loc, plusBias);
|
||||
Value result = createLinalgPayloadCalculationForNormOps(
|
||||
b, loc, var.getType(), input, mean, var, eps, weight,
|
||||
bias);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
|
@ -500,6 +562,228 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// For layernorm, the mean and standard-deviation are calculated separately over
|
||||
// the last certain number dimensions which have to be of the shape specified by
|
||||
// normalized_shape.
|
||||
//
|
||||
// The shapes of different parts are as the following:
|
||||
// +-------------------+--------------------+
|
||||
// | meanAndVarShape | normalizedShape |
|
||||
// +-------------------+---------------------
|
||||
// <------------+ inputShape +-------------->
|
||||
|
||||
// There are the following steps:
|
||||
// Step 1. Check if all the arguments meet the requirements.
|
||||
// Step 2. Common parts to be used for getting mean and var.
|
||||
// This includes elements count, affineMap and iteratorTypes.
|
||||
// Step 3. Get mean.
|
||||
// Step 4. Get var.
|
||||
// Step 5. Get layernorm.
|
||||
namespace {
|
||||
class ConvertAtenLayerNormOp : public OpConversionPattern<AtenLayerNormOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenLayerNormOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
AtenLayerNormOp::Adaptor adaptor(operands);
|
||||
MLIRContext *context = op->getContext();
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.input();
|
||||
Value weight = adaptor.weight();
|
||||
Value bias = adaptor.bias();
|
||||
Value eps = adaptor.eps();
|
||||
Value normalizedShape = op.normalized_shape();
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
// TODO: Handle the None cases for the optional parameters:
|
||||
// weight, bias.
|
||||
if (failed(checkNotNone(rewriter, op, weight)) ||
|
||||
failed(checkNotNone(rewriter, op, bias)))
|
||||
return failure();
|
||||
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
int64_t inputRank = inputType.getRank();
|
||||
Type elemTy = inputType.getElementType();
|
||||
|
||||
// Step 1. Check if all the arguments meet the requirements.
|
||||
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||||
if (!getListConstructElements(normalizedShape,
|
||||
normalizedShapeSizesTorchInt)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unimplemented normalized_shape not"
|
||||
"constructed from ListConstruct");
|
||||
}
|
||||
SmallVector<Value> normalizedShapeSizesInt = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt);
|
||||
int64_t normalizedShapeRank = normalizedShapeSizesInt.size();
|
||||
if (weightType.getRank() != normalizedShapeRank ||
|
||||
biasType.getRank() != normalizedShapeRank ||
|
||||
inputRank < normalizedShapeRank || normalizedShapeRank < 1)
|
||||
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
|
||||
"normalized shape not compatible");
|
||||
|
||||
// Check all the dimensions match the normalized_shape
|
||||
int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size();
|
||||
for (auto en : enumerate((normalizedShapeSizesInt))) {
|
||||
auto index = en.index();
|
||||
auto inputDim =
|
||||
getDimOp(rewriter, loc, input, index + meanAndVarShapeRank);
|
||||
auto weightDim = getDimOp(rewriter, loc, weight, index);
|
||||
auto biasDim = getDimOp(rewriter, loc, bias, index);
|
||||
|
||||
auto expectedSize = en.value();
|
||||
checkDimEqualHelper(rewriter, loc, inputDim, expectedSize);
|
||||
checkDimEqualHelper(rewriter, loc, weightDim, expectedSize);
|
||||
checkDimEqualHelper(rewriter, loc, biasDim, expectedSize);
|
||||
}
|
||||
|
||||
// Get iterator types for input shape.
|
||||
SmallVector<StringRef> normalizedShapeIteratorTypes(
|
||||
normalizedShapeRank, getReductionIteratorTypeName());
|
||||
SmallVector<StringRef> meanAndVarIterationTypes(
|
||||
meanAndVarShapeRank, getParallelIteratorTypeName());
|
||||
SmallVector<StringRef> inputShapeIteratorTypes = meanAndVarIterationTypes;
|
||||
inputShapeIteratorTypes.append(normalizedShapeIteratorTypes);
|
||||
|
||||
// Step 2. Common parts to be used for getting mean and var.
|
||||
|
||||
// Get sizes and affineMaps needed for mean and var.
|
||||
AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank);
|
||||
SmallVector<AffineExpr> meanAndVarShapeExprs;
|
||||
for (int i = 0; i < meanAndVarShapeRank; i++)
|
||||
meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
|
||||
auto meanAndVarShapeAffineMap = AffineMap::get(
|
||||
/*dimCount=*/inputRank,
|
||||
/*symbolCount=*/0, meanAndVarShapeExprs, context);
|
||||
SmallVector<Value> meanAndVarShapeSizes =
|
||||
getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1);
|
||||
|
||||
// Get number of elements to be used for calculating mean and var.
|
||||
Value elemCnts = normalizedShapeSizesInt[0];
|
||||
for (int i = 1; i < normalizedShapeRank; i++) {
|
||||
elemCnts =
|
||||
rewriter.create<MulIOp>(loc, elemCnts, normalizedShapeSizesInt[i]);
|
||||
}
|
||||
Value elemCntsFloat = rewriter.create<SIToFPOp>(loc, elemTy, elemCnts);
|
||||
|
||||
// Helper to calculate mean and var.
|
||||
auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) {
|
||||
SmallVector<AffineMap> indexingMaps(
|
||||
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
|
||||
Value initShapeTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, meanAndVarShapeSizes, elemTy);
|
||||
return rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/meanAndVarIterationTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value sumOrSqureSum = args[0];
|
||||
Value result =
|
||||
b.create<DivFOp>(loc, sumOrSqureSum, elemCntsFloat);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
};
|
||||
|
||||
// Step 3. Get mean.
|
||||
|
||||
// Get sum to be used for calculating mean.
|
||||
SmallVector<AffineMap, 2> sumIndexingMaps = {
|
||||
inputShapeAffineMap, // input
|
||||
meanAndVarShapeAffineMap, // output
|
||||
};
|
||||
auto initSumTensor =
|
||||
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
|
||||
Value sum = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initSumTensor.getType(), input, initSumTensor,
|
||||
/*indexingMaps=*/sumIndexingMaps,
|
||||
/*iteratorTypes=*/inputShapeIteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value input = args[0], sum = args[1];
|
||||
Value result =
|
||||
rewriter.create<AddFOp>(loc, sum, input);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
Value mean = genMeanOrVarCalculation(sum);
|
||||
|
||||
// Step 4. Get var.
|
||||
|
||||
// Calculate squareSum for the layer.
|
||||
SmallVector<AffineMap> squareSumIndexingMaps{
|
||||
inputShapeAffineMap,
|
||||
meanAndVarShapeAffineMap,
|
||||
meanAndVarShapeAffineMap,
|
||||
};
|
||||
auto initSquareSumTensor =
|
||||
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
|
||||
Value squareSum =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initSquareSumTensor.getType(), ValueRange{input, mean},
|
||||
initSquareSumTensor,
|
||||
/*indexingMaps=*/squareSumIndexingMaps,
|
||||
/*iteratorTypes=*/inputShapeIteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value input = args[0], mean = args[1], squareSum = args[2];
|
||||
Value sub = rewriter.create<SubFOp>(loc, input, mean);
|
||||
Value square = rewriter.create<MulFOp>(loc, sub, sub);
|
||||
Value result =
|
||||
rewriter.create<AddFOp>(loc, squareSum, square);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
Value var = genMeanOrVarCalculation(squareSum);
|
||||
|
||||
// Step 5. Get layernorm.
|
||||
|
||||
// Get affineMap for normalized shape.
|
||||
SmallVector<AffineExpr> normalizedShapeExprs;
|
||||
for (int i = meanAndVarShapeRank; i < inputRank; i++)
|
||||
normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
|
||||
auto normalizedShapeAffineMap = AffineMap::get(
|
||||
/*dimCount=*/inputRank,
|
||||
/*symbolCount=*/0, normalizedShapeExprs, context);
|
||||
|
||||
auto inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
Value initLayerNormTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
|
||||
SmallVector<AffineMap> indexingMaps(1, inputShapeAffineMap);
|
||||
indexingMaps.resize(3, meanAndVarShapeAffineMap);
|
||||
indexingMaps.resize(5, normalizedShapeAffineMap);
|
||||
indexingMaps.push_back(inputShapeAffineMap);
|
||||
SmallVector<StringRef> layerNormIterationTypes(
|
||||
inputRank, getParallelIteratorTypeName());
|
||||
Value layerNorm =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initLayerNormTensor.getType(),
|
||||
ValueRange{input, mean, var, weight, bias}, initLayerNormTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/layerNormIterationTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value input = args[0], mean = args[1], var = args[2],
|
||||
weight = args[3], bias = args[4];
|
||||
Value result = createLinalgPayloadCalculationForNormOps(
|
||||
b, loc, elemTy, input, mean, var, eps, weight, bias);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, layerNorm);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
||||
public:
|
||||
|
@ -1611,6 +1895,8 @@ public:
|
|||
patterns.add<ConvertAtenCatOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenGatherOp>();
|
||||
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -196,7 +196,7 @@ public:
|
|||
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
|
||||
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
||||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
||||
AtenCopy_Op, AtenCumsumOp>(op)) {
|
||||
AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -476,6 +476,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit(
|
||||
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
||||
)
|
||||
|
@ -593,6 +596,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::div : (Scalar, Scalar) -> (float)")
|
||||
emit("aten::eq.device : (Device, Device) -> (bool)")
|
||||
|
||||
|
||||
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
||||
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
|
||||
with open(td_file, "w") as f:
|
||||
|
|
Loading…
Reference in New Issue