E2e support for layernorm.

pull/348/head
Yi Zhang 2021-09-24 11:44:16 -04:00
parent b01f579687
commit 98ba255288
5 changed files with 422 additions and 25 deletions

View File

@ -9,6 +9,7 @@ from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================
class BatchNorm1DModule(torch.nn.Module): class BatchNorm1DModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -17,8 +18,10 @@ class BatchNorm1DModule(torch.nn.Module):
self.bn1d.eval() self.bn1d.eval()
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6]) self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0]) self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0])) self.bn1d.weight = torch.nn.Parameter(
torch.tensor([3.0, 2.0, 4.0, 5.0]))
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6])) self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
@export @export
@annotate_args([ @annotate_args([
None, None,
@ -27,10 +30,12 @@ class BatchNorm1DModule(torch.nn.Module):
def forward(self, x): def forward(self, x):
return self.bn1d(x) return self.bn1d(x)
@register_test_case(module_factory=lambda: BatchNorm1DModule()) @register_test_case(module_factory=lambda: BatchNorm1DModule())
def BatchNorm1DModule_basic(module, tu: TestUtils): def BatchNorm1DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 3)) module.forward(tu.rand(10, 4, 3))
# ============================================================================== # ==============================================================================
class BatchNorm2DModule(torch.nn.Module): class BatchNorm2DModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -41,6 +46,7 @@ class BatchNorm2DModule(torch.nn.Module):
self.bn2d.running_var = torch.tensor([3.0, 2.0]) self.bn2d.running_var = torch.tensor([3.0, 2.0])
self.bn2d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0])) self.bn2d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0]))
self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4])) self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4]))
@export @export
@annotate_args([ @annotate_args([
None, None,
@ -49,10 +55,12 @@ class BatchNorm2DModule(torch.nn.Module):
def forward(self, x): def forward(self, x):
return self.bn2d(x) return self.bn2d(x)
@register_test_case(module_factory=lambda: BatchNorm2DModule()) @register_test_case(module_factory=lambda: BatchNorm2DModule())
def BatchNorm2DModule_basic(module, tu: TestUtils): def BatchNorm2DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 2, 3, 3)) module.forward(tu.rand(10, 2, 3, 3))
# ============================================================================== # ==============================================================================
class BatchNorm3DModule(torch.nn.Module): class BatchNorm3DModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -61,8 +69,11 @@ class BatchNorm3DModule(torch.nn.Module):
self.bn3d.eval() self.bn3d.eval()
self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]) self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])
self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]) self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])
self.bn3d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])) self.bn3d.weight = torch.nn.Parameter(
self.bn3d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])) torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
self.bn3d.bias = torch.nn.Parameter(
torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
@export @export
@annotate_args([ @annotate_args([
None, None,
@ -71,6 +82,83 @@ class BatchNorm3DModule(torch.nn.Module):
def forward(self, x): def forward(self, x):
return self.bn3d(x) return self.bn3d(x)
@register_test_case(module_factory=lambda: BatchNorm3DModule()) @register_test_case(module_factory=lambda: BatchNorm3DModule())
def BatchNorm3DModule_basic(module, tu: TestUtils): def BatchNorm3DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 3, 6, 4)) module.forward(tu.rand(2, 5, 3, 6, 4))
# ==============================================================================
class LayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ly = torch.nn.LayerNorm([2, 2, 3])
self.ly.eval()
self.ly.weight = torch.nn.Parameter(
torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]],
[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]))
self.ly.bias = torch.nn.Parameter(
torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]],
[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]))
@export
@annotate_args([
None,
([2, 5, 2, 2, 3], torch.float32, True),
])
def forward(self, x):
return self.ly(x)
@register_test_case(module_factory=lambda: LayerNormModule())
def LayerNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3))
# ==============================================================================
class LayerNormLastDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ly = torch.nn.LayerNorm([3])
self.ly.eval()
self.ly.weight = torch.nn.Parameter(torch.tensor([2.0, 3.0, 2.0]))
self.ly.bias = torch.nn.Parameter(torch.tensor([0.2, 0.4, 0.3]))
@export
@annotate_args([
None,
([2, 5, 2, 2, 3], torch.float32, True),
])
def forward(self, x):
return self.ly(x)
@register_test_case(module_factory=lambda: LayerNormLastDimModule())
def LayerNormLastDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3))
# ==============================================================================
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ly = torch.nn.LayerNorm([2, 2, 3])
self.ly.eval()
self.ly.weight = torch.nn.Parameter(
torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]],
[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]))
self.ly.bias = torch.nn.Parameter(
torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]],
[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]))
@export
@annotate_args([
None,
([2, 2, 3], torch.float32, True),
])
def forward(self, x):
return self.ly(x)
@register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 3))

