mirror of https://github.com/llvm/torch-mlir
Add script tools/format_source.sh and run it on all python and c++ sources.
parent
c3d4436397
commit
2ba8296151
|
@ -16,6 +16,6 @@ namespace mlir {
|
|||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToTCPPass();
|
||||
}
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_CONVERSION_TCFTOTCP_CONVERTTCFTOTCP_H
|
||||
|
|
|
@ -17,6 +17,6 @@ namespace mlir {
|
|||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCPToLinalgPass();
|
||||
}
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_CONVERSION_TCPTOLINALG_CONVERTTCPTOLINALG_H
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace npcomp_rt {
|
|||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h.inc"
|
||||
|
||||
} // namespace tcf
|
||||
} // namespace npcomp_rt
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -25,8 +25,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLowerBroadcastToToLoopsPass();
|
|||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createLowerLinalgOnTensorToLinalgOnMemrefPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createResolveShapeOfOpsPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createResolveShapeOfOpsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createResolveTensorLoadStoreOpsPass();
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertTCFToTCP : public ConvertTCFToTCPBase<ConvertTCFToTCP> {
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/NpcompRt/IR/NpcompRtDialect.h"
|
||||
#include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::npcomp_rt;
|
||||
|
@ -44,4 +44,3 @@ void NpcompRtDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|||
llvm_unreachable("unexpected 'npcomp_rt' type kind");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,6 @@ namespace NPCOMP {
|
|||
namespace npcomp_rt {
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/NpcompRt/IR/NpcompRtOps.cpp.inc"
|
||||
} // namespace tcf
|
||||
} // namespace npcomp_rt
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
|
|
@ -29,10 +29,11 @@ LogicalResult ShapeObserveErrorOp::inferReturnTypes(
|
|||
// GetExtentOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GetExtentOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
LogicalResult
|
||||
GetExtentOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
|
||||
ValueRange operands, DictionaryAttr attributes,
|
||||
RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(IndexType::get(context));
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -65,21 +65,22 @@ using namespace mlir::NPCOMP;
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class ResolveShapeOfOpViaAllocMemRefOp : public OpRewritePattern<shape::ShapeOfOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (auto tensorLoad = llvm::dyn_cast_or_null<TensorLoadOp>(
|
||||
op.getOperand().getDefiningOp())) {
|
||||
if (auto allocMemRef = llvm::dyn_cast_or_null<tcp::AllocMemRefOp>(
|
||||
tensorLoad.getOperand().getDefiningOp())) {
|
||||
rewriter.replaceOp(op, allocMemRef.getOperand());
|
||||
return success();
|
||||
}
|
||||
class ResolveShapeOfOpViaAllocMemRefOp
|
||||
: public OpRewritePattern<shape::ShapeOfOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (auto tensorLoad = llvm::dyn_cast_or_null<TensorLoadOp>(
|
||||
op.getOperand().getDefiningOp())) {
|
||||
if (auto allocMemRef = llvm::dyn_cast_or_null<tcp::AllocMemRefOp>(
|
||||
tensorLoad.getOperand().getDefiningOp())) {
|
||||
rewriter.replaceOp(op, allocMemRef.getOperand());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -92,7 +93,7 @@ class ResolveShapeOfOps : public ResolveShapeOfOpsBase<ResolveShapeOfOps> {
|
|||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ResolveShapeOfOpViaAllocMemRefOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
//target.addIllegalOp<shape::ShapeOfOp>();
|
||||
// target.addIllegalOp<shape::ShapeOfOp>();
|
||||
target.addDynamicallyLegalOp<shape::ShapeOfOp>(
|
||||
[](shape::ShapeOfOp shapeOf) {
|
||||
// Only shape.shape_of on arguments to the entry block are legal at
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/E2E/E2E.h"
|
||||
#include "PassDetail.h"
|
||||
#include "npcomp/E2E/E2E.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
|
@ -66,7 +66,6 @@ public:
|
|||
Value outputExtent = rewriter.create<tcp::GetExtentOp>(
|
||||
op.getLoc(), op.shape(), rewriter.getI64IntegerAttr(i));
|
||||
outputExtents.push_back(outputExtent);
|
||||
|
||||
}
|
||||
int rankDiff = resultType.getRank() - inputType.getRank();
|
||||
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
||||
|
@ -108,7 +107,8 @@ public:
|
|||
}
|
||||
Value load =
|
||||
rewriter.create<LoadOp>(op.getLoc(), inputMemref, inputIndices);
|
||||
rewriter.create<StoreOp>(op.getLoc(), load, resultMemref, inductionVariables);
|
||||
rewriter.create<StoreOp>(op.getLoc(), load, resultMemref,
|
||||
inductionVariables);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<TensorLoadOp>(op, resultMemref);
|
||||
|
@ -173,91 +173,87 @@ mlir::NPCOMP::createLowerBroadcastToToLoopsPass() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class LowerLinalgGenericTensorToMemRef : public OpRewritePattern<linalg::GenericOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(linalg::GenericOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
class LowerLinalgGenericTensorToMemRef
|
||||
: public OpRewritePattern<linalg::GenericOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(linalg::GenericOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
// TODO: Replace this with more generic code operating on named
|
||||
// structured ops too.
|
||||
// TODO: Replace this with more generic code operating on named
|
||||
// structured ops too.
|
||||
|
||||
// Only handle generic ops where all operands and results are tensors.
|
||||
if (!llvm::all_of(op.getOperandTypes(), [](Type type) {
|
||||
return type.isa<RankedTensorType>();
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(op, "all operands must be tensors");
|
||||
}
|
||||
if (!llvm::all_of(op.getResultTypes(), [](Type type) {
|
||||
return type.isa<RankedTensorType>();
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(op, "all results must be tensors");
|
||||
}
|
||||
|
||||
// TODO: Loosen restrictions on indexing maps.
|
||||
// This will require more principled handling of shape reification
|
||||
// earlier in the compilation stack, as in general output shapes of a
|
||||
// linalg.generic cannot be inferred easily.
|
||||
// See:
|
||||
// https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866
|
||||
if (!llvm::all_of(op.indexing_maps(), [](Attribute map) {
|
||||
return map.cast<AffineMapAttr>().getValue().isIdentity();
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all indexing maps must be identity maps");
|
||||
}
|
||||
if (!llvm::all_of(op.iterator_types(), [](Attribute str) {
|
||||
return str.cast<StringAttr>().getValue() ==
|
||||
getParallelIteratorTypeName();
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all iterator types must be 'parallel'");
|
||||
}
|
||||
|
||||
SmallVector<Value, 6> memrefs;
|
||||
SmallVector<Value, 6> resultMemrefs;
|
||||
SmallVector<Value, 6> operandShapes;
|
||||
for (auto tensor : op.getOperands()) {
|
||||
auto shape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), tensor);
|
||||
auto memref =
|
||||
allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
|
||||
rewriter.create<TensorStoreOp>(op.getLoc(), tensor, memref);
|
||||
memrefs.push_back(memref);
|
||||
operandShapes.push_back(shape);
|
||||
}
|
||||
auto shapeType = shape::ShapeType::get(rewriter.getContext());
|
||||
SmallVector<Type, 6> shapeTypes(op.getNumResults(), shapeType);
|
||||
// TODO: We need more principled handling of output shapes.
|
||||
// This assumes that all results have the same shape, which is justified
|
||||
// by checks above, but we really need a better story here.
|
||||
SmallVector<Value, 6> resultShapes(op.getNumResults(), operandShapes[0]);
|
||||
for (auto t : llvm::zip(op.getResults(), resultShapes)) {
|
||||
auto tensor = std::get<0>(t);
|
||||
auto shape = std::get<1>(t);
|
||||
auto memref =
|
||||
allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
|
||||
memrefs.push_back(memref);
|
||||
resultMemrefs.push_back(memref);
|
||||
}
|
||||
auto newGeneric = rewriter.create<linalg::GenericOp>(
|
||||
op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs());
|
||||
newGeneric.region().getBlocks().clear();
|
||||
BlockAndValueMapping mapper;
|
||||
op.region().cloneInto(&newGeneric.region(), mapper);
|
||||
for (auto memref : resultMemrefs) {
|
||||
newGeneric.region().front().addArgument(
|
||||
memref.getType().cast<MemRefType>().getElementType());
|
||||
}
|
||||
auto newResultTensors =
|
||||
llvm::to_vector<6>(llvm::map_range(resultMemrefs, [&](Value memref) {
|
||||
return rewriter.create<TensorLoadOp>(op.getLoc(), memref)
|
||||
.getResult();
|
||||
}));
|
||||
rewriter.replaceOp(op, newResultTensors);
|
||||
return success();
|
||||
// Only handle generic ops where all operands and results are tensors.
|
||||
if (!llvm::all_of(op.getOperandTypes(),
|
||||
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
||||
return rewriter.notifyMatchFailure(op, "all operands must be tensors");
|
||||
}
|
||||
if (!llvm::all_of(op.getResultTypes(),
|
||||
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
||||
return rewriter.notifyMatchFailure(op, "all results must be tensors");
|
||||
}
|
||||
|
||||
// TODO: Loosen restrictions on indexing maps.
|
||||
// This will require more principled handling of shape reification
|
||||
// earlier in the compilation stack, as in general output shapes of a
|
||||
// linalg.generic cannot be inferred easily.
|
||||
// See:
|
||||
// https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866
|
||||
if (!llvm::all_of(op.indexing_maps(), [](Attribute map) {
|
||||
return map.cast<AffineMapAttr>().getValue().isIdentity();
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all indexing maps must be identity maps");
|
||||
}
|
||||
if (!llvm::all_of(op.iterator_types(), [](Attribute str) {
|
||||
return str.cast<StringAttr>().getValue() ==
|
||||
getParallelIteratorTypeName();
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all iterator types must be 'parallel'");
|
||||
}
|
||||
|
||||
SmallVector<Value, 6> memrefs;
|
||||
SmallVector<Value, 6> resultMemrefs;
|
||||
SmallVector<Value, 6> operandShapes;
|
||||
for (auto tensor : op.getOperands()) {
|
||||
auto shape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), tensor);
|
||||
auto memref = allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
|
||||
rewriter.create<TensorStoreOp>(op.getLoc(), tensor, memref);
|
||||
memrefs.push_back(memref);
|
||||
operandShapes.push_back(shape);
|
||||
}
|
||||
auto shapeType = shape::ShapeType::get(rewriter.getContext());
|
||||
SmallVector<Type, 6> shapeTypes(op.getNumResults(), shapeType);
|
||||
// TODO: We need more principled handling of output shapes.
|
||||
// This assumes that all results have the same shape, which is justified
|
||||
// by checks above, but we really need a better story here.
|
||||
SmallVector<Value, 6> resultShapes(op.getNumResults(), operandShapes[0]);
|
||||
for (auto t : llvm::zip(op.getResults(), resultShapes)) {
|
||||
auto tensor = std::get<0>(t);
|
||||
auto shape = std::get<1>(t);
|
||||
auto memref = allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
|
||||
memrefs.push_back(memref);
|
||||
resultMemrefs.push_back(memref);
|
||||
}
|
||||
auto newGeneric = rewriter.create<linalg::GenericOp>(
|
||||
op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs());
|
||||
newGeneric.region().getBlocks().clear();
|
||||
BlockAndValueMapping mapper;
|
||||
op.region().cloneInto(&newGeneric.region(), mapper);
|
||||
for (auto memref : resultMemrefs) {
|
||||
newGeneric.region().front().addArgument(
|
||||
memref.getType().cast<MemRefType>().getElementType());
|
||||
}
|
||||
auto newResultTensors =
|
||||
llvm::to_vector<6>(llvm::map_range(resultMemrefs, [&](Value memref) {
|
||||
return rewriter.create<TensorLoadOp>(op.getLoc(), memref).getResult();
|
||||
}));
|
||||
rewriter.replaceOp(op, newResultTensors);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerLinalgOnTensorToLinalgOnMemref
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/E2E/E2E.h"
|
||||
#include "PassDetail.h"
|
||||
#include "npcomp/E2E/E2E.h"
|
||||
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
|
|
|
@ -22,66 +22,77 @@ def add():
|
|||
# CHECK: {{.*}} = basicpy.binary_expr %[[A]] "Add" %[[B]] : (i64, i64) -> !basicpy.UnknownType
|
||||
return a + b
|
||||
|
||||
|
||||
# CHECK-LABEL: func @sub
|
||||
@import_global
|
||||
def sub():
|
||||
# CHECK: basicpy.binary_expr {{.*}} "Sub"
|
||||
return 4 - 2
|
||||
|
||||
|
||||
# CHECK-LABEL: func @mult
|
||||
@import_global
|
||||
def mult():
|
||||
# CHECK: basicpy.binary_expr {{.*}} "Mult"
|
||||
return 4 * 2
|
||||
|
||||
|
||||
# CHECK-LABEL: func @div
|
||||
@import_global
|
||||
def div():
|
||||
# CHECK: basicpy.binary_expr {{.*}} "Div"
|
||||
return 4 / 2
|
||||
|
||||
|
||||
# CHECK-LABEL: func @floor_div
|
||||
@import_global
|
||||
def floor_div():
|
||||
# CHECK: basicpy.binary_expr {{.*}} "FloorDiv"
|
||||
return 4 // 2
|
||||
|
||||
|
||||
# CHECK-LABEL: func @matmul
|
||||
@import_global
|
||||
def matmul():
|
||||
# CHECK: basicpy.binary_expr {{.*}} "MatMult"
|
||||
return 4 @ 2
|
||||
|
||||
|
||||
# CHECK-LABEL: func @modulo
|
||||
@import_global
|
||||
def modulo():
|
||||
# CHECK: basicpy.binary_expr {{.*}} "Mod"
|
||||
return 4 % 2
|
||||
|
||||
|
||||
# CHECK-LABEL: func @left_shift
|
||||
@import_global
|
||||
def left_shift():
|
||||
# CHECK: basicpy.binary_expr {{.*}} "LShift"
|
||||
return 4 << 2
|
||||
|
||||
|
||||
# CHECK-LABEL: func @right_shift
|
||||
@import_global
|
||||
def right_shift():
|
||||
# CHECK: basicpy.binary_expr {{.*}} "RShift"
|
||||
return 4 >> 2
|
||||
|
||||
|
||||
# CHECK-LABEL: func @bit_and
|
||||
@import_global
|
||||
def bit_and():
|
||||
# CHECK: basicpy.binary_expr {{.*}} "BitAnd"
|
||||
return 4 & 2
|
||||
|
||||
|
||||
# CHECK-LABEL: func @bit_xor
|
||||
@import_global
|
||||
def bit_xor():
|
||||
# CHECK: basicpy.binary_expr {{.*}} "BitXor"
|
||||
return 4 ^ 2
|
||||
|
||||
|
||||
# CHECK-LABEL: func @bit_or
|
||||
@import_global
|
||||
def bit_or():
|
||||
|
|
|
@ -38,6 +38,7 @@ def logical_and():
|
|||
# CHECK: }
|
||||
return x and y and z
|
||||
|
||||
|
||||
# CHECK-LABEL: func @logical_or
|
||||
@import_global
|
||||
def logical_or():
|
||||
|
@ -65,17 +66,19 @@ def logical_or():
|
|||
z = 2
|
||||
return x or y or z
|
||||
|
||||
|
||||
# CHECK-LABEL: func @logical_not
|
||||
@import_global
|
||||
def logical_not():
|
||||
# CHECK: %[[X:.*]] = constant 1
|
||||
x = 1
|
||||
# CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||
# CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||
# CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant false
|
||||
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.to_boolean %[[X]]
|
||||
# CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType
|
||||
return not x
|
||||
|
||||
|
||||
# CHECK-LABEL: func @conditional
|
||||
@import_global
|
||||
def conditional():
|
||||
|
|
|
@ -21,6 +21,7 @@ def binary_lt_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare %[[A]] "Lt" %[[B]] : i64, i64
|
||||
return x < y
|
||||
|
||||
|
||||
# CHECK-LABEL: func @binary_gt_
|
||||
@import_global
|
||||
def binary_gt_():
|
||||
|
@ -29,6 +30,7 @@ def binary_gt_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Gt" {{.*}} : i64, i64
|
||||
return x > y
|
||||
|
||||
|
||||
# CHECK-LABEL: func @binary_lte_
|
||||
@import_global
|
||||
def binary_lte_():
|
||||
|
@ -37,6 +39,7 @@ def binary_lte_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "LtE" {{.*}} : i64, i64
|
||||
return x <= y
|
||||
|
||||
|
||||
# CHECK-LABEL: func @binary_gte_
|
||||
@import_global
|
||||
def binary_gte_():
|
||||
|
@ -45,6 +48,7 @@ def binary_gte_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "GtE" {{.*}} : i64, i64
|
||||
return x >= y
|
||||
|
||||
|
||||
# CHECK-LABEL: func @binary_eq_
|
||||
@import_global
|
||||
def binary_eq_():
|
||||
|
@ -53,6 +57,7 @@ def binary_eq_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Eq" {{.*}} : i64, i64
|
||||
return x == y
|
||||
|
||||
|
||||
# CHECK-LABEL: func @binary_neq_
|
||||
@import_global
|
||||
def binary_neq_():
|
||||
|
@ -61,6 +66,7 @@ def binary_neq_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "NotEq" {{.*}} : i64, i64
|
||||
return x != y
|
||||
|
||||
|
||||
# CHECK-LABEL: func @binary_is_
|
||||
@import_global
|
||||
def binary_is_():
|
||||
|
@ -69,6 +75,7 @@ def binary_is_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "Is" {{.*}} : i64, i64
|
||||
return x is y
|
||||
|
||||
|
||||
# CHECK-LABEL: func @binary_is_not_
|
||||
@import_global
|
||||
def binary_is_not_():
|
||||
|
@ -77,6 +84,7 @@ def binary_is_not_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "IsNot" {{.*}} : i64, i64
|
||||
return x is not y
|
||||
|
||||
|
||||
# CHECK-LABEL: func @binary_in_
|
||||
@import_global
|
||||
def binary_in_():
|
||||
|
@ -85,6 +93,7 @@ def binary_in_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "In" {{.*}} : i64, i64
|
||||
return x in y
|
||||
|
||||
|
||||
# CHECK-LABEL: func @binary_not_in_
|
||||
@import_global
|
||||
def binary_not_in_():
|
||||
|
@ -93,6 +102,7 @@ def binary_not_in_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "NotIn" {{.*}} : i64, i64
|
||||
return x not in y
|
||||
|
||||
|
||||
@import_global
|
||||
def short_circuit():
|
||||
# CHECK: %[[X:.*]] = constant 1 : i64
|
||||
|
@ -123,6 +133,7 @@ def short_circuit():
|
|||
# CHECK: return %[[RESULT]]
|
||||
return x < y == z >= omega
|
||||
|
||||
|
||||
# CHECK-LABEL: nested_short_circuit_expression
|
||||
@import_global
|
||||
def nested_short_circuit_expression():
|
||||
|
|
|
@ -20,6 +20,7 @@ def integer_constants():
|
|||
# CHECK: return %[[A_CAST]]
|
||||
return a
|
||||
|
||||
|
||||
# CHECK-LABEL: func @float_constants
|
||||
@import_global
|
||||
def float_constants():
|
||||
|
@ -29,6 +30,7 @@ def float_constants():
|
|||
# CHECK: return %[[A_CAST]]
|
||||
return a
|
||||
|
||||
|
||||
# CHECK-LABEL: func @bool_true_constant
|
||||
@import_global
|
||||
def bool_true_constant():
|
||||
|
@ -37,6 +39,7 @@ def bool_true_constant():
|
|||
a = True
|
||||
return a
|
||||
|
||||
|
||||
# CHECK-LABEL: func @bool_false_constant
|
||||
@import_global
|
||||
def bool_false_constant():
|
||||
|
@ -45,6 +48,7 @@ def bool_false_constant():
|
|||
a = False
|
||||
return a
|
||||
|
||||
|
||||
# CHECK-LABEL: func @string_constant
|
||||
@import_global
|
||||
def string_constant():
|
||||
|
@ -53,6 +57,7 @@ def string_constant():
|
|||
a = "foobar"
|
||||
return a
|
||||
|
||||
|
||||
# CHECK-LABEL: func @joined_string_constant
|
||||
@import_global
|
||||
def joined_string_constant():
|
||||
|
@ -61,6 +66,7 @@ def joined_string_constant():
|
|||
a = "I am" " still here"
|
||||
return a
|
||||
|
||||
|
||||
# CHECK-LABEL: func @bytes_constant
|
||||
@import_global
|
||||
def bytes_constant():
|
||||
|
@ -69,6 +75,7 @@ def bytes_constant():
|
|||
a = b"foobar"
|
||||
return a
|
||||
|
||||
|
||||
# CHECK-LABEL: func @ellipsis
|
||||
@import_global
|
||||
def ellipsis():
|
||||
|
@ -77,6 +84,7 @@ def ellipsis():
|
|||
a = ...
|
||||
return a
|
||||
|
||||
|
||||
# CHECK-LABEL: func @none_constant
|
||||
@import_global
|
||||
def none_constant():
|
||||
|
|
|
@ -10,6 +10,7 @@ def import_global(f):
|
|||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: func @positional_args
|
||||
# CHECK-SAME: (%arg0: !basicpy.UnknownType, %arg1: !basicpy.UnknownType) -> !basicpy.UnknownType
|
||||
@import_global
|
||||
|
@ -17,6 +18,7 @@ def positional_args(a, b):
|
|||
# CHECK: basicpy.binary_expr %arg0 "Add" %arg1
|
||||
return a + b
|
||||
|
||||
|
||||
# CHECK-LABEL: func @pass_no_return
|
||||
@import_global
|
||||
def pass_no_return():
|
||||
|
|
|
@ -10,6 +10,7 @@ def import_global(f):
|
|||
print("// -----")
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: func @arithmetic_expression
|
||||
# CHECK-SAME: () -> i64
|
||||
@import_global
|
||||
|
@ -21,6 +22,7 @@ def arithmetic_expression():
|
|||
# CHECK: return{{.*}} : i64
|
||||
return 1 + 2 - 3 * 4
|
||||
|
||||
|
||||
# CHECK-LABEL: func @arg_inference
|
||||
# CHECK-SAME: (%arg0: i64, %arg1: i64) -> i64
|
||||
@import_global
|
||||
|
@ -31,9 +33,10 @@ def arg_inference(a, b):
|
|||
# CHECK: return{{.*}} : i64
|
||||
return a + 2 * b
|
||||
|
||||
|
||||
# CHECK-LABEL: func @conditional_inference
|
||||
# CHECK-SAME: (%arg0: i64, %arg1: !basicpy.BoolType, %arg2: i64) -> !basicpy.BoolType
|
||||
@import_global
|
||||
def conditional_inference(cond, a, b):
|
||||
# CHECK-NOT: UnknownType
|
||||
return a if cond + 1 else not(b * 4)
|
||||
return a if cond + 1 else not (b * 4)
|
||||
|
|
|
@ -39,15 +39,13 @@ llvm_config.with_system_environment(
|
|||
|
||||
llvm_config.use_default_substitutions()
|
||||
|
||||
# excludes: A list of files/directories to exclude from the testsuite. The
|
||||
# 'Inputs'subdirectories contain auxiliary inputs for various tests in their
|
||||
# excludes: A list of files/directories to exclude from the testsuite. The
|
||||
# 'Inputs'subdirectories contain auxiliary inputs for various tests in their
|
||||
# parent directories.
|
||||
config.excludes = [
|
||||
'Inputs', 'Examples',
|
||||
'lit.cfg.py',
|
||||
'CMakeLists.txt',
|
||||
'README.txt',
|
||||
'LICENSE.txt']
|
||||
'Inputs', 'Examples', 'lit.cfg.py', 'CMakeLists.txt', 'README.txt',
|
||||
'LICENSE.txt'
|
||||
]
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
@ -56,22 +54,21 @@ config.test_source_root = os.path.dirname(__file__)
|
|||
config.test_exec_root = os.path.join(config.npcomp_obj_root, 'pytest')
|
||||
config.npcomp_tools_dir = os.path.join(config.npcomp_obj_root, 'tools')
|
||||
config.npcomp_runtime_shlib = os.path.join(
|
||||
config.npcomp_obj_root,
|
||||
'runtime',
|
||||
'libNPCOMPRuntime' + config.llvm_shlib_ext
|
||||
)
|
||||
config.npcomp_obj_root, 'runtime',
|
||||
'libNPCOMPRuntime' + config.llvm_shlib_ext)
|
||||
|
||||
# Tweak the PATH and PYTHONPATH to include the tools dir.
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
llvm_config.with_environment('PYTHONPATH', [
|
||||
os.path.join(config.npcomp_obj_root, "python"),
|
||||
os.path.join(config.npcomp_obj_root, "python_native")],
|
||||
append_path=True)
|
||||
os.path.join(config.npcomp_obj_root, "python_native")
|
||||
],
|
||||
append_path=True)
|
||||
|
||||
tool_dirs = [
|
||||
os.path.join(config.npcomp_tools_dir, 'npcomp-opt'),
|
||||
os.path.join(config.npcomp_tools_dir, 'npcomp-run-mlir'),
|
||||
config.llvm_tools_dir,
|
||||
os.path.join(config.npcomp_tools_dir, 'npcomp-opt'),
|
||||
os.path.join(config.npcomp_tools_dir, 'npcomp-run-mlir'),
|
||||
config.llvm_tools_dir,
|
||||
]
|
||||
tools = [
|
||||
'npcomp-opt',
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
|
|
|
@ -65,6 +65,7 @@ class DialectHelper(Basicpy.DialectHelper):
|
|||
tensor<*x!numpy.any_dtype>
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def numpy_any_dtype(self):
|
||||
return self.context.parse_type("!numpy.any_dtype")
|
||||
|
|
|
@ -9,9 +9,9 @@ from typing import Optional
|
|||
from npcomp.types import *
|
||||
|
||||
__all__ = [
|
||||
"Exporter",
|
||||
"ExportFunction",
|
||||
"ExportPyFunction",
|
||||
"Exporter",
|
||||
"ExportFunction",
|
||||
"ExportPyFunction",
|
||||
]
|
||||
|
||||
|
||||
|
@ -21,40 +21,41 @@ def _value_type_from_annotation(annotation):
|
|||
return ValueType(TypeClass.NdArray)
|
||||
else:
|
||||
return ValueType()
|
||||
|
||||
|
||||
|
||||
|
||||
def _signature_from_pyfunc(pyfunc):
|
||||
pysig = inspect.signature(pyfunc)
|
||||
sig = Signature(len(pysig.parameters))
|
||||
# Arguments
|
||||
for i, param in enumerate(pysig.parameters.values()):
|
||||
if param.kind not in (
|
||||
param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
||||
if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
||||
raise ValueError(
|
||||
"Currently only positional function signature are supported")
|
||||
|
||||
"Currently only positional function signature are supported")
|
||||
|
||||
sig.arg_names[i] = param.name
|
||||
annot = param.annotation
|
||||
if annot is param.empty: continue
|
||||
if annot is param.empty:
|
||||
continue
|
||||
sig.args[i] = _value_type_from_annotation(annot)
|
||||
|
||||
|
||||
# Result
|
||||
if pysig.return_annotation is not pysig.empty:
|
||||
sig.result = _value_type_from_annotation(pysig.return_annotation)
|
||||
|
||||
|
||||
return sig
|
||||
|
||||
|
||||
|
||||
class ExportFunction:
|
||||
"""Base class for functions that can be exported."""
|
||||
__slots__ = ["_sig"]
|
||||
|
||||
def __init__(self, sig=None):
|
||||
self._sig = sig if sig else Signature()
|
||||
|
||||
|
||||
@property
|
||||
def sig(self):
|
||||
return self._sig
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return "def %r" % self._sig
|
||||
|
||||
|
@ -86,21 +87,21 @@ class ExportPyFunction(ExportFunction):
|
|||
pydef mul(a: NdArray[Rank(2)], b: Any) -> NdArray[Shape(1, 2)]
|
||||
"""
|
||||
__slots__ = ExportFunction.__slots__ + ["_pyfunc", "__name__"]
|
||||
|
||||
|
||||
def __init__(self, pyfunc, name=None):
|
||||
super().__init__(sig=_signature_from_pyfunc(pyfunc))
|
||||
assert (hasattr(pyfunc, "__call__")
|
||||
and hasattr(pyfunc, "__name__")), "Not a python function"
|
||||
assert (hasattr(pyfunc, "__call__") and
|
||||
hasattr(pyfunc, "__name__")), "Not a python function"
|
||||
self._pyfunc = pyfunc
|
||||
self.__name__ = name if name else pyfunc.__name__
|
||||
|
||||
@property
|
||||
def pyfunc(self):
|
||||
return self._pyfunc
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return "pydef %s%r" % (self.__name__, self._sig)
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._pyfunc(*args, **kwargs)
|
||||
|
||||
|
@ -108,53 +109,59 @@ class ExportPyFunction(ExportFunction):
|
|||
class _ExpandoNode:
|
||||
"""Expando object that can be indexed into to construct a namespace."""
|
||||
__slots__ = [
|
||||
"_parent", "_services", "_local_name", "_parent_name",
|
||||
"_children", "_attached"]
|
||||
def __init__(self, parent: Optional["_ExpandoNode"],
|
||||
services: "_Services",
|
||||
local_name: str):
|
||||
"_parent", "_services", "_local_name", "_parent_name", "_children",
|
||||
"_attached"
|
||||
]
|
||||
|
||||
def __init__(self, parent: Optional["_ExpandoNode"], services: "_Services",
|
||||
local_name: str):
|
||||
super().__init__()
|
||||
object.__setattr__(self, "_parent", parent)
|
||||
object.__setattr__(self, "_services", services)
|
||||
object.__setattr__(self, "_local_name", local_name)
|
||||
object.__setattr__(self, "_parent_name",
|
||||
object.__setattr__(self, "_parent_name",
|
||||
parent._get_full_name() if parent else "")
|
||||
object.__setattr__(self, "_children", {})
|
||||
object.__setattr__(self, "_attached", parent is None)
|
||||
|
||||
|
||||
def _attach(self):
|
||||
if self._attached: return
|
||||
if self._attached:
|
||||
return
|
||||
if self._local_name in self._parent._children:
|
||||
raise KeyError("Cannot re-assign '%s'" % (self._get_full_name(),))
|
||||
self._parent._attach()
|
||||
self._parent._children[self._local_name] = self
|
||||
object.__setattr__(self, "_attached", True)
|
||||
|
||||
|
||||
def _get_full_name(self):
|
||||
if not self._parent: return "" # Root is always empty name.
|
||||
full_name = (self._parent_name + "." + self._local_name
|
||||
if self._parent_name else self._local_name)
|
||||
if not self._parent:
|
||||
return "" # Root is always empty name.
|
||||
full_name = (self._parent_name + "." +
|
||||
self._local_name if self._parent_name else self._local_name)
|
||||
return full_name
|
||||
|
||||
|
||||
def _get_child_name(self, child_local_name):
|
||||
full_name = self._get_full_name()
|
||||
if not full_name: return child_local_name
|
||||
else: return full_name + "." + child_local_name
|
||||
|
||||
if not full_name:
|
||||
return child_local_name
|
||||
else:
|
||||
return full_name + "." + child_local_name
|
||||
|
||||
def __repr__(self):
|
||||
return "Namespace(\"%s\")" % (self._get_full_name())
|
||||
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._children
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
key = str(key)
|
||||
existing = self._children.get(key)
|
||||
if existing is not None: return existing
|
||||
if existing is not None:
|
||||
return existing
|
||||
# Speculatively create a child expando.
|
||||
child = _ExpandoNode(self, self._services, key)
|
||||
return child
|
||||
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if not inspect.isfunction(value):
|
||||
raise TypeError("Cannot assign value to an exporter: %r" % (value,))
|
||||
|
@ -164,19 +171,19 @@ class _ExpandoNode:
|
|||
raise KeyError("Cannot re-assign '%s'" % (child_name))
|
||||
self._attach()
|
||||
self._children[key] = self._services.wrap_function(value, child_name)
|
||||
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self[name]
|
||||
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
try:
|
||||
self[name] = value
|
||||
except KeyError as e:
|
||||
raise AttributeError(str(e)) from None
|
||||
|
||||
|
||||
def __dir__(self):
|
||||
return self._children.keys()
|
||||
|
||||
|
||||
|
||||
class _Services:
|
||||
"""Services and support for the Exporter.
|
||||
|
@ -184,12 +191,14 @@ class _Services:
|
|||
Exporters are user objects, so most of the functional components are
|
||||
contained in the associated _Services object.
|
||||
"""
|
||||
|
||||
def wrap_function(self, f, full_name):
|
||||
if isinstance(f, ExportFunction): return f
|
||||
if isinstance(f, ExportFunction):
|
||||
return f
|
||||
# TODO: Need to scan through providers and choose.
|
||||
return ExportPyFunction(f, name=full_name)
|
||||
|
||||
|
||||
|
||||
|
||||
class Exporter:
|
||||
"""Top-level UI object for assembling a program for export.
|
||||
|
||||
|
@ -244,31 +253,32 @@ class Exporter:
|
|||
AttributeError: "Cannot re-assign 'ns1'"
|
||||
"""
|
||||
__slots__ = ["_root", "_services"]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
services = _Services()
|
||||
object.__setattr__(self, "_root", _ExpandoNode(None, services, ""))
|
||||
object.__setattr__(self, "_services", services)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return "Exporter()"
|
||||
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._root
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._root[key]
|
||||
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._root[key] = value
|
||||
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._root, name)
|
||||
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
setattr(self._root, name, value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
"""Test for the MLIR IR Python bindings"""
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
|
@ -39,5 +38,4 @@ except ValueError as e:
|
|||
# CHECK: [ERROR]: expected operation name in quotes
|
||||
print(e)
|
||||
|
||||
|
||||
test_utils.end_filecheck_test(__file__)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
"""Test for the MLIR Pass Python bindings"""
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
|
|
|
@ -48,10 +48,11 @@ class TraceContext:
|
|||
"""
|
||||
_local = threading.local()
|
||||
__slots__ = [
|
||||
"_desc",
|
||||
"_next_id",
|
||||
"active",
|
||||
"_desc",
|
||||
"_next_id",
|
||||
"active",
|
||||
]
|
||||
|
||||
def __init__(self, desc=None):
|
||||
_check_numpy_version()
|
||||
self._desc = desc
|
||||
|
@ -87,11 +88,11 @@ class TraceContext:
|
|||
@classmethod
|
||||
def optional_current(cls) -> Optional["TraceContext"]:
|
||||
s = cls._get_context_stack()
|
||||
if s:
|
||||
if s:
|
||||
return s[-1]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
def current(cls) -> "TraceContext":
|
||||
c = cls.optional_current()
|
||||
|
@ -120,7 +121,7 @@ class TraceContext:
|
|||
|
||||
def _assert_active(tc: TraceContext):
|
||||
assert tc.active, (
|
||||
"Attempt to trace an action on an inactive trace context: %r" % tc)
|
||||
"Attempt to trace an action on an inactive trace context: %r" % tc)
|
||||
|
||||
|
||||
class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
|
||||
|
@ -134,6 +135,7 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
|
|||
>>> TracedArray(tc=tc)
|
||||
<TracedArray 2>
|
||||
"""
|
||||
|
||||
def __init__(self, tc: Optional[TraceContext] = None):
|
||||
self._tc = tc if tc is not None else TraceContext.current()
|
||||
self._uid = self._tc.get_next_id()
|
||||
|
@ -160,28 +162,28 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
|
|||
|
||||
def __array_function__(self, func, types, args, kwargs):
|
||||
tc = self._tc
|
||||
_assert_active(tc)
|
||||
_assert_active(tc)
|
||||
return tc._handle_array_func(func, types, args, kwargs)
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
"""Shortcut for transpose."""
|
||||
tc = self._tc
|
||||
_assert_active(tc)
|
||||
_assert_active(tc)
|
||||
return tc._handle_array_func(np.transpose, [TracedArray], [self], {})
|
||||
|
||||
|
||||
|
||||
def _check_numpy_version():
|
||||
version = np.lib.NumpyVersion(np.__version__)
|
||||
if version < "1.16.0":
|
||||
raise RuntimeError("Numpy version >= 1.16 is required")
|
||||
if version > "1.17.0": return
|
||||
if version > "1.17.0":
|
||||
return
|
||||
if os.environ.get("NUMPY_EXPERIMENTAL_ARRAY_FUNCTION") != "1":
|
||||
raise RuntimeError(
|
||||
"For numpy 1.16, the environment variable "
|
||||
"NUMPY_EXPERIMENTAL_ARRAY_FUNCTION must equal 1")
|
||||
raise RuntimeError("For numpy 1.16, the environment variable "
|
||||
"NUMPY_EXPERIMENTAL_ARRAY_FUNCTION must equal 1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
|
|
@ -9,9 +9,11 @@ from npcomp.utils import test_utils
|
|||
|
||||
test_utils.start_filecheck_test()
|
||||
|
||||
|
||||
def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
return a * b + a + b
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = Exporter()
|
||||
exp.simple_mul = simple_mul
|
||||
|
@ -22,7 +24,7 @@ exp.simple_mul.sig.args["b"] += Shape(1)
|
|||
exp.simple_mul.sig.args["b"] += DType(np.float32)
|
||||
exp.simple_mul.sig.result += Shape(1, 4)
|
||||
exp.simple_mul.sig.result += DynamicDim(0)
|
||||
exp.simple_mul.sig.result += DType(np.float32)
|
||||
exp.simple_mul.sig.result += DType(np.float32)
|
||||
|
||||
mb = ModuleBuilder()
|
||||
mb.trace(exp.simple_mul)
|
||||
|
|
|
@ -8,30 +8,30 @@ from enum import Enum
|
|||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
"Unspec",
|
||||
"ArrayConstraint",
|
||||
"ArrayParams",
|
||||
"DType",
|
||||
"DimFlag",
|
||||
"DimFlagEnum",
|
||||
"DynamicDim",
|
||||
"Rank",
|
||||
"Shape",
|
||||
"Signature",
|
||||
"TypeClass",
|
||||
"TypeConstraints",
|
||||
"ValueType",
|
||||
"Unspec",
|
||||
"ArrayConstraint",
|
||||
"ArrayParams",
|
||||
"DType",
|
||||
"DimFlag",
|
||||
"DimFlagEnum",
|
||||
"DynamicDim",
|
||||
"Rank",
|
||||
"Shape",
|
||||
"Signature",
|
||||
"TypeClass",
|
||||
"TypeConstraints",
|
||||
"ValueType",
|
||||
]
|
||||
|
||||
# TODO: All supported types
|
||||
_DTYPE_TO_ASM_DICT = {
|
||||
np.bool: "i1", # TODO: May need a custom type to signify 8bit storage
|
||||
np.int8: "s8",
|
||||
np.int16: "s16",
|
||||
np.int32: "s32",
|
||||
np.int64: "s64",
|
||||
np.float32: "f32",
|
||||
np.float64: "f64",
|
||||
np.bool: "i1", # TODO: May need a custom type to signify 8bit storage
|
||||
np.int8: "s8",
|
||||
np.int16: "s16",
|
||||
np.int32: "s32",
|
||||
np.int64: "s64",
|
||||
np.float32: "f32",
|
||||
np.float64: "f64",
|
||||
}
|
||||
|
||||
|
||||
|
@ -67,29 +67,39 @@ class _LiterateEnum(Enum):
|
|||
ValueError: Cannot parse SampleEnum 1.0
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v):
|
||||
if isinstance(v, cls): return v
|
||||
if isinstance(v, cls):
|
||||
return v
|
||||
if not v or not isinstance(v, str) or v[0] == '_' or not hasattr(cls, v):
|
||||
raise ValueError("Cannot parse %s %r" % (
|
||||
cls.__name__.split(".")[-1], v,))
|
||||
cls.__name__.split(".")[-1],
|
||||
v,
|
||||
))
|
||||
value = getattr(cls, v)
|
||||
if not isinstance(value, cls):
|
||||
raise ValueError("Cannot parse %s %r" % (
|
||||
cls.__name__.split(".")[-1], v,))
|
||||
cls.__name__.split(".")[-1],
|
||||
v,
|
||||
))
|
||||
return value
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
|
||||
|
||||
# Special "unspecified" value that we use throughout.
|
||||
class _Unspec:
|
||||
__slots__ = []
|
||||
|
||||
def __str__(self):
|
||||
return "Unspec"
|
||||
|
||||
def __repr__(self):
|
||||
return "Unspec"
|
||||
|
||||
|
||||
Unspec = _Unspec()
|
||||
|
||||
|
||||
|
@ -97,8 +107,8 @@ class TypeClass(_LiterateEnum):
|
|||
"""Top level types in the npcomp language."""
|
||||
Any = 0
|
||||
NdArray = 1
|
||||
|
||||
|
||||
|
||||
|
||||
class ValueType:
|
||||
"""The type a value can take in the npcomp language.
|
||||
|
||||
|
@ -123,32 +133,32 @@ class ValueType:
|
|||
NdArray
|
||||
"""
|
||||
__slots__ = ["_constraints", "_type_class"]
|
||||
|
||||
|
||||
def __init__(self, type_class=TypeClass.Any, *constraints):
|
||||
super().__init__()
|
||||
self._type_class = TypeClass.parse(type_class)
|
||||
self._constraints = TypeConstraints(constraints)
|
||||
|
||||
def __iadd__(self, constraint):
|
||||
assert isinstance(constraint, TypeConstraint), (
|
||||
"Can only add constraints to a ValueType")
|
||||
assert isinstance(
|
||||
constraint, TypeConstraint), ("Can only add constraints to a ValueType")
|
||||
self._constraints.append(constraint)
|
||||
return self
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
if not self._constraints:
|
||||
return repr(self._type_class)
|
||||
return "%r[%s]" % (self._type_class,
|
||||
", ".join([repr(c) for c in self._constraints]))
|
||||
|
||||
return "%r[%s]" % (self._type_class, ", ".join(
|
||||
[repr(c) for c in self._constraints]))
|
||||
|
||||
@property
|
||||
def type_class(self):
|
||||
return self._type_class
|
||||
|
||||
|
||||
@type_class.setter
|
||||
def type_class(self, type_class):
|
||||
self._type_class = TypeClass.parse(type_class)
|
||||
|
||||
|
||||
@property
|
||||
def constraints(self):
|
||||
return self._constraints
|
||||
|
@ -179,6 +189,7 @@ class ValueTypeList:
|
|||
(Any, Any, Any)
|
||||
"""
|
||||
__slots__ = ["_list", "_names"]
|
||||
|
||||
def __init__(self, arity=0, names=None):
|
||||
self._list = [ValueType() for _ in range(arity)]
|
||||
self._names = names
|
||||
|
@ -188,18 +199,19 @@ class ValueTypeList:
|
|||
# Scan for the index.
|
||||
if self._names:
|
||||
for i, n in enumerate(self._names):
|
||||
if n == key: return i
|
||||
if n == key:
|
||||
return i
|
||||
raise KeyError("Unknown key '%s'" % key)
|
||||
return key
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._list[self._key_to_index(key)]
|
||||
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if not isinstance(value, ValueType):
|
||||
value = ValueType(value)
|
||||
self._list[self._key_to_index(key)] = value
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
return self._list.__iter__()
|
||||
|
||||
|
@ -233,20 +245,21 @@ class Signature:
|
|||
(a: Any, b: NdArray[Rank(2)]) -> NdArray[Rank(3)]
|
||||
"""
|
||||
__slots__ = ["_args", "_arg_names", "_result"]
|
||||
|
||||
def __init__(self, arity=0):
|
||||
super().__init__()
|
||||
self._result = ValueType()
|
||||
self._arg_names = [None] * arity
|
||||
self._args = ValueTypeList(arity, names=self._arg_names)
|
||||
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return self._args
|
||||
|
||||
|
||||
@property
|
||||
def arg_names(self):
|
||||
return self._arg_names
|
||||
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return self._result
|
||||
|
@ -256,14 +269,14 @@ class Signature:
|
|||
if not isinstance(value, ValueType):
|
||||
value = ValueType(value)
|
||||
self._result = value
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
args_repr = "(%s)" % (
|
||||
", ".join(
|
||||
args_repr = "(%s)" % (", ".join(
|
||||
((n + ": " + repr(t)) if n else repr(t))
|
||||
for t, n in zip(self._args, self._arg_names)),)
|
||||
return "%s -> %r" % (args_repr, self._result)
|
||||
|
||||
|
||||
class ArrayParams:
|
||||
"""Represents parameters defining how to construct an array.
|
||||
|
||||
|
@ -277,7 +290,7 @@ class ArrayParams:
|
|||
ArrayParams(dtype=float32, shape=(1, 2, 3))
|
||||
"""
|
||||
__slots__ = ["dtype", "shape"]
|
||||
|
||||
|
||||
def __init__(self, dtype=Unspec, shape=Unspec, rank=Unspec):
|
||||
self.dtype = dtype
|
||||
if shape is not Unspec:
|
||||
|
@ -289,9 +302,10 @@ class ArrayParams:
|
|||
|
||||
@property
|
||||
def rank(self):
|
||||
if self.shape is Unspec: return Unspec
|
||||
if self.shape is Unspec:
|
||||
return Unspec
|
||||
return len(self.shape)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_constraints(cls, constraints):
|
||||
"""Constructs params for a TypeConstraints list.
|
||||
|
@ -341,42 +355,40 @@ class ArrayParams:
|
|||
shape_c = constraints.one_of(Shape)
|
||||
rank_c = constraints.one_of(Rank)
|
||||
dim_flags = constraints.all_of(DimFlag)
|
||||
|
||||
|
||||
dtype = dtype_c.dtype if dtype_c else Unspec
|
||||
shape = Unspec
|
||||
|
||||
|
||||
# Compute shape
|
||||
if shape_c:
|
||||
# TODO: Should be in canonicalizer
|
||||
if rank_c and rank_c.rank != len(shape_c.dims):
|
||||
raise ValueError("Conflicting shape and rank: %r vs %r" % (
|
||||
rank_c, shape_c))
|
||||
raise ValueError("Conflicting shape and rank: %r vs %r" %
|
||||
(rank_c, shape_c))
|
||||
shape = list(shape_c.dims)
|
||||
elif rank_c:
|
||||
shape = [-1 for _ in range(rank_c.rank)]
|
||||
|
||||
|
||||
# Apply dim flags
|
||||
if shape is not Unspec and dim_flags:
|
||||
for df in dim_flags:
|
||||
flag, for_dims = df.dim_flag
|
||||
for d in for_dims:
|
||||
if d < 0 or d >= len(shape):
|
||||
raise ValueError("Out of range %r for shape %r" % (
|
||||
df, shape))
|
||||
raise ValueError("Out of range %r for shape %r" % (df, shape))
|
||||
if flag == DimFlagEnum.Dynamic:
|
||||
shape[d] = -1
|
||||
|
||||
|
||||
return cls(dtype=dtype, shape=shape)
|
||||
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
try:
|
||||
s = "ArrayParams(dtype=%s" % (
|
||||
self.dtype.__name__ if isinstance(self.dtype, type) else self.dtype,)
|
||||
s = "ArrayParams(dtype=%s" % (self.dtype.__name__ if isinstance(
|
||||
self.dtype, type) else self.dtype,)
|
||||
if self.shape is not Unspec:
|
||||
s += ", shape=%r" % (tuple(self.shape),)
|
||||
s += ")"
|
||||
return s
|
||||
return s
|
||||
except:
|
||||
return "ArrayParams(ERROR)"
|
||||
|
||||
|
@ -400,7 +412,7 @@ class ArrayParams:
|
|||
if any(d < 0 for d in self.shape):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def mlir_tensor_type_asm(self):
|
||||
"""Get a corresponding MLIR tensor type.
|
||||
|
@ -436,15 +448,16 @@ class ArrayParams:
|
|||
else:
|
||||
dtype_asm = _dtype_to_mlir_asm(self.dtype)
|
||||
if not dtype_asm:
|
||||
raise ValueError(
|
||||
"Unsupported MLIR tensor element type %r" % (self.dtype,))
|
||||
raise ValueError("Unsupported MLIR tensor element type %r" %
|
||||
(self.dtype,))
|
||||
if self.shape is Unspec:
|
||||
shape_asm = "*"
|
||||
else:
|
||||
shape_asm = "x".join((str(d) if d >= 0 else "?") for d in self.shape)
|
||||
if shape_asm: shape_asm += "x"
|
||||
if shape_asm:
|
||||
shape_asm += "x"
|
||||
return "tensor<%s%s>" % (shape_asm, dtype_asm)
|
||||
|
||||
|
||||
def new_ndarray(self):
|
||||
"""Creates a new ndarray from these params.
|
||||
|
||||
|
@ -458,7 +471,7 @@ class ArrayParams:
|
|||
if not self.is_concrete:
|
||||
raise ValueError("%r is not concrete" % (self,))
|
||||
return np.ndarray(dtype=self.dtype, shape=self.shape)
|
||||
|
||||
|
||||
|
||||
class TypeConstraint:
|
||||
"""Base class for type constraints."""
|
||||
|
@ -481,57 +494,59 @@ class TypeConstraints(list):
|
|||
...
|
||||
AssertionError
|
||||
"""
|
||||
|
||||
def __init__(self, *constraints):
|
||||
if len(constraints) == 1 and not isinstance(
|
||||
constraints[0], ArrayConstraint):
|
||||
if len(constraints) == 1 and not isinstance(constraints[0],
|
||||
ArrayConstraint):
|
||||
constraints = constraints[0]
|
||||
super().__init__(constraints)
|
||||
assert(all(isinstance(c, ArrayConstraint) for c in self))
|
||||
|
||||
assert (all(isinstance(c, ArrayConstraint) for c in self))
|
||||
|
||||
def __repr__(self):
|
||||
return "TypeConstraints(%s)" % (
|
||||
", ".join([repr(c) for c in self]))
|
||||
return "TypeConstraints(%s)" % (", ".join([repr(c) for c in self]))
|
||||
|
||||
def all_of(self, clazz):
|
||||
"""Finds all of the given class."""
|
||||
return [c for c in self if isinstance(c, clazz)]
|
||||
|
||||
|
||||
def one_of(self, clazz):
|
||||
"""Finds at most one constraint of the given class."""
|
||||
found = [c for c in self if isinstance(c, clazz)]
|
||||
if not found: return None
|
||||
if not found:
|
||||
return None
|
||||
if len(found) > 1:
|
||||
raise ValueError("Conflicting constraints. Expected one of %r. Got %r" % (
|
||||
clazz, found))
|
||||
raise ValueError("Conflicting constraints. Expected one of %r. Got %r" %
|
||||
(clazz, found))
|
||||
return found[0]
|
||||
|
||||
|
||||
class ArrayConstraint(TypeConstraint):
|
||||
"""Base class for a constraint on an array's characteristics."""
|
||||
|
||||
def implies_dtype(self):
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def implies_rank(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def implies_dims(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def dims(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def implies_dim_flag(self):
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def dim_flag(self):
|
||||
raise NotImplementedError()
|
||||
|
@ -550,22 +565,22 @@ class DType(ArrayConstraint):
|
|||
AssertionError
|
||||
"""
|
||||
__slots__ = ["_dtype"]
|
||||
|
||||
|
||||
def __init__(self, dtype):
|
||||
super().__init__()
|
||||
assert isinstance(dtype, type)
|
||||
self._dtype = dtype
|
||||
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
|
||||
def implies_dtype(self):
|
||||
return True
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return "DType(%s)" % (self._dtype.__name__,)
|
||||
|
||||
|
||||
|
||||
class Rank(ArrayConstraint):
|
||||
"""Establishes a fixed rank for the array.
|
||||
|
@ -585,21 +600,21 @@ class Rank(ArrayConstraint):
|
|||
|
||||
"""
|
||||
__slots__ = ["_rank"]
|
||||
|
||||
|
||||
def __init__(self, rank):
|
||||
super().__init__()
|
||||
assert(isinstance(rank, int) and rank >= 0)
|
||||
assert (isinstance(rank, int) and rank >= 0)
|
||||
self._rank = rank
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
|
||||
def implies_rank(self):
|
||||
return True
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return "Rank(%d)" % (self._rank)
|
||||
return "Rank(%d)" % (self._rank)
|
||||
|
||||
|
||||
class Shape(ArrayConstraint):
|
||||
|
@ -619,29 +634,29 @@ class Shape(ArrayConstraint):
|
|||
AssertionError
|
||||
"""
|
||||
__slots__ = ["_dims"]
|
||||
|
||||
|
||||
def __init__(self, *dims):
|
||||
super().__init__()
|
||||
assert(all(d is Unspec or (isinstance(d, int) and d >= 0) for d in dims))
|
||||
assert (all(d is Unspec or (isinstance(d, int) and d >= 0) for d in dims))
|
||||
self._dims = tuple(dims)
|
||||
|
||||
@property
|
||||
def dims(self):
|
||||
return self._dims
|
||||
|
||||
|
||||
def implies_dims(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return len(self._dims)
|
||||
|
||||
|
||||
def implies_rank(self):
|
||||
return True
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return "Shape(%s)" % (", ".join(str(d) for d in self._dims))
|
||||
|
||||
|
||||
|
||||
class DimFlagEnum(_LiterateEnum):
|
||||
"""Flag for the kind of DimFlag constraint."""
|
||||
|
@ -661,35 +676,35 @@ class DimFlag(ArrayConstraint):
|
|||
DimFlag(Dynamic, (0, 1))
|
||||
"""
|
||||
__slots__ = ["_flag", "_dims"]
|
||||
|
||||
|
||||
def __init__(self, flag, dims=Unspec):
|
||||
super().__init__()
|
||||
self._flag = DimFlagEnum.parse(flag)
|
||||
if isinstance(dims, int):
|
||||
assert(dims >= 0)
|
||||
assert (dims >= 0)
|
||||
self._dims = (dims,)
|
||||
elif dims is Unspec:
|
||||
self._dims = Unspec
|
||||
else:
|
||||
self._dims = tuple(dims)
|
||||
assert(all(isinstance(d, int) and d >= 0 for d in self._dims))
|
||||
assert (all(isinstance(d, int) and d >= 0 for d in self._dims))
|
||||
|
||||
def implies_dim_flag(self):
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def dim_flag(self):
|
||||
return self._flag, self._dims
|
||||
|
||||
def __repr__(self):
|
||||
return "DimFlag(%r, %r)" % (self._flag, self._dims)
|
||||
|
||||
|
||||
|
||||
|
||||
def DynamicDim(dims=Unspec):
|
||||
"""Dim flag that signals a dimension should be considered dynamic."""
|
||||
return DimFlag(DimFlagEnum.Dynamic, dims)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
|
|
@ -13,14 +13,17 @@ _filecheck_binary_var = "FILECHECK_BINARY"
|
|||
_redirect_io = None
|
||||
_redirect_context = None
|
||||
|
||||
|
||||
def is_filecheck_disabled():
|
||||
return _disable_var in os.environ
|
||||
|
||||
|
||||
def start_filecheck_test():
|
||||
if is_filecheck_disabled():
|
||||
print("WARNING:FileCheck disabled due to", _disable_var,
|
||||
"in the environment", file=sys.stderr)
|
||||
print("WARNING:FileCheck disabled due to",
|
||||
_disable_var,
|
||||
"in the environment",
|
||||
file=sys.stderr)
|
||||
return
|
||||
global _redirect_io
|
||||
global _redirect_context
|
||||
|
@ -30,7 +33,8 @@ def start_filecheck_test():
|
|||
|
||||
|
||||
def end_filecheck_test(main_file):
|
||||
if is_filecheck_disabled(): return
|
||||
if is_filecheck_disabled():
|
||||
return
|
||||
global _redirect_io
|
||||
global _redirect_context
|
||||
_redirect_context.__exit__(None, None, None)
|
||||
|
@ -44,7 +48,7 @@ def end_filecheck_test(main_file):
|
|||
filecheck_args = [filecheck_binary, main_file, "--dump-input=fail"]
|
||||
p = subprocess.Popen(filecheck_args, stdin=subprocess.PIPE)
|
||||
p.communicate(filecheck_input.encode("UTF-8"))
|
||||
sys.exit(p.returncode)
|
||||
sys.exit(p.returncode)
|
||||
|
||||
|
||||
def run_under_filecheck(main_file, callback, disable_filecheck=False):
|
||||
|
@ -60,8 +64,10 @@ def run_under_filecheck(main_file, callback, disable_filecheck=False):
|
|||
disable_filecheck: Whether to disable filecheck.
|
||||
"""
|
||||
if disable_filecheck or is_filecheck_disabled():
|
||||
print("WARNING:FileCheck disabled due to", _disable_var,
|
||||
"in the environment", file=sys.stderr)
|
||||
print("WARNING:FileCheck disabled due to",
|
||||
_disable_var,
|
||||
"in the environment",
|
||||
file=sys.stderr)
|
||||
callback()
|
||||
sys.exit(0)
|
||||
|
||||
|
|
|
@ -4,26 +4,26 @@ import os
|
|||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
TEST_MODULES = (
|
||||
"npcomp.mlir_ir_test",
|
||||
"npcomp.mlir_pass_test",
|
||||
"npcomp.dialect.Basicpy",
|
||||
"npcomp.dialect.Numpy",
|
||||
"npcomp.tracing.context",
|
||||
"npcomp.tracing.mlir_trace",
|
||||
"npcomp.types",
|
||||
"npcomp.exporter",
|
||||
"npcomp.tracing.mlir_trace_test",
|
||||
"npcomp.mlir_ir_test",
|
||||
"npcomp.mlir_pass_test",
|
||||
"npcomp.dialect.Basicpy",
|
||||
"npcomp.dialect.Numpy",
|
||||
"npcomp.tracing.context",
|
||||
"npcomp.tracing.mlir_trace",
|
||||
"npcomp.types",
|
||||
"npcomp.exporter",
|
||||
"npcomp.tracing.mlir_trace_test",
|
||||
)
|
||||
|
||||
# Compute PYTHONPATH for sub processes.
|
||||
DIRSEP = os.path.pathsep
|
||||
LOCAL_PYTHONPATH_COMPONENTS = [
|
||||
# This directory.
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
# The parallel python_native directory (assuming in the build tree).
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "python_native"))
|
||||
# This directory.
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
# The parallel python_native directory (assuming in the build tree).
|
||||
os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "python_native"))
|
||||
]
|
||||
PYTHONPATH = DIRSEP.join(LOCAL_PYTHONPATH_COMPONENTS)
|
||||
if "PYTHONPATH" in os.environ:
|
||||
|
@ -33,9 +33,8 @@ CHILD_ENVIRON["PYTHONPATH"] = PYTHONPATH
|
|||
|
||||
# Configure filecheck.
|
||||
FILECHECK_BINARY = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..", "..", "..", "bin", "FileCheck"))
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "bin",
|
||||
"FileCheck"))
|
||||
if os.path.exists(FILECHECK_BINARY):
|
||||
CHILD_ENVIRON["FILECHECK_BINARY"] = FILECHECK_BINARY
|
||||
else:
|
||||
|
@ -47,14 +46,14 @@ failed = []
|
|||
for test_module in TEST_MODULES:
|
||||
print("--------====== RUNNING %s ======--------" % test_module)
|
||||
try:
|
||||
subprocess.check_call([sys.executable, "-Wignore", "-m", test_module],
|
||||
subprocess.check_call([sys.executable, "-Wignore", "-m", test_module],
|
||||
env=CHILD_ENVIRON)
|
||||
print("--------====== DONE %s ======--------\n" % test_module)
|
||||
passed.append(test_module)
|
||||
except subprocess.CalledProcessError:
|
||||
print("!!!!!!!!====== ERROR %s ======!!!!!!!!\n" % test_module)
|
||||
failed.append(test_module)
|
||||
|
||||
|
||||
print("Done: %d passed, %d failed" % (len(passed), len(failed)))
|
||||
if failed:
|
||||
for test_module in failed:
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
from npcomp.compiler.frontend import *
|
||||
|
||||
|
||||
def binary_expression():
|
||||
a = 1
|
||||
b = 100
|
||||
|
@ -11,6 +12,7 @@ def binary_expression():
|
|||
c = c * 2.0
|
||||
return c
|
||||
|
||||
|
||||
fe = ImportFrontend()
|
||||
try:
|
||||
f = fe.import_global_function(binary_expression)
|
||||
|
|
|
@ -9,9 +9,11 @@ from npcomp.types import *
|
|||
weights = np.random.uniform(size=(16, 4)).astype(np.float32)
|
||||
bias = np.random.uniform(size=(4,)).astype(np.float32)
|
||||
|
||||
|
||||
def constants(a: np.ndarray) -> np.ndarray:
|
||||
return np.dot(a, weights) + bias
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.constants = constants
|
||||
|
|
|
@ -6,20 +6,22 @@ import numpy as np
|
|||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
|
||||
def dot2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
return np.dot(a, b)
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.dot2d = dot2d
|
||||
exp.dot2d.sig.args["a"] += Shape(4, 16)
|
||||
exp.dot2d.sig.args["a"] += DynamicDim(0)
|
||||
exp.dot2d.sig.args["a"] += DType(np.float32)
|
||||
exp.dot2d.sig.args["b"] += Shape(16,32)
|
||||
exp.dot2d.sig.args["b"] += Shape(16, 32)
|
||||
exp.dot2d.sig.args["b"] += DType(np.float32)
|
||||
exp.dot2d.sig.result += Shape(4, 32)
|
||||
exp.dot2d.sig.result += DynamicDim(0)
|
||||
exp.dot2d.sig.result += DType(np.float32)
|
||||
exp.dot2d.sig.result += DType(np.float32)
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.dot2d)
|
||||
|
|
|
@ -6,9 +6,11 @@ import numpy as np
|
|||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
|
||||
def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
return a * b + a + b
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.simple_mul = simple_mul
|
||||
|
@ -19,7 +21,7 @@ exp.simple_mul.sig.args["b"] += Shape(1)
|
|||
exp.simple_mul.sig.args["b"] += DType(np.float32)
|
||||
exp.simple_mul.sig.result += Shape(1, 4)
|
||||
exp.simple_mul.sig.result += DynamicDim(0)
|
||||
exp.simple_mul.sig.result += DType(np.float32)
|
||||
exp.simple_mul.sig.result += DType(np.float32)
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.simple_mul)
|
||||
|
|
|
@ -6,9 +6,11 @@ import numpy as np
|
|||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
|
||||
def slice_array1(a: np.ndarray) -> np.ndarray:
|
||||
return a[1, 2:10:2, 3:4, ..., :, 0]
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.slice_array1 = slice_array1
|
||||
|
|
|
@ -6,12 +6,15 @@ import numpy as np
|
|||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
|
||||
def transpose_attribute(a: np.ndarray) -> np.ndarray:
|
||||
return a.T
|
||||
|
||||
|
||||
def transpose(a: np.ndarray) -> np.ndarray:
|
||||
return np.transpose(a)
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.transpose_attribute = transpose_attribute
|
||||
|
|
|
@ -187,11 +187,10 @@ private:
|
|||
class PyDialectHelper {
|
||||
public:
|
||||
PyDialectHelper(PyContext &context, PyOpBuilder &builder)
|
||||
: context(context), pyOpBuilder(builder) {}
|
||||
: context(context), pyOpBuilder(builder) {}
|
||||
static void bind(py::module m);
|
||||
MLIRContext *getContext() {
|
||||
return pyOpBuilder.getContext();
|
||||
}
|
||||
MLIRContext *getContext() { return pyOpBuilder.getContext(); }
|
||||
|
||||
protected:
|
||||
PyContext &context;
|
||||
PyOpBuilder &pyOpBuilder;
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
#!/bin/bash
|
||||
# Formats all source files.
|
||||
|
||||
set +e
|
||||
td="$(dirname $0)/.."
|
||||
|
||||
function find_cc_sources() {
|
||||
local dir="$1"
|
||||
find "$dir" -name "*.h"
|
||||
find "$dir" -name "*.cpp"
|
||||
}
|
||||
# C/C++ sources.
|
||||
set -o xtrace
|
||||
clang-format -i \
|
||||
$(find_cc_sources include) \
|
||||
$(find_cc_sources lib) \
|
||||
$(find_cc_sources python_native)
|
||||
|
||||
# Python sources.
|
||||
yapf --recursive -i "$td/python" "$td/pytest"
|
Loading…
Reference in New Issue