mirror of https://github.com/llvm/torch-mlir
70 lines
3.6 KiB
C++
70 lines
3.6 KiB
C++
//===------------------------------------------------------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
namespace mlir {
|
|
namespace torch {
|
|
namespace torch_to_linalg {
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// TorchToLinalg conversion patterns
|
|
// -----------------------------------------------------------------------------
|
|
//
|
|
// Most of these patterns consist of:
|
|
// 1. Checking that the operand/result types and other static properties are
|
|
// good-enough to create a valid linalg op (such as operands being of
|
|
// ranks/dtypes acceptable to the linalg op).
|
|
// 2. Creating dynamic error guards, usually checking a predicate on the
|
|
// compatibility of operand shapes.
|
|
// 3. Creating init tensors for the computation op. Usually this involves
|
|
// reifying IR for a shape transfer function based on the operand shapes.
|
|
// 4. Creating a named linalg op to replace the original op.
|
|
//
|
|
// TODO: This should be more automated, and ideally written as fully-executable
|
|
// Python code that is tested for correctness at the Python level and then
|
|
// just serialized (e.g. as a TorchScript function) and inlined to provide
|
|
// the implementation of the op.
|
|
// TODO: Linalg is not a stable format. As much as possible of this should
|
|
// lower through a more stable format, such as TOSA. However, TOSA currently
|
|
// lacks enough support for dynamic shapes and error assertions to be used
|
|
// for this purpose.
|
|
|
|
void populateTensorScalarInteropPatternsAndLegality(
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
ConversionTarget &target);
|
|
void populateLinearPatternsAndLegality(TypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
ConversionTarget &target);
|
|
void populatePoolingPatternsAndLegality(TypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
ConversionTarget &target);
|
|
void populateRandomPatternsAndLegality(TypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
ConversionTarget &target);
|
|
void populateUncategorizedPatternsAndLegality(TypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
ConversionTarget &target);
|
|
void populateReductionPatternsAndLegality(TypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
ConversionTarget &target);
|
|
void populateDataMovementPatternsAndLegality(TypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
ConversionTarget &target);
|
|
void populateIndirectDataMovementPatternsAndLegality(
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
ConversionTarget &target);
|
|
void populateTensorConstructorsPatternsAndLegality(TypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
ConversionTarget &target);
|
|
|
|
} // namespace torch_to_linalg
|
|
} // namespace torch
|
|
} // namespace mlir
|