//===- mlir_utils.h ---------------------------------------------*- C++ -*-===// // // This file is licensed under a pytorch-style license // See frontends/pytorch/LICENSE for license information. // //===----------------------------------------------------------------------===// #ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_MLIR_UTILS_H #define NPCOMP_FRONTENDS_PYTORCH_CSRC_MLIR_UTILS_H #include #include #include #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "c10/util/ArrayRef.h" #include "c10/util/Optional.h" namespace torch_mlir { inline MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } inline MlirStringRef toMlirStringRef(const char *s) { return mlirStringRefCreate(s, std::strlen(s)); } inline MlirNamedAttribute toMlirNamedAttribute(const char *s, MlirAttribute attr) { MlirContext context = mlirAttributeGetContext(attr); MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s)); return mlirNamedAttributeGet(ident, attr); } inline void addToMlirOperationState(MlirOperationState &state, MlirNamedAttribute namedAttr) { mlirOperationStateAddAttributes(&state, 1, &namedAttr); } inline void addToMlirOperationState(MlirOperationState &state, MlirRegion region) { mlirOperationStateAddOwnedRegions(&state, 1, ®ion); } inline void addToMlirOperationState(MlirOperationState &state, MlirValue value) { mlirOperationStateAddOperands(&state, 1, &value); } inline void addToMlirOperationState(MlirOperationState &state, const std::vector &values) { mlirOperationStateAddOperands(&state, values.size(), values.data()); } inline void addToMlirOperationState(MlirOperationState &state, c10::ArrayRef values) { mlirOperationStateAddOperands(&state, values.size(), values.data()); } inline void addToMlirOperationState(MlirOperationState &state, MlirType resultType) { mlirOperationStateAddResults(&state, 1, &resultType); } inline void addToMlirOperationState(MlirOperationState &state, const std::vector &resultTypes) { mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); } inline void addToMlirOperationState(MlirOperationState &state, c10::ArrayRef resultTypes) { mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); } template void addToMlirOperationState(MlirOperationState &state, c10::optional o) { if (o.has_value()) { addToMlirOperationState(state, o.value()); } } inline void addToMlirOperationState(MlirOperationState &state) {} template void addToMlirOperationState(MlirOperationState &state, T &&t, U &&u, Ts &&...ts) { addToMlirOperationState(state, std::forward(t)); addToMlirOperationState(state, std::forward(u), std::forward(ts)...); } template MlirOperation createMlirOperation(std::string name, MlirLocation loc, Ts &&...ts) { MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc); addToMlirOperationState(state, std::forward(ts)...); return mlirOperationCreate(&state); } template MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name, MlirLocation loc, Ts &&...ts) { MlirOperation operation = createMlirOperation(name, loc, std::forward(ts)...); mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block), operation); return operation; } } // namespace torch_mlir #endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_MLIR_UTILS_H