torch-mlir/lib/Conversion/TorchToLinalg/PopulatePatterns.h

70 lines
3.6 KiB
C++

//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace torch {
namespace torch_to_linalg {
// -----------------------------------------------------------------------------
// TorchToLinalg conversion patterns
// -----------------------------------------------------------------------------
//
// Most of these patterns consist of:
// 1. Checking that the operand/result types and other static properties are
// good-enough to create a valid linalg op (such as operands being of
// ranks/dtypes acceptable to the linalg op).
// 2. Creating dynamic error guards, usually checking a predicate on the
// compatibility of operand shapes.
// 3. Creating init tensors for the computation op. Usually this involves
// reifying IR for a shape transfer function based on the operand shapes.
// 4. Creating a named linalg op to replace the original op.
//
// TODO: This should be more automated, and ideally written as fully-executable
// Python code that is tested for correctness at the Python level and then
// just serialized (e.g. as a TorchScript function) and inlined to provide
// the implementation of the op.
// TODO: Linalg is not a stable format. As much as possible of this should
// lower through a more stable format, such as TOSA. However, TOSA currently
// lacks enough support for dynamic shapes and error assertions to be used
// for this purpose.
void populateTensorScalarInteropPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
void populateLinearPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populatePoolingPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populateRandomPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populateUncategorizedPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populateReductionPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populateDataMovementPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populateIndirectDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
void populateTensorConstructorsPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir