mirror of https://github.com/llvm/torch-mlir
E2e support for layernorm.
parent
b01f579687
commit
98ba255288
|
@ -9,6 +9,7 @@ from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class BatchNorm1DModule(torch.nn.Module):
|
class BatchNorm1DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -17,8 +18,10 @@ class BatchNorm1DModule(torch.nn.Module):
|
||||||
self.bn1d.eval()
|
self.bn1d.eval()
|
||||||
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
||||||
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
||||||
self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
self.bn1d.weight = torch.nn.Parameter(
|
||||||
|
torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||||
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
None,
|
None,
|
||||||
|
@ -27,10 +30,12 @@ class BatchNorm1DModule(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bn1d(x)
|
return self.bn1d(x)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BatchNorm1DModule())
|
@register_test_case(module_factory=lambda: BatchNorm1DModule())
|
||||||
def BatchNorm1DModule_basic(module, tu: TestUtils):
|
def BatchNorm1DModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(10, 4, 3))
|
module.forward(tu.rand(10, 4, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class BatchNorm2DModule(torch.nn.Module):
|
class BatchNorm2DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -41,6 +46,7 @@ class BatchNorm2DModule(torch.nn.Module):
|
||||||
self.bn2d.running_var = torch.tensor([3.0, 2.0])
|
self.bn2d.running_var = torch.tensor([3.0, 2.0])
|
||||||
self.bn2d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0]))
|
self.bn2d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0]))
|
||||||
self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4]))
|
self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4]))
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
None,
|
None,
|
||||||
|
@ -49,10 +55,12 @@ class BatchNorm2DModule(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bn2d(x)
|
return self.bn2d(x)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BatchNorm2DModule())
|
@register_test_case(module_factory=lambda: BatchNorm2DModule())
|
||||||
def BatchNorm2DModule_basic(module, tu: TestUtils):
|
def BatchNorm2DModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(10, 2, 3, 3))
|
module.forward(tu.rand(10, 2, 3, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class BatchNorm3DModule(torch.nn.Module):
|
class BatchNorm3DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -61,8 +69,11 @@ class BatchNorm3DModule(torch.nn.Module):
|
||||||
self.bn3d.eval()
|
self.bn3d.eval()
|
||||||
self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])
|
self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])
|
||||||
self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])
|
self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])
|
||||||
self.bn3d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
|
self.bn3d.weight = torch.nn.Parameter(
|
||||||
self.bn3d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
|
torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
|
||||||
|
self.bn3d.bias = torch.nn.Parameter(
|
||||||
|
torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
None,
|
None,
|
||||||
|
@ -71,6 +82,83 @@ class BatchNorm3DModule(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bn3d(x)
|
return self.bn3d(x)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BatchNorm3DModule())
|
@register_test_case(module_factory=lambda: BatchNorm3DModule())
|
||||||
def BatchNorm3DModule_basic(module, tu: TestUtils):
|
def BatchNorm3DModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 5, 3, 6, 4))
|
module.forward(tu.rand(2, 5, 3, 6, 4))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
class LayerNormModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
||||||
|
self.ly.eval()
|
||||||
|
self.ly.weight = torch.nn.Parameter(
|
||||||
|
torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]],
|
||||||
|
[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]))
|
||||||
|
self.ly.bias = torch.nn.Parameter(
|
||||||
|
torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]],
|
||||||
|
[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]))
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 5, 2, 2, 3], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ly(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: LayerNormModule())
|
||||||
|
def LayerNormModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
class LayerNormLastDimModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ly = torch.nn.LayerNorm([3])
|
||||||
|
self.ly.eval()
|
||||||
|
self.ly.weight = torch.nn.Parameter(torch.tensor([2.0, 3.0, 2.0]))
|
||||||
|
self.ly.bias = torch.nn.Parameter(torch.tensor([0.2, 0.4, 0.3]))
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 5, 2, 2, 3], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ly(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: LayerNormLastDimModule())
|
||||||
|
def LayerNormLastDimModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
||||||
|
self.ly.eval()
|
||||||
|
self.ly.weight = torch.nn.Parameter(
|
||||||
|
torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]],
|
||||||
|
[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]))
|
||||||
|
self.ly.bias = torch.nn.Parameter(
|
||||||
|
torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]],
|
||||||
|
[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]))
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 2, 3], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ly(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
|
||||||
|
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 2, 3))
|
||||||
|
|
|
@ -899,6 +899,25 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
|
||||||
let assemblyFormat = "$input `,` $weight `,` $bias `,` $running_mean `,` $running_var `,` $training `,` $momentum `,` $eps `,` $cudnn_enabled attr-dict `:` type($input) `,` type($weight) `,` type($bias) `,` type($running_mean) `,` type($running_var) `,` type($training) `,` type($momentum) `,` type($eps) `,` type($cudnn_enabled) `->` type($result)";
|
let assemblyFormat = "$input `,` $weight `,` $bias `,` $running_mean `,` $running_var `,` $training `,` $momentum `,` $eps `,` $cudnn_enabled attr-dict `:` type($input) `,` type($weight) `,` type($bias) `,` type($running_mean) `,` type($running_var) `,` type($training) `,` type($momentum) `,` type($eps) `,` type($cudnn_enabled) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$input,
|
||||||
|
TorchIntListType:$normalized_shape,
|
||||||
|
AnyTorchOptionalTensorType:$weight,
|
||||||
|
AnyTorchOptionalTensorType:$bias,
|
||||||
|
Torch_FloatType:$eps,
|
||||||
|
Torch_BoolType:$cudnn_enable
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps `,` $cudnn_enable attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `,` type($cudnn_enable) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics
|
HasValueSemantics
|
||||||
|
|
|
@ -62,6 +62,15 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value v) {
|
||||||
|
Type type = v.getType();
|
||||||
|
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() ||
|
||||||
|
type.isa<mlir::NoneType>())
|
||||||
|
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Hack to deal with the Torch list type arguments which is not supported end
|
// Hack to deal with the Torch list type arguments which is not supported end
|
||||||
// to end. Constant values can be be extracted directly and non constant
|
// to end. Constant values can be be extracted directly and non constant
|
||||||
// list values are not supported.
|
// list values are not supported.
|
||||||
|
@ -96,23 +105,40 @@ static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) {
|
||||||
return b.create<tensor::DimOp>(loc, v, dimension);
|
return b.create<tensor::DimOp>(loc, v, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDimIndex,
|
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
|
||||||
Value rhsDimIndex) {
|
Value rhsDim) {
|
||||||
Value lhsDimInt = castIndexToInt(b, loc, lhsDimIndex);
|
Type lhsType = lhsDim.getType();
|
||||||
Value rhsDimInt = castIndexToInt(b, loc, rhsDimIndex);
|
Type rhsType = rhsDim.getType();
|
||||||
|
auto checkIntOrIndex = [](Type type) {
|
||||||
|
assert(type.isa<IntegerType>() ||
|
||||||
|
type.isa<IndexType>() && "must be either integer or index type");
|
||||||
|
};
|
||||||
|
checkIntOrIndex(lhsType);
|
||||||
|
checkIntOrIndex(rhsType);
|
||||||
|
Value lhsDimInt = lhsType.isIndex() ? castIndexToInt(b, loc, lhsDim) : lhsDim;
|
||||||
|
Value rhsDimInt = rhsType.isIndex() ? castIndexToInt(b, loc, rhsDim) : rhsDim;
|
||||||
Value contractingDimEqual =
|
Value contractingDimEqual =
|
||||||
b.create<CmpIOp>(loc, CmpIPredicate::eq, lhsDimInt, rhsDimInt);
|
b.create<CmpIOp>(loc, CmpIPredicate::eq, lhsDimInt, rhsDimInt);
|
||||||
b.create<AssertOp>(loc, contractingDimEqual,
|
b.create<AssertOp>(loc, contractingDimEqual,
|
||||||
b.getStringAttr("mismatching contracting dimension"));
|
b.getStringAttr("mismatching contracting dimension"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
||||||
|
Value tensor, int dim) {
|
||||||
|
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
||||||
|
assert(dim < type.getRank() &&
|
||||||
|
"The given dim must be smaller than tensor rank");
|
||||||
|
(void)type;
|
||||||
|
SmallVector<Value> sizes;
|
||||||
|
for (int i = 0; i <= dim; i++)
|
||||||
|
sizes.push_back(getDimOp(b, loc, tensor, i));
|
||||||
|
return sizes;
|
||||||
|
}
|
||||||
|
|
||||||
static SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc,
|
static SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc,
|
||||||
Value tensor) {
|
Value tensor) {
|
||||||
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
||||||
SmallVector<Value> sizes;
|
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
|
||||||
for (int i = 0; i < type.getRank(); i++)
|
|
||||||
sizes.push_back(getDimOp(b, loc, tensor, i));
|
|
||||||
return sizes;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
|
@ -173,6 +199,19 @@ getAsOpFoldResult(OpBuilder &b, Location loc, SmallVectorImpl<int64_t> &ints) {
|
||||||
ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); }));
|
ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); }));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is a temporary solution to deal with types that are not fully supported
|
||||||
|
// like list, dict. For those container tyes, this helper can be used to
|
||||||
|
// convert their elements to valid target type.
|
||||||
|
// TODO: remove this when list gets full support.
|
||||||
|
static SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
|
||||||
|
TypeConverter *converter,
|
||||||
|
SmallVectorImpl<Value> &vs) {
|
||||||
|
return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) {
|
||||||
|
return converter->materializeTargetConversion(
|
||||||
|
b, loc, converter->convertType(v.getType()), v);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to get the padding tensor given the padding int values.
|
// Helper function to get the padding tensor given the padding int values.
|
||||||
// It's assumed that the padding on the low end and high end are the same.
|
// It's assumed that the padding on the low end and high end are the same.
|
||||||
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
||||||
|
@ -192,6 +231,14 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
||||||
return paddedInput;
|
return paddedInput;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
|
||||||
|
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
|
||||||
|
if (!listConstruct)
|
||||||
|
return false;
|
||||||
|
elems = llvm::to_vector<4>(listConstruct.elements());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenAdaptiveAvgPool2dOp
|
class ConvertAtenAdaptiveAvgPool2dOp
|
||||||
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
|
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
|
||||||
|
@ -393,6 +440,22 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Normalization formula:
|
||||||
|
// ((input - mean) / sqrt(var + eps)) * weight + bias
|
||||||
|
static Value createLinalgPayloadCalculationForNormOps(
|
||||||
|
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
|
||||||
|
Value eps, Value weight, Value bias) {
|
||||||
|
Value inputSubMean = b.create<SubFOp>(loc, input, mean);
|
||||||
|
// The eps is always f64.
|
||||||
|
Value truncatedEps = b.create<FPTruncOp>(loc, elemTy, eps);
|
||||||
|
Value varPlusEps = b.create<AddFOp>(loc, var, truncatedEps);
|
||||||
|
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
|
||||||
|
Value temp = b.create<MulFOp>(loc, inputSubMean, rSTD);
|
||||||
|
Value timesWeight = b.create<MulFOp>(loc, temp, weight);
|
||||||
|
Value plusBias = b.create<AddFOp>(loc, timesWeight, bias);
|
||||||
|
return plusBias;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
|
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -411,11 +474,17 @@ public:
|
||||||
Value training = adaptor.training();
|
Value training = adaptor.training();
|
||||||
Value eps = adaptor.eps();
|
Value eps = adaptor.eps();
|
||||||
|
|
||||||
// TODO: Handle the None cases for the optional parameters:
|
|
||||||
// weight, bias, running_mean, running_var.
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
// TODO: Handle the None cases for the optional parameters:
|
||||||
|
// weight, bias.
|
||||||
|
if (failed(checkNotNone(rewriter, op, weight)) ||
|
||||||
|
failed(checkNotNone(rewriter, op, bias)) ||
|
||||||
|
failed(checkNotNone(rewriter, op, runningMean)) ||
|
||||||
|
failed(checkNotNone(rewriter, op, runningVar)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||||
|
@ -480,17 +549,10 @@ public:
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
Value input = args[0], weight = args[1], bias = args[2],
|
Value input = args[0], weight = args[1], bias = args[2],
|
||||||
mean = args[3], var = args[4];
|
mean = args[3], var = args[4];
|
||||||
// ((input - mean) / sqrt(var + eps)) * weight + bias
|
Value result = createLinalgPayloadCalculationForNormOps(
|
||||||
Value inputSubMean = b.create<SubFOp>(loc, input, mean);
|
b, loc, var.getType(), input, mean, var, eps, weight,
|
||||||
// The eps is always f64.
|
bias);
|
||||||
Value truncatedEps =
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
b.create<FPTruncOp>(loc, var.getType(), eps);
|
|
||||||
Value varPlusEps = b.create<AddFOp>(loc, var, truncatedEps);
|
|
||||||
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
|
|
||||||
Value temp = b.create<MulFOp>(loc, inputSubMean, rSTD);
|
|
||||||
Value timesWeight = b.create<MulFOp>(loc, temp, weight);
|
|
||||||
Value plusBias = b.create<AddFOp>(loc, timesWeight, bias);
|
|
||||||
b.create<linalg::YieldOp>(loc, plusBias);
|
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
@ -500,6 +562,228 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// For layernorm, the mean and standard-deviation are calculated separately over
|
||||||
|
// the last certain number dimensions which have to be of the shape specified by
|
||||||
|
// normalized_shape.
|
||||||
|
//
|
||||||
|
// The shapes of different parts are as the following:
|
||||||
|
// +-------------------+--------------------+
|
||||||
|
// | meanAndVarShape | normalizedShape |
|
||||||
|
// +-------------------+---------------------
|
||||||
|
// <------------+ inputShape +-------------->
|
||||||
|
|
||||||
|
// There are the following steps:
|
||||||
|
// Step 1. Check if all the arguments meet the requirements.
|
||||||
|
// Step 2. Common parts to be used for getting mean and var.
|
||||||
|
// This includes elements count, affineMap and iteratorTypes.
|
||||||
|
// Step 3. Get mean.
|
||||||
|
// Step 4. Get var.
|
||||||
|
// Step 5. Get layernorm.
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenLayerNormOp : public OpConversionPattern<AtenLayerNormOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenLayerNormOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
AtenLayerNormOp::Adaptor adaptor(operands);
|
||||||
|
MLIRContext *context = op->getContext();
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value input = adaptor.input();
|
||||||
|
Value weight = adaptor.weight();
|
||||||
|
Value bias = adaptor.bias();
|
||||||
|
Value eps = adaptor.eps();
|
||||||
|
Value normalizedShape = op.normalized_shape();
|
||||||
|
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// TODO: Handle the None cases for the optional parameters:
|
||||||
|
// weight, bias.
|
||||||
|
if (failed(checkNotNone(rewriter, op, weight)) ||
|
||||||
|
failed(checkNotNone(rewriter, op, bias)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||||
|
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||||
|
int64_t inputRank = inputType.getRank();
|
||||||
|
Type elemTy = inputType.getElementType();
|
||||||
|
|
||||||
|
// Step 1. Check if all the arguments meet the requirements.
|
||||||
|
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||||||
|
if (!getListConstructElements(normalizedShape,
|
||||||
|
normalizedShapeSizesTorchInt)) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Unimplemented normalized_shape not"
|
||||||
|
"constructed from ListConstruct");
|
||||||
|
}
|
||||||
|
SmallVector<Value> normalizedShapeSizesInt = getTypeConvertedValues(
|
||||||
|
rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt);
|
||||||
|
int64_t normalizedShapeRank = normalizedShapeSizesInt.size();
|
||||||
|
if (weightType.getRank() != normalizedShapeRank ||
|
||||||
|
biasType.getRank() != normalizedShapeRank ||
|
||||||
|
inputRank < normalizedShapeRank || normalizedShapeRank < 1)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
|
||||||
|
"normalized shape not compatible");
|
||||||
|
|
||||||
|
// Check all the dimensions match the normalized_shape
|
||||||
|
int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size();
|
||||||
|
for (auto en : enumerate((normalizedShapeSizesInt))) {
|
||||||
|
auto index = en.index();
|
||||||
|
auto inputDim =
|
||||||
|
getDimOp(rewriter, loc, input, index + meanAndVarShapeRank);
|
||||||
|
auto weightDim = getDimOp(rewriter, loc, weight, index);
|
||||||
|
auto biasDim = getDimOp(rewriter, loc, bias, index);
|
||||||
|
|
||||||
|
auto expectedSize = en.value();
|
||||||
|
checkDimEqualHelper(rewriter, loc, inputDim, expectedSize);
|
||||||
|
checkDimEqualHelper(rewriter, loc, weightDim, expectedSize);
|
||||||
|
checkDimEqualHelper(rewriter, loc, biasDim, expectedSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get iterator types for input shape.
|
||||||
|
SmallVector<StringRef> normalizedShapeIteratorTypes(
|
||||||
|
normalizedShapeRank, getReductionIteratorTypeName());
|
||||||
|
SmallVector<StringRef> meanAndVarIterationTypes(
|
||||||
|
meanAndVarShapeRank, getParallelIteratorTypeName());
|
||||||
|
SmallVector<StringRef> inputShapeIteratorTypes = meanAndVarIterationTypes;
|
||||||
|
inputShapeIteratorTypes.append(normalizedShapeIteratorTypes);
|
||||||
|
|
||||||
|
// Step 2. Common parts to be used for getting mean and var.
|
||||||
|
|
||||||
|
// Get sizes and affineMaps needed for mean and var.
|
||||||
|
AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank);
|
||||||
|
SmallVector<AffineExpr> meanAndVarShapeExprs;
|
||||||
|
for (int i = 0; i < meanAndVarShapeRank; i++)
|
||||||
|
meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
|
||||||
|
auto meanAndVarShapeAffineMap = AffineMap::get(
|
||||||
|
/*dimCount=*/inputRank,
|
||||||
|
/*symbolCount=*/0, meanAndVarShapeExprs, context);
|
||||||
|
SmallVector<Value> meanAndVarShapeSizes =
|
||||||
|
getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1);
|
||||||
|
|
||||||
|
// Get number of elements to be used for calculating mean and var.
|
||||||
|
Value elemCnts = normalizedShapeSizesInt[0];
|
||||||
|
for (int i = 1; i < normalizedShapeRank; i++) {
|
||||||
|
elemCnts =
|
||||||
|
rewriter.create<MulIOp>(loc, elemCnts, normalizedShapeSizesInt[i]);
|
||||||
|
}
|
||||||
|
Value elemCntsFloat = rewriter.create<SIToFPOp>(loc, elemTy, elemCnts);
|
||||||
|
|
||||||
|
// Helper to calculate mean and var.
|
||||||
|
auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) {
|
||||||
|
SmallVector<AffineMap> indexingMaps(
|
||||||
|
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
|
||||||
|
Value initShapeTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, meanAndVarShapeSizes, elemTy);
|
||||||
|
return rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor,
|
||||||
|
/*indexingMaps=*/indexingMaps,
|
||||||
|
/*iteratorTypes=*/meanAndVarIterationTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value sumOrSqureSum = args[0];
|
||||||
|
Value result =
|
||||||
|
b.create<DivFOp>(loc, sumOrSqureSum, elemCntsFloat);
|
||||||
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Step 3. Get mean.
|
||||||
|
|
||||||
|
// Get sum to be used for calculating mean.
|
||||||
|
SmallVector<AffineMap, 2> sumIndexingMaps = {
|
||||||
|
inputShapeAffineMap, // input
|
||||||
|
meanAndVarShapeAffineMap, // output
|
||||||
|
};
|
||||||
|
auto initSumTensor =
|
||||||
|
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
|
||||||
|
Value sum = rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, initSumTensor.getType(), input, initSumTensor,
|
||||||
|
/*indexingMaps=*/sumIndexingMaps,
|
||||||
|
/*iteratorTypes=*/inputShapeIteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value input = args[0], sum = args[1];
|
||||||
|
Value result =
|
||||||
|
rewriter.create<AddFOp>(loc, sum, input);
|
||||||
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
Value mean = genMeanOrVarCalculation(sum);
|
||||||
|
|
||||||
|
// Step 4. Get var.
|
||||||
|
|
||||||
|
// Calculate squareSum for the layer.
|
||||||
|
SmallVector<AffineMap> squareSumIndexingMaps{
|
||||||
|
inputShapeAffineMap,
|
||||||
|
meanAndVarShapeAffineMap,
|
||||||
|
meanAndVarShapeAffineMap,
|
||||||
|
};
|
||||||
|
auto initSquareSumTensor =
|
||||||
|
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
|
||||||
|
Value squareSum =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, initSquareSumTensor.getType(), ValueRange{input, mean},
|
||||||
|
initSquareSumTensor,
|
||||||
|
/*indexingMaps=*/squareSumIndexingMaps,
|
||||||
|
/*iteratorTypes=*/inputShapeIteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value input = args[0], mean = args[1], squareSum = args[2];
|
||||||
|
Value sub = rewriter.create<SubFOp>(loc, input, mean);
|
||||||
|
Value square = rewriter.create<MulFOp>(loc, sub, sub);
|
||||||
|
Value result =
|
||||||
|
rewriter.create<AddFOp>(loc, squareSum, square);
|
||||||
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
Value var = genMeanOrVarCalculation(squareSum);
|
||||||
|
|
||||||
|
// Step 5. Get layernorm.
|
||||||
|
|
||||||
|
// Get affineMap for normalized shape.
|
||||||
|
SmallVector<AffineExpr> normalizedShapeExprs;
|
||||||
|
for (int i = meanAndVarShapeRank; i < inputRank; i++)
|
||||||
|
normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
|
||||||
|
auto normalizedShapeAffineMap = AffineMap::get(
|
||||||
|
/*dimCount=*/inputRank,
|
||||||
|
/*symbolCount=*/0, normalizedShapeExprs, context);
|
||||||
|
|
||||||
|
auto inputSizes = getTensorSizes(rewriter, loc, input);
|
||||||
|
Value initLayerNormTensor =
|
||||||
|
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
|
||||||
|
SmallVector<AffineMap> indexingMaps(1, inputShapeAffineMap);
|
||||||
|
indexingMaps.resize(3, meanAndVarShapeAffineMap);
|
||||||
|
indexingMaps.resize(5, normalizedShapeAffineMap);
|
||||||
|
indexingMaps.push_back(inputShapeAffineMap);
|
||||||
|
SmallVector<StringRef> layerNormIterationTypes(
|
||||||
|
inputRank, getParallelIteratorTypeName());
|
||||||
|
Value layerNorm =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, initLayerNormTensor.getType(),
|
||||||
|
ValueRange{input, mean, var, weight, bias}, initLayerNormTensor,
|
||||||
|
/*indexingMaps=*/indexingMaps,
|
||||||
|
/*iteratorTypes=*/layerNormIterationTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value input = args[0], mean = args[1], var = args[2],
|
||||||
|
weight = args[3], bias = args[4];
|
||||||
|
Value result = createLinalgPayloadCalculationForNormOps(
|
||||||
|
b, loc, elemTy, input, mean, var, eps, weight, bias);
|
||||||
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, layerNorm);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -1611,6 +1895,8 @@ public:
|
||||||
patterns.add<ConvertAtenCatOp>(typeConverter, context);
|
patterns.add<ConvertAtenCatOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenGatherOp>();
|
target.addIllegalOp<AtenGatherOp>();
|
||||||
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
|
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenLayerNormOp>();
|
||||||
|
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
|
|
@ -196,7 +196,7 @@ public:
|
||||||
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
|
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
|
||||||
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
||||||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
||||||
AtenCopy_Op, AtenCumsumOp>(op)) {
|
AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp>(op)) {
|
||||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -476,6 +476,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
|
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
emit(
|
||||||
|
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
@ -593,6 +596,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::div : (Scalar, Scalar) -> (float)")
|
emit("aten::div : (Scalar, Scalar) -> (float)")
|
||||||
emit("aten::eq.device : (Device, Device) -> (bool)")
|
emit("aten::eq.device : (Device, Device) -> (bool)")
|
||||||
|
|
||||||
|
|
||||||
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
||||||
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
|
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
|
||||||
with open(td_file, "w") as f:
|
with open(td_file, "w") as f:
|
||||||
|
|
Loading…
Reference in New Issue