2020-10-23 14:31:34 +08:00
|
|
|
//===- OpInterfaces.h -------------------------------------------*- C++ -*-===//
|
|
|
|
//
|
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#ifndef NPCOMP_DIALECT_TORCH_IR_OPINTERFACES_H
|
|
|
|
#define NPCOMP_DIALECT_TORCH_IR_OPINTERFACES_H
|
|
|
|
|
|
|
|
#include "mlir/IR/Types.h"
|
|
|
|
#include "llvm/ADT/BitmaskEnum.h"
|
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
namespace NPCOMP {
|
|
|
|
namespace Torch {
|
|
|
|
|
|
|
|
/// Conversion rule to apply to a value (argument or return).
|
|
|
|
namespace KernelValueConversion {
|
|
|
|
enum BitMask {
|
|
|
|
// No coercion.
|
|
|
|
kNone = 0,
|
|
|
|
|
|
|
|
// Coerce/require an immutable tensor value.
|
|
|
|
kImmutableTensor = 2,
|
|
|
|
|
|
|
|
// Coerce/require a mutable tensor value.
|
|
|
|
kMutableTensor = 4,
|
|
|
|
|
2020-10-28 09:13:23 +08:00
|
|
|
// If the source is a Scalar and the target is a Tensor, promotes
|
|
|
|
// to a 0d tensor.
|
|
|
|
kPromoteScalar = 8,
|
|
|
|
|
|
|
|
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kPromoteScalar)
|
2020-10-23 14:31:34 +08:00
|
|
|
};
|
|
|
|
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
|
|
|
|
} // namespace KernelValueConversion
|
|
|
|
|
|
|
|
/// Metadata for building a recognized/named kernel from a generic
|
|
|
|
/// torch.kernel_call.
|
|
|
|
struct KernelMetadata {
|
|
|
|
/// The Torch kernel name.
|
|
|
|
StringRef kernelName;
|
|
|
|
|
|
|
|
/// Whether the kernel supports variable args or returns (extending the
|
|
|
|
/// formal args and returns).
|
|
|
|
bool isVararg = false;
|
|
|
|
bool isVarret = false;
|
|
|
|
|
|
|
|
/// Formal argument and return types in Torch signature syntax.
|
|
|
|
SmallVector<StringRef, 4> argTypes;
|
|
|
|
SmallVector<StringRef, 4> returnTypes;
|
|
|
|
|
|
|
|
void addArgTypes(std::initializer_list<StringRef> ilist) {
|
|
|
|
argTypes.insert(argTypes.end(), ilist);
|
|
|
|
}
|
|
|
|
|
|
|
|
void addReturnTypes(std::initializer_list<StringRef> ilist) {
|
|
|
|
returnTypes.insert(returnTypes.end(), ilist);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Extended metadata for constructing a kernel op.
|
|
|
|
struct BuildKernelMetadata : public KernelMetadata {
|
|
|
|
/// Many ops have a variant with a single trailing parameter one argument
|
|
|
|
/// past the formal argument list of the buildable op. If this flag is true,
|
|
|
|
/// it activates a heuristic that will recognize this situation and promote
|
|
|
|
/// it to an appropriate store of the result. It is deemed of little value
|
|
|
|
/// to support each one of these variants as a dedicated op when they can
|
|
|
|
/// all be handled with this flag.
|
|
|
|
bool promoteTrailingOutTensor = false;
|
|
|
|
|
|
|
|
SmallVector<KernelValueConversion::BitMask, 4> argConversions;
|
|
|
|
SmallVector<KernelValueConversion::BitMask, 4> returnConversions;
|
|
|
|
|
|
|
|
void addArgConversions(
|
|
|
|
std::initializer_list<KernelValueConversion::BitMask> ilist) {
|
|
|
|
argConversions.insert(argConversions.end(), ilist);
|
|
|
|
}
|
|
|
|
void addReturnConversions(
|
|
|
|
std::initializer_list<KernelValueConversion::BitMask> ilist) {
|
|
|
|
returnConversions.insert(returnConversions.end(), ilist);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Gets the arg conversion flag for arg 'i'. Returns kNone if conversion
|
|
|
|
/// flags are not defined that far.
|
|
|
|
KernelValueConversion::BitMask getArgConversion(size_t i) const {
|
|
|
|
if (i >= argConversions.size())
|
|
|
|
return KernelValueConversion::BitMask::kNone;
|
|
|
|
return argConversions[i];
|
|
|
|
}
|
|
|
|
/// Gets the return conversion flag for arg 'i'. Returns kNone if conversion
|
|
|
|
/// flags are not defined that far.
|
|
|
|
KernelValueConversion::BitMask getReturnConversion(size_t i) const {
|
|
|
|
if (i >= returnConversions.size())
|
|
|
|
return KernelValueConversion::BitMask::kNone;
|
|
|
|
return returnConversions[i];
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace Torch
|
|
|
|
|
|
|
|
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h.inc"
|
|
|
|
} // namespace NPCOMP
|
|
|
|
} // namespace mlir
|
|
|
|
|
|
|
|
#endif // NPCOMP_DIALECT_TORCH_IR_OPINTERFACES_H
|