diff --git a/e2e_testing/torchscript/batchnorm.py b/e2e_testing/torchscript/batchnorm.py index e9cd77702..44298f6c5 100644 --- a/e2e_testing/torchscript/batchnorm.py +++ b/e2e_testing/torchscript/batchnorm.py @@ -9,6 +9,7 @@ from torch_mlir_e2e_test.torchscript.framework import TestUtils from torch_mlir_e2e_test.torchscript.registry import register_test_case from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + # ============================================================================== class BatchNorm1DModule(torch.nn.Module): def __init__(self): @@ -17,8 +18,10 @@ class BatchNorm1DModule(torch.nn.Module): self.bn1d.eval() self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6]) self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0]) - self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0])) + self.bn1d.weight = torch.nn.Parameter( + torch.tensor([3.0, 2.0, 4.0, 5.0])) self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6])) + @export @annotate_args([ None, @@ -27,10 +30,12 @@ class BatchNorm1DModule(torch.nn.Module): def forward(self, x): return self.bn1d(x) + @register_test_case(module_factory=lambda: BatchNorm1DModule()) def BatchNorm1DModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 4, 3)) + # ============================================================================== class BatchNorm2DModule(torch.nn.Module): def __init__(self): @@ -41,6 +46,7 @@ class BatchNorm2DModule(torch.nn.Module): self.bn2d.running_var = torch.tensor([3.0, 2.0]) self.bn2d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0])) self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4])) + @export @annotate_args([ None, @@ -49,10 +55,12 @@ class BatchNorm2DModule(torch.nn.Module): def forward(self, x): return self.bn2d(x) + @register_test_case(module_factory=lambda: BatchNorm2DModule()) def BatchNorm2DModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 2, 3, 3)) + # ============================================================================== class BatchNorm3DModule(torch.nn.Module): def __init__(self): @@ -61,8 +69,11 @@ class BatchNorm3DModule(torch.nn.Module): self.bn3d.eval() self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]) self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]) - self.bn3d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])) - self.bn3d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])) + self.bn3d.weight = torch.nn.Parameter( + torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])) + self.bn3d.bias = torch.nn.Parameter( + torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])) + @export @annotate_args([ None, @@ -71,6 +82,83 @@ class BatchNorm3DModule(torch.nn.Module): def forward(self, x): return self.bn3d(x) + @register_test_case(module_factory=lambda: BatchNorm3DModule()) def BatchNorm3DModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 3, 6, 4)) + + +# ============================================================================== +class LayerNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ly = torch.nn.LayerNorm([2, 2, 3]) + self.ly.eval() + self.ly.weight = torch.nn.Parameter( + torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]], + [[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]])) + self.ly.bias = torch.nn.Parameter( + torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]], + [[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]])) + + @export + @annotate_args([ + None, + ([2, 5, 2, 2, 3], torch.float32, True), + ]) + def forward(self, x): + return self.ly(x) + + +@register_test_case(module_factory=lambda: LayerNormModule()) +def LayerNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 2, 2, 3)) + + +# ============================================================================== +class LayerNormLastDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ly = torch.nn.LayerNorm([3]) + self.ly.eval() + self.ly.weight = torch.nn.Parameter(torch.tensor([2.0, 3.0, 2.0])) + self.ly.bias = torch.nn.Parameter(torch.tensor([0.2, 0.4, 0.3])) + + @export + @annotate_args([ + None, + ([2, 5, 2, 2, 3], torch.float32, True), + ]) + def forward(self, x): + return self.ly(x) + + +@register_test_case(module_factory=lambda: LayerNormLastDimModule()) +def LayerNormLastDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 2, 2, 3)) + +# ============================================================================== +class LayerNormNormalizeOverAllDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ly = torch.nn.LayerNorm([2, 2, 3]) + self.ly.eval() + self.ly.weight = torch.nn.Parameter( + torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]], + [[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]])) + self.ly.bias = torch.nn.Parameter( + torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]], + [[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]])) + + @export + @annotate_args([ + None, + ([2, 2, 3], torch.float32, True), + ]) + def forward(self, x): + return self.ly(x) + + +@register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule()) +def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 2, 3)) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index fbab9b5d6..26bd3aebc 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -899,6 +899,25 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ let assemblyFormat = "$input `,` $weight `,` $bias `,` $running_mean `,` $running_var `,` $training `,` $momentum `,` $eps `,` $cudnn_enabled attr-dict `:` type($input) `,` type($weight) `,` type($bias) `,` type($running_mean) `,` type($running_var) `,` type($training) `,` type($momentum) `,` type($eps) `,` type($cudnn_enabled) `->` type($result)"; } +def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + TorchIntListType:$normalized_shape, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enable + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps `,` $cudnn_enable attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `,` type($cudnn_enable) `->` type($result)"; +} + def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index d2f0391bf..31a8ff75c 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -62,6 +62,15 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op, return success(); } +static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, + Value v) { + Type type = v.getType(); + if (type.isa() || type.isa() || + type.isa()) + return rewriter.notifyMatchFailure(op, "unimplemented None type arg"); + return success(); +} + // Hack to deal with the Torch list type arguments which is not supported end // to end. Constant values can be be extracted directly and non constant // list values are not supported. @@ -96,23 +105,40 @@ static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) { return b.create(loc, v, dimension); } -static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDimIndex, - Value rhsDimIndex) { - Value lhsDimInt = castIndexToInt(b, loc, lhsDimIndex); - Value rhsDimInt = castIndexToInt(b, loc, rhsDimIndex); +static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim, + Value rhsDim) { + Type lhsType = lhsDim.getType(); + Type rhsType = rhsDim.getType(); + auto checkIntOrIndex = [](Type type) { + assert(type.isa() || + type.isa() && "must be either integer or index type"); + }; + checkIntOrIndex(lhsType); + checkIntOrIndex(rhsType); + Value lhsDimInt = lhsType.isIndex() ? castIndexToInt(b, loc, lhsDim) : lhsDim; + Value rhsDimInt = rhsType.isIndex() ? castIndexToInt(b, loc, rhsDim) : rhsDim; Value contractingDimEqual = b.create(loc, CmpIPredicate::eq, lhsDimInt, rhsDimInt); b.create(loc, contractingDimEqual, b.getStringAttr("mismatching contracting dimension")); } +static SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, + Value tensor, int dim) { + RankedTensorType type = tensor.getType().cast(); + assert(dim < type.getRank() && + "The given dim must be smaller than tensor rank"); + (void)type; + SmallVector sizes; + for (int i = 0; i <= dim; i++) + sizes.push_back(getDimOp(b, loc, tensor, i)); + return sizes; +} + static SmallVector getTensorSizes(OpBuilder &b, Location loc, Value tensor) { RankedTensorType type = tensor.getType().cast(); - SmallVector sizes; - for (int i = 0; i < type.getRank(); i++) - sizes.push_back(getDimOp(b, loc, tensor, i)); - return sizes; + return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1); } static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, @@ -173,6 +199,19 @@ getAsOpFoldResult(OpBuilder &b, Location loc, SmallVectorImpl &ints) { ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); })); } +// This is a temporary solution to deal with types that are not fully supported +// like list, dict. For those container tyes, this helper can be used to +// convert their elements to valid target type. +// TODO: remove this when list gets full support. +static SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, + TypeConverter *converter, + SmallVectorImpl &vs) { + return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) { + return converter->materializeTargetConversion( + b, loc, converter->convertType(v.getType()), v); + })); +} + // Helper function to get the padding tensor given the padding int values. // It's assumed that the padding on the low end and high end are the same. static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input, @@ -192,6 +231,14 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input, return paddedInput; } +static bool getListConstructElements(Value v, SmallVectorImpl &elems) { + auto listConstruct = v.getDefiningOp(); + if (!listConstruct) + return false; + elems = llvm::to_vector<4>(listConstruct.elements()); + return true; +} + namespace { class ConvertAtenAdaptiveAvgPool2dOp : public OpConversionPattern { @@ -393,6 +440,22 @@ public: }; } // namespace +// Normalization formula: +// ((input - mean) / sqrt(var + eps)) * weight + bias +static Value createLinalgPayloadCalculationForNormOps( + OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var, + Value eps, Value weight, Value bias) { + Value inputSubMean = b.create(loc, input, mean); + // The eps is always f64. + Value truncatedEps = b.create(loc, elemTy, eps); + Value varPlusEps = b.create(loc, var, truncatedEps); + Value rSTD = b.create(loc, varPlusEps); + Value temp = b.create(loc, inputSubMean, rSTD); + Value timesWeight = b.create(loc, temp, weight); + Value plusBias = b.create(loc, timesWeight, bias); + return plusBias; +} + namespace { class ConvertAtenBatchNormOp : public OpConversionPattern { public: @@ -411,11 +474,17 @@ public: Value training = adaptor.training(); Value eps = adaptor.eps(); - // TODO: Handle the None cases for the optional parameters: - // weight, bias, running_mean, running_var. if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); + // TODO: Handle the None cases for the optional parameters: + // weight, bias. + if (failed(checkNotNone(rewriter, op, weight)) || + failed(checkNotNone(rewriter, op, bias)) || + failed(checkNotNone(rewriter, op, runningMean)) || + failed(checkNotNone(rewriter, op, runningVar))) + return failure(); + auto inputType = input.getType().cast(); auto weightType = weight.getType().cast(); auto biasType = bias.getType().cast(); @@ -480,17 +549,10 @@ public: [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], weight = args[1], bias = args[2], mean = args[3], var = args[4]; - // ((input - mean) / sqrt(var + eps)) * weight + bias - Value inputSubMean = b.create(loc, input, mean); - // The eps is always f64. - Value truncatedEps = - b.create(loc, var.getType(), eps); - Value varPlusEps = b.create(loc, var, truncatedEps); - Value rSTD = b.create(loc, varPlusEps); - Value temp = b.create(loc, inputSubMean, rSTD); - Value timesWeight = b.create(loc, temp, weight); - Value plusBias = b.create(loc, timesWeight, bias); - b.create(loc, plusBias); + Value result = createLinalgPayloadCalculationForNormOps( + b, loc, var.getType(), input, mean, var, eps, weight, + bias); + b.create(loc, result); }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); @@ -500,6 +562,228 @@ public: }; } // namespace +// For layernorm, the mean and standard-deviation are calculated separately over +// the last certain number dimensions which have to be of the shape specified by +// normalized_shape. +// +// The shapes of different parts are as the following: +// +-------------------+--------------------+ +// | meanAndVarShape | normalizedShape | +// +-------------------+--------------------- +// <------------+ inputShape +--------------> + +// There are the following steps: +// Step 1. Check if all the arguments meet the requirements. +// Step 2. Common parts to be used for getting mean and var. +// This includes elements count, affineMap and iteratorTypes. +// Step 3. Get mean. +// Step 4. Get var. +// Step 5. Get layernorm. +namespace { +class ConvertAtenLayerNormOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenLayerNormOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + AtenLayerNormOp::Adaptor adaptor(operands); + MLIRContext *context = op->getContext(); + Location loc = op->getLoc(); + Value input = adaptor.input(); + Value weight = adaptor.weight(); + Value bias = adaptor.bias(); + Value eps = adaptor.eps(); + Value normalizedShape = op.normalized_shape(); + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + // TODO: Handle the None cases for the optional parameters: + // weight, bias. + if (failed(checkNotNone(rewriter, op, weight)) || + failed(checkNotNone(rewriter, op, bias))) + return failure(); + + auto inputType = input.getType().cast(); + auto weightType = weight.getType().cast(); + auto biasType = bias.getType().cast(); + int64_t inputRank = inputType.getRank(); + Type elemTy = inputType.getElementType(); + + // Step 1. Check if all the arguments meet the requirements. + SmallVector normalizedShapeSizesTorchInt; + if (!getListConstructElements(normalizedShape, + normalizedShapeSizesTorchInt)) { + return rewriter.notifyMatchFailure(op, + "Unimplemented normalized_shape not" + "constructed from ListConstruct"); + } + SmallVector normalizedShapeSizesInt = getTypeConvertedValues( + rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt); + int64_t normalizedShapeRank = normalizedShapeSizesInt.size(); + if (weightType.getRank() != normalizedShapeRank || + biasType.getRank() != normalizedShapeRank || + inputRank < normalizedShapeRank || normalizedShapeRank < 1) + return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or" + "normalized shape not compatible"); + + // Check all the dimensions match the normalized_shape + int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size(); + for (auto en : enumerate((normalizedShapeSizesInt))) { + auto index = en.index(); + auto inputDim = + getDimOp(rewriter, loc, input, index + meanAndVarShapeRank); + auto weightDim = getDimOp(rewriter, loc, weight, index); + auto biasDim = getDimOp(rewriter, loc, bias, index); + + auto expectedSize = en.value(); + checkDimEqualHelper(rewriter, loc, inputDim, expectedSize); + checkDimEqualHelper(rewriter, loc, weightDim, expectedSize); + checkDimEqualHelper(rewriter, loc, biasDim, expectedSize); + } + + // Get iterator types for input shape. + SmallVector normalizedShapeIteratorTypes( + normalizedShapeRank, getReductionIteratorTypeName()); + SmallVector meanAndVarIterationTypes( + meanAndVarShapeRank, getParallelIteratorTypeName()); + SmallVector inputShapeIteratorTypes = meanAndVarIterationTypes; + inputShapeIteratorTypes.append(normalizedShapeIteratorTypes); + + // Step 2. Common parts to be used for getting mean and var. + + // Get sizes and affineMaps needed for mean and var. + AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank); + SmallVector meanAndVarShapeExprs; + for (int i = 0; i < meanAndVarShapeRank; i++) + meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context)); + auto meanAndVarShapeAffineMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, meanAndVarShapeExprs, context); + SmallVector meanAndVarShapeSizes = + getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1); + + // Get number of elements to be used for calculating mean and var. + Value elemCnts = normalizedShapeSizesInt[0]; + for (int i = 1; i < normalizedShapeRank; i++) { + elemCnts = + rewriter.create(loc, elemCnts, normalizedShapeSizesInt[i]); + } + Value elemCntsFloat = rewriter.create(loc, elemTy, elemCnts); + + // Helper to calculate mean and var. + auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) { + SmallVector indexingMaps( + 2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank)); + Value initShapeTensor = rewriter.create( + loc, meanAndVarShapeSizes, elemTy); + return rewriter + .create( + loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/meanAndVarIterationTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value sumOrSqureSum = args[0]; + Value result = + b.create(loc, sumOrSqureSum, elemCntsFloat); + b.create(loc, result); + }) + .getResult(0); + }; + + // Step 3. Get mean. + + // Get sum to be used for calculating mean. + SmallVector sumIndexingMaps = { + inputShapeAffineMap, // input + meanAndVarShapeAffineMap, // output + }; + auto initSumTensor = + createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy); + Value sum = rewriter + .create( + loc, initSumTensor.getType(), input, initSumTensor, + /*indexingMaps=*/sumIndexingMaps, + /*iteratorTypes=*/inputShapeIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], sum = args[1]; + Value result = + rewriter.create(loc, sum, input); + b.create(loc, result); + }) + .getResult(0); + Value mean = genMeanOrVarCalculation(sum); + + // Step 4. Get var. + + // Calculate squareSum for the layer. + SmallVector squareSumIndexingMaps{ + inputShapeAffineMap, + meanAndVarShapeAffineMap, + meanAndVarShapeAffineMap, + }; + auto initSquareSumTensor = + createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy); + Value squareSum = + rewriter + .create( + loc, initSquareSumTensor.getType(), ValueRange{input, mean}, + initSquareSumTensor, + /*indexingMaps=*/squareSumIndexingMaps, + /*iteratorTypes=*/inputShapeIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], mean = args[1], squareSum = args[2]; + Value sub = rewriter.create(loc, input, mean); + Value square = rewriter.create(loc, sub, sub); + Value result = + rewriter.create(loc, squareSum, square); + b.create(loc, result); + }) + .getResult(0); + Value var = genMeanOrVarCalculation(squareSum); + + // Step 5. Get layernorm. + + // Get affineMap for normalized shape. + SmallVector normalizedShapeExprs; + for (int i = meanAndVarShapeRank; i < inputRank; i++) + normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context)); + auto normalizedShapeAffineMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, normalizedShapeExprs, context); + + auto inputSizes = getTensorSizes(rewriter, loc, input); + Value initLayerNormTensor = + rewriter.create(loc, inputSizes, elemTy); + SmallVector indexingMaps(1, inputShapeAffineMap); + indexingMaps.resize(3, meanAndVarShapeAffineMap); + indexingMaps.resize(5, normalizedShapeAffineMap); + indexingMaps.push_back(inputShapeAffineMap); + SmallVector layerNormIterationTypes( + inputRank, getParallelIteratorTypeName()); + Value layerNorm = + rewriter + .create( + loc, initLayerNormTensor.getType(), + ValueRange{input, mean, var, weight, bias}, initLayerNormTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/layerNormIterationTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], mean = args[1], var = args[2], + weight = args[3], bias = args[4]; + Value result = createLinalgPayloadCalculationForNormOps( + b, loc, elemTy, input, mean, var, eps, weight, bias); + b.create(loc, result); + }) + .getResult(0); + + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, layerNorm); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenMmOp : public OpConversionPattern { public: @@ -1611,6 +1895,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 6555c2563..70702a741 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -196,7 +196,7 @@ public: AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, - AtenCopy_Op, AtenCumsumOp>(op)) { + AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 0afff9e33..2c3d0bb4f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -476,6 +476,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit( "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" ) + emit( + "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" + ) emit( "aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)" ) @@ -593,6 +596,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::div : (Scalar, Scalar) -> (float)") emit("aten::eq.device : (Device, Device) -> (bool)") + def emit_quantized_ops(torch_ir_dir: str, registry: Registry): td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td") with open(td_file, "w") as f: