diff --git a/include/npcomp/Conversion/TCFToTCP/TCFToTCP.h b/include/npcomp/Conversion/TCFToTCP/TCFToTCP.h index 074b26561..1947b74ba 100644 --- a/include/npcomp/Conversion/TCFToTCP/TCFToTCP.h +++ b/include/npcomp/Conversion/TCFToTCP/TCFToTCP.h @@ -16,6 +16,6 @@ namespace mlir { namespace NPCOMP { std::unique_ptr> createConvertTCFToTCPPass(); } -} +} // namespace mlir #endif // NPCOMP_CONVERSION_TCFTOTCP_CONVERTTCFTOTCP_H diff --git a/include/npcomp/Conversion/TCPToLinalg/TCPToLinalg.h b/include/npcomp/Conversion/TCPToLinalg/TCPToLinalg.h index 4d4bae094..d00ed07df 100644 --- a/include/npcomp/Conversion/TCPToLinalg/TCPToLinalg.h +++ b/include/npcomp/Conversion/TCPToLinalg/TCPToLinalg.h @@ -17,6 +17,6 @@ namespace mlir { namespace NPCOMP { std::unique_ptr> createConvertTCPToLinalgPass(); } -} +} // namespace mlir #endif // NPCOMP_CONVERSION_TCPTOLINALG_CONVERTTCPTOLINALG_H diff --git a/include/npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h b/include/npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h index 61755fbf2..7b20259bd 100644 --- a/include/npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h +++ b/include/npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h @@ -20,7 +20,7 @@ namespace npcomp_rt { #define GET_OP_CLASSES #include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h.inc" -} // namespace tcf +} // namespace npcomp_rt } // namespace NPCOMP } // namespace mlir diff --git a/include/npcomp/E2E/E2E.h b/include/npcomp/E2E/E2E.h index 3221f85a1..c7e30627b 100644 --- a/include/npcomp/E2E/E2E.h +++ b/include/npcomp/E2E/E2E.h @@ -25,8 +25,7 @@ std::unique_ptr> createLowerBroadcastToToLoopsPass(); std::unique_ptr> createLowerLinalgOnTensorToLinalgOnMemrefPass(); -std::unique_ptr> -createResolveShapeOfOpsPass(); +std::unique_ptr> createResolveShapeOfOpsPass(); std::unique_ptr> createResolveTensorLoadStoreOpsPass(); diff --git a/lib/Conversion/TCFToTCP/TCFToTCP.cpp b/lib/Conversion/TCFToTCP/TCFToTCP.cpp index 74fdef43a..56bc37505 100644 --- a/lib/Conversion/TCFToTCP/TCFToTCP.cpp +++ b/lib/Conversion/TCFToTCP/TCFToTCP.cpp @@ -52,7 +52,7 @@ public: return success(); } }; -} +} // namespace namespace { class ConvertTCFToTCP : public ConvertTCFToTCPBase { diff --git a/lib/Dialect/NpcompRt/IR/NpcompRtDialect.cpp b/lib/Dialect/NpcompRt/IR/NpcompRtDialect.cpp index 8c0ff8894..b2f8ce47c 100644 --- a/lib/Dialect/NpcompRt/IR/NpcompRtDialect.cpp +++ b/lib/Dialect/NpcompRt/IR/NpcompRtDialect.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "npcomp/Dialect/NpcompRt/IR/NpcompRtDialect.h" -#include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h" #include "mlir/IR/DialectImplementation.h" +#include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h" using namespace mlir; using namespace mlir::NPCOMP::npcomp_rt; @@ -44,4 +44,3 @@ void NpcompRtDialect::printType(Type type, DialectAsmPrinter &os) const { llvm_unreachable("unexpected 'npcomp_rt' type kind"); } } - diff --git a/lib/Dialect/NpcompRt/IR/NpcompRtOps.cpp b/lib/Dialect/NpcompRt/IR/NpcompRtOps.cpp index a751cf181..6fc0d7c23 100644 --- a/lib/Dialect/NpcompRt/IR/NpcompRtOps.cpp +++ b/lib/Dialect/NpcompRt/IR/NpcompRtOps.cpp @@ -18,6 +18,6 @@ namespace NPCOMP { namespace npcomp_rt { #define GET_OP_CLASSES #include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.cpp.inc" -} // namespace tcf +} // namespace npcomp_rt } // namespace NPCOMP } // namespace mlir diff --git a/lib/Dialect/TCP/IR/TCPOps.cpp b/lib/Dialect/TCP/IR/TCPOps.cpp index 44c8fa0a0..a0289b230 100644 --- a/lib/Dialect/TCP/IR/TCPOps.cpp +++ b/lib/Dialect/TCP/IR/TCPOps.cpp @@ -29,10 +29,11 @@ LogicalResult ShapeObserveErrorOp::inferReturnTypes( // GetExtentOp //===----------------------------------------------------------------------===// -LogicalResult GetExtentOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +GetExtentOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(IndexType::get(context)); return success(); } diff --git a/lib/E2E/E2E.cpp b/lib/E2E/E2E.cpp index 5f25f5571..4092b7f77 100644 --- a/lib/E2E/E2E.cpp +++ b/lib/E2E/E2E.cpp @@ -65,21 +65,22 @@ using namespace mlir::NPCOMP; //===----------------------------------------------------------------------===// namespace { -class ResolveShapeOfOpViaAllocMemRefOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(shape::ShapeOfOp op, - PatternRewriter &rewriter) const override { - if (auto tensorLoad = llvm::dyn_cast_or_null( - op.getOperand().getDefiningOp())) { - if (auto allocMemRef = llvm::dyn_cast_or_null( - tensorLoad.getOperand().getDefiningOp())) { - rewriter.replaceOp(op, allocMemRef.getOperand()); - return success(); - } +class ResolveShapeOfOpViaAllocMemRefOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(shape::ShapeOfOp op, + PatternRewriter &rewriter) const override { + if (auto tensorLoad = llvm::dyn_cast_or_null( + op.getOperand().getDefiningOp())) { + if (auto allocMemRef = llvm::dyn_cast_or_null( + tensorLoad.getOperand().getDefiningOp())) { + rewriter.replaceOp(op, allocMemRef.getOperand()); + return success(); } - return failure(); } + return failure(); + } }; } // namespace @@ -92,7 +93,7 @@ class ResolveShapeOfOps : public ResolveShapeOfOpsBase { OwningRewritePatternList patterns; patterns.insert(context); ConversionTarget target(*context); - //target.addIllegalOp(); + // target.addIllegalOp(); target.addDynamicallyLegalOp( [](shape::ShapeOfOp shapeOf) { // Only shape.shape_of on arguments to the entry block are legal at diff --git a/lib/E2E/LowerToHybridTensorMemRef.cpp b/lib/E2E/LowerToHybridTensorMemRef.cpp index e24ef73b4..924d9a4c2 100644 --- a/lib/E2E/LowerToHybridTensorMemRef.cpp +++ b/lib/E2E/LowerToHybridTensorMemRef.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "npcomp/E2E/E2E.h" #include "PassDetail.h" +#include "npcomp/E2E/E2E.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" @@ -66,7 +66,6 @@ public: Value outputExtent = rewriter.create( op.getLoc(), op.shape(), rewriter.getI64IntegerAttr(i)); outputExtents.push_back(outputExtent); - } int rankDiff = resultType.getRank() - inputType.getRank(); for (int i = 0, e = inputType.getRank(); i < e; i++) { @@ -108,7 +107,8 @@ public: } Value load = rewriter.create(op.getLoc(), inputMemref, inputIndices); - rewriter.create(op.getLoc(), load, resultMemref, inductionVariables); + rewriter.create(op.getLoc(), load, resultMemref, + inductionVariables); } rewriter.replaceOpWithNewOp(op, resultMemref); @@ -173,91 +173,87 @@ mlir::NPCOMP::createLowerBroadcastToToLoopsPass() { //===----------------------------------------------------------------------===// namespace { -class LowerLinalgGenericTensorToMemRef : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(linalg::GenericOp op, - PatternRewriter &rewriter) const override { +class LowerLinalgGenericTensorToMemRef + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::GenericOp op, + PatternRewriter &rewriter) const override { - // TODO: Replace this with more generic code operating on named - // structured ops too. + // TODO: Replace this with more generic code operating on named + // structured ops too. - // Only handle generic ops where all operands and results are tensors. - if (!llvm::all_of(op.getOperandTypes(), [](Type type) { - return type.isa(); - })) { - return rewriter.notifyMatchFailure(op, "all operands must be tensors"); - } - if (!llvm::all_of(op.getResultTypes(), [](Type type) { - return type.isa(); - })) { - return rewriter.notifyMatchFailure(op, "all results must be tensors"); - } - - // TODO: Loosen restrictions on indexing maps. - // This will require more principled handling of shape reification - // earlier in the compilation stack, as in general output shapes of a - // linalg.generic cannot be inferred easily. - // See: - // https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866 - if (!llvm::all_of(op.indexing_maps(), [](Attribute map) { - return map.cast().getValue().isIdentity(); - })) { - return rewriter.notifyMatchFailure( - op, "all indexing maps must be identity maps"); - } - if (!llvm::all_of(op.iterator_types(), [](Attribute str) { - return str.cast().getValue() == - getParallelIteratorTypeName(); - })) { - return rewriter.notifyMatchFailure( - op, "all iterator types must be 'parallel'"); - } - - SmallVector memrefs; - SmallVector resultMemrefs; - SmallVector operandShapes; - for (auto tensor : op.getOperands()) { - auto shape = rewriter.create(op.getLoc(), tensor); - auto memref = - allocMemRefForTensor(rewriter, tensor, shape, op.getLoc()); - rewriter.create(op.getLoc(), tensor, memref); - memrefs.push_back(memref); - operandShapes.push_back(shape); - } - auto shapeType = shape::ShapeType::get(rewriter.getContext()); - SmallVector shapeTypes(op.getNumResults(), shapeType); - // TODO: We need more principled handling of output shapes. - // This assumes that all results have the same shape, which is justified - // by checks above, but we really need a better story here. - SmallVector resultShapes(op.getNumResults(), operandShapes[0]); - for (auto t : llvm::zip(op.getResults(), resultShapes)) { - auto tensor = std::get<0>(t); - auto shape = std::get<1>(t); - auto memref = - allocMemRefForTensor(rewriter, tensor, shape, op.getLoc()); - memrefs.push_back(memref); - resultMemrefs.push_back(memref); - } - auto newGeneric = rewriter.create( - op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs()); - newGeneric.region().getBlocks().clear(); - BlockAndValueMapping mapper; - op.region().cloneInto(&newGeneric.region(), mapper); - for (auto memref : resultMemrefs) { - newGeneric.region().front().addArgument( - memref.getType().cast().getElementType()); - } - auto newResultTensors = - llvm::to_vector<6>(llvm::map_range(resultMemrefs, [&](Value memref) { - return rewriter.create(op.getLoc(), memref) - .getResult(); - })); - rewriter.replaceOp(op, newResultTensors); - return success(); + // Only handle generic ops where all operands and results are tensors. + if (!llvm::all_of(op.getOperandTypes(), + [](Type type) { return type.isa(); })) { + return rewriter.notifyMatchFailure(op, "all operands must be tensors"); } + if (!llvm::all_of(op.getResultTypes(), + [](Type type) { return type.isa(); })) { + return rewriter.notifyMatchFailure(op, "all results must be tensors"); + } + + // TODO: Loosen restrictions on indexing maps. + // This will require more principled handling of shape reification + // earlier in the compilation stack, as in general output shapes of a + // linalg.generic cannot be inferred easily. + // See: + // https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866 + if (!llvm::all_of(op.indexing_maps(), [](Attribute map) { + return map.cast().getValue().isIdentity(); + })) { + return rewriter.notifyMatchFailure( + op, "all indexing maps must be identity maps"); + } + if (!llvm::all_of(op.iterator_types(), [](Attribute str) { + return str.cast().getValue() == + getParallelIteratorTypeName(); + })) { + return rewriter.notifyMatchFailure( + op, "all iterator types must be 'parallel'"); + } + + SmallVector memrefs; + SmallVector resultMemrefs; + SmallVector operandShapes; + for (auto tensor : op.getOperands()) { + auto shape = rewriter.create(op.getLoc(), tensor); + auto memref = allocMemRefForTensor(rewriter, tensor, shape, op.getLoc()); + rewriter.create(op.getLoc(), tensor, memref); + memrefs.push_back(memref); + operandShapes.push_back(shape); + } + auto shapeType = shape::ShapeType::get(rewriter.getContext()); + SmallVector shapeTypes(op.getNumResults(), shapeType); + // TODO: We need more principled handling of output shapes. + // This assumes that all results have the same shape, which is justified + // by checks above, but we really need a better story here. + SmallVector resultShapes(op.getNumResults(), operandShapes[0]); + for (auto t : llvm::zip(op.getResults(), resultShapes)) { + auto tensor = std::get<0>(t); + auto shape = std::get<1>(t); + auto memref = allocMemRefForTensor(rewriter, tensor, shape, op.getLoc()); + memrefs.push_back(memref); + resultMemrefs.push_back(memref); + } + auto newGeneric = rewriter.create( + op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs()); + newGeneric.region().getBlocks().clear(); + BlockAndValueMapping mapper; + op.region().cloneInto(&newGeneric.region(), mapper); + for (auto memref : resultMemrefs) { + newGeneric.region().front().addArgument( + memref.getType().cast().getElementType()); + } + auto newResultTensors = + llvm::to_vector<6>(llvm::map_range(resultMemrefs, [&](Value memref) { + return rewriter.create(op.getLoc(), memref).getResult(); + })); + rewriter.replaceOp(op, newResultTensors); + return success(); + } }; -} +} // namespace namespace { class LowerLinalgOnTensorToLinalgOnMemref diff --git a/lib/E2E/LowerToLLVM.cpp b/lib/E2E/LowerToLLVM.cpp index ff2671c7e..815fe4c45 100644 --- a/lib/E2E/LowerToLLVM.cpp +++ b/lib/E2E/LowerToLLVM.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "npcomp/E2E/E2E.h" #include "PassDetail.h" +#include "npcomp/E2E/E2E.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" diff --git a/pytest/Compiler/binary_expressions.py b/pytest/Compiler/binary_expressions.py index 3f8f4ed45..26a71336b 100644 --- a/pytest/Compiler/binary_expressions.py +++ b/pytest/Compiler/binary_expressions.py @@ -22,66 +22,77 @@ def add(): # CHECK: {{.*}} = basicpy.binary_expr %[[A]] "Add" %[[B]] : (i64, i64) -> !basicpy.UnknownType return a + b + # CHECK-LABEL: func @sub @import_global def sub(): # CHECK: basicpy.binary_expr {{.*}} "Sub" return 4 - 2 + # CHECK-LABEL: func @mult @import_global def mult(): # CHECK: basicpy.binary_expr {{.*}} "Mult" return 4 * 2 + # CHECK-LABEL: func @div @import_global def div(): # CHECK: basicpy.binary_expr {{.*}} "Div" return 4 / 2 + # CHECK-LABEL: func @floor_div @import_global def floor_div(): # CHECK: basicpy.binary_expr {{.*}} "FloorDiv" return 4 // 2 + # CHECK-LABEL: func @matmul @import_global def matmul(): # CHECK: basicpy.binary_expr {{.*}} "MatMult" return 4 @ 2 + # CHECK-LABEL: func @modulo @import_global def modulo(): # CHECK: basicpy.binary_expr {{.*}} "Mod" return 4 % 2 + # CHECK-LABEL: func @left_shift @import_global def left_shift(): # CHECK: basicpy.binary_expr {{.*}} "LShift" return 4 << 2 + # CHECK-LABEL: func @right_shift @import_global def right_shift(): # CHECK: basicpy.binary_expr {{.*}} "RShift" return 4 >> 2 + # CHECK-LABEL: func @bit_and @import_global def bit_and(): # CHECK: basicpy.binary_expr {{.*}} "BitAnd" return 4 & 2 + # CHECK-LABEL: func @bit_xor @import_global def bit_xor(): # CHECK: basicpy.binary_expr {{.*}} "BitXor" return 4 ^ 2 + # CHECK-LABEL: func @bit_or @import_global def bit_or(): diff --git a/pytest/Compiler/booleans.py b/pytest/Compiler/booleans.py index 6e330c922..c3e1b2c8d 100644 --- a/pytest/Compiler/booleans.py +++ b/pytest/Compiler/booleans.py @@ -38,6 +38,7 @@ def logical_and(): # CHECK: } return x and y and z + # CHECK-LABEL: func @logical_or @import_global def logical_or(): @@ -65,17 +66,19 @@ def logical_or(): z = 2 return x or y or z + # CHECK-LABEL: func @logical_not @import_global def logical_not(): # CHECK: %[[X:.*]] = constant 1 x = 1 - # CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true + # CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true # CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant false # CHECK-DAG: %[[CONDITION:.*]] = basicpy.to_boolean %[[X]] # CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType return not x + # CHECK-LABEL: func @conditional @import_global def conditional(): diff --git a/pytest/Compiler/comparisons.py b/pytest/Compiler/comparisons.py index 82f4985eb..3ea2c040f 100644 --- a/pytest/Compiler/comparisons.py +++ b/pytest/Compiler/comparisons.py @@ -21,6 +21,7 @@ def binary_lt_(): # CHECK: {{.*}} = basicpy.binary_compare %[[A]] "Lt" %[[B]] : i64, i64 return x < y + # CHECK-LABEL: func @binary_gt_ @import_global def binary_gt_(): @@ -29,6 +30,7 @@ def binary_gt_(): # CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Gt" {{.*}} : i64, i64 return x > y + # CHECK-LABEL: func @binary_lte_ @import_global def binary_lte_(): @@ -37,6 +39,7 @@ def binary_lte_(): # CHECK: {{.*}} = basicpy.binary_compare {{.*}} "LtE" {{.*}} : i64, i64 return x <= y + # CHECK-LABEL: func @binary_gte_ @import_global def binary_gte_(): @@ -45,6 +48,7 @@ def binary_gte_(): # CHECK: {{.*}} = basicpy.binary_compare {{.*}} "GtE" {{.*}} : i64, i64 return x >= y + # CHECK-LABEL: func @binary_eq_ @import_global def binary_eq_(): @@ -53,6 +57,7 @@ def binary_eq_(): # CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Eq" {{.*}} : i64, i64 return x == y + # CHECK-LABEL: func @binary_neq_ @import_global def binary_neq_(): @@ -61,6 +66,7 @@ def binary_neq_(): # CHECK: {{.*}} = basicpy.binary_compare {{.*}} "NotEq" {{.*}} : i64, i64 return x != y + # CHECK-LABEL: func @binary_is_ @import_global def binary_is_(): @@ -69,6 +75,7 @@ def binary_is_(): # CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Is" {{.*}} : i64, i64 return x is y + # CHECK-LABEL: func @binary_is_not_ @import_global def binary_is_not_(): @@ -77,6 +84,7 @@ def binary_is_not_(): # CHECK: {{.*}} = basicpy.binary_compare {{.*}} "IsNot" {{.*}} : i64, i64 return x is not y + # CHECK-LABEL: func @binary_in_ @import_global def binary_in_(): @@ -85,6 +93,7 @@ def binary_in_(): # CHECK: {{.*}} = basicpy.binary_compare {{.*}} "In" {{.*}} : i64, i64 return x in y + # CHECK-LABEL: func @binary_not_in_ @import_global def binary_not_in_(): @@ -93,6 +102,7 @@ def binary_not_in_(): # CHECK: {{.*}} = basicpy.binary_compare {{.*}} "NotIn" {{.*}} : i64, i64 return x not in y + @import_global def short_circuit(): # CHECK: %[[X:.*]] = constant 1 : i64 @@ -123,6 +133,7 @@ def short_circuit(): # CHECK: return %[[RESULT]] return x < y == z >= omega + # CHECK-LABEL: nested_short_circuit_expression @import_global def nested_short_circuit_expression(): diff --git a/pytest/Compiler/constants.py b/pytest/Compiler/constants.py index cc308fd1e..cddcf136a 100644 --- a/pytest/Compiler/constants.py +++ b/pytest/Compiler/constants.py @@ -20,6 +20,7 @@ def integer_constants(): # CHECK: return %[[A_CAST]] return a + # CHECK-LABEL: func @float_constants @import_global def float_constants(): @@ -29,6 +30,7 @@ def float_constants(): # CHECK: return %[[A_CAST]] return a + # CHECK-LABEL: func @bool_true_constant @import_global def bool_true_constant(): @@ -37,6 +39,7 @@ def bool_true_constant(): a = True return a + # CHECK-LABEL: func @bool_false_constant @import_global def bool_false_constant(): @@ -45,6 +48,7 @@ def bool_false_constant(): a = False return a + # CHECK-LABEL: func @string_constant @import_global def string_constant(): @@ -53,6 +57,7 @@ def string_constant(): a = "foobar" return a + # CHECK-LABEL: func @joined_string_constant @import_global def joined_string_constant(): @@ -61,6 +66,7 @@ def joined_string_constant(): a = "I am" " still here" return a + # CHECK-LABEL: func @bytes_constant @import_global def bytes_constant(): @@ -69,6 +75,7 @@ def bytes_constant(): a = b"foobar" return a + # CHECK-LABEL: func @ellipsis @import_global def ellipsis(): @@ -77,6 +84,7 @@ def ellipsis(): a = ... return a + # CHECK-LABEL: func @none_constant @import_global def none_constant(): diff --git a/pytest/Compiler/structure.py b/pytest/Compiler/structure.py index f09cb2ea3..090bddcac 100644 --- a/pytest/Compiler/structure.py +++ b/pytest/Compiler/structure.py @@ -10,6 +10,7 @@ def import_global(f): print(fe.ir_module.to_asm()) return f + # CHECK-LABEL: func @positional_args # CHECK-SAME: (%arg0: !basicpy.UnknownType, %arg1: !basicpy.UnknownType) -> !basicpy.UnknownType @import_global @@ -17,6 +18,7 @@ def positional_args(a, b): # CHECK: basicpy.binary_expr %arg0 "Add" %arg1 return a + b + # CHECK-LABEL: func @pass_no_return @import_global def pass_no_return(): diff --git a/pytest/Compiler/type_inference.py b/pytest/Compiler/type_inference.py index b75fdda51..57d5032ed 100644 --- a/pytest/Compiler/type_inference.py +++ b/pytest/Compiler/type_inference.py @@ -10,6 +10,7 @@ def import_global(f): print("// -----") return f + # CHECK-LABEL: func @arithmetic_expression # CHECK-SAME: () -> i64 @import_global @@ -21,6 +22,7 @@ def arithmetic_expression(): # CHECK: return{{.*}} : i64 return 1 + 2 - 3 * 4 + # CHECK-LABEL: func @arg_inference # CHECK-SAME: (%arg0: i64, %arg1: i64) -> i64 @import_global @@ -31,9 +33,10 @@ def arg_inference(a, b): # CHECK: return{{.*}} : i64 return a + 2 * b + # CHECK-LABEL: func @conditional_inference # CHECK-SAME: (%arg0: i64, %arg1: !basicpy.BoolType, %arg2: i64) -> !basicpy.BoolType @import_global def conditional_inference(cond, a, b): # CHECK-NOT: UnknownType - return a if cond + 1 else not(b * 4) + return a if cond + 1 else not (b * 4) diff --git a/pytest/lit.cfg.py b/pytest/lit.cfg.py index dc7386d43..e07c6e620 100644 --- a/pytest/lit.cfg.py +++ b/pytest/lit.cfg.py @@ -39,15 +39,13 @@ llvm_config.with_system_environment( llvm_config.use_default_substitutions() -# excludes: A list of files/directories to exclude from the testsuite. The -# 'Inputs'subdirectories contain auxiliary inputs for various tests in their +# excludes: A list of files/directories to exclude from the testsuite. The +# 'Inputs'subdirectories contain auxiliary inputs for various tests in their # parent directories. config.excludes = [ - 'Inputs', 'Examples', - 'lit.cfg.py', - 'CMakeLists.txt', - 'README.txt', - 'LICENSE.txt'] + 'Inputs', 'Examples', 'lit.cfg.py', 'CMakeLists.txt', 'README.txt', + 'LICENSE.txt' +] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) @@ -56,22 +54,21 @@ config.test_source_root = os.path.dirname(__file__) config.test_exec_root = os.path.join(config.npcomp_obj_root, 'pytest') config.npcomp_tools_dir = os.path.join(config.npcomp_obj_root, 'tools') config.npcomp_runtime_shlib = os.path.join( - config.npcomp_obj_root, - 'runtime', - 'libNPCOMPRuntime' + config.llvm_shlib_ext -) + config.npcomp_obj_root, 'runtime', + 'libNPCOMPRuntime' + config.llvm_shlib_ext) # Tweak the PATH and PYTHONPATH to include the tools dir. llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) llvm_config.with_environment('PYTHONPATH', [ os.path.join(config.npcomp_obj_root, "python"), - os.path.join(config.npcomp_obj_root, "python_native")], - append_path=True) + os.path.join(config.npcomp_obj_root, "python_native") +], + append_path=True) tool_dirs = [ - os.path.join(config.npcomp_tools_dir, 'npcomp-opt'), - os.path.join(config.npcomp_tools_dir, 'npcomp-run-mlir'), - config.llvm_tools_dir, + os.path.join(config.npcomp_tools_dir, 'npcomp-opt'), + os.path.join(config.npcomp_tools_dir, 'npcomp-run-mlir'), + config.llvm_tools_dir, ] tools = [ 'npcomp-opt', diff --git a/python/npcomp/compiler/__init__.py b/python/npcomp/compiler/__init__.py index 84a0f9eef..6f19c9f0a 100644 --- a/python/npcomp/compiler/__init__.py +++ b/python/npcomp/compiler/__init__.py @@ -1,4 +1,3 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - diff --git a/python/npcomp/decorators.py b/python/npcomp/decorators.py index 84a0f9eef..6f19c9f0a 100644 --- a/python/npcomp/decorators.py +++ b/python/npcomp/decorators.py @@ -1,4 +1,3 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - diff --git a/python/npcomp/dialect/Numpy.py b/python/npcomp/dialect/Numpy.py index 0d7fbd985..69f936088 100644 --- a/python/npcomp/dialect/Numpy.py +++ b/python/npcomp/dialect/Numpy.py @@ -65,6 +65,7 @@ class DialectHelper(Basicpy.DialectHelper): tensor<*x!numpy.any_dtype> """ + @property def numpy_any_dtype(self): return self.context.parse_type("!numpy.any_dtype") diff --git a/python/npcomp/exporter.py b/python/npcomp/exporter.py index ba08cff41..2d5921c8f 100644 --- a/python/npcomp/exporter.py +++ b/python/npcomp/exporter.py @@ -9,9 +9,9 @@ from typing import Optional from npcomp.types import * __all__ = [ - "Exporter", - "ExportFunction", - "ExportPyFunction", + "Exporter", + "ExportFunction", + "ExportPyFunction", ] @@ -21,40 +21,41 @@ def _value_type_from_annotation(annotation): return ValueType(TypeClass.NdArray) else: return ValueType() - - + + def _signature_from_pyfunc(pyfunc): pysig = inspect.signature(pyfunc) sig = Signature(len(pysig.parameters)) # Arguments for i, param in enumerate(pysig.parameters.values()): - if param.kind not in ( - param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): + if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): raise ValueError( - "Currently only positional function signature are supported") - + "Currently only positional function signature are supported") + sig.arg_names[i] = param.name annot = param.annotation - if annot is param.empty: continue + if annot is param.empty: + continue sig.args[i] = _value_type_from_annotation(annot) - + # Result if pysig.return_annotation is not pysig.empty: sig.result = _value_type_from_annotation(pysig.return_annotation) - + return sig - + class ExportFunction: """Base class for functions that can be exported.""" __slots__ = ["_sig"] + def __init__(self, sig=None): self._sig = sig if sig else Signature() - + @property def sig(self): return self._sig - + def __repr__(self): return "def %r" % self._sig @@ -86,21 +87,21 @@ class ExportPyFunction(ExportFunction): pydef mul(a: NdArray[Rank(2)], b: Any) -> NdArray[Shape(1, 2)] """ __slots__ = ExportFunction.__slots__ + ["_pyfunc", "__name__"] - + def __init__(self, pyfunc, name=None): super().__init__(sig=_signature_from_pyfunc(pyfunc)) - assert (hasattr(pyfunc, "__call__") - and hasattr(pyfunc, "__name__")), "Not a python function" + assert (hasattr(pyfunc, "__call__") and + hasattr(pyfunc, "__name__")), "Not a python function" self._pyfunc = pyfunc self.__name__ = name if name else pyfunc.__name__ @property def pyfunc(self): return self._pyfunc - + def __repr__(self): return "pydef %s%r" % (self.__name__, self._sig) - + def __call__(self, *args, **kwargs): return self._pyfunc(*args, **kwargs) @@ -108,53 +109,59 @@ class ExportPyFunction(ExportFunction): class _ExpandoNode: """Expando object that can be indexed into to construct a namespace.""" __slots__ = [ - "_parent", "_services", "_local_name", "_parent_name", - "_children", "_attached"] - def __init__(self, parent: Optional["_ExpandoNode"], - services: "_Services", - local_name: str): + "_parent", "_services", "_local_name", "_parent_name", "_children", + "_attached" + ] + + def __init__(self, parent: Optional["_ExpandoNode"], services: "_Services", + local_name: str): super().__init__() object.__setattr__(self, "_parent", parent) object.__setattr__(self, "_services", services) object.__setattr__(self, "_local_name", local_name) - object.__setattr__(self, "_parent_name", + object.__setattr__(self, "_parent_name", parent._get_full_name() if parent else "") object.__setattr__(self, "_children", {}) object.__setattr__(self, "_attached", parent is None) - + def _attach(self): - if self._attached: return + if self._attached: + return if self._local_name in self._parent._children: raise KeyError("Cannot re-assign '%s'" % (self._get_full_name(),)) self._parent._attach() self._parent._children[self._local_name] = self object.__setattr__(self, "_attached", True) - + def _get_full_name(self): - if not self._parent: return "" # Root is always empty name. - full_name = (self._parent_name + "." + self._local_name - if self._parent_name else self._local_name) + if not self._parent: + return "" # Root is always empty name. + full_name = (self._parent_name + "." + + self._local_name if self._parent_name else self._local_name) return full_name - + def _get_child_name(self, child_local_name): full_name = self._get_full_name() - if not full_name: return child_local_name - else: return full_name + "." + child_local_name - + if not full_name: + return child_local_name + else: + return full_name + "." + child_local_name + def __repr__(self): return "Namespace(\"%s\")" % (self._get_full_name()) - + def __contains__(self, key): return key in self._children - + def __getitem__(self, key): key = str(key) existing = self._children.get(key) - if existing is not None: return existing + if existing is not None: + return existing # Speculatively create a child expando. child = _ExpandoNode(self, self._services, key) return child - + def __setitem__(self, key, value): if not inspect.isfunction(value): raise TypeError("Cannot assign value to an exporter: %r" % (value,)) @@ -164,19 +171,19 @@ class _ExpandoNode: raise KeyError("Cannot re-assign '%s'" % (child_name)) self._attach() self._children[key] = self._services.wrap_function(value, child_name) - + def __getattr__(self, name): return self[name] - + def __setattr__(self, name, value): try: self[name] = value except KeyError as e: raise AttributeError(str(e)) from None - + def __dir__(self): return self._children.keys() - + class _Services: """Services and support for the Exporter. @@ -184,12 +191,14 @@ class _Services: Exporters are user objects, so most of the functional components are contained in the associated _Services object. """ + def wrap_function(self, f, full_name): - if isinstance(f, ExportFunction): return f + if isinstance(f, ExportFunction): + return f # TODO: Need to scan through providers and choose. return ExportPyFunction(f, name=full_name) - - + + class Exporter: """Top-level UI object for assembling a program for export. @@ -244,31 +253,32 @@ class Exporter: AttributeError: "Cannot re-assign 'ns1'" """ __slots__ = ["_root", "_services"] + def __init__(self): super().__init__() services = _Services() object.__setattr__(self, "_root", _ExpandoNode(None, services, "")) object.__setattr__(self, "_services", services) - + def __repr__(self): return "Exporter()" - + def __contains__(self, key): return key in self._root - + def __getitem__(self, key): return self._root[key] - + def __setitem__(self, key, value): self._root[key] = value - + def __getattr__(self, name): return getattr(self._root, name) - + def __setattr__(self, name, value): setattr(self._root, name, value) if __name__ == "__main__": - import doctest - doctest.testmod() + import doctest + doctest.testmod() diff --git a/python/npcomp/mlir_ir_test.py b/python/npcomp/mlir_ir_test.py index e84c221ee..5293988c4 100644 --- a/python/npcomp/mlir_ir_test.py +++ b/python/npcomp/mlir_ir_test.py @@ -1,7 +1,6 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - """Test for the MLIR IR Python bindings""" from _npcomp.mlir import ir @@ -39,5 +38,4 @@ except ValueError as e: # CHECK: [ERROR]: expected operation name in quotes print(e) - test_utils.end_filecheck_test(__file__) diff --git a/python/npcomp/mlir_pass_test.py b/python/npcomp/mlir_pass_test.py index b3b50e88e..b9c8b2059 100644 --- a/python/npcomp/mlir_pass_test.py +++ b/python/npcomp/mlir_pass_test.py @@ -1,7 +1,6 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - """Test for the MLIR Pass Python bindings""" from _npcomp.mlir import ir diff --git a/python/npcomp/tracing/context.py b/python/npcomp/tracing/context.py index 05bf56df4..2a0bd05f5 100644 --- a/python/npcomp/tracing/context.py +++ b/python/npcomp/tracing/context.py @@ -48,10 +48,11 @@ class TraceContext: """ _local = threading.local() __slots__ = [ - "_desc", - "_next_id", - "active", + "_desc", + "_next_id", + "active", ] + def __init__(self, desc=None): _check_numpy_version() self._desc = desc @@ -87,11 +88,11 @@ class TraceContext: @classmethod def optional_current(cls) -> Optional["TraceContext"]: s = cls._get_context_stack() - if s: + if s: return s[-1] else: return None - + @classmethod def current(cls) -> "TraceContext": c = cls.optional_current() @@ -120,7 +121,7 @@ class TraceContext: def _assert_active(tc: TraceContext): assert tc.active, ( - "Attempt to trace an action on an inactive trace context: %r" % tc) + "Attempt to trace an action on an inactive trace context: %r" % tc) class TracedArray(np.lib.mixins.NDArrayOperatorsMixin): @@ -134,6 +135,7 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin): >>> TracedArray(tc=tc) """ + def __init__(self, tc: Optional[TraceContext] = None): self._tc = tc if tc is not None else TraceContext.current() self._uid = self._tc.get_next_id() @@ -160,28 +162,28 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin): def __array_function__(self, func, types, args, kwargs): tc = self._tc - _assert_active(tc) + _assert_active(tc) return tc._handle_array_func(func, types, args, kwargs) @property def T(self): """Shortcut for transpose.""" tc = self._tc - _assert_active(tc) + _assert_active(tc) return tc._handle_array_func(np.transpose, [TracedArray], [self], {}) - + def _check_numpy_version(): version = np.lib.NumpyVersion(np.__version__) if version < "1.16.0": raise RuntimeError("Numpy version >= 1.16 is required") - if version > "1.17.0": return + if version > "1.17.0": + return if os.environ.get("NUMPY_EXPERIMENTAL_ARRAY_FUNCTION") != "1": - raise RuntimeError( - "For numpy 1.16, the environment variable " - "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION must equal 1") + raise RuntimeError("For numpy 1.16, the environment variable " + "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION must equal 1") if __name__ == "__main__": - import doctest - doctest.testmod() + import doctest + doctest.testmod() diff --git a/python/npcomp/tracing/mlir_trace_test.py b/python/npcomp/tracing/mlir_trace_test.py index 3accb73ea..5e253006a 100644 --- a/python/npcomp/tracing/mlir_trace_test.py +++ b/python/npcomp/tracing/mlir_trace_test.py @@ -9,9 +9,11 @@ from npcomp.utils import test_utils test_utils.start_filecheck_test() + def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray: return a * b + a + b + # TODO: Implement subclassing and deriving constraints by run exp = Exporter() exp.simple_mul = simple_mul @@ -22,7 +24,7 @@ exp.simple_mul.sig.args["b"] += Shape(1) exp.simple_mul.sig.args["b"] += DType(np.float32) exp.simple_mul.sig.result += Shape(1, 4) exp.simple_mul.sig.result += DynamicDim(0) -exp.simple_mul.sig.result += DType(np.float32) +exp.simple_mul.sig.result += DType(np.float32) mb = ModuleBuilder() mb.trace(exp.simple_mul) diff --git a/python/npcomp/types.py b/python/npcomp/types.py index db592b134..569525272 100644 --- a/python/npcomp/types.py +++ b/python/npcomp/types.py @@ -8,30 +8,30 @@ from enum import Enum import numpy as np __all__ = [ - "Unspec", - "ArrayConstraint", - "ArrayParams", - "DType", - "DimFlag", - "DimFlagEnum", - "DynamicDim", - "Rank", - "Shape", - "Signature", - "TypeClass", - "TypeConstraints", - "ValueType", + "Unspec", + "ArrayConstraint", + "ArrayParams", + "DType", + "DimFlag", + "DimFlagEnum", + "DynamicDim", + "Rank", + "Shape", + "Signature", + "TypeClass", + "TypeConstraints", + "ValueType", ] # TODO: All supported types _DTYPE_TO_ASM_DICT = { - np.bool: "i1", # TODO: May need a custom type to signify 8bit storage - np.int8: "s8", - np.int16: "s16", - np.int32: "s32", - np.int64: "s64", - np.float32: "f32", - np.float64: "f64", + np.bool: "i1", # TODO: May need a custom type to signify 8bit storage + np.int8: "s8", + np.int16: "s16", + np.int32: "s32", + np.int64: "s64", + np.float32: "f32", + np.float64: "f64", } @@ -67,29 +67,39 @@ class _LiterateEnum(Enum): ValueError: Cannot parse SampleEnum 1.0 """ + @classmethod def parse(cls, v): - if isinstance(v, cls): return v + if isinstance(v, cls): + return v if not v or not isinstance(v, str) or v[0] == '_' or not hasattr(cls, v): raise ValueError("Cannot parse %s %r" % ( - cls.__name__.split(".")[-1], v,)) + cls.__name__.split(".")[-1], + v, + )) value = getattr(cls, v) if not isinstance(value, cls): raise ValueError("Cannot parse %s %r" % ( - cls.__name__.split(".")[-1], v,)) + cls.__name__.split(".")[-1], + v, + )) return value - + def __repr__(self): return self.name - - + + # Special "unspecified" value that we use throughout. class _Unspec: __slots__ = [] + def __str__(self): return "Unspec" + def __repr__(self): return "Unspec" + + Unspec = _Unspec() @@ -97,8 +107,8 @@ class TypeClass(_LiterateEnum): """Top level types in the npcomp language.""" Any = 0 NdArray = 1 - - + + class ValueType: """The type a value can take in the npcomp language. @@ -123,32 +133,32 @@ class ValueType: NdArray """ __slots__ = ["_constraints", "_type_class"] - + def __init__(self, type_class=TypeClass.Any, *constraints): super().__init__() self._type_class = TypeClass.parse(type_class) self._constraints = TypeConstraints(constraints) def __iadd__(self, constraint): - assert isinstance(constraint, TypeConstraint), ( - "Can only add constraints to a ValueType") + assert isinstance( + constraint, TypeConstraint), ("Can only add constraints to a ValueType") self._constraints.append(constraint) return self - + def __repr__(self): if not self._constraints: return repr(self._type_class) - return "%r[%s]" % (self._type_class, - ", ".join([repr(c) for c in self._constraints])) - + return "%r[%s]" % (self._type_class, ", ".join( + [repr(c) for c in self._constraints])) + @property def type_class(self): return self._type_class - + @type_class.setter def type_class(self, type_class): self._type_class = TypeClass.parse(type_class) - + @property def constraints(self): return self._constraints @@ -179,6 +189,7 @@ class ValueTypeList: (Any, Any, Any) """ __slots__ = ["_list", "_names"] + def __init__(self, arity=0, names=None): self._list = [ValueType() for _ in range(arity)] self._names = names @@ -188,18 +199,19 @@ class ValueTypeList: # Scan for the index. if self._names: for i, n in enumerate(self._names): - if n == key: return i + if n == key: + return i raise KeyError("Unknown key '%s'" % key) return key - + def __getitem__(self, key): return self._list[self._key_to_index(key)] - + def __setitem__(self, key, value): if not isinstance(value, ValueType): value = ValueType(value) self._list[self._key_to_index(key)] = value - + def __iter__(self): return self._list.__iter__() @@ -233,20 +245,21 @@ class Signature: (a: Any, b: NdArray[Rank(2)]) -> NdArray[Rank(3)] """ __slots__ = ["_args", "_arg_names", "_result"] + def __init__(self, arity=0): super().__init__() self._result = ValueType() self._arg_names = [None] * arity self._args = ValueTypeList(arity, names=self._arg_names) - + @property def args(self): return self._args - + @property def arg_names(self): return self._arg_names - + @property def result(self): return self._result @@ -256,14 +269,14 @@ class Signature: if not isinstance(value, ValueType): value = ValueType(value) self._result = value - + def __repr__(self): - args_repr = "(%s)" % ( - ", ".join( + args_repr = "(%s)" % (", ".join( ((n + ": " + repr(t)) if n else repr(t)) for t, n in zip(self._args, self._arg_names)),) return "%s -> %r" % (args_repr, self._result) + class ArrayParams: """Represents parameters defining how to construct an array. @@ -277,7 +290,7 @@ class ArrayParams: ArrayParams(dtype=float32, shape=(1, 2, 3)) """ __slots__ = ["dtype", "shape"] - + def __init__(self, dtype=Unspec, shape=Unspec, rank=Unspec): self.dtype = dtype if shape is not Unspec: @@ -289,9 +302,10 @@ class ArrayParams: @property def rank(self): - if self.shape is Unspec: return Unspec + if self.shape is Unspec: + return Unspec return len(self.shape) - + @classmethod def from_constraints(cls, constraints): """Constructs params for a TypeConstraints list. @@ -341,42 +355,40 @@ class ArrayParams: shape_c = constraints.one_of(Shape) rank_c = constraints.one_of(Rank) dim_flags = constraints.all_of(DimFlag) - + dtype = dtype_c.dtype if dtype_c else Unspec shape = Unspec - + # Compute shape if shape_c: # TODO: Should be in canonicalizer if rank_c and rank_c.rank != len(shape_c.dims): - raise ValueError("Conflicting shape and rank: %r vs %r" % ( - rank_c, shape_c)) + raise ValueError("Conflicting shape and rank: %r vs %r" % + (rank_c, shape_c)) shape = list(shape_c.dims) elif rank_c: shape = [-1 for _ in range(rank_c.rank)] - + # Apply dim flags if shape is not Unspec and dim_flags: for df in dim_flags: flag, for_dims = df.dim_flag for d in for_dims: if d < 0 or d >= len(shape): - raise ValueError("Out of range %r for shape %r" % ( - df, shape)) + raise ValueError("Out of range %r for shape %r" % (df, shape)) if flag == DimFlagEnum.Dynamic: shape[d] = -1 - + return cls(dtype=dtype, shape=shape) - - + def __repr__(self): try: - s = "ArrayParams(dtype=%s" % ( - self.dtype.__name__ if isinstance(self.dtype, type) else self.dtype,) + s = "ArrayParams(dtype=%s" % (self.dtype.__name__ if isinstance( + self.dtype, type) else self.dtype,) if self.shape is not Unspec: s += ", shape=%r" % (tuple(self.shape),) s += ")" - return s + return s except: return "ArrayParams(ERROR)" @@ -400,7 +412,7 @@ class ArrayParams: if any(d < 0 for d in self.shape): return False return True - + @property def mlir_tensor_type_asm(self): """Get a corresponding MLIR tensor type. @@ -436,15 +448,16 @@ class ArrayParams: else: dtype_asm = _dtype_to_mlir_asm(self.dtype) if not dtype_asm: - raise ValueError( - "Unsupported MLIR tensor element type %r" % (self.dtype,)) + raise ValueError("Unsupported MLIR tensor element type %r" % + (self.dtype,)) if self.shape is Unspec: shape_asm = "*" else: shape_asm = "x".join((str(d) if d >= 0 else "?") for d in self.shape) - if shape_asm: shape_asm += "x" + if shape_asm: + shape_asm += "x" return "tensor<%s%s>" % (shape_asm, dtype_asm) - + def new_ndarray(self): """Creates a new ndarray from these params. @@ -458,7 +471,7 @@ class ArrayParams: if not self.is_concrete: raise ValueError("%r is not concrete" % (self,)) return np.ndarray(dtype=self.dtype, shape=self.shape) - + class TypeConstraint: """Base class for type constraints.""" @@ -481,57 +494,59 @@ class TypeConstraints(list): ... AssertionError """ + def __init__(self, *constraints): - if len(constraints) == 1 and not isinstance( - constraints[0], ArrayConstraint): + if len(constraints) == 1 and not isinstance(constraints[0], + ArrayConstraint): constraints = constraints[0] super().__init__(constraints) - assert(all(isinstance(c, ArrayConstraint) for c in self)) - + assert (all(isinstance(c, ArrayConstraint) for c in self)) + def __repr__(self): - return "TypeConstraints(%s)" % ( - ", ".join([repr(c) for c in self])) + return "TypeConstraints(%s)" % (", ".join([repr(c) for c in self])) def all_of(self, clazz): """Finds all of the given class.""" return [c for c in self if isinstance(c, clazz)] - + def one_of(self, clazz): """Finds at most one constraint of the given class.""" found = [c for c in self if isinstance(c, clazz)] - if not found: return None + if not found: + return None if len(found) > 1: - raise ValueError("Conflicting constraints. Expected one of %r. Got %r" % ( - clazz, found)) + raise ValueError("Conflicting constraints. Expected one of %r. Got %r" % + (clazz, found)) return found[0] class ArrayConstraint(TypeConstraint): """Base class for a constraint on an array's characteristics.""" + def implies_dtype(self): return False - + @property def dtype(self): raise NotImplementedError() - + def implies_rank(self): return False @property def rank(self): raise NotImplementedError() - + def implies_dims(self): return False @property def dims(self): raise NotImplementedError() - + def implies_dim_flag(self): return False - + @property def dim_flag(self): raise NotImplementedError() @@ -550,22 +565,22 @@ class DType(ArrayConstraint): AssertionError """ __slots__ = ["_dtype"] - + def __init__(self, dtype): super().__init__() assert isinstance(dtype, type) self._dtype = dtype - + @property def dtype(self): return self._dtype - + def implies_dtype(self): return True - + def __repr__(self): return "DType(%s)" % (self._dtype.__name__,) - + class Rank(ArrayConstraint): """Establishes a fixed rank for the array. @@ -585,21 +600,21 @@ class Rank(ArrayConstraint): """ __slots__ = ["_rank"] - + def __init__(self, rank): super().__init__() - assert(isinstance(rank, int) and rank >= 0) + assert (isinstance(rank, int) and rank >= 0) self._rank = rank @property def rank(self): return self._rank - + def implies_rank(self): return True - + def __repr__(self): - return "Rank(%d)" % (self._rank) + return "Rank(%d)" % (self._rank) class Shape(ArrayConstraint): @@ -619,29 +634,29 @@ class Shape(ArrayConstraint): AssertionError """ __slots__ = ["_dims"] - + def __init__(self, *dims): super().__init__() - assert(all(d is Unspec or (isinstance(d, int) and d >= 0) for d in dims)) + assert (all(d is Unspec or (isinstance(d, int) and d >= 0) for d in dims)) self._dims = tuple(dims) @property def dims(self): return self._dims - + def implies_dims(self): return True @property def rank(self): return len(self._dims) - + def implies_rank(self): return True - + def __repr__(self): return "Shape(%s)" % (", ".join(str(d) for d in self._dims)) - + class DimFlagEnum(_LiterateEnum): """Flag for the kind of DimFlag constraint.""" @@ -661,35 +676,35 @@ class DimFlag(ArrayConstraint): DimFlag(Dynamic, (0, 1)) """ __slots__ = ["_flag", "_dims"] - + def __init__(self, flag, dims=Unspec): super().__init__() self._flag = DimFlagEnum.parse(flag) if isinstance(dims, int): - assert(dims >= 0) + assert (dims >= 0) self._dims = (dims,) elif dims is Unspec: self._dims = Unspec else: self._dims = tuple(dims) - assert(all(isinstance(d, int) and d >= 0 for d in self._dims)) + assert (all(isinstance(d, int) and d >= 0 for d in self._dims)) def implies_dim_flag(self): return False - + @property def dim_flag(self): return self._flag, self._dims def __repr__(self): return "DimFlag(%r, %r)" % (self._flag, self._dims) - - + + def DynamicDim(dims=Unspec): """Dim flag that signals a dimension should be considered dynamic.""" return DimFlag(DimFlagEnum.Dynamic, dims) - - + + if __name__ == "__main__": import doctest doctest.testmod() diff --git a/python/npcomp/utils/test_utils.py b/python/npcomp/utils/test_utils.py index 0422bc597..206e264bc 100644 --- a/python/npcomp/utils/test_utils.py +++ b/python/npcomp/utils/test_utils.py @@ -13,14 +13,17 @@ _filecheck_binary_var = "FILECHECK_BINARY" _redirect_io = None _redirect_context = None + def is_filecheck_disabled(): return _disable_var in os.environ def start_filecheck_test(): if is_filecheck_disabled(): - print("WARNING:FileCheck disabled due to", _disable_var, - "in the environment", file=sys.stderr) + print("WARNING:FileCheck disabled due to", + _disable_var, + "in the environment", + file=sys.stderr) return global _redirect_io global _redirect_context @@ -30,7 +33,8 @@ def start_filecheck_test(): def end_filecheck_test(main_file): - if is_filecheck_disabled(): return + if is_filecheck_disabled(): + return global _redirect_io global _redirect_context _redirect_context.__exit__(None, None, None) @@ -44,7 +48,7 @@ def end_filecheck_test(main_file): filecheck_args = [filecheck_binary, main_file, "--dump-input=fail"] p = subprocess.Popen(filecheck_args, stdin=subprocess.PIPE) p.communicate(filecheck_input.encode("UTF-8")) - sys.exit(p.returncode) + sys.exit(p.returncode) def run_under_filecheck(main_file, callback, disable_filecheck=False): @@ -60,8 +64,10 @@ def run_under_filecheck(main_file, callback, disable_filecheck=False): disable_filecheck: Whether to disable filecheck. """ if disable_filecheck or is_filecheck_disabled(): - print("WARNING:FileCheck disabled due to", _disable_var, - "in the environment", file=sys.stderr) + print("WARNING:FileCheck disabled due to", + _disable_var, + "in the environment", + file=sys.stderr) callback() sys.exit(0) diff --git a/python/run_tests.py b/python/run_tests.py index 6e178d25f..4c26c588f 100755 --- a/python/run_tests.py +++ b/python/run_tests.py @@ -4,26 +4,26 @@ import os import subprocess import sys - TEST_MODULES = ( - "npcomp.mlir_ir_test", - "npcomp.mlir_pass_test", - "npcomp.dialect.Basicpy", - "npcomp.dialect.Numpy", - "npcomp.tracing.context", - "npcomp.tracing.mlir_trace", - "npcomp.types", - "npcomp.exporter", - "npcomp.tracing.mlir_trace_test", + "npcomp.mlir_ir_test", + "npcomp.mlir_pass_test", + "npcomp.dialect.Basicpy", + "npcomp.dialect.Numpy", + "npcomp.tracing.context", + "npcomp.tracing.mlir_trace", + "npcomp.types", + "npcomp.exporter", + "npcomp.tracing.mlir_trace_test", ) # Compute PYTHONPATH for sub processes. DIRSEP = os.path.pathsep LOCAL_PYTHONPATH_COMPONENTS = [ - # This directory. - os.path.abspath(os.path.dirname(__file__)), - # The parallel python_native directory (assuming in the build tree). - os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "python_native")) + # This directory. + os.path.abspath(os.path.dirname(__file__)), + # The parallel python_native directory (assuming in the build tree). + os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "python_native")) ] PYTHONPATH = DIRSEP.join(LOCAL_PYTHONPATH_COMPONENTS) if "PYTHONPATH" in os.environ: @@ -33,9 +33,8 @@ CHILD_ENVIRON["PYTHONPATH"] = PYTHONPATH # Configure filecheck. FILECHECK_BINARY = os.path.abspath( - os.path.join( - os.path.dirname(__file__), - "..", "..", "..", "bin", "FileCheck")) + os.path.join(os.path.dirname(__file__), "..", "..", "..", "bin", + "FileCheck")) if os.path.exists(FILECHECK_BINARY): CHILD_ENVIRON["FILECHECK_BINARY"] = FILECHECK_BINARY else: @@ -47,14 +46,14 @@ failed = [] for test_module in TEST_MODULES: print("--------====== RUNNING %s ======--------" % test_module) try: - subprocess.check_call([sys.executable, "-Wignore", "-m", test_module], + subprocess.check_call([sys.executable, "-Wignore", "-m", test_module], env=CHILD_ENVIRON) print("--------====== DONE %s ======--------\n" % test_module) passed.append(test_module) except subprocess.CalledProcessError: print("!!!!!!!!====== ERROR %s ======!!!!!!!!\n" % test_module) failed.append(test_module) - + print("Done: %d passed, %d failed" % (len(passed), len(failed))) if failed: for test_module in failed: diff --git a/python/samples/ast_extraction.py b/python/samples/ast_extraction.py index 133d3307f..c3ba763ce 100644 --- a/python/samples/ast_extraction.py +++ b/python/samples/ast_extraction.py @@ -4,6 +4,7 @@ from npcomp.compiler.frontend import * + def binary_expression(): a = 1 b = 100 @@ -11,6 +12,7 @@ def binary_expression(): c = c * 2.0 return c + fe = ImportFrontend() try: f = fe.import_global_function(binary_expression) diff --git a/python/samples/const.py b/python/samples/const.py index fa2256636..e75a10b3c 100644 --- a/python/samples/const.py +++ b/python/samples/const.py @@ -9,9 +9,11 @@ from npcomp.types import * weights = np.random.uniform(size=(16, 4)).astype(np.float32) bias = np.random.uniform(size=(4,)).astype(np.float32) + def constants(a: np.ndarray) -> np.ndarray: return np.dot(a, weights) + bias + # TODO: Implement subclassing and deriving constraints by run exp = npc.Exporter() exp.constants = constants diff --git a/python/samples/dot.py b/python/samples/dot.py index a1ee98d6c..72979aa43 100644 --- a/python/samples/dot.py +++ b/python/samples/dot.py @@ -6,20 +6,22 @@ import numpy as np import npcomp as npc from npcomp.types import * + def dot2d(a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.dot(a, b) + # TODO: Implement subclassing and deriving constraints by run exp = npc.Exporter() exp.dot2d = dot2d exp.dot2d.sig.args["a"] += Shape(4, 16) exp.dot2d.sig.args["a"] += DynamicDim(0) exp.dot2d.sig.args["a"] += DType(np.float32) -exp.dot2d.sig.args["b"] += Shape(16,32) +exp.dot2d.sig.args["b"] += Shape(16, 32) exp.dot2d.sig.args["b"] += DType(np.float32) exp.dot2d.sig.result += Shape(4, 32) exp.dot2d.sig.result += DynamicDim(0) -exp.dot2d.sig.result += DType(np.float32) +exp.dot2d.sig.result += DType(np.float32) mb = npc.tracing.ModuleBuilder() mb.trace(exp.dot2d) diff --git a/python/samples/simple_ufunc.py b/python/samples/simple_ufunc.py index 8c3db3347..a3341f9d7 100644 --- a/python/samples/simple_ufunc.py +++ b/python/samples/simple_ufunc.py @@ -6,9 +6,11 @@ import numpy as np import npcomp as npc from npcomp.types import * + def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray: return a * b + a + b + # TODO: Implement subclassing and deriving constraints by run exp = npc.Exporter() exp.simple_mul = simple_mul @@ -19,7 +21,7 @@ exp.simple_mul.sig.args["b"] += Shape(1) exp.simple_mul.sig.args["b"] += DType(np.float32) exp.simple_mul.sig.result += Shape(1, 4) exp.simple_mul.sig.result += DynamicDim(0) -exp.simple_mul.sig.result += DType(np.float32) +exp.simple_mul.sig.result += DType(np.float32) mb = npc.tracing.ModuleBuilder() mb.trace(exp.simple_mul) diff --git a/python/samples/slice.py b/python/samples/slice.py index 39dd0e840..96238e3aa 100644 --- a/python/samples/slice.py +++ b/python/samples/slice.py @@ -6,9 +6,11 @@ import numpy as np import npcomp as npc from npcomp.types import * + def slice_array1(a: np.ndarray) -> np.ndarray: return a[1, 2:10:2, 3:4, ..., :, 0] + # TODO: Implement subclassing and deriving constraints by run exp = npc.Exporter() exp.slice_array1 = slice_array1 diff --git a/python/samples/transpose.py b/python/samples/transpose.py index e17dda5c5..4e0bf1cd7 100644 --- a/python/samples/transpose.py +++ b/python/samples/transpose.py @@ -6,12 +6,15 @@ import numpy as np import npcomp as npc from npcomp.types import * + def transpose_attribute(a: np.ndarray) -> np.ndarray: return a.T + def transpose(a: np.ndarray) -> np.ndarray: return np.transpose(a) + # TODO: Implement subclassing and deriving constraints by run exp = npc.Exporter() exp.transpose_attribute = transpose_attribute diff --git a/python_native/MlirIr.h b/python_native/MlirIr.h index 37925d816..3a6b663b6 100644 --- a/python_native/MlirIr.h +++ b/python_native/MlirIr.h @@ -187,11 +187,10 @@ private: class PyDialectHelper { public: PyDialectHelper(PyContext &context, PyOpBuilder &builder) - : context(context), pyOpBuilder(builder) {} + : context(context), pyOpBuilder(builder) {} static void bind(py::module m); - MLIRContext *getContext() { - return pyOpBuilder.getContext(); - } + MLIRContext *getContext() { return pyOpBuilder.getContext(); } + protected: PyContext &context; PyOpBuilder &pyOpBuilder; diff --git a/tools/format_sources.sh b/tools/format_sources.sh new file mode 100755 index 000000000..2d3cbef06 --- /dev/null +++ b/tools/format_sources.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Formats all source files. + +set +e +td="$(dirname $0)/.." + +function find_cc_sources() { + local dir="$1" + find "$dir" -name "*.h" + find "$dir" -name "*.cpp" +} +# C/C++ sources. +set -o xtrace +clang-format -i \ + $(find_cc_sources include) \ + $(find_cc_sources lib) \ + $(find_cc_sources python_native) + +# Python sources. +yapf --recursive -i "$td/python" "$td/pytest"