mirror of https://github.com/llvm/torch-mlir
48 lines
2.1 KiB
C++
48 lines
2.1 KiB
C++
//===------------------------------------------------------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
namespace mlir {
|
|
namespace torch {
|
|
namespace torch_to_linalg {
|
|
|
|
// Helper function to get the padding tensor given the padding int values.
|
|
Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
|
SmallVectorImpl<int64_t> &lowPaddingInts,
|
|
SmallVectorImpl<int64_t> &highPaddingInts, Value pad);
|
|
|
|
// Helper function to get the padding tensor given the padding int values.
|
|
// It's assumed that the padding on the low end and high end are the same,
|
|
// and that zero padding is required.
|
|
Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
|
SmallVectorImpl<int64_t> &paddingInts);
|
|
|
|
// Helper function to caculate the output tensor dims for convolution-like ops.
|
|
// Along each dim:
|
|
// dim_out =
|
|
// floor((dim_in + 2 * padding - dilation * (kernelSize - 1) - 1) / stride) + 1
|
|
Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
|
|
Value paddingInt, Value dilationInt,
|
|
Value kernelSizeInt, Value strideInt);
|
|
|
|
// Create a reduction of `tensorOperand`, reducing along the dimensions
|
|
// in `dimSet`. If `keepDim` is true, the output tensor is the same
|
|
// rank as the `tensorOperand` and reduced dimensions are set to size 1.
|
|
// `initElem` is the element used to initialize the output tensor
|
|
// where the reduction will be stored.
|
|
Value createReductionLinalgGeneric(
|
|
OpBuilder &b, Location loc, Value tensorOperand,
|
|
const DenseSet<int64_t> &dimSet, bool keepDim, Value initElem,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
|
|
|
|
} // namespace torch_to_linalg
|
|
} // namespace torch
|
|
} // namespace mlir
|