View File

@ -899,6 +899,25 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
let assemblyFormat = "$input `,` $weight `,` $bias `,` $running_mean `,` $running_var `,` $training `,` $momentum `,` $eps `,` $cudnn_enabled attr-dict `:` type($input) `,` type($weight) `,` type($bias) `,` type($running_mean) `,` type($running_var) `,` type($training) `,` type($momentum) `,` type($eps) `,` type($cudnn_enabled) `->` type($result)"; let assemblyFormat = "$input `,` $weight `,` $bias `,` $running_mean `,` $running_var `,` $training `,` $momentum `,` $eps `,` $cudnn_enabled attr-dict `:` type($input) `,` type($weight) `,` type($bias) `,` type($running_mean) `,` type($running_var) `,` type($training) `,` type($momentum) `,` type($eps) `,` type($cudnn_enabled) `->` type($result)";
} }
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
TorchIntListType:$normalized_shape,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enable
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps `,` $cudnn_enable attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `,` type($cudnn_enable) `->` type($result)";
}
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics

View File

@ -62,6 +62,15 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
return success(); return success();
} }
static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op,
Value v) {
Type type = v.getType();
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() ||
type.isa<mlir::NoneType>())
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
return success();
}
// Hack to deal with the Torch list type arguments which is not supported end // Hack to deal with the Torch list type arguments which is not supported end
// to end. Constant values can be be extracted directly and non constant // to end. Constant values can be be extracted directly and non constant
// list values are not supported. // list values are not supported.
@ -96,23 +105,40 @@ static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) {
return b.create<tensor::DimOp>(loc, v, dimension); return b.create<tensor::DimOp>(loc, v, dimension);
} }
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDimIndex, static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
Value rhsDimIndex) { Value rhsDim) {
Value lhsDimInt = castIndexToInt(b, loc, lhsDimIndex); Type lhsType = lhsDim.getType();
Value rhsDimInt = castIndexToInt(b, loc, rhsDimIndex); Type rhsType = rhsDim.getType();
auto checkIntOrIndex = [](Type type) {
assert(type.isa<IntegerType>() ||
type.isa<IndexType>() && "must be either integer or index type");
};
checkIntOrIndex(lhsType);
checkIntOrIndex(rhsType);
Value lhsDimInt = lhsType.isIndex() ? castIndexToInt(b, loc, lhsDim) : lhsDim;
Value rhsDimInt = rhsType.isIndex() ? castIndexToInt(b, loc, rhsDim) : rhsDim;
Value contractingDimEqual = Value contractingDimEqual =
b.create<CmpIOp>(loc, CmpIPredicate::eq, lhsDimInt, rhsDimInt); b.create<CmpIOp>(loc, CmpIPredicate::eq, lhsDimInt, rhsDimInt);
b.create<AssertOp>(loc, contractingDimEqual, b.create<AssertOp>(loc, contractingDimEqual,
b.getStringAttr("mismatching contracting dimension")); b.getStringAttr("mismatching contracting dimension"));
} }
static SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
Value tensor, int dim) {
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
assert(dim < type.getRank() &&
"The given dim must be smaller than tensor rank");
(void)type;
SmallVector<Value> sizes;
for (int i = 0; i <= dim; i++)
sizes.push_back(getDimOp(b, loc, tensor, i));
return sizes;
}
static SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc, static SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc,
Value tensor) { Value tensor) {
RankedTensorType type = tensor.getType().cast<RankedTensorType>(); RankedTensorType type = tensor.getType().cast<RankedTensorType>();
SmallVector<Value> sizes; return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
for (int i = 0; i < type.getRank(); i++)
sizes.push_back(getDimOp(b, loc, tensor, i));
return sizes;
} }
static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
@ -173,6 +199,19 @@ getAsOpFoldResult(OpBuilder &b, Location loc, SmallVectorImpl<int64_t> &ints) {
ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); })); ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); }));
} }
// This is a temporary solution to deal with types that are not fully supported
// like list, dict. For those container tyes, this helper can be used to
// convert their elements to valid target type.
// TODO: remove this when list gets full support.
static SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
TypeConverter *converter,
SmallVectorImpl<Value> &vs) {
return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) {
return converter->materializeTargetConversion(
b, loc, converter->convertType(v.getType()), v);
}));
}
// Helper function to get the padding tensor given the padding int values. // Helper function to get the padding tensor given the padding int values.
// It's assumed that the padding on the low end and high end are the same. // It's assumed that the padding on the low end and high end are the same.
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input, static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
@ -192,6 +231,14 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
return paddedInput; return paddedInput;
} }
static bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
if (!listConstruct)
return false;
elems = llvm::to_vector<4>(listConstruct.elements());
return true;
}
namespace { namespace {
class ConvertAtenAdaptiveAvgPool2dOp class ConvertAtenAdaptiveAvgPool2dOp
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> { : public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
@ -393,6 +440,22 @@ public:
}; };
} // namespace } // namespace
// Normalization formula:
// ((input - mean) / sqrt(var + eps)) * weight + bias
static Value createLinalgPayloadCalculationForNormOps(
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
Value eps, Value weight, Value bias) {
Value inputSubMean = b.create<SubFOp>(loc, input, mean);
// The eps is always f64.
Value truncatedEps = b.create<FPTruncOp>(loc, elemTy, eps);
Value varPlusEps = b.create<AddFOp>(loc, var, truncatedEps);
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
Value temp = b.create<MulFOp>(loc, inputSubMean, rSTD);
Value timesWeight = b.create<MulFOp>(loc, temp, weight);
Value plusBias = b.create<AddFOp>(loc, timesWeight, bias);
return plusBias;
}
namespace { namespace {
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> { class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
public: public:
@ -411,11 +474,17 @@ public:
Value training = adaptor.training(); Value training = adaptor.training();
Value eps = adaptor.eps(); Value eps = adaptor.eps();
// TODO: Handle the None cases for the optional parameters:
// weight, bias, running_mean, running_var.
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
// TODO: Handle the None cases for the optional parameters:
// weight, bias.
if (failed(checkNotNone(rewriter, op, weight)) ||
failed(checkNotNone(rewriter, op, bias)) ||
failed(checkNotNone(rewriter, op, runningMean)) ||
failed(checkNotNone(rewriter, op, runningVar)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>(); auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>(); auto biasType = bias.getType().cast<RankedTensorType>();
@ -480,17 +549,10 @@ public:
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], weight = args[1], bias = args[2], Value input = args[0], weight = args[1], bias = args[2],
mean = args[3], var = args[4]; mean = args[3], var = args[4];
// ((input - mean) / sqrt(var + eps)) * weight + bias Value result = createLinalgPayloadCalculationForNormOps(
Value inputSubMean = b.create<SubFOp>(loc, input, mean); b, loc, var.getType(), input, mean, var, eps, weight,
// The eps is always f64. bias);
Value truncatedEps = b.create<linalg::YieldOp>(loc, result);
b.create<FPTruncOp>(loc, var.getType(), eps);
Value varPlusEps = b.create<AddFOp>(loc, var, truncatedEps);
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
Value temp = b.create<MulFOp>(loc, inputSubMean, rSTD);
Value timesWeight = b.create<MulFOp>(loc, temp, weight);
Value plusBias = b.create<AddFOp>(loc, timesWeight, bias);
b.create<linalg::YieldOp>(loc, plusBias);
}) })
.getResult(0); .getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
@ -500,6 +562,228 @@ public:
}; };
} // namespace } // namespace
// For layernorm, the mean and standard-deviation are calculated separately over
// the last certain number dimensions which have to be of the shape specified by
// normalized_shape.
//
// The shapes of different parts are as the following:
// +-------------------+--------------------+
// | meanAndVarShape | normalizedShape |
// +-------------------+---------------------
// <------------+ inputShape +-------------->
// There are the following steps:
// Step 1. Check if all the arguments meet the requirements.
// Step 2. Common parts to be used for getting mean and var.
// This includes elements count, affineMap and iteratorTypes.
// Step 3. Get mean.
// Step 4. Get var.
// Step 5. Get layernorm.
namespace {
class ConvertAtenLayerNormOp : public OpConversionPattern<AtenLayerNormOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenLayerNormOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
AtenLayerNormOp::Adaptor adaptor(operands);
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
Value eps = adaptor.eps();
Value normalizedShape = op.normalized_shape();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Handle the None cases for the optional parameters:
// weight, bias.
if (failed(checkNotNone(rewriter, op, weight)) ||
failed(checkNotNone(rewriter, op, bias)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
Type elemTy = inputType.getElementType();
// Step 1. Check if all the arguments meet the requirements.
SmallVector<Value> normalizedShapeSizesTorchInt;
if (!getListConstructElements(normalizedShape,
normalizedShapeSizesTorchInt)) {
return rewriter.notifyMatchFailure(op,
"Unimplemented normalized_shape not"
"constructed from ListConstruct");
}
SmallVector<Value> normalizedShapeSizesInt = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt);
int64_t normalizedShapeRank = normalizedShapeSizesInt.size();
if (weightType.getRank() != normalizedShapeRank ||
biasType.getRank() != normalizedShapeRank ||
inputRank < normalizedShapeRank || normalizedShapeRank < 1)
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
"normalized shape not compatible");
// Check all the dimensions match the normalized_shape
int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size();
for (auto en : enumerate((normalizedShapeSizesInt))) {
auto index = en.index();
auto inputDim =
getDimOp(rewriter, loc, input, index + meanAndVarShapeRank);
auto weightDim = getDimOp(rewriter, loc, weight, index);
auto biasDim = getDimOp(rewriter, loc, bias, index);
auto expectedSize = en.value();
checkDimEqualHelper(rewriter, loc, inputDim, expectedSize);
checkDimEqualHelper(rewriter, loc, weightDim, expectedSize);
checkDimEqualHelper(rewriter, loc, biasDim, expectedSize);
}
// Get iterator types for input shape.
SmallVector<StringRef> normalizedShapeIteratorTypes(
normalizedShapeRank, getReductionIteratorTypeName());
SmallVector<StringRef> meanAndVarIterationTypes(
meanAndVarShapeRank, getParallelIteratorTypeName());
SmallVector<StringRef> inputShapeIteratorTypes = meanAndVarIterationTypes;
inputShapeIteratorTypes.append(normalizedShapeIteratorTypes);
// Step 2. Common parts to be used for getting mean and var.
// Get sizes and affineMaps needed for mean and var.
AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank);
SmallVector<AffineExpr> meanAndVarShapeExprs;
for (int i = 0; i < meanAndVarShapeRank; i++)
meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
auto meanAndVarShapeAffineMap = AffineMap::get(
/*dimCount=*/inputRank,
/*symbolCount=*/0, meanAndVarShapeExprs, context);
SmallVector<Value> meanAndVarShapeSizes =
getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1);
// Get number of elements to be used for calculating mean and var.
Value elemCnts = normalizedShapeSizesInt[0];
for (int i = 1; i < normalizedShapeRank; i++) {
elemCnts =
rewriter.create<MulIOp>(loc, elemCnts, normalizedShapeSizesInt[i]);
}
Value elemCntsFloat = rewriter.create<SIToFPOp>(loc, elemTy, elemCnts);
// Helper to calculate mean and var.
auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) {
SmallVector<AffineMap> indexingMaps(
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
Value initShapeTensor = rewriter.create<linalg::InitTensorOp>(
loc, meanAndVarShapeSizes, elemTy);
return rewriter
.create<linalg::GenericOp>(
loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/meanAndVarIterationTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value sumOrSqureSum = args[0];
Value result =
b.create<DivFOp>(loc, sumOrSqureSum, elemCntsFloat);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
};
// Step 3. Get mean.
// Get sum to be used for calculating mean.
SmallVector<AffineMap, 2> sumIndexingMaps = {
inputShapeAffineMap, // input
meanAndVarShapeAffineMap, // output
};
auto initSumTensor =
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
Value sum = rewriter
.create<linalg::GenericOp>(
loc, initSumTensor.getType(), input, initSumTensor,
/*indexingMaps=*/sumIndexingMaps,
/*iteratorTypes=*/inputShapeIteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], sum = args[1];
Value result =
rewriter.create<AddFOp>(loc, sum, input);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Value mean = genMeanOrVarCalculation(sum);
// Step 4. Get var.
// Calculate squareSum for the layer.
SmallVector<AffineMap> squareSumIndexingMaps{
inputShapeAffineMap,
meanAndVarShapeAffineMap,
meanAndVarShapeAffineMap,
};
auto initSquareSumTensor =
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
Value squareSum =
rewriter
.create<linalg::GenericOp>(
loc, initSquareSumTensor.getType(), ValueRange{input, mean},
initSquareSumTensor,
/*indexingMaps=*/squareSumIndexingMaps,
/*iteratorTypes=*/inputShapeIteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], squareSum = args[2];
Value sub = rewriter.create<SubFOp>(loc, input, mean);
Value square = rewriter.create<MulFOp>(loc, sub, sub);
Value result =
rewriter.create<AddFOp>(loc, squareSum, square);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Value var = genMeanOrVarCalculation(squareSum);
// Step 5. Get layernorm.
// Get affineMap for normalized shape.
SmallVector<AffineExpr> normalizedShapeExprs;
for (int i = meanAndVarShapeRank; i < inputRank; i++)
normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
auto normalizedShapeAffineMap = AffineMap::get(
/*dimCount=*/inputRank,
/*symbolCount=*/0, normalizedShapeExprs, context);
auto inputSizes = getTensorSizes(rewriter, loc, input);
Value initLayerNormTensor =
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
SmallVector<AffineMap> indexingMaps(1, inputShapeAffineMap);
indexingMaps.resize(3, meanAndVarShapeAffineMap);
indexingMaps.resize(5, normalizedShapeAffineMap);
indexingMaps.push_back(inputShapeAffineMap);
SmallVector<StringRef> layerNormIterationTypes(
inputRank, getParallelIteratorTypeName());
Value layerNorm =
rewriter
.create<linalg::GenericOp>(
loc, initLayerNormTensor.getType(),
ValueRange{input, mean, var, weight, bias}, initLayerNormTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/layerNormIterationTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], var = args[2],
weight = args[3], bias = args[4];
Value result = createLinalgPayloadCalculationForNormOps(
b, loc, elemTy, input, mean, var, eps, weight, bias);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, layerNorm);
return success();
}
};
} // namespace
namespace { namespace {
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> { class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
public: public:
@ -1611,6 +1895,8 @@ public:
patterns.add<ConvertAtenCatOp>(typeConverter, context); patterns.add<ConvertAtenCatOp>(typeConverter, context);
target.addIllegalOp<AtenGatherOp>(); target.addIllegalOp<AtenGatherOp>();
patterns.add<ConvertAtenGatherOp>(typeConverter, context); patterns.add<ConvertAtenGatherOp>(typeConverter, context);
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))

View File

@ -196,7 +196,7 @@ public:
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
AtenCopy_Op, AtenCumsumOp>(op)) { AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }

View File

@ -476,6 +476,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit( emit(
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
) )
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit( emit(
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)" "aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
) )
@ -593,6 +596,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::div : (Scalar, Scalar) -> (float)") emit("aten::div : (Scalar, Scalar) -> (float)")
emit("aten::eq.device : (Device, Device) -> (bool)") emit("aten::eq.device : (Device, Device) -> (bool)")
def emit_quantized_ops(torch_ir_dir: str, registry: Registry): def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td") td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
with open(td_file, "w") as f: with open(td_file, "w") as f: