Add script tools/format_source.sh and run it on all python and c++ sources.

pull/1/head
Stella Laurenzo 2020-06-13 14:53:54 -07:00
parent c3d4436397
commit 2ba8296151
37 changed files with 439 additions and 346 deletions

View File

@ -16,6 +16,6 @@ namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToTCPPass();
}
}
} // namespace mlir
#endif // NPCOMP_CONVERSION_TCFTOTCP_CONVERTTCFTOTCP_H

View File

@ -17,6 +17,6 @@ namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCPToLinalgPass();
}
}
} // namespace mlir
#endif // NPCOMP_CONVERSION_TCPTOLINALG_CONVERTTCPTOLINALG_H

View File

@ -20,7 +20,7 @@ namespace npcomp_rt {
#define GET_OP_CLASSES
#include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h.inc"
} // namespace tcf
} // namespace npcomp_rt
} // namespace NPCOMP
} // namespace mlir

View File

@ -25,8 +25,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLowerBroadcastToToLoopsPass();
std::unique_ptr<OperationPass<FuncOp>>
createLowerLinalgOnTensorToLinalgOnMemrefPass();
std::unique_ptr<OperationPass<FuncOp>>
createResolveShapeOfOpsPass();
std::unique_ptr<OperationPass<FuncOp>> createResolveShapeOfOpsPass();
std::unique_ptr<OperationPass<FuncOp>> createResolveTensorLoadStoreOpsPass();

View File

@ -52,7 +52,7 @@ public:
return success();
}
};
}
} // namespace
namespace {
class ConvertTCFToTCP : public ConvertTCFToTCPBase<ConvertTCFToTCP> {

View File

@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/NpcompRt/IR/NpcompRtDialect.h"
#include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h"
using namespace mlir;
using namespace mlir::NPCOMP::npcomp_rt;
@ -44,4 +44,3 @@ void NpcompRtDialect::printType(Type type, DialectAsmPrinter &os) const {
llvm_unreachable("unexpected 'npcomp_rt' type kind");
}
}

View File

@ -18,6 +18,6 @@ namespace NPCOMP {
namespace npcomp_rt {
#define GET_OP_CLASSES
#include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.cpp.inc"
} // namespace tcf
} // namespace npcomp_rt
} // namespace NPCOMP
} // namespace mlir

View File

@ -29,10 +29,11 @@ LogicalResult ShapeObserveErrorOp::inferReturnTypes(
// GetExtentOp
//===----------------------------------------------------------------------===//
LogicalResult GetExtentOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
LogicalResult
GetExtentOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(IndexType::get(context));
return success();
}

View File

@ -65,21 +65,22 @@ using namespace mlir::NPCOMP;
//===----------------------------------------------------------------------===//
namespace {
class ResolveShapeOfOpViaAllocMemRefOp : public OpRewritePattern<shape::ShapeOfOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
if (auto tensorLoad = llvm::dyn_cast_or_null<TensorLoadOp>(
op.getOperand().getDefiningOp())) {
if (auto allocMemRef = llvm::dyn_cast_or_null<tcp::AllocMemRefOp>(
tensorLoad.getOperand().getDefiningOp())) {
rewriter.replaceOp(op, allocMemRef.getOperand());
return success();
}
class ResolveShapeOfOpViaAllocMemRefOp
: public OpRewritePattern<shape::ShapeOfOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
if (auto tensorLoad = llvm::dyn_cast_or_null<TensorLoadOp>(
op.getOperand().getDefiningOp())) {
if (auto allocMemRef = llvm::dyn_cast_or_null<tcp::AllocMemRefOp>(
tensorLoad.getOperand().getDefiningOp())) {
rewriter.replaceOp(op, allocMemRef.getOperand());
return success();
}
return failure();
}
return failure();
}
};
} // namespace
@ -92,7 +93,7 @@ class ResolveShapeOfOps : public ResolveShapeOfOpsBase<ResolveShapeOfOps> {
OwningRewritePatternList patterns;
patterns.insert<ResolveShapeOfOpViaAllocMemRefOp>(context);
ConversionTarget target(*context);
//target.addIllegalOp<shape::ShapeOfOp>();
// target.addIllegalOp<shape::ShapeOfOp>();
target.addDynamicallyLegalOp<shape::ShapeOfOp>(
[](shape::ShapeOfOp shapeOf) {
// Only shape.shape_of on arguments to the entry block are legal at

View File

@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
#include "npcomp/E2E/E2E.h"
#include "PassDetail.h"
#include "npcomp/E2E/E2E.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
@ -66,7 +66,6 @@ public:
Value outputExtent = rewriter.create<tcp::GetExtentOp>(
op.getLoc(), op.shape(), rewriter.getI64IntegerAttr(i));
outputExtents.push_back(outputExtent);
}
int rankDiff = resultType.getRank() - inputType.getRank();
for (int i = 0, e = inputType.getRank(); i < e; i++) {
@ -108,7 +107,8 @@ public:
}
Value load =
rewriter.create<LoadOp>(op.getLoc(), inputMemref, inputIndices);
rewriter.create<StoreOp>(op.getLoc(), load, resultMemref, inductionVariables);
rewriter.create<StoreOp>(op.getLoc(), load, resultMemref,
inductionVariables);
}
rewriter.replaceOpWithNewOp<TensorLoadOp>(op, resultMemref);
@ -173,91 +173,87 @@ mlir::NPCOMP::createLowerBroadcastToToLoopsPass() {
//===----------------------------------------------------------------------===//
namespace {
class LowerLinalgGenericTensorToMemRef : public OpRewritePattern<linalg::GenericOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
class LowerLinalgGenericTensorToMemRef
: public OpRewritePattern<linalg::GenericOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
// TODO: Replace this with more generic code operating on named
// structured ops too.
// TODO: Replace this with more generic code operating on named
// structured ops too.
// Only handle generic ops where all operands and results are tensors.
if (!llvm::all_of(op.getOperandTypes(), [](Type type) {
return type.isa<RankedTensorType>();
})) {
return rewriter.notifyMatchFailure(op, "all operands must be tensors");
}
if (!llvm::all_of(op.getResultTypes(), [](Type type) {
return type.isa<RankedTensorType>();
})) {
return rewriter.notifyMatchFailure(op, "all results must be tensors");
}
// TODO: Loosen restrictions on indexing maps.
// This will require more principled handling of shape reification
// earlier in the compilation stack, as in general output shapes of a
// linalg.generic cannot be inferred easily.
// See:
// https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866
if (!llvm::all_of(op.indexing_maps(), [](Attribute map) {
return map.cast<AffineMapAttr>().getValue().isIdentity();
})) {
return rewriter.notifyMatchFailure(
op, "all indexing maps must be identity maps");
}
if (!llvm::all_of(op.iterator_types(), [](Attribute str) {
return str.cast<StringAttr>().getValue() ==
getParallelIteratorTypeName();
})) {
return rewriter.notifyMatchFailure(
op, "all iterator types must be 'parallel'");
}
SmallVector<Value, 6> memrefs;
SmallVector<Value, 6> resultMemrefs;
SmallVector<Value, 6> operandShapes;
for (auto tensor : op.getOperands()) {
auto shape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), tensor);
auto memref =
allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
rewriter.create<TensorStoreOp>(op.getLoc(), tensor, memref);
memrefs.push_back(memref);
operandShapes.push_back(shape);
}
auto shapeType = shape::ShapeType::get(rewriter.getContext());
SmallVector<Type, 6> shapeTypes(op.getNumResults(), shapeType);
// TODO: We need more principled handling of output shapes.
// This assumes that all results have the same shape, which is justified
// by checks above, but we really need a better story here.
SmallVector<Value, 6> resultShapes(op.getNumResults(), operandShapes[0]);
for (auto t : llvm::zip(op.getResults(), resultShapes)) {
auto tensor = std::get<0>(t);
auto shape = std::get<1>(t);
auto memref =
allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
memrefs.push_back(memref);
resultMemrefs.push_back(memref);
}
auto newGeneric = rewriter.create<linalg::GenericOp>(
op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs());
newGeneric.region().getBlocks().clear();
BlockAndValueMapping mapper;
op.region().cloneInto(&newGeneric.region(), mapper);
for (auto memref : resultMemrefs) {
newGeneric.region().front().addArgument(
memref.getType().cast<MemRefType>().getElementType());
}
auto newResultTensors =
llvm::to_vector<6>(llvm::map_range(resultMemrefs, [&](Value memref) {
return rewriter.create<TensorLoadOp>(op.getLoc(), memref)
.getResult();
}));
rewriter.replaceOp(op, newResultTensors);
return success();
// Only handle generic ops where all operands and results are tensors.
if (!llvm::all_of(op.getOperandTypes(),
[](Type type) { return type.isa<RankedTensorType>(); })) {
return rewriter.notifyMatchFailure(op, "all operands must be tensors");
}
if (!llvm::all_of(op.getResultTypes(),
[](Type type) { return type.isa<RankedTensorType>(); })) {
return rewriter.notifyMatchFailure(op, "all results must be tensors");
}
// TODO: Loosen restrictions on indexing maps.
// This will require more principled handling of shape reification
// earlier in the compilation stack, as in general output shapes of a
// linalg.generic cannot be inferred easily.
// See:
// https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866
if (!llvm::all_of(op.indexing_maps(), [](Attribute map) {
return map.cast<AffineMapAttr>().getValue().isIdentity();
})) {
return rewriter.notifyMatchFailure(
op, "all indexing maps must be identity maps");
}
if (!llvm::all_of(op.iterator_types(), [](Attribute str) {
return str.cast<StringAttr>().getValue() ==
getParallelIteratorTypeName();
})) {
return rewriter.notifyMatchFailure(
op, "all iterator types must be 'parallel'");
}
SmallVector<Value, 6> memrefs;
SmallVector<Value, 6> resultMemrefs;
SmallVector<Value, 6> operandShapes;
for (auto tensor : op.getOperands()) {
auto shape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), tensor);
auto memref = allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
rewriter.create<TensorStoreOp>(op.getLoc(), tensor, memref);
memrefs.push_back(memref);
operandShapes.push_back(shape);
}
auto shapeType = shape::ShapeType::get(rewriter.getContext());
SmallVector<Type, 6> shapeTypes(op.getNumResults(), shapeType);
// TODO: We need more principled handling of output shapes.
// This assumes that all results have the same shape, which is justified
// by checks above, but we really need a better story here.
SmallVector<Value, 6> resultShapes(op.getNumResults(), operandShapes[0]);
for (auto t : llvm::zip(op.getResults(), resultShapes)) {
auto tensor = std::get<0>(t);
auto shape = std::get<1>(t);
auto memref = allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
memrefs.push_back(memref);
resultMemrefs.push_back(memref);
}
auto newGeneric = rewriter.create<linalg::GenericOp>(
op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs());
newGeneric.region().getBlocks().clear();
BlockAndValueMapping mapper;
op.region().cloneInto(&newGeneric.region(), mapper);
for (auto memref : resultMemrefs) {
newGeneric.region().front().addArgument(
memref.getType().cast<MemRefType>().getElementType());
}
auto newResultTensors =
llvm::to_vector<6>(llvm::map_range(resultMemrefs, [&](Value memref) {
return rewriter.create<TensorLoadOp>(op.getLoc(), memref).getResult();
}));
rewriter.replaceOp(op, newResultTensors);
return success();
}
};
}
} // namespace
namespace {
class LowerLinalgOnTensorToLinalgOnMemref

View File

@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
#include "npcomp/E2E/E2E.h"
#include "PassDetail.h"
#include "npcomp/E2E/E2E.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"

View File

@ -22,66 +22,77 @@ def add():
# CHECK: {{.*}} = basicpy.binary_expr %[[A]] "Add" %[[B]] : (i64, i64) -> !basicpy.UnknownType
return a + b
# CHECK-LABEL: func @sub
@import_global
def sub():
# CHECK: basicpy.binary_expr {{.*}} "Sub"
return 4 - 2
# CHECK-LABEL: func @mult
@import_global
def mult():
# CHECK: basicpy.binary_expr {{.*}} "Mult"
return 4 * 2
# CHECK-LABEL: func @div
@import_global
def div():
# CHECK: basicpy.binary_expr {{.*}} "Div"
return 4 / 2
# CHECK-LABEL: func @floor_div
@import_global
def floor_div():
# CHECK: basicpy.binary_expr {{.*}} "FloorDiv"
return 4 // 2
# CHECK-LABEL: func @matmul
@import_global
def matmul():
# CHECK: basicpy.binary_expr {{.*}} "MatMult"
return 4 @ 2
# CHECK-LABEL: func @modulo
@import_global
def modulo():
# CHECK: basicpy.binary_expr {{.*}} "Mod"
return 4 % 2
# CHECK-LABEL: func @left_shift
@import_global
def left_shift():
# CHECK: basicpy.binary_expr {{.*}} "LShift"
return 4 << 2
# CHECK-LABEL: func @right_shift
@import_global
def right_shift():
# CHECK: basicpy.binary_expr {{.*}} "RShift"
return 4 >> 2
# CHECK-LABEL: func @bit_and
@import_global
def bit_and():
# CHECK: basicpy.binary_expr {{.*}} "BitAnd"
return 4 & 2
# CHECK-LABEL: func @bit_xor
@import_global
def bit_xor():
# CHECK: basicpy.binary_expr {{.*}} "BitXor"
return 4 ^ 2
# CHECK-LABEL: func @bit_or
@import_global
def bit_or():

View File

@ -38,6 +38,7 @@ def logical_and():
# CHECK: }
return x and y and z
# CHECK-LABEL: func @logical_or
@import_global
def logical_or():
@ -65,17 +66,19 @@ def logical_or():
z = 2
return x or y or z
# CHECK-LABEL: func @logical_not
@import_global
def logical_not():
# CHECK: %[[X:.*]] = constant 1
x = 1
# CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true
# CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true
# CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant false
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.to_boolean %[[X]]
# CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType
return not x
# CHECK-LABEL: func @conditional
@import_global
def conditional():

View File

@ -21,6 +21,7 @@ def binary_lt_():
# CHECK: {{.*}} = basicpy.binary_compare %[[A]] "Lt" %[[B]] : i64, i64
return x < y
# CHECK-LABEL: func @binary_gt_
@import_global
def binary_gt_():
@ -29,6 +30,7 @@ def binary_gt_():
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Gt" {{.*}} : i64, i64
return x > y
# CHECK-LABEL: func @binary_lte_
@import_global
def binary_lte_():
@ -37,6 +39,7 @@ def binary_lte_():
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "LtE" {{.*}} : i64, i64
return x <= y
# CHECK-LABEL: func @binary_gte_
@import_global
def binary_gte_():
@ -45,6 +48,7 @@ def binary_gte_():
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "GtE" {{.*}} : i64, i64
return x >= y
# CHECK-LABEL: func @binary_eq_
@import_global
def binary_eq_():
@ -53,6 +57,7 @@ def binary_eq_():
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Eq" {{.*}} : i64, i64
return x == y
# CHECK-LABEL: func @binary_neq_
@import_global
def binary_neq_():
@ -61,6 +66,7 @@ def binary_neq_():
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "NotEq" {{.*}} : i64, i64
return x != y
# CHECK-LABEL: func @binary_is_
@import_global
def binary_is_():
@ -69,6 +75,7 @@ def binary_is_():
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Is" {{.*}} : i64, i64
return x is y
# CHECK-LABEL: func @binary_is_not_
@import_global
def binary_is_not_():
@ -77,6 +84,7 @@ def binary_is_not_():
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "IsNot" {{.*}} : i64, i64
return x is not y
# CHECK-LABEL: func @binary_in_
@import_global
def binary_in_():
@ -85,6 +93,7 @@ def binary_in_():
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "In" {{.*}} : i64, i64
return x in y
# CHECK-LABEL: func @binary_not_in_
@import_global
def binary_not_in_():
@ -93,6 +102,7 @@ def binary_not_in_():
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "NotIn" {{.*}} : i64, i64
return x not in y
@import_global
def short_circuit():
# CHECK: %[[X:.*]] = constant 1 : i64
@ -123,6 +133,7 @@ def short_circuit():
# CHECK: return %[[RESULT]]
return x < y == z >= omega
# CHECK-LABEL: nested_short_circuit_expression
@import_global
def nested_short_circuit_expression():

View File

@ -20,6 +20,7 @@ def integer_constants():
# CHECK: return %[[A_CAST]]
return a
# CHECK-LABEL: func @float_constants
@import_global
def float_constants():
@ -29,6 +30,7 @@ def float_constants():
# CHECK: return %[[A_CAST]]
return a
# CHECK-LABEL: func @bool_true_constant
@import_global
def bool_true_constant():
@ -37,6 +39,7 @@ def bool_true_constant():
a = True
return a
# CHECK-LABEL: func @bool_false_constant
@import_global
def bool_false_constant():
@ -45,6 +48,7 @@ def bool_false_constant():
a = False
return a
# CHECK-LABEL: func @string_constant
@import_global
def string_constant():
@ -53,6 +57,7 @@ def string_constant():
a = "foobar"
return a
# CHECK-LABEL: func @joined_string_constant
@import_global
def joined_string_constant():
@ -61,6 +66,7 @@ def joined_string_constant():
a = "I am" " still here"
return a
# CHECK-LABEL: func @bytes_constant
@import_global
def bytes_constant():
@ -69,6 +75,7 @@ def bytes_constant():
a = b"foobar"
return a
# CHECK-LABEL: func @ellipsis
@import_global
def ellipsis():
@ -77,6 +84,7 @@ def ellipsis():
a = ...
return a
# CHECK-LABEL: func @none_constant
@import_global
def none_constant():

View File

@ -10,6 +10,7 @@ def import_global(f):
print(fe.ir_module.to_asm())
return f
# CHECK-LABEL: func @positional_args
# CHECK-SAME: (%arg0: !basicpy.UnknownType, %arg1: !basicpy.UnknownType) -> !basicpy.UnknownType
@import_global
@ -17,6 +18,7 @@ def positional_args(a, b):
# CHECK: basicpy.binary_expr %arg0 "Add" %arg1
return a + b
# CHECK-LABEL: func @pass_no_return
@import_global
def pass_no_return():

View File

@ -10,6 +10,7 @@ def import_global(f):
print("// -----")
return f
# CHECK-LABEL: func @arithmetic_expression
# CHECK-SAME: () -> i64
@import_global
@ -21,6 +22,7 @@ def arithmetic_expression():
# CHECK: return{{.*}} : i64
return 1 + 2 - 3 * 4
# CHECK-LABEL: func @arg_inference
# CHECK-SAME: (%arg0: i64, %arg1: i64) -> i64
@import_global
@ -31,9 +33,10 @@ def arg_inference(a, b):
# CHECK: return{{.*}} : i64
return a + 2 * b
# CHECK-LABEL: func @conditional_inference
# CHECK-SAME: (%arg0: i64, %arg1: !basicpy.BoolType, %arg2: i64) -> !basicpy.BoolType
@import_global
def conditional_inference(cond, a, b):
# CHECK-NOT: UnknownType
return a if cond + 1 else not(b * 4)
return a if cond + 1 else not (b * 4)

View File

@ -39,15 +39,13 @@ llvm_config.with_system_environment(
llvm_config.use_default_substitutions()
# excludes: A list of files/directories to exclude from the testsuite. The
# 'Inputs'subdirectories contain auxiliary inputs for various tests in their
# excludes: A list of files/directories to exclude from the testsuite. The
# 'Inputs'subdirectories contain auxiliary inputs for various tests in their
# parent directories.
config.excludes = [
'Inputs', 'Examples',
'lit.cfg.py',
'CMakeLists.txt',
'README.txt',
'LICENSE.txt']
'Inputs', 'Examples', 'lit.cfg.py', 'CMakeLists.txt', 'README.txt',
'LICENSE.txt'
]
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
@ -56,22 +54,21 @@ config.test_source_root = os.path.dirname(__file__)
config.test_exec_root = os.path.join(config.npcomp_obj_root, 'pytest')
config.npcomp_tools_dir = os.path.join(config.npcomp_obj_root, 'tools')
config.npcomp_runtime_shlib = os.path.join(
config.npcomp_obj_root,
'runtime',
'libNPCOMPRuntime' + config.llvm_shlib_ext
)
config.npcomp_obj_root, 'runtime',
'libNPCOMPRuntime' + config.llvm_shlib_ext)
# Tweak the PATH and PYTHONPATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
llvm_config.with_environment('PYTHONPATH', [
os.path.join(config.npcomp_obj_root, "python"),
os.path.join(config.npcomp_obj_root, "python_native")],
append_path=True)
os.path.join(config.npcomp_obj_root, "python_native")
],
append_path=True)
tool_dirs = [
os.path.join(config.npcomp_tools_dir, 'npcomp-opt'),
os.path.join(config.npcomp_tools_dir, 'npcomp-run-mlir'),
config.llvm_tools_dir,
os.path.join(config.npcomp_tools_dir, 'npcomp-opt'),
os.path.join(config.npcomp_tools_dir, 'npcomp-run-mlir'),
config.llvm_tools_dir,
]
tools = [
'npcomp-opt',

View File

@ -1,4 +1,3 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

View File

@ -1,4 +1,3 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

View File

@ -65,6 +65,7 @@ class DialectHelper(Basicpy.DialectHelper):
tensor<*x!numpy.any_dtype>
"""
@property
def numpy_any_dtype(self):
return self.context.parse_type("!numpy.any_dtype")

View File

@ -9,9 +9,9 @@ from typing import Optional
from npcomp.types import *
__all__ = [
"Exporter",
"ExportFunction",
"ExportPyFunction",
"Exporter",
"ExportFunction",
"ExportPyFunction",
]
@ -21,40 +21,41 @@ def _value_type_from_annotation(annotation):
return ValueType(TypeClass.NdArray)
else:
return ValueType()
def _signature_from_pyfunc(pyfunc):
pysig = inspect.signature(pyfunc)
sig = Signature(len(pysig.parameters))
# Arguments
for i, param in enumerate(pysig.parameters.values()):
if param.kind not in (
param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
raise ValueError(
"Currently only positional function signature are supported")
"Currently only positional function signature are supported")
sig.arg_names[i] = param.name
annot = param.annotation
if annot is param.empty: continue
if annot is param.empty:
continue
sig.args[i] = _value_type_from_annotation(annot)
# Result
if pysig.return_annotation is not pysig.empty:
sig.result = _value_type_from_annotation(pysig.return_annotation)
return sig
class ExportFunction:
"""Base class for functions that can be exported."""
__slots__ = ["_sig"]
def __init__(self, sig=None):
self._sig = sig if sig else Signature()
@property
def sig(self):
return self._sig
def __repr__(self):
return "def %r" % self._sig
@ -86,21 +87,21 @@ class ExportPyFunction(ExportFunction):
pydef mul(a: NdArray[Rank(2)], b: Any) -> NdArray[Shape(1, 2)]
"""
__slots__ = ExportFunction.__slots__ + ["_pyfunc", "__name__"]
def __init__(self, pyfunc, name=None):
super().__init__(sig=_signature_from_pyfunc(pyfunc))
assert (hasattr(pyfunc, "__call__")
and hasattr(pyfunc, "__name__")), "Not a python function"
assert (hasattr(pyfunc, "__call__") and
hasattr(pyfunc, "__name__")), "Not a python function"
self._pyfunc = pyfunc
self.__name__ = name if name else pyfunc.__name__
@property
def pyfunc(self):
return self._pyfunc
def __repr__(self):
return "pydef %s%r" % (self.__name__, self._sig)
def __call__(self, *args, **kwargs):
return self._pyfunc(*args, **kwargs)
@ -108,53 +109,59 @@ class ExportPyFunction(ExportFunction):
class _ExpandoNode:
"""Expando object that can be indexed into to construct a namespace."""
__slots__ = [
"_parent", "_services", "_local_name", "_parent_name",
"_children", "_attached"]
def __init__(self, parent: Optional["_ExpandoNode"],
services: "_Services",
local_name: str):
"_parent", "_services", "_local_name", "_parent_name", "_children",
"_attached"
]
def __init__(self, parent: Optional["_ExpandoNode"], services: "_Services",
local_name: str):
super().__init__()
object.__setattr__(self, "_parent", parent)
object.__setattr__(self, "_services", services)
object.__setattr__(self, "_local_name", local_name)
object.__setattr__(self, "_parent_name",
object.__setattr__(self, "_parent_name",
parent._get_full_name() if parent else "")
object.__setattr__(self, "_children", {})
object.__setattr__(self, "_attached", parent is None)
def _attach(self):
if self._attached: return
if self._attached:
return
if self._local_name in self._parent._children:
raise KeyError("Cannot re-assign '%s'" % (self._get_full_name(),))
self._parent._attach()
self._parent._children[self._local_name] = self
object.__setattr__(self, "_attached", True)
def _get_full_name(self):
if not self._parent: return "" # Root is always empty name.
full_name = (self._parent_name + "." + self._local_name
if self._parent_name else self._local_name)
if not self._parent:
return "" # Root is always empty name.
full_name = (self._parent_name + "." +
self._local_name if self._parent_name else self._local_name)
return full_name
def _get_child_name(self, child_local_name):
full_name = self._get_full_name()
if not full_name: return child_local_name
else: return full_name + "." + child_local_name
if not full_name:
return child_local_name
else:
return full_name + "." + child_local_name
def __repr__(self):
return "Namespace(\"%s\")" % (self._get_full_name())
def __contains__(self, key):
return key in self._children
def __getitem__(self, key):
key = str(key)
existing = self._children.get(key)
if existing is not None: return existing
if existing is not None:
return existing
# Speculatively create a child expando.
child = _ExpandoNode(self, self._services, key)
return child
def __setitem__(self, key, value):
if not inspect.isfunction(value):
raise TypeError("Cannot assign value to an exporter: %r" % (value,))
@ -164,19 +171,19 @@ class _ExpandoNode:
raise KeyError("Cannot re-assign '%s'" % (child_name))
self._attach()
self._children[key] = self._services.wrap_function(value, child_name)
def __getattr__(self, name):
return self[name]
def __setattr__(self, name, value):
try:
self[name] = value
except KeyError as e:
raise AttributeError(str(e)) from None
def __dir__(self):
return self._children.keys()
class _Services:
"""Services and support for the Exporter.
@ -184,12 +191,14 @@ class _Services:
Exporters are user objects, so most of the functional components are
contained in the associated _Services object.
"""
def wrap_function(self, f, full_name):
if isinstance(f, ExportFunction): return f
if isinstance(f, ExportFunction):
return f
# TODO: Need to scan through providers and choose.
return ExportPyFunction(f, name=full_name)
class Exporter:
"""Top-level UI object for assembling a program for export.
@ -244,31 +253,32 @@ class Exporter:
AttributeError: "Cannot re-assign 'ns1'"
"""
__slots__ = ["_root", "_services"]
def __init__(self):
super().__init__()
services = _Services()
object.__setattr__(self, "_root", _ExpandoNode(None, services, ""))
object.__setattr__(self, "_services", services)
def __repr__(self):
return "Exporter()"
def __contains__(self, key):
return key in self._root
def __getitem__(self, key):
return self._root[key]
def __setitem__(self, key, value):
self._root[key] = value
def __getattr__(self, name):
return getattr(self._root, name)
def __setattr__(self, name, value):
setattr(self._root, name, value)
if __name__ == "__main__":
import doctest
doctest.testmod()
import doctest
doctest.testmod()

View File

@ -1,7 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Test for the MLIR IR Python bindings"""
from _npcomp.mlir import ir
@ -39,5 +38,4 @@ except ValueError as e:
# CHECK: [ERROR]: expected operation name in quotes
print(e)
test_utils.end_filecheck_test(__file__)

View File

@ -1,7 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Test for the MLIR Pass Python bindings"""
from _npcomp.mlir import ir

View File

@ -48,10 +48,11 @@ class TraceContext:
"""
_local = threading.local()
__slots__ = [
"_desc",
"_next_id",
"active",
"_desc",
"_next_id",
"active",
]
def __init__(self, desc=None):
_check_numpy_version()
self._desc = desc
@ -87,11 +88,11 @@ class TraceContext:
@classmethod
def optional_current(cls) -> Optional["TraceContext"]:
s = cls._get_context_stack()
if s:
if s:
return s[-1]
else:
return None
@classmethod
def current(cls) -> "TraceContext":
c = cls.optional_current()
@ -120,7 +121,7 @@ class TraceContext:
def _assert_active(tc: TraceContext):
assert tc.active, (
"Attempt to trace an action on an inactive trace context: %r" % tc)
"Attempt to trace an action on an inactive trace context: %r" % tc)
class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
@ -134,6 +135,7 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
>>> TracedArray(tc=tc)
<TracedArray 2>
"""
def __init__(self, tc: Optional[TraceContext] = None):
self._tc = tc if tc is not None else TraceContext.current()
self._uid = self._tc.get_next_id()
@ -160,28 +162,28 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
def __array_function__(self, func, types, args, kwargs):
tc = self._tc
_assert_active(tc)
_assert_active(tc)
return tc._handle_array_func(func, types, args, kwargs)
@property
def T(self):
"""Shortcut for transpose."""
tc = self._tc
_assert_active(tc)
_assert_active(tc)
return tc._handle_array_func(np.transpose, [TracedArray], [self], {})
def _check_numpy_version():
version = np.lib.NumpyVersion(np.__version__)
if version < "1.16.0":
raise RuntimeError("Numpy version >= 1.16 is required")
if version > "1.17.0": return
if version > "1.17.0":
return
if os.environ.get("NUMPY_EXPERIMENTAL_ARRAY_FUNCTION") != "1":
raise RuntimeError(
"For numpy 1.16, the environment variable "
"NUMPY_EXPERIMENTAL_ARRAY_FUNCTION must equal 1")
raise RuntimeError("For numpy 1.16, the environment variable "
"NUMPY_EXPERIMENTAL_ARRAY_FUNCTION must equal 1")
if __name__ == "__main__":
import doctest
doctest.testmod()
import doctest
doctest.testmod()

View File

@ -9,9 +9,11 @@ from npcomp.utils import test_utils
test_utils.start_filecheck_test()
def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
return a * b + a + b
# TODO: Implement subclassing and deriving constraints by run
exp = Exporter()
exp.simple_mul = simple_mul
@ -22,7 +24,7 @@ exp.simple_mul.sig.args["b"] += Shape(1)
exp.simple_mul.sig.args["b"] += DType(np.float32)
exp.simple_mul.sig.result += Shape(1, 4)
exp.simple_mul.sig.result += DynamicDim(0)
exp.simple_mul.sig.result += DType(np.float32)
exp.simple_mul.sig.result += DType(np.float32)
mb = ModuleBuilder()
mb.trace(exp.simple_mul)

View File

@ -8,30 +8,30 @@ from enum import Enum
import numpy as np
__all__ = [
"Unspec",
"ArrayConstraint",
"ArrayParams",
"DType",
"DimFlag",
"DimFlagEnum",
"DynamicDim",
"Rank",
"Shape",
"Signature",
"TypeClass",
"TypeConstraints",
"ValueType",
"Unspec",
"ArrayConstraint",
"ArrayParams",
"DType",
"DimFlag",
"DimFlagEnum",
"DynamicDim",
"Rank",
"Shape",
"Signature",
"TypeClass",
"TypeConstraints",
"ValueType",
]
# TODO: All supported types
_DTYPE_TO_ASM_DICT = {
np.bool: "i1", # TODO: May need a custom type to signify 8bit storage
np.int8: "s8",
np.int16: "s16",
np.int32: "s32",
np.int64: "s64",
np.float32: "f32",
np.float64: "f64",
np.bool: "i1", # TODO: May need a custom type to signify 8bit storage
np.int8: "s8",
np.int16: "s16",
np.int32: "s32",
np.int64: "s64",
np.float32: "f32",
np.float64: "f64",
}
@ -67,29 +67,39 @@ class _LiterateEnum(Enum):
ValueError: Cannot parse SampleEnum 1.0
"""
@classmethod
def parse(cls, v):
if isinstance(v, cls): return v
if isinstance(v, cls):
return v
if not v or not isinstance(v, str) or v[0] == '_' or not hasattr(cls, v):
raise ValueError("Cannot parse %s %r" % (
cls.__name__.split(".")[-1], v,))
cls.__name__.split(".")[-1],
v,
))
value = getattr(cls, v)
if not isinstance(value, cls):
raise ValueError("Cannot parse %s %r" % (
cls.__name__.split(".")[-1], v,))
cls.__name__.split(".")[-1],
v,
))
return value
def __repr__(self):
return self.name
# Special "unspecified" value that we use throughout.
class _Unspec:
__slots__ = []
def __str__(self):
return "Unspec"
def __repr__(self):
return "Unspec"
Unspec = _Unspec()
@ -97,8 +107,8 @@ class TypeClass(_LiterateEnum):
"""Top level types in the npcomp language."""
Any = 0
NdArray = 1
class ValueType:
"""The type a value can take in the npcomp language.
@ -123,32 +133,32 @@ class ValueType:
NdArray
"""
__slots__ = ["_constraints", "_type_class"]
def __init__(self, type_class=TypeClass.Any, *constraints):
super().__init__()
self._type_class = TypeClass.parse(type_class)
self._constraints = TypeConstraints(constraints)
def __iadd__(self, constraint):
assert isinstance(constraint, TypeConstraint), (
"Can only add constraints to a ValueType")
assert isinstance(
constraint, TypeConstraint), ("Can only add constraints to a ValueType")
self._constraints.append(constraint)
return self
def __repr__(self):
if not self._constraints:
return repr(self._type_class)
return "%r[%s]" % (self._type_class,
", ".join([repr(c) for c in self._constraints]))
return "%r[%s]" % (self._type_class, ", ".join(
[repr(c) for c in self._constraints]))
@property
def type_class(self):
return self._type_class
@type_class.setter
def type_class(self, type_class):
self._type_class = TypeClass.parse(type_class)
@property
def constraints(self):
return self._constraints
@ -179,6 +189,7 @@ class ValueTypeList:
(Any, Any, Any)
"""
__slots__ = ["_list", "_names"]
def __init__(self, arity=0, names=None):
self._list = [ValueType() for _ in range(arity)]
self._names = names
@ -188,18 +199,19 @@ class ValueTypeList:
# Scan for the index.
if self._names:
for i, n in enumerate(self._names):
if n == key: return i
if n == key:
return i
raise KeyError("Unknown key '%s'" % key)
return key
def __getitem__(self, key):
return self._list[self._key_to_index(key)]
def __setitem__(self, key, value):
if not isinstance(value, ValueType):
value = ValueType(value)
self._list[self._key_to_index(key)] = value
def __iter__(self):
return self._list.__iter__()
@ -233,20 +245,21 @@ class Signature:
(a: Any, b: NdArray[Rank(2)]) -> NdArray[Rank(3)]
"""
__slots__ = ["_args", "_arg_names", "_result"]
def __init__(self, arity=0):
super().__init__()
self._result = ValueType()
self._arg_names = [None] * arity
self._args = ValueTypeList(arity, names=self._arg_names)
@property
def args(self):
return self._args
@property
def arg_names(self):
return self._arg_names
@property
def result(self):
return self._result
@ -256,14 +269,14 @@ class Signature:
if not isinstance(value, ValueType):
value = ValueType(value)
self._result = value
def __repr__(self):
args_repr = "(%s)" % (
", ".join(
args_repr = "(%s)" % (", ".join(
((n + ": " + repr(t)) if n else repr(t))
for t, n in zip(self._args, self._arg_names)),)
return "%s -> %r" % (args_repr, self._result)
class ArrayParams:
"""Represents parameters defining how to construct an array.
@ -277,7 +290,7 @@ class ArrayParams:
ArrayParams(dtype=float32, shape=(1, 2, 3))
"""
__slots__ = ["dtype", "shape"]
def __init__(self, dtype=Unspec, shape=Unspec, rank=Unspec):
self.dtype = dtype
if shape is not Unspec:
@ -289,9 +302,10 @@ class ArrayParams:
@property
def rank(self):
if self.shape is Unspec: return Unspec
if self.shape is Unspec:
return Unspec
return len(self.shape)
@classmethod
def from_constraints(cls, constraints):
"""Constructs params for a TypeConstraints list.
@ -341,42 +355,40 @@ class ArrayParams:
shape_c = constraints.one_of(Shape)
rank_c = constraints.one_of(Rank)
dim_flags = constraints.all_of(DimFlag)
dtype = dtype_c.dtype if dtype_c else Unspec
shape = Unspec
# Compute shape
if shape_c:
# TODO: Should be in canonicalizer
if rank_c and rank_c.rank != len(shape_c.dims):
raise ValueError("Conflicting shape and rank: %r vs %r" % (
rank_c, shape_c))
raise ValueError("Conflicting shape and rank: %r vs %r" %
(rank_c, shape_c))
shape = list(shape_c.dims)
elif rank_c:
shape = [-1 for _ in range(rank_c.rank)]
# Apply dim flags
if shape is not Unspec and dim_flags:
for df in dim_flags:
flag, for_dims = df.dim_flag
for d in for_dims:
if d < 0 or d >= len(shape):
raise ValueError("Out of range %r for shape %r" % (
df, shape))
raise ValueError("Out of range %r for shape %r" % (df, shape))
if flag == DimFlagEnum.Dynamic:
shape[d] = -1
return cls(dtype=dtype, shape=shape)
def __repr__(self):
try:
s = "ArrayParams(dtype=%s" % (
self.dtype.__name__ if isinstance(self.dtype, type) else self.dtype,)
s = "ArrayParams(dtype=%s" % (self.dtype.__name__ if isinstance(
self.dtype, type) else self.dtype,)
if self.shape is not Unspec:
s += ", shape=%r" % (tuple(self.shape),)
s += ")"
return s
return s
except:
return "ArrayParams(ERROR)"
@ -400,7 +412,7 @@ class ArrayParams:
if any(d < 0 for d in self.shape):
return False
return True
@property
def mlir_tensor_type_asm(self):
"""Get a corresponding MLIR tensor type.
@ -436,15 +448,16 @@ class ArrayParams:
else:
dtype_asm = _dtype_to_mlir_asm(self.dtype)
if not dtype_asm:
raise ValueError(
"Unsupported MLIR tensor element type %r" % (self.dtype,))
raise ValueError("Unsupported MLIR tensor element type %r" %
(self.dtype,))
if self.shape is Unspec:
shape_asm = "*"
else:
shape_asm = "x".join((str(d) if d >= 0 else "?") for d in self.shape)
if shape_asm: shape_asm += "x"
if shape_asm:
shape_asm += "x"
return "tensor<%s%s>" % (shape_asm, dtype_asm)
def new_ndarray(self):
"""Creates a new ndarray from these params.
@ -458,7 +471,7 @@ class ArrayParams:
if not self.is_concrete:
raise ValueError("%r is not concrete" % (self,))
return np.ndarray(dtype=self.dtype, shape=self.shape)
class TypeConstraint:
"""Base class for type constraints."""
@ -481,57 +494,59 @@ class TypeConstraints(list):
...
AssertionError
"""
def __init__(self, *constraints):
if len(constraints) == 1 and not isinstance(
constraints[0], ArrayConstraint):
if len(constraints) == 1 and not isinstance(constraints[0],
ArrayConstraint):
constraints = constraints[0]
super().__init__(constraints)
assert(all(isinstance(c, ArrayConstraint) for c in self))
assert (all(isinstance(c, ArrayConstraint) for c in self))
def __repr__(self):
return "TypeConstraints(%s)" % (
", ".join([repr(c) for c in self]))
return "TypeConstraints(%s)" % (", ".join([repr(c) for c in self]))
def all_of(self, clazz):
"""Finds all of the given class."""
return [c for c in self if isinstance(c, clazz)]
def one_of(self, clazz):
"""Finds at most one constraint of the given class."""
found = [c for c in self if isinstance(c, clazz)]
if not found: return None
if not found:
return None
if len(found) > 1:
raise ValueError("Conflicting constraints. Expected one of %r. Got %r" % (
clazz, found))
raise ValueError("Conflicting constraints. Expected one of %r. Got %r" %
(clazz, found))
return found[0]
class ArrayConstraint(TypeConstraint):
"""Base class for a constraint on an array's characteristics."""
def implies_dtype(self):
return False
@property
def dtype(self):
raise NotImplementedError()
def implies_rank(self):
return False
@property
def rank(self):
raise NotImplementedError()
def implies_dims(self):
return False
@property
def dims(self):
raise NotImplementedError()
def implies_dim_flag(self):
return False
@property
def dim_flag(self):
raise NotImplementedError()
@ -550,22 +565,22 @@ class DType(ArrayConstraint):
AssertionError
"""
__slots__ = ["_dtype"]
def __init__(self, dtype):
super().__init__()
assert isinstance(dtype, type)
self._dtype = dtype
@property
def dtype(self):
return self._dtype
def implies_dtype(self):
return True
def __repr__(self):
return "DType(%s)" % (self._dtype.__name__,)
class Rank(ArrayConstraint):
"""Establishes a fixed rank for the array.
@ -585,21 +600,21 @@ class Rank(ArrayConstraint):
"""
__slots__ = ["_rank"]
def __init__(self, rank):
super().__init__()
assert(isinstance(rank, int) and rank >= 0)
assert (isinstance(rank, int) and rank >= 0)
self._rank = rank
@property
def rank(self):
return self._rank
def implies_rank(self):
return True
def __repr__(self):
return "Rank(%d)" % (self._rank)
return "Rank(%d)" % (self._rank)
class Shape(ArrayConstraint):
@ -619,29 +634,29 @@ class Shape(ArrayConstraint):
AssertionError
"""
__slots__ = ["_dims"]
def __init__(self, *dims):
super().__init__()
assert(all(d is Unspec or (isinstance(d, int) and d >= 0) for d in dims))
assert (all(d is Unspec or (isinstance(d, int) and d >= 0) for d in dims))
self._dims = tuple(dims)
@property
def dims(self):
return self._dims
def implies_dims(self):
return True
@property
def rank(self):
return len(self._dims)
def implies_rank(self):
return True
def __repr__(self):
return "Shape(%s)" % (", ".join(str(d) for d in self._dims))
class DimFlagEnum(_LiterateEnum):
"""Flag for the kind of DimFlag constraint."""
@ -661,35 +676,35 @@ class DimFlag(ArrayConstraint):
DimFlag(Dynamic, (0, 1))
"""
__slots__ = ["_flag", "_dims"]
def __init__(self, flag, dims=Unspec):
super().__init__()
self._flag = DimFlagEnum.parse(flag)
if isinstance(dims, int):
assert(dims >= 0)
assert (dims >= 0)
self._dims = (dims,)
elif dims is Unspec:
self._dims = Unspec
else:
self._dims = tuple(dims)
assert(all(isinstance(d, int) and d >= 0 for d in self._dims))
assert (all(isinstance(d, int) and d >= 0 for d in self._dims))
def implies_dim_flag(self):
return False
@property
def dim_flag(self):
return self._flag, self._dims
def __repr__(self):
return "DimFlag(%r, %r)" % (self._flag, self._dims)
def DynamicDim(dims=Unspec):
"""Dim flag that signals a dimension should be considered dynamic."""
return DimFlag(DimFlagEnum.Dynamic, dims)
if __name__ == "__main__":
import doctest
doctest.testmod()

View File

@ -13,14 +13,17 @@ _filecheck_binary_var = "FILECHECK_BINARY"
_redirect_io = None
_redirect_context = None
def is_filecheck_disabled():
return _disable_var in os.environ
def start_filecheck_test():
if is_filecheck_disabled():
print("WARNING:FileCheck disabled due to", _disable_var,
"in the environment", file=sys.stderr)
print("WARNING:FileCheck disabled due to",
_disable_var,
"in the environment",
file=sys.stderr)
return
global _redirect_io
global _redirect_context
@ -30,7 +33,8 @@ def start_filecheck_test():
def end_filecheck_test(main_file):
if is_filecheck_disabled(): return
if is_filecheck_disabled():
return
global _redirect_io
global _redirect_context
_redirect_context.__exit__(None, None, None)
@ -44,7 +48,7 @@ def end_filecheck_test(main_file):
filecheck_args = [filecheck_binary, main_file, "--dump-input=fail"]
p = subprocess.Popen(filecheck_args, stdin=subprocess.PIPE)
p.communicate(filecheck_input.encode("UTF-8"))
sys.exit(p.returncode)
sys.exit(p.returncode)
def run_under_filecheck(main_file, callback, disable_filecheck=False):
@ -60,8 +64,10 @@ def run_under_filecheck(main_file, callback, disable_filecheck=False):
disable_filecheck: Whether to disable filecheck.
"""
if disable_filecheck or is_filecheck_disabled():
print("WARNING:FileCheck disabled due to", _disable_var,
"in the environment", file=sys.stderr)
print("WARNING:FileCheck disabled due to",
_disable_var,
"in the environment",
file=sys.stderr)
callback()
sys.exit(0)

View File

@ -4,26 +4,26 @@ import os
import subprocess
import sys
TEST_MODULES = (
"npcomp.mlir_ir_test",
"npcomp.mlir_pass_test",
"npcomp.dialect.Basicpy",
"npcomp.dialect.Numpy",
"npcomp.tracing.context",
"npcomp.tracing.mlir_trace",
"npcomp.types",
"npcomp.exporter",
"npcomp.tracing.mlir_trace_test",
"npcomp.mlir_ir_test",
"npcomp.mlir_pass_test",
"npcomp.dialect.Basicpy",
"npcomp.dialect.Numpy",
"npcomp.tracing.context",
"npcomp.tracing.mlir_trace",
"npcomp.types",
"npcomp.exporter",
"npcomp.tracing.mlir_trace_test",
)
# Compute PYTHONPATH for sub processes.
DIRSEP = os.path.pathsep
LOCAL_PYTHONPATH_COMPONENTS = [
# This directory.
os.path.abspath(os.path.dirname(__file__)),
# The parallel python_native directory (assuming in the build tree).
os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "python_native"))
# This directory.
os.path.abspath(os.path.dirname(__file__)),
# The parallel python_native directory (assuming in the build tree).
os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "python_native"))
]
PYTHONPATH = DIRSEP.join(LOCAL_PYTHONPATH_COMPONENTS)
if "PYTHONPATH" in os.environ:
@ -33,9 +33,8 @@ CHILD_ENVIRON["PYTHONPATH"] = PYTHONPATH
# Configure filecheck.
FILECHECK_BINARY = os.path.abspath(
os.path.join(
os.path.dirname(__file__),
"..", "..", "..", "bin", "FileCheck"))
os.path.join(os.path.dirname(__file__), "..", "..", "..", "bin",
"FileCheck"))
if os.path.exists(FILECHECK_BINARY):
CHILD_ENVIRON["FILECHECK_BINARY"] = FILECHECK_BINARY
else:
@ -47,14 +46,14 @@ failed = []
for test_module in TEST_MODULES:
print("--------====== RUNNING %s ======--------" % test_module)
try:
subprocess.check_call([sys.executable, "-Wignore", "-m", test_module],
subprocess.check_call([sys.executable, "-Wignore", "-m", test_module],
env=CHILD_ENVIRON)
print("--------====== DONE %s ======--------\n" % test_module)
passed.append(test_module)
except subprocess.CalledProcessError:
print("!!!!!!!!====== ERROR %s ======!!!!!!!!\n" % test_module)
failed.append(test_module)
print("Done: %d passed, %d failed" % (len(passed), len(failed)))
if failed:
for test_module in failed:

View File

@ -4,6 +4,7 @@
from npcomp.compiler.frontend import *
def binary_expression():
a = 1
b = 100
@ -11,6 +12,7 @@ def binary_expression():
c = c * 2.0
return c
fe = ImportFrontend()
try:
f = fe.import_global_function(binary_expression)

View File

@ -9,9 +9,11 @@ from npcomp.types import *
weights = np.random.uniform(size=(16, 4)).astype(np.float32)
bias = np.random.uniform(size=(4,)).astype(np.float32)
def constants(a: np.ndarray) -> np.ndarray:
return np.dot(a, weights) + bias
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.constants = constants

View File

@ -6,20 +6,22 @@ import numpy as np
import npcomp as npc
from npcomp.types import *
def dot2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
return np.dot(a, b)
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.dot2d = dot2d
exp.dot2d.sig.args["a"] += Shape(4, 16)
exp.dot2d.sig.args["a"] += DynamicDim(0)
exp.dot2d.sig.args["a"] += DType(np.float32)
exp.dot2d.sig.args["b"] += Shape(16,32)
exp.dot2d.sig.args["b"] += Shape(16, 32)
exp.dot2d.sig.args["b"] += DType(np.float32)
exp.dot2d.sig.result += Shape(4, 32)
exp.dot2d.sig.result += DynamicDim(0)
exp.dot2d.sig.result += DType(np.float32)
exp.dot2d.sig.result += DType(np.float32)
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.dot2d)

View File

@ -6,9 +6,11 @@ import numpy as np
import npcomp as npc
from npcomp.types import *
def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
return a * b + a + b
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.simple_mul = simple_mul
@ -19,7 +21,7 @@ exp.simple_mul.sig.args["b"] += Shape(1)
exp.simple_mul.sig.args["b"] += DType(np.float32)
exp.simple_mul.sig.result += Shape(1, 4)
exp.simple_mul.sig.result += DynamicDim(0)
exp.simple_mul.sig.result += DType(np.float32)
exp.simple_mul.sig.result += DType(np.float32)
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.simple_mul)

View File

@ -6,9 +6,11 @@ import numpy as np
import npcomp as npc
from npcomp.types import *
def slice_array1(a: np.ndarray) -> np.ndarray:
return a[1, 2:10:2, 3:4, ..., :, 0]
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.slice_array1 = slice_array1

View File

@ -6,12 +6,15 @@ import numpy as np
import npcomp as npc
from npcomp.types import *
def transpose_attribute(a: np.ndarray) -> np.ndarray:
return a.T
def transpose(a: np.ndarray) -> np.ndarray:
return np.transpose(a)
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.transpose_attribute = transpose_attribute

View File

@ -187,11 +187,10 @@ private:
class PyDialectHelper {
public:
PyDialectHelper(PyContext &context, PyOpBuilder &builder)
: context(context), pyOpBuilder(builder) {}
: context(context), pyOpBuilder(builder) {}
static void bind(py::module m);
MLIRContext *getContext() {
return pyOpBuilder.getContext();
}
MLIRContext *getContext() { return pyOpBuilder.getContext(); }
protected:
PyContext &context;
PyOpBuilder &pyOpBuilder;

View File

@ -0,0 +1,20 @@
#!/bin/bash
# Formats all source files.
set +e
td="$(dirname $0)/.."
function find_cc_sources() {
local dir="$1"
find "$dir" -name "*.h"
find "$dir" -name "*.cpp"
}
# C/C++ sources.
set -o xtrace
clang-format -i \
$(find_cc_sources include) \
$(find_cc_sources lib) \
$(find_cc_sources python_native)
# Python sources.
yapf --recursive -i "$td/python" "$td/pytest"