mirror of https://github.com/llvm/torch-mlir
Totally rework RefE2E tensor to memref flow. (#42)
This now gets the overall "RefE2E" compilation stack to a point that I'm fairly happy with. We simplify it by mostly embracing the "descriptor" view of the world. The overall flow is best understood by reading through the createE2ELoweringPipeline function in lib/E2E/E2E.cpp That function creates a pass pipeline that lowers from "TCF" (which is ~numpy level of abstraction) down to LLVM IR. A brief high-level summary of what happens there: 1. TCF to TCP conversion. This involves reifying error handling in the form of shape constraints. See test/Conversion/TCFToTCP/basic.mlir 2. Lowering shape constraints. This converts shape constraints into eager error-handling code. See test/E2E/lower-shape-constraints.mlir This pass will soon go upstream. Because this lowers to std.assert, some later passes like LowerToNpcomprtABI and LowerToLLVM are updated to properly plumb this through e2e. See test/npcomp-run-mlir/invalid-broadcast.mlir for an execution test that properly aborts in case of an error. 3. Lowering tensors to memrefs. This is done via a series of passes rather than an single mega conversion. Unlike the previous code that mixed in the npcomprt ABI stuff here, it's now a very clean "pure memref" conversion. See test/E2E/lower-*-to-memref.mlir and lib/E2E/TensorToMemref/ Most of the changes are concentrated here. 4. As part of the above, we use the upstream ConvertShapeToStandard for lowering shapes. 5. We lower linalg to loops and lower loops to CFG using upstream passes. 6. Rewrite the "ABI" boundaries of the program to npcomprt data structures (LowerToNpcomprtABI). This mainly affects ABI boundaries and how global tensor constants are represented. One of the major improvements in this commit is that now it's a very clean rewrite that just replaces memrefs on ABI boundaries with !npcomprt.tensor (before there was a get_extent function that is not needed). See test/E2E/lower-to-npcomprt-abi.mlir 7. Lower to LLVM with upstream mlir patterns + some patterns for the npcomprt lowerings. One aspect here that is still a remnant of a non-descriptor-based tensor to memref flow is the BypassShapes + LowerShapedResultsToMemref. BypassShapes wraps the "tensor compute" ops in a tcp.shaped_results (basically a "tie_shape" kind of op), and then LowerShapedResultsToMemref uses those annotations to allocate output buffers while lowering the "tensor compute ops". Note that there are very few "tensor compute" ops currently supported (tcp.add + tcp.broadcast_to), so we just hardcode them in both passes. Realistically, I expect this to go away as we fully embrace the descriptor-based approach for simplicity, so don't look too deep into it.pull/47/head
parent
a74a98094b
commit
75f57b461e
|
@ -36,26 +36,14 @@ def Npcomprt_FromMemrefOp : Npcomprt_Op<"from_memref"> {
|
||||||
let assemblyFormat = "$memref attr-dict `:` type($memref)";
|
let assemblyFormat = "$memref attr-dict `:` type($memref)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_GetExtentOp : Npcomprt_Op<"get_extent"> {
|
|
||||||
let summary = "Gets the specified extent of the tensor";
|
|
||||||
let description = [{
|
|
||||||
Gets the `dim`'th extent of the tensor.
|
|
||||||
}];
|
|
||||||
let arguments = (ins Npcomprt_Tensor:$tensor, I32:$dim);
|
|
||||||
// TODO: Use i32 instead of index so the runtime function
|
|
||||||
// can return std::int32_t.
|
|
||||||
let results = (outs Index:$extent);
|
|
||||||
let assemblyFormat = "$tensor `,` $dim attr-dict";
|
|
||||||
}
|
|
||||||
|
|
||||||
def Npcomprt_AbortIfOp : Npcomprt_Op<"abort_if"> {
|
def Npcomprt_AbortIfOp : Npcomprt_Op<"abort_if"> {
|
||||||
let summary = "Aborts if the predicate is true";
|
let summary = "Aborts if the predicate is true";
|
||||||
let description = [{
|
let description = [{
|
||||||
Aborts if the predicate is true.
|
Aborts if the predicate is true.
|
||||||
}];
|
}];
|
||||||
let arguments = (ins I1:$pred);
|
let arguments = (ins I1:$pred, StrAttr:$msg);
|
||||||
let results = (outs);
|
let results = (outs);
|
||||||
let assemblyFormat = "$pred attr-dict";
|
let assemblyFormat = "$pred `,` $msg attr-dict";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_GlobalOp : Npcomprt_Op<"global", [Symbol]> {
|
def Npcomprt_GlobalOp : Npcomprt_Op<"global", [Symbol]> {
|
||||||
|
|
|
@ -13,6 +13,7 @@ include "npcomp/Dialect/TCP/IR/TCPBase.td"
|
||||||
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
include "mlir/IR/SymbolInterfaces.td"
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
|
||||||
class TCP_Op<string mnemonic, list<OpTrait> traits = []>
|
class TCP_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
|
@ -43,22 +44,6 @@ It is undefined behavior if such a broadcast is not legal.
|
||||||
let results = (outs AnyRankedTensor:$result);
|
let results = (outs AnyRankedTensor:$result);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Ops that need to be factored to a proper home.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// TODO: Find a home for these.
|
|
||||||
|
|
||||||
// TODO: This probably doesn't belong in the tcp dialect.
|
|
||||||
def TCP_AllocMemRefOp : TCP_Op<"alloc_memref", []> {
|
|
||||||
let summary = "Allocates a memref of the given shape.";
|
|
||||||
let description = [{
|
|
||||||
Allocates a memref of the given shape.
|
|
||||||
}];
|
|
||||||
let arguments = (ins Shape_ExtentTensorType:$shape);
|
|
||||||
let results = (outs AnyMemRef:$memref);
|
|
||||||
let assemblyFormat = "$shape attr-dict `:` type($memref)";
|
|
||||||
}
|
|
||||||
|
|
||||||
def TCP_GlobalOp : TCP_Op<"global", [Symbol]> {
|
def TCP_GlobalOp : TCP_Op<"global", [Symbol]> {
|
||||||
let summary = "Represents a global variable";
|
let summary = "Represents a global variable";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -74,6 +59,50 @@ def TCP_GlobalOp : TCP_Op<"global", [Symbol]> {
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Ops related to tensor->memref conversion.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TODO: These ops probably belong in a "TCP on memrefs" dialect analogous
|
||||||
|
// to `lmhlo`
|
||||||
|
|
||||||
|
// TODO: Use TypesMatchWith to verify this better.
|
||||||
|
def TCP_TensorToMemrefOp : TCP_Op<"tensor_to_memref", [NoSideEffect]> {
|
||||||
|
let summary = "Converts a tensor to a memref";
|
||||||
|
let description = [{
|
||||||
|
This op is used to materialize conversions to allow incremental lowering of
|
||||||
|
tensors to memrefs.
|
||||||
|
}];
|
||||||
|
let arguments = (ins AnyRankedTensor:$tensor);
|
||||||
|
let results = (outs AnyMemRef:$memref);
|
||||||
|
let assemblyFormat = "attr-dict $tensor `:` type($tensor) `->` type($memref)";
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Use TypesMatchWith to verify this better.
|
||||||
|
def TCP_MemrefToTensorOp : TCP_Op<"memref_to_tensor", [NoSideEffect]> {
|
||||||
|
let summary = "Converts a memref to a tensor";
|
||||||
|
let description = [{
|
||||||
|
This op is used to materialize conversions to allow incremental lowering of
|
||||||
|
tensors to memrefs.
|
||||||
|
}];
|
||||||
|
let arguments = (ins AnyMemRef:$memref);
|
||||||
|
let results = (outs AnyRankedTensor:$tensor);
|
||||||
|
let assemblyFormat = "attr-dict $memref `:` type($memref) `->` type($tensor)";
|
||||||
|
}
|
||||||
|
|
||||||
|
def TCP_AllocMemRefOp : TCP_Op<"alloc_memref", []> {
|
||||||
|
let summary = "Allocates a memref of the given shape.";
|
||||||
|
let description = [{
|
||||||
|
Allocates a memref of the given shape.
|
||||||
|
|
||||||
|
This op is a convenience for creating a bunch of
|
||||||
|
shape.get_extent + std.alloc ops.
|
||||||
|
}];
|
||||||
|
let arguments = (ins Shape_ExtentTensorType:$shape);
|
||||||
|
let results = (outs AnyMemRef:$memref);
|
||||||
|
let assemblyFormat = "$shape attr-dict `:` type($memref)";
|
||||||
|
}
|
||||||
|
|
||||||
def TCP_GetGlobalMemrefOp : TCP_Op<"get_global_memref"> {
|
def TCP_GetGlobalMemrefOp : TCP_Op<"get_global_memref"> {
|
||||||
let summary = "Obtain a memref pointing at the given global";
|
let summary = "Obtain a memref pointing at the given global";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -85,27 +114,64 @@ def TCP_GetGlobalMemrefOp : TCP_Op<"get_global_memref"> {
|
||||||
let verifier = "return ::verify$cppClass(*this);";
|
let verifier = "return ::verify$cppClass(*this);";
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Change to a more principled error handling mechanism.
|
//===----------------------------------------------------------------------===//
|
||||||
// This op probably doesn't need to exist eventually.
|
// Ops related to shapes.
|
||||||
// This op is also not correctly modeled right now, since it itself doesn't
|
//===----------------------------------------------------------------------===//
|
||||||
// produce the error in practice. The ops like shape.broadcast itself, when
|
// TODO: These belong in a shape-related dialect.
|
||||||
// lowered, immediately produce errors.
|
|
||||||
// TODO: This should eventually be moved to a shape dialect.
|
|
||||||
def TCP_ShapeObserveErrorOp : TCP_Op<"shape_observe_error", []> {
|
|
||||||
let summary = "Observes the fact that a shape might be an error.";
|
|
||||||
let description = [{
|
|
||||||
This op is a structural placeholder that captures a shape such that it
|
|
||||||
is not erased. This will keep around shape computations that are later
|
|
||||||
lowered into eager error handling code.
|
|
||||||
|
|
||||||
The interaction of this op, especially with control flow and side
|
def TCP_ShapedResultsOp : TCP_Op<"shaped_results", [
|
||||||
effecting ops, is not very well-defined, and needs to be worked
|
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||||
on/redesigned.
|
SingleBlockImplicitTerminator<"YieldOp">,
|
||||||
|
RecursiveSideEffects,
|
||||||
|
NoRegionArguments
|
||||||
|
]> {
|
||||||
|
let summary = "Result shape annotation";
|
||||||
|
let description = [{
|
||||||
|
Represents a computation whose outputs have a precomputed shape.
|
||||||
|
The i-th result has the shape described by the i-th operand.
|
||||||
|
|
||||||
|
This op is not isolated from above, so if the region needs any inputs,
|
||||||
|
they can simply be captured. Hence, this op is a
|
||||||
|
"this tensor has this shape" annotation with a slightly different set of
|
||||||
|
tradeoffs than the so-called "tie shape" kinds of operations.
|
||||||
|
In particular, this region-based formulation has the opportunity to
|
||||||
|
capture structural invariants.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```mlir
|
||||||
|
// sincos is an elementwise operation, so it doesn't change the shape.
|
||||||
|
%x = ...
|
||||||
|
%xShape = ...
|
||||||
|
%sin, %cos = tcp.shaped_results %xShape, %xShape {
|
||||||
|
%sin, cos = "some.sincos"(%x)
|
||||||
|
: tensor<?xf32> -> (tensor<?xf32>, tensor<?xf32>)
|
||||||
|
tcp.yield %sin, %cos : tensor<?xf32>, tensor<?xf32>
|
||||||
|
}
|
||||||
|
```
|
||||||
}];
|
}];
|
||||||
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
|
let arguments = (ins
|
||||||
// TODO: ODS seems to create redeclared class members if we remove this,
|
Variadic<Shape_ExtentTensorType>:$resultShapes
|
||||||
// resulting in C++ compilation errors.
|
);
|
||||||
let results = (outs NoneType:$dummy);
|
let results = (outs Variadic<AnyTensor>:$results);
|
||||||
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<
|
||||||
|
"OpBuilder &builder, OperationState &result, TypeRange resultTypes, "
|
||||||
|
"ValueRange resultShapes">
|
||||||
|
];
|
||||||
|
|
||||||
|
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||||
|
let verifier = [{ return ::verify$cppClass(*this); }];
|
||||||
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TCP_YieldOp : TCP_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
|
||||||
|
ParentOneOf<["ShapedResultsOp"]>]> {
|
||||||
|
let summary = "Yield-like terminator for TCP dialect";
|
||||||
|
let description = "See scf.yield";
|
||||||
|
let arguments = (ins Variadic<AnyType>:$operands);
|
||||||
|
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TCP_OPS
|
#endif // TCP_OPS
|
||||||
|
|
|
@ -23,21 +23,18 @@ void registerE2EPasses();
|
||||||
//
|
//
|
||||||
// Pass summaries are in Passes.td.
|
// Pass summaries are in Passes.td.
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLowerBroadcastToToLoopsPass();
|
std::unique_ptr<OperationPass<FuncOp>> createBypassShapesPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>> createLowerShapeConstraintsPass();
|
||||||
createLowerLinalgOnTensorToLinalgOnMemrefPass();
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLowerShapedResultsToMemrefPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLowerStdToMemrefPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createLowerConstantTensorsToMemrefsPass();
|
createLowerConstantTensorsToMemrefPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createResolveShapeOfOpsPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLowerStructuralToMemrefPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createResolveTensorLoadStoreOpsPass();
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLowerLinalgLoopDimOpsPass();
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLowerRankedShapesPass();
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createLowerToNpcomprtABIPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createLowerToNpcomprtABIPass();
|
||||||
|
|
||||||
|
@ -45,8 +42,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLowerAllocMemRefOpsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass();
|
||||||
|
|
||||||
void createLowerToHybridTensorMemRefPipeline(OpPassManager &pm);
|
|
||||||
|
|
||||||
struct E2ELoweringPipelineOptions
|
struct E2ELoweringPipelineOptions
|
||||||
: public PassPipelineOptions<E2ELoweringPipelineOptions> {
|
: public PassPipelineOptions<E2ELoweringPipelineOptions> {
|
||||||
// If this option is true, then perform optimizations.
|
// If this option is true, then perform optimizations.
|
||||||
|
|
|
@ -11,49 +11,65 @@
|
||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def LowerLinalgOnTensorToLinalgOnMemref :
|
def BypassShapes : Pass<"bypass-shapes", "FuncOp"> {
|
||||||
Pass<"lower-linalg-tensor-to-memref", "FuncOp"> {
|
let summary = "Bypass shape calculations around ops";
|
||||||
let summary = "Lowers linalg on tensors to linalg on memrefs";
|
let constructor = "mlir::NPCOMP::createBypassShapesPass()";
|
||||||
let constructor = "mlir::NPCOMP::createLowerLinalgOnTensorToLinalgOnMemrefPass()";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def LowerBroadcastToToLoops :
|
def LowerShapeConstraints : Pass<"lower-shape-constraints", "FuncOp"> {
|
||||||
Pass<"lower-broadcast-to-to-loops", "FuncOp"> {
|
let summary = "Lower shape dialect constructs related to constraints";
|
||||||
let summary = "Lower tcp::BroadcastTo to loops.";
|
let constructor = "mlir::NPCOMP::createLowerShapeConstraintsPass()";
|
||||||
let constructor = "mlir::NPCOMP::createLowerBroadcastToToLoopsPass()";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def LowerConstantTensorsToMemrefs :
|
def LowerShapedResultsToMemref : Pass<"lower-shaped-results-to-memref", "FuncOp"> {
|
||||||
Pass<"lower-constant-tensors-to-memrefs", "ModuleOp"> {
|
let summary = "Lower tcp.shaped_results regions";
|
||||||
let summary = "Lower std.constant of tensor type to hybrid tensor/memref.";
|
let constructor = "mlir::NPCOMP::createLowerShapedResultsToMemrefPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
def LowerStdToMemref : Pass<"lower-std-to-memref", "FuncOp"> {
|
||||||
|
let summary = "Lower std ops to memref";
|
||||||
|
let constructor = "mlir::NPCOMP::createLowerStdToMemrefPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
def LowerConstantTensorsToMemref :
|
||||||
|
Pass<"lower-constant-tensors-to-memref", "ModuleOp"> {
|
||||||
|
let summary = "Lower std.constant of tensor type to memref";
|
||||||
let description = [{
|
let description = [{
|
||||||
This has to be a module pass since it involves creating tcp.global ops.
|
This must be a module pass since it involves creating tcp.global ops.
|
||||||
}];
|
}];
|
||||||
let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefsPass()";
|
let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def ResolveShapeOfOps : Pass<"resolve-shape-of-ops", "FuncOp"> {
|
def LowerStructuralToMemref :
|
||||||
let summary = "Resolve shape.shape_of ops to other shapes.";
|
Pass<"lower-structural-to-memref", "FuncOp"> {
|
||||||
let constructor = "mlir::NPCOMP::createResolveShapeOfOpsPass()";
|
let summary = "Lower structural IR constructs to memref";
|
||||||
}
|
let description = [{
|
||||||
|
Structural constructs include:
|
||||||
|
- control flow ops (both CFG and SCF)
|
||||||
|
- function signatures
|
||||||
|
- TODO: calls
|
||||||
|
An op is "structural" if it doesn't really care about the types it operates
|
||||||
|
on, but the types just have to converted to be consistent.
|
||||||
|
|
||||||
def ResolveTensorLoadStoreOps : Pass<"resolve-tensor-load-store-ops", "FuncOp"> {
|
This pass also cleans up any previous memref<->tensor materializations,
|
||||||
let summary = "Resolve tensor_load/tensor_store ops";
|
finalizing the conversion from tensor to memref.
|
||||||
let constructor = "mlir::NPCOMP::createResolveTensorLoadStoreOpsPass()";
|
}];
|
||||||
}
|
let constructor = "mlir::NPCOMP::createLowerStructuralToMemrefPass()";
|
||||||
|
|
||||||
def LowerLinalgLoopDimOps : Pass<"lower-linalg-loop-dim-ops", "FuncOp"> {
|
|
||||||
let summary = "Lower dim ops introduced by linalg to loops lowering";
|
|
||||||
let constructor = "mlir::NPCOMP::createLowerLinalgLoopDimOpsPass();";
|
|
||||||
}
|
|
||||||
|
|
||||||
def LowerRankedShapes : Pass<"lower-ranked-shapes", "FuncOp"> {
|
|
||||||
let summary = "Lower ranked !shape.shape types to SSA values";
|
|
||||||
let constructor = "mlir::NPCOMP::createLowerRankedShapesPass()";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def LowerToNpcomprtABI : Pass<"lower-to-npcomprt-abi", "ModuleOp"> {
|
def LowerToNpcomprtABI : Pass<"lower-to-npcomprt-abi", "ModuleOp"> {
|
||||||
let summary = "Lower tensors at ABI boundaries to npcomprt dialect";
|
let summary = "Lower constructs requiring runtime support to `npcomprt`";
|
||||||
|
let description = [{
|
||||||
|
We have a specialized dialect `npcomprt` which models our runtime's data
|
||||||
|
structures, and function signatures (and presumably eventually, other
|
||||||
|
ABI boundaries like external calls if we ever support it) will be
|
||||||
|
converted.
|
||||||
|
|
||||||
|
The constructs requiring runtime support are:
|
||||||
|
- function signatures / module metadata
|
||||||
|
- globals
|
||||||
|
- error handling
|
||||||
|
}];
|
||||||
let constructor = "mlir::NPCOMP::createLowerToNpcomprtABIPass()";
|
let constructor = "mlir::NPCOMP::createLowerToNpcomprtABIPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,13 +38,18 @@ public:
|
||||||
}
|
}
|
||||||
Value lhsShape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.lhs());
|
Value lhsShape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.lhs());
|
||||||
Value rhsShape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.rhs());
|
Value rhsShape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.rhs());
|
||||||
|
|
||||||
|
// Create the constraints, and the assuming region.
|
||||||
|
Value witness = rewriter.create<shape::CstrBroadcastableOp>(
|
||||||
|
op.getLoc(), lhsShape, rhsShape);
|
||||||
|
auto assuming = rewriter.create<shape::AssumingOp>(
|
||||||
|
op.getLoc(), ArrayRef<Type>{op.getType()}, witness);
|
||||||
|
|
||||||
|
// Start building the region body.
|
||||||
|
rewriter.createBlock(&assuming.doRegion());
|
||||||
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
|
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
|
||||||
op.getLoc(), rewriter.getType<mlir::shape::ShapeType>(), lhsShape,
|
op.getLoc(), getExtentTensorType(rewriter), lhsShape, rhsShape,
|
||||||
rhsShape,
|
|
||||||
/*error=*/nullptr);
|
/*error=*/nullptr);
|
||||||
rewriter.create<tcp::ShapeObserveErrorOp>(op.getLoc(), broadcastedShape);
|
|
||||||
Value broadcastedExtents = rewriter.create<shape::ToExtentTensorOp>(
|
|
||||||
op.getLoc(), getExtentTensorType(rewriter), broadcastedShape);
|
|
||||||
|
|
||||||
// TODO: It's annoying to do the dynamic broadcast above then
|
// TODO: It's annoying to do the dynamic broadcast above then
|
||||||
// do the static transfer function here. Would be nice if they could
|
// do the static transfer function here. Would be nice if they could
|
||||||
|
@ -55,12 +60,15 @@ public:
|
||||||
auto resultType =
|
auto resultType =
|
||||||
RankedTensorType::get(broadcastedStaticShape, lhsType.getElementType());
|
RankedTensorType::get(broadcastedStaticShape, lhsType.getElementType());
|
||||||
Value lhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
|
Value lhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
|
||||||
op.getLoc(), resultType, op.lhs(), broadcastedExtents);
|
op.getLoc(), resultType, op.lhs(), broadcastedShape);
|
||||||
Value rhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
|
Value rhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
|
||||||
op.getLoc(), resultType, op.rhs(), broadcastedExtents);
|
op.getLoc(), resultType, op.rhs(), broadcastedShape);
|
||||||
Value add = rewriter.create<tcp::AddOp>(op.getLoc(), op.getType(),
|
Value add = rewriter.create<tcp::AddOp>(op.getLoc(), op.getType(),
|
||||||
lhsBroadcasted, rhsBroadcasted);
|
lhsBroadcasted, rhsBroadcasted);
|
||||||
rewriter.replaceOp(op, add);
|
rewriter.create<shape::AssumingYieldOp>(op.getLoc(), add);
|
||||||
|
|
||||||
|
// Finally, replace with the results of the shape.assuming
|
||||||
|
rewriter.replaceOp(op, assuming.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -7,14 +7,44 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||||
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP::tcp;
|
using namespace mlir::NPCOMP::tcp;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TCPDialect Dialect Interfaces
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct TCPInlinerInterface : public DialectInlinerInterface {
|
||||||
|
using DialectInlinerInterface::DialectInlinerInterface;
|
||||||
|
bool isLegalToInline(Region *dest, Region *src,
|
||||||
|
BlockAndValueMapping &valueMapping) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool isLegalToInline(Operation *, Region *,
|
||||||
|
BlockAndValueMapping &) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
void handleTerminator(Operation *op,
|
||||||
|
ArrayRef<Value> valuesToRepl) const final {
|
||||||
|
auto retValOp = dyn_cast<YieldOp>(op);
|
||||||
|
if (!retValOp)
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
|
||||||
|
std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
void TCPDialect::initialize() {
|
void TCPDialect::initialize() {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.cpp.inc"
|
#include "npcomp/Dialect/TCP/IR/TCPOps.cpp.inc"
|
||||||
>();
|
>();
|
||||||
|
addInterfaces<TCPInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,11 +9,82 @@
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
using namespace mlir::NPCOMP::tcp;
|
using namespace mlir::NPCOMP::tcp;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TensorToMemrefOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
if (auto memrefToTensor = tensor().getDefiningOp<tcp::MemrefToTensorOp>())
|
||||||
|
return memrefToTensor.memref();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ShapedResultsOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void ShapedResultsOp::build(OpBuilder &builder, OperationState &result,
|
||||||
|
TypeRange resultTypes, ValueRange resultShapes) {
|
||||||
|
result.addOperands(resultShapes);
|
||||||
|
result.addTypes(resultTypes);
|
||||||
|
(void)result.addRegion();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyShapedResultsOp(ShapedResultsOp op) {
|
||||||
|
if (op.getNumOperands() != op.getNumResults())
|
||||||
|
return op.emitError() << "number of operands must equal number of results";
|
||||||
|
if (op.getNumOperands() == 0)
|
||||||
|
return op.emitError() << "must have at least one operand/result";
|
||||||
|
return RegionBranchOpInterface::verifyTypes(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printShapedResultsOp(OpAsmPrinter &p, ShapedResultsOp &op) {
|
||||||
|
p << "tcp.shaped_results ";
|
||||||
|
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
||||||
|
p.printOperands(op.getOperands());
|
||||||
|
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
|
||||||
|
p << " : ";
|
||||||
|
interleaveComma(op.getOperandTypes(), p);
|
||||||
|
p << " -> ";
|
||||||
|
interleaveComma(op.getResultTypes(), p);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseShapedResultsOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
|
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||||
|
return failure();
|
||||||
|
SmallVector<OpAsmParser::OperandType, 6> operands;
|
||||||
|
if (parser.parseOperandList(operands))
|
||||||
|
return failure();
|
||||||
|
auto *body = result.addRegion();
|
||||||
|
if (parser.parseRegion(*body, llvm::None, llvm::None))
|
||||||
|
return failure();
|
||||||
|
SmallVector<Type, 6> inputTypes;
|
||||||
|
if (parser.parseColonTypeList(inputTypes))
|
||||||
|
return failure();
|
||||||
|
if (parser.resolveOperands(operands, inputTypes, parser.getNameLoc(),
|
||||||
|
result.operands))
|
||||||
|
return failure();
|
||||||
|
if (parser.parseArrowTypeList(result.types))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ShapedResultsOp::getSuccessorRegions(
|
||||||
|
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
|
if (index.hasValue())
|
||||||
|
regions.push_back(RegionSuccessor(getResults()));
|
||||||
|
else
|
||||||
|
regions.push_back(RegionSuccessor(&body()));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GlobalOp
|
// GlobalOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
#include "npcomp/E2E/E2E.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
static bool isSimpleElementwiseLinalgGeneric(linalg::GenericOp op) {
|
||||||
|
// Only handle generic ops where all operands and results are tensors.
|
||||||
|
if (!llvm::all_of(op.getOperandTypes(),
|
||||||
|
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!llvm::all_of(op.getResultTypes(),
|
||||||
|
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Loosen restrictions on indexing maps.
|
||||||
|
// This will require more principled handling of shape reification
|
||||||
|
// earlier in the compilation stack, as in general output shapes of a
|
||||||
|
// linalg.generic cannot be inferred easily.
|
||||||
|
// See:
|
||||||
|
// https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866
|
||||||
|
if (!llvm::all_of(op.indexing_maps(), [](Attribute map) {
|
||||||
|
return map.cast<AffineMapAttr>().getValue().isIdentity();
|
||||||
|
})) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!llvm::all_of(op.iterator_types(), [](Attribute str) {
|
||||||
|
return str.cast<StringAttr>().getValue() ==
|
||||||
|
getParallelIteratorTypeName();
|
||||||
|
})) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Don't just open-code all shape transfer functions here.
|
||||||
|
// Note: for now, we can't just rely on an OpInterface, since OpInterfaces
|
||||||
|
// cannot be "externally applied". E.g. we can't change the definition of
|
||||||
|
// linalg::GenericOp.
|
||||||
|
static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
||||||
|
OpBuilder builder(&op);
|
||||||
|
if (auto linalgGeneric = dyn_cast<linalg::GenericOp>(op)) {
|
||||||
|
// TODO: Avoid this excessive restriction.
|
||||||
|
// This will require more principled handling of the lowering to
|
||||||
|
// linalg.generic -- it should generally happen after this pass, becaue in
|
||||||
|
// general output shapes of a linalg.generic cannot be inferred easily. See:
|
||||||
|
// https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866
|
||||||
|
if (!isSimpleElementwiseLinalgGeneric(linalgGeneric))
|
||||||
|
return {};
|
||||||
|
// All shapes of all operands and results are the same for now. So
|
||||||
|
// arbitrarily pick the first operand.
|
||||||
|
return {builder.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand(0))};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto broadcastTo = dyn_cast<tcp::BroadcastToOp>(op)) {
|
||||||
|
return {broadcastTo.shape()};
|
||||||
|
}
|
||||||
|
|
||||||
|
// No shape transfer function.
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// TODO: There is a coupling between this pass and LowerShapedResults.
|
||||||
|
// Any op that is wrapped in tcp.shaped_results here needs to be known how to be
|
||||||
|
// lowered by LowerShapedResults.
|
||||||
|
class BypassShapes : public BypassShapesBase<BypassShapes> {
|
||||||
|
void runOnOperation() {
|
||||||
|
auto func = getOperation();
|
||||||
|
func.walk([&](Operation *opPtr) {
|
||||||
|
Operation &op = *opPtr;
|
||||||
|
SmallVector<Value, 6> resultShapes = bypassResultShapes(op);
|
||||||
|
if (resultShapes.empty())
|
||||||
|
return;
|
||||||
|
// We have result shapes, so wrap this op in a tcp.shaped_results op.
|
||||||
|
OpBuilder builder(&op);
|
||||||
|
auto shapedResults = builder.create<tcp::ShapedResultsOp>(
|
||||||
|
op.getLoc(), op.getResultTypes(), resultShapes);
|
||||||
|
op.replaceAllUsesWith(shapedResults);
|
||||||
|
|
||||||
|
// Move the op into the body and yield the results.
|
||||||
|
Block *body = builder.createBlock(&shapedResults.body());
|
||||||
|
op.moveBefore(body, body->end());
|
||||||
|
builder.create<tcp::YieldOp>(op.getLoc(), op.getResults());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> mlir::NPCOMP::createBypassShapesPass() {
|
||||||
|
return std::make_unique<BypassShapes>();
|
||||||
|
}
|
|
@ -1,9 +1,13 @@
|
||||||
add_mlir_library(NPCOMPE2E
|
add_mlir_library(NPCOMPE2E
|
||||||
|
BypassShapes.cpp
|
||||||
E2E.cpp
|
E2E.cpp
|
||||||
LowerRankedShapes.cpp
|
LowerShapeConstraints.cpp
|
||||||
LowerToHybridTensorMemRef.cpp
|
|
||||||
LowerToLLVM.cpp
|
LowerToLLVM.cpp
|
||||||
LowerToNpcomprtABI.cpp
|
LowerToNpcomprtABI.cpp
|
||||||
|
TensorToMemref/LowerConstantTensorsToMemref.cpp
|
||||||
|
TensorToMemref/LowerShapedResultsToMemref.cpp
|
||||||
|
TensorToMemref/LowerStdToMemref.cpp
|
||||||
|
TensorToMemref/LowerStructuralToMemref.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SRC_DIR}/include/npcomp/E2E
|
${PROJECT_SRC_DIR}/include/npcomp/E2E
|
||||||
|
@ -17,9 +21,10 @@ add_mlir_library(NPCOMPE2E
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRLinalgOps
|
MLIRLinalgOps
|
||||||
|
MLIRSCFToStandard
|
||||||
|
MLIRShapeToStandard
|
||||||
MLIRStandardOps
|
MLIRStandardOps
|
||||||
MLIRStandardToLLVM
|
MLIRStandardToLLVM
|
||||||
MLIRSCFToStandard
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mlir_check_all_link_libraries(NPCOMPE2E)
|
mlir_check_all_link_libraries(NPCOMPE2E)
|
||||||
|
|
404
lib/E2E/E2E.cpp
404
lib/E2E/E2E.cpp
|
@ -10,39 +10,13 @@
|
||||||
// At the moment, the first "end" is TCF ops and the second "end" is `llvm`
|
// At the moment, the first "end" is TCF ops and the second "end" is `llvm`
|
||||||
// dialect suitable for jitting.
|
// dialect suitable for jitting.
|
||||||
//
|
//
|
||||||
// This is still work-in-progress and not even working end-to-end for the
|
|
||||||
// most trivial examples, see TODO's in createE2ELoweringPipeline for the
|
|
||||||
// status.
|
|
||||||
//
|
|
||||||
// As a pragmatic matter, I generally tend to drop random passes and stuff
|
|
||||||
// inside this top-level file and then shard it out to separate files once
|
|
||||||
// a clear organizing principle arises (to avoid premature organizing).
|
|
||||||
//
|
|
||||||
// Once we have end-to-end functionality working, we will throw
|
|
||||||
// increasingly complex programs and augment this pass pipeline, likely
|
|
||||||
// introducing better structure and more clear principles.
|
|
||||||
//
|
|
||||||
// I wish I had a clear view of how this pipeline should perfectly layer
|
|
||||||
// ahead of time, but unfortunately I don't since it crosses half a dozen
|
|
||||||
// abstraction levels / dialects, some of which have no precedent that I'm
|
|
||||||
// aware of (dynamic-shape-aware, error-aware TCF -> TCP) or very little
|
|
||||||
// (tensor -> memref/buffer with dynamic shapes, shape -> SSA values for
|
|
||||||
// ranked shape extents).
|
|
||||||
//
|
|
||||||
// Right now there's lots of stuff in this pipeline that is limited to
|
|
||||||
// special cases where I have an idea of how to elaborate it to the general
|
|
||||||
// case. The priority is getting and end-to-end flow working that we can
|
|
||||||
// grow out organically to a curriculum of more complex cases, elaborating
|
|
||||||
// on the design principles and layering as necessitated by the curriculum.
|
|
||||||
//
|
|
||||||
// This should be fun :)
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/E2E/E2E.h"
|
#include "npcomp/E2E/E2E.h"
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||||
|
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
#include "mlir/Dialect/Linalg/Passes.h"
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
|
@ -75,193 +49,6 @@ void mlir::NPCOMP::registerE2EPasses() {
|
||||||
mlir::PassPipelineRegistration<E2ELoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<E2ELoweringPipelineOptions>(
|
||||||
"e2e-lowering-pipeline", "E2E lowering pipeline.",
|
"e2e-lowering-pipeline", "E2E lowering pipeline.",
|
||||||
mlir::NPCOMP::createE2ELoweringPipeline);
|
mlir::NPCOMP::createE2ELoweringPipeline);
|
||||||
mlir::PassPipelineRegistration<>(
|
|
||||||
"lower-to-hybrid-tensor-memref-pipeline",
|
|
||||||
"Pipeline lowering to hybrid tensor/memref.",
|
|
||||||
mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ResolveShapeOfOps
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class ResolveShapeOfOpViaAllocMemRefOp
|
|
||||||
: public OpRewritePattern<shape::ShapeOfOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
if (auto tensorLoad = llvm::dyn_cast_or_null<TensorLoadOp>(
|
|
||||||
op.getOperand().getDefiningOp())) {
|
|
||||||
if (auto allocMemRef = llvm::dyn_cast_or_null<tcp::AllocMemRefOp>(
|
|
||||||
tensorLoad.getOperand().getDefiningOp())) {
|
|
||||||
rewriter.replaceOp(op, allocMemRef.getOperand());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class ResolveShapeOfOps : public ResolveShapeOfOpsBase<ResolveShapeOfOps> {
|
|
||||||
void runOnOperation() {
|
|
||||||
auto func = getOperation();
|
|
||||||
auto *context = &getContext();
|
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
patterns.insert<ResolveShapeOfOpViaAllocMemRefOp>(context);
|
|
||||||
ConversionTarget target(*context);
|
|
||||||
// target.addIllegalOp<shape::ShapeOfOp>();
|
|
||||||
target.addDynamicallyLegalOp<shape::ShapeOfOp>(
|
|
||||||
[](shape::ShapeOfOp shapeOf) {
|
|
||||||
// Only shape.shape_of on arguments to the entry block are legal at
|
|
||||||
// this point. They are assumed to be resolved eventually via
|
|
||||||
// the lowering of the tensor argument to some ABI that has the
|
|
||||||
// relevant information available. But this is ABI dependent.
|
|
||||||
// TODO: Convince myself that we never need to deal with general
|
|
||||||
// block operands, or implement general handling of block
|
|
||||||
// operands (need to add new bb operands of !shape.shape type).
|
|
||||||
if (auto blockArg = shapeOf.getOperand().dyn_cast<BlockArgument>()) {
|
|
||||||
Block *block = blockArg.getOwner();
|
|
||||||
if (&block->getParent()->front() == block) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
|
||||||
mlir::NPCOMP::createResolveShapeOfOpsPass() {
|
|
||||||
return std::make_unique<ResolveShapeOfOps>();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ResolveTensorLoadStoreOps
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class ReplaceTensorStoreWithCopyPattern
|
|
||||||
: public OpRewritePattern<TensorStoreOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(TensorStoreOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto tensorLoad =
|
|
||||||
llvm::dyn_cast_or_null<TensorLoadOp>(op.tensor().getDefiningOp());
|
|
||||||
if (!tensorLoad)
|
|
||||||
return rewriter.notifyMatchFailure(op, "not fed by tensor_load op");
|
|
||||||
rewriter.replaceOpWithNewOp<linalg::CopyOp>(op, tensorLoad.memref(),
|
|
||||||
op.memref());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class EraseUnusedTensorLoadOpPattern : public OpRewritePattern<TensorLoadOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(TensorLoadOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
if (!op.use_empty())
|
|
||||||
return rewriter.notifyMatchFailure(op, "has uses");
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class ResolveTensorLoadStoreOps
|
|
||||||
: public ResolveTensorLoadStoreOpsBase<ResolveTensorLoadStoreOps> {
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
||||||
registry.insert<linalg::LinalgDialect>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
auto func = getOperation();
|
|
||||||
auto *context = &getContext();
|
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
patterns.insert<ReplaceTensorStoreWithCopyPattern>(context);
|
|
||||||
patterns.insert<EraseUnusedTensorLoadOpPattern>(context);
|
|
||||||
ConversionTarget target(*context);
|
|
||||||
target.addLegalDialect<linalg::LinalgDialect>();
|
|
||||||
target.addDynamicallyLegalOp<TensorLoadOp>([](TensorLoadOp op) {
|
|
||||||
for (auto user : op.getResult().getUsers())
|
|
||||||
if (!isa<ReturnOp>(user))
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
target.addDynamicallyLegalOp<TensorStoreOp>(
|
|
||||||
[](TensorStoreOp op) { return op.tensor().isa<BlockArgument>(); });
|
|
||||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
|
||||||
mlir::NPCOMP::createResolveTensorLoadStoreOpsPass() {
|
|
||||||
return std::make_unique<ResolveTensorLoadStoreOps>();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// LowerLinalgLoopDimOps
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerLinalgLoopDimOp : public OpRewritePattern<DimOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(DimOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto allocMemRef = op.memrefOrTensor().getDefiningOp<tcp::AllocMemRefOp>();
|
|
||||||
if (!allocMemRef)
|
|
||||||
return rewriter.notifyMatchFailure(op, "could not find alloc_memref");
|
|
||||||
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(
|
|
||||||
op, rewriter.getIndexType(), allocMemRef.shape(), op.index());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerLinalgLoopDimOps
|
|
||||||
: public LowerLinalgLoopDimOpsBase<LowerLinalgLoopDimOps> {
|
|
||||||
void runOnOperation() {
|
|
||||||
auto func = getOperation();
|
|
||||||
auto *context = &getContext();
|
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
patterns.insert<LowerLinalgLoopDimOp>(context);
|
|
||||||
ConversionTarget target(*context);
|
|
||||||
target.addDynamicallyLegalOp<DimOp>([](DimOp op) -> bool {
|
|
||||||
// TODO: We only need this because we use `dim` ops for the memref
|
|
||||||
// ABI. Once we layer that out into our own runtime types, we can
|
|
||||||
// remove this.
|
|
||||||
return !op.memrefOrTensor().getDefiningOp<tcp::AllocMemRefOp>();
|
|
||||||
});
|
|
||||||
target.addLegalOp<shape::GetExtentOp>();
|
|
||||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
|
||||||
mlir::NPCOMP::createLowerLinalgLoopDimOpsPass() {
|
|
||||||
return std::make_unique<LowerLinalgLoopDimOps>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -327,87 +114,102 @@ mlir::NPCOMP::createLowerAllocMemRefOpsPass() {
|
||||||
|
|
||||||
void mlir::NPCOMP::createE2ELoweringPipeline(
|
void mlir::NPCOMP::createE2ELoweringPipeline(
|
||||||
OpPassManager &pm, const E2ELoweringPipelineOptions &options) {
|
OpPassManager &pm, const E2ELoweringPipelineOptions &options) {
|
||||||
// Input IR is TCF ops.
|
// This "end to end" lowering pipline loewrings from approximately the "numpy"
|
||||||
|
// level of abstraction (which is a dialect we call "TCF", or "Tensor Compute
|
||||||
|
// Frontend") all the way down to LLVM IR.
|
||||||
|
|
||||||
// Convert to TCP.
|
// Convert from TCF to TCP.
|
||||||
|
//
|
||||||
|
// TCF has implicit broadcasting, and issues errors "inside the ops" in the
|
||||||
|
// case of invalid broadcasts.
|
||||||
|
//
|
||||||
|
// TCP does not. So we need to reify the broadcasting and error checking.
|
||||||
pm.addPass(createConvertTCFToTCPPass());
|
pm.addPass(createConvertTCFToTCPPass());
|
||||||
|
|
||||||
// TODO: Do tcp.island coarsening here.
|
|
||||||
|
|
||||||
// TODO: This is approximately the place that we would fork off when
|
|
||||||
// lowering to IREE.
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// Tensor to buffer (memref) conversion.
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// Convert tcp ops to Linalg where possible, as we want generic linalg
|
// Convert tcp ops to Linalg where possible, as we want generic linalg
|
||||||
// tensor->memref to do most of the mechanical work of rewriting ops in
|
// tensor->memref to do most of the mechanical work of rewriting ops in
|
||||||
// terms of tensors to ops in terms of memrefs (since it is easy on that
|
// terms of tensors to ops in terms of memrefs (since it is easy on that
|
||||||
// representation).
|
// representation).
|
||||||
|
// TODO: Does this make sense? Should we instead go to an "TCP on buffers" and
|
||||||
|
// only lower to linalg at the buffer level?
|
||||||
pm.addPass(createConvertTCPToLinalgPass());
|
pm.addPass(createConvertTCPToLinalgPass());
|
||||||
|
|
||||||
// Lower to hybrid tensor/memref
|
// For operations with a shape transfer function, explicitly bypass their
|
||||||
|
// shape computations with tcp.shaped_results ops.
|
||||||
//
|
//
|
||||||
// The hybrid tensor/memref representation guarantees:
|
// Right now, our lowering flow depends heavily on descriptors, so technically
|
||||||
// - every use of a tensor is a tensor_store op writing it into a memref
|
// we don't need to bypass shapes -- we can just splat out the shape
|
||||||
// - every def of a tensor is a tensor_load op loading out of some memref.
|
// calculations when lowering the ops themselves. However, this design keeps
|
||||||
// - every memref is allocated by a `tcp.alloc_memref(%shape)` op.
|
// the door open to various future directions, and is an interesting example
|
||||||
// - every memref is only ever writen once, and never mutated
|
// in its own right.
|
||||||
//
|
//
|
||||||
// Exceptions: "boundaries" such as function arguments and island
|
// For example, if we want to lower to command-buffer style API's like Vulkan,
|
||||||
// live-outs.
|
// then we need (for correctness) to bypass the shapes (actually,
|
||||||
|
// something more sophisticated than just that) if we want to do command
|
||||||
|
// buffer formation while we are still on tensors (e.g. to record workgroup
|
||||||
|
// sizes). We might not care about pursuing that direction here though. So
|
||||||
|
// consider this pass as purely advisory now.
|
||||||
//
|
//
|
||||||
// Or, another way to say this: the hybrid tensor/memref representation
|
// One case where we might still be interested in this is dealing with
|
||||||
// doesn't attempt to eliminate the original tensors from the program,
|
// linalg.generic ops and other types of "fusions" that have shape transfer
|
||||||
// but rather locally expands operations on tensors to be small subgraphs
|
// functions that are not easily reconstructible and thus we have to capture
|
||||||
// with tensor_load/tensor_store at the boundaries, leaving enough
|
// the shape transfer functions earlier in the pipeline.
|
||||||
// invariants that we can clean it up later.
|
pm.addPass(createBypassShapesPass());
|
||||||
|
|
||||||
|
// Lower shape constraints before we enter tensor->memref conversion.
|
||||||
|
// That is, we expand witnesses + shape.assuming + shape.cstr_* ops to
|
||||||
|
// eager error handling code that doesn't have witnesses or shape.assuming.
|
||||||
|
pm.addPass(createLowerShapeConstraintsPass());
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Lower the `tensor` type to `memref`.
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// We make a conscious effort here to do this as a sequence of separate passes
|
||||||
|
// rather than a single mega dialect conversion pass.
|
||||||
//
|
//
|
||||||
// The core invariants that are needed for this step are that the
|
// This means that intermediate steps have source/target materializations
|
||||||
// tensor-level ops we receive as input have a way of calculating the
|
// (tcp.memref_to_tensor / tcp.tensor_to_memref) in the IR.
|
||||||
// sizes for their outputs. This is equivalent to saying that
|
|
||||||
// `shape.shape_of` on the result of an op must be calculatable in terms
|
|
||||||
// of the shapes of the inputs to the op.
|
|
||||||
createLowerToHybridTensorMemRefPipeline(pm);
|
|
||||||
|
|
||||||
// At this point, the invariants of the hybrid tensor/memref
|
// Lower ops enclosed in tcp.shaped_results regions.
|
||||||
// representation allow us to resolve `shape.shape_of` ops to shape
|
// For now, this is covering the "tensor compute" ops like tcp.add /
|
||||||
// computations earlier in the program. Specifically, every
|
// tcp.broadcast_to (the former being handled via a special subset of
|
||||||
// `shape.shape_of` can be resolved to the shape argument to the
|
// linalg.generic) -- we only handle those two, so having an isolated pass
|
||||||
// corresponding `tcp.alloc_memref` op of the tensor_load that produced
|
// that hardcodes all of them is fine -- eventually we might want something
|
||||||
// that tensor.
|
// more pluggable. The exact interface for this pluggability depends on
|
||||||
pm.addPass(createResolveShapeOfOpsPass());
|
// what design we want to settle on for bypassing shape computations.
|
||||||
|
pm.addPass(createLowerShapedResultsToMemrefPass());
|
||||||
// Now, we use the hybrid tensor/memref invariants to replace the
|
// Lower tensor-valued constants to tcp.global.
|
||||||
// tensor_store ops with memref copy operations and erase the
|
pm.addPass(createLowerConstantTensorsToMemrefPass());
|
||||||
// tensor_load/tensor_store ops.
|
// tcp::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument. We
|
||||||
pm.addPass(createResolveTensorLoadStoreOpsPass());
|
// need to resolve this to std.alloc which takes individual extents.
|
||||||
|
pm.addPass(createLowerAllocMemRefOpsPass());
|
||||||
// At this point, the IR is in a form where there are no tensor ops
|
// Lower shape ops to std.
|
||||||
// (except tensor_store's of arguments, tensor_load's of returns, and
|
// TODO: This should in principle be moved before tensor->memref conversion.
|
||||||
// constants).
|
// But some of the tensor->memref lowerings above use shape.get_extent. For
|
||||||
|
// example, when lowering a broadcast, we need to get an extent from its shape
|
||||||
|
// operand to allocate the output.
|
||||||
|
pm.addPass(createConvertShapeToStandardPass());
|
||||||
|
// Lower std ops to memref.
|
||||||
|
// This includes ops like extract_element.
|
||||||
|
pm.addPass(createLowerStdToMemrefPass());
|
||||||
|
// Lower control flow and other "structural" ops.
|
||||||
//
|
//
|
||||||
// This is a reasonable representation for doing buffer assignment.
|
// These ops are generally not sensitive to the types that they operate on
|
||||||
// TODO: Do buffer assignment here.
|
// (e.g. the types of block operands, function arguments, etc.). But they all
|
||||||
|
// need to be converted consistently. So it makes sense to do this as the
|
||||||
|
// final step of conversion, which also finalizes the elimination of all
|
||||||
|
// stray source/target materializations introduced by the incremental
|
||||||
|
// tensor->memref lowering.
|
||||||
|
//
|
||||||
|
// This completes conversion to memref. There are no `tensor`'s after
|
||||||
|
// this point.
|
||||||
|
pm.addPass(createLowerStructuralToMemrefPass());
|
||||||
|
|
||||||
// We need to finalize the removal of tensors from the program. To do
|
// TODO: Do buffer assignment. We should be able to just drop in the upstream
|
||||||
// that, we need to interface with a runtime ABI.
|
// pass?
|
||||||
// We have a specialized dialect npcomprt which models the runtime data
|
|
||||||
// structures, and function signatures (and presumably eventually, other
|
|
||||||
// ABI boundaries like external calls if we ever support it) will be
|
|
||||||
// converted.
|
|
||||||
pm.addPass(createLowerToNpcomprtABIPass());
|
|
||||||
|
|
||||||
// TODO: Might want a different kind of island to better represent this.
|
// At this point, we have lots of loose stuff floating around from lowering,
|
||||||
// This island op would explicitly capture all tensors as inputs, and it
|
// so it's a good time to do some general cleanups.
|
||||||
// would establish a more formalized ABI with the interior of the body
|
|
||||||
// region (much like IREE does with dispatch regions). For now, we are
|
|
||||||
// planning on just inlining the islands, so there is little value in
|
|
||||||
// doing this, but we should look at the layering aspects here later.
|
|
||||||
|
|
||||||
// At this point, we have loose shape calculations floating around, so
|
|
||||||
// it's a good time to do some general cleanups.
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
pm.addPass(createCanonicalizerPass());
|
pm.addPass(createCanonicalizerPass());
|
||||||
pm.addPass(createCSEPass());
|
pm.addPass(createCSEPass());
|
||||||
|
@ -423,47 +225,9 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
|
||||||
// TODO: Do some linalg optimizations like tiling here.
|
// TODO: Do some linalg optimizations like tiling here.
|
||||||
pm.addPass(createConvertLinalgToLoopsPass());
|
pm.addPass(createConvertLinalgToLoopsPass());
|
||||||
|
|
||||||
// Lowering linalg to loops introduces `dim` ops. Here we look through
|
|
||||||
// use-def chains to find `tcp.alloc_memref` ops that we can get a shape
|
|
||||||
// out of.
|
|
||||||
// Currently, this is trivial, but after more aggressive buffer
|
|
||||||
// allocation optimizations or linalg tiling this step will need to look
|
|
||||||
// through slices/views and stuff.
|
|
||||||
// TODO: It seems that "dim on memrefs" is being resolved in a
|
|
||||||
// fundamentally different way from "dim on tensors" is earlier in the
|
|
||||||
// pipeline. Investigate.
|
|
||||||
// We could somewhat unify them by having enough folding patterns for
|
|
||||||
// `shape.shape_of`. Above, we used the pattern
|
|
||||||
// "shape_of(tensor_load(alloc_memref(%shape))) -> %shape". Here we are
|
|
||||||
// doing `shape_of(alloc_memref(%shape)) -> %shape". It seems
|
|
||||||
// dangerous to just have a pile of these patterns and hope that one of
|
|
||||||
// them resolves things at any given point. So what we do is to use a
|
|
||||||
// very narrowly focused set of patterns that exploit just the invariants
|
|
||||||
// at each point.
|
|
||||||
pm.addPass(createLowerLinalgLoopDimOpsPass());
|
|
||||||
|
|
||||||
// AllocMemRefOp's take a `!shape.shape` as an argument. We need to
|
|
||||||
// resolve this to individual extents before we lower ranked shapes.
|
|
||||||
pm.addPass(createLowerAllocMemRefOpsPass());
|
|
||||||
|
|
||||||
// Lower shapes to SSA values.
|
|
||||||
// This replaces all tcf::GetExtentOp's with explicit SSA computations
|
|
||||||
// for the scalar extent. This requires shapes which are ranked. Any
|
|
||||||
// unranked shapes will need to be handled by a runtime shape type,
|
|
||||||
// though we don't currently support that.
|
|
||||||
//
|
|
||||||
// At this point, in the case of programs with only ranked shapes, all
|
|
||||||
// !shape.shape types will be gone.
|
|
||||||
// TODO: Better demarcate the invariants here, such as having a verifier
|
|
||||||
// pass that checks no !shape.shape types left.
|
|
||||||
pm.addPass(createLowerRankedShapesPass());
|
|
||||||
|
|
||||||
// Run a some cleanups.
|
// Run a some cleanups.
|
||||||
// TODO: Some folding and DCE of dangling ops is still needed here. Once the
|
|
||||||
// invariants above are tightened up, the canonicalize should be moved into
|
|
||||||
// the optimize block.
|
|
||||||
pm.addPass(createCanonicalizerPass());
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
|
pm.addPass(createCanonicalizerPass());
|
||||||
pm.addPass(createCSEPass());
|
pm.addPass(createCSEPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -474,9 +238,13 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
|
||||||
// Convert scf to std control flow in preparation for going to LLVM.
|
// Convert scf to std control flow in preparation for going to LLVM.
|
||||||
pm.addPass(createLowerToCFGPass());
|
pm.addPass(createLowerToCFGPass());
|
||||||
|
|
||||||
|
// Convert functions signatures and other constructs that interface with the
|
||||||
|
// runtime to the `npcomprt` dialect.
|
||||||
|
pm.addPass(createLowerToNpcomprtABIPass());
|
||||||
|
|
||||||
// Finally, convert to LLVM dialect using our custom LowerToLLVM pass
|
// Finally, convert to LLVM dialect using our custom LowerToLLVM pass
|
||||||
// which reuses the upstream patterns and gives us a place to add our own
|
// which reuses the upstream patterns and gives us a place to add our own
|
||||||
// patterns for any custom ops and types we wish to lower.
|
// patterns for our own custom ops like the npcomprt ops.
|
||||||
pm.addPass(createLowerToLLVMPass());
|
pm.addPass(createLowerToLLVMPass());
|
||||||
|
|
||||||
// Although LLVM will clean everything up eventually, for the sake of IR
|
// Although LLVM will clean everything up eventually, for the sake of IR
|
||||||
|
|
|
@ -1,273 +0,0 @@
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "PassDetail.h"
|
|
||||||
#include "npcomp/E2E/E2E.h"
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
|
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace mlir::NPCOMP;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerConstShapeOp : public OpConversionPattern<shape::ConstShapeOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(shape::ConstShapeOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
auto extents = llvm::to_vector<6>(llvm::map_range(
|
|
||||||
op.shape().getValues<int64_t>(), [&](int64_t extent) -> Value {
|
|
||||||
return rewriter.create<ConstantIndexOp>(op.getLoc(), extent);
|
|
||||||
}));
|
|
||||||
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
|
|
||||||
op, rewriter.getType<shape::ShapeType>(), extents);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Given an operand that is either a Shape or Extent Tensor, returns an
|
|
||||||
// Extent Tensor or nullptr if this cannot be locally determined.
|
|
||||||
// The return value, if !nullptr, will be a 1D RankedTensorType (with possibly
|
|
||||||
// unknown element).
|
|
||||||
Value findExtentsFromShape(Value operand, bool requireKnownRank) {
|
|
||||||
if (auto tensorType = operand.getType().dyn_cast<RankedTensorType>()) {
|
|
||||||
if (tensorType.getRank() == 1 &&
|
|
||||||
(!requireKnownRank || tensorType.hasStaticShape())) {
|
|
||||||
return operand;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
shape::BroadcastOp::Adaptor adaptor(operands);
|
|
||||||
// When the ranks are statically known, generate non-branchy code.
|
|
||||||
// TODO: Generate rank-generic code.
|
|
||||||
auto lhsExtents = findExtentsFromShape(adaptor.lhs(), true);
|
|
||||||
auto rhsExtents = findExtentsFromShape(adaptor.rhs(), true);
|
|
||||||
if (!lhsExtents || !rhsExtents)
|
|
||||||
return rewriter.notifyMatchFailure(op, "dynamic extents not supported");
|
|
||||||
|
|
||||||
// Establish invariant that rank(lhs) >= rank(rhs).
|
|
||||||
auto lhsSize = lhsExtents.getType().cast<RankedTensorType>().getDimSize(0);
|
|
||||||
auto rhsSize = rhsExtents.getType().cast<RankedTensorType>().getDimSize(0);
|
|
||||||
if (lhsSize < rhsSize) {
|
|
||||||
std::swap(lhsExtents, rhsExtents);
|
|
||||||
std::swap(lhsSize, rhsSize);
|
|
||||||
}
|
|
||||||
auto rankDiscrepancy = lhsSize - rhsSize;
|
|
||||||
|
|
||||||
// Helper that creates IR
|
|
||||||
// ```
|
|
||||||
// abort_if(extent != resultExtent && extent != 1)
|
|
||||||
// ```
|
|
||||||
// This is the numpy broadcasting legality check.
|
|
||||||
auto createAbortIfIllegalBroadcastExtent = [&](Value extent,
|
|
||||||
Value resultExtent) {
|
|
||||||
auto c1 = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
|
|
||||||
auto extentNeMax = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne,
|
|
||||||
extent, resultExtent);
|
|
||||||
auto extentNeOne =
|
|
||||||
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, extent, c1);
|
|
||||||
auto bothTrue =
|
|
||||||
rewriter.create<AndOp>(op.getLoc(), extentNeMax, extentNeOne);
|
|
||||||
// TODO: Should there be a more generic error-handling dialect?
|
|
||||||
// It seems a bit awkward to hardcode npcomprt here.
|
|
||||||
rewriter.create<npcomprt::AbortIfOp>(op.getLoc(), bothTrue);
|
|
||||||
};
|
|
||||||
|
|
||||||
SmallVector<Value, 6> resultExtents;
|
|
||||||
for (int i = 0, e = lhsSize; i < e; i++) {
|
|
||||||
auto lhsDim = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
|
|
||||||
auto lhsExtent = rewriter.create<ExtractElementOp>(
|
|
||||||
op.getLoc(), lhsExtents, ValueRange{lhsDim});
|
|
||||||
if (i < rankDiscrepancy) {
|
|
||||||
// Padded extent.
|
|
||||||
resultExtents.push_back(lhsExtent);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non-padded extent.
|
|
||||||
auto rhsDim =
|
|
||||||
rewriter.create<ConstantIndexOp>(op.getLoc(), i - rankDiscrepancy);
|
|
||||||
auto rhsExtent = rewriter.create<ExtractElementOp>(
|
|
||||||
op.getLoc(), rhsExtents, ValueRange{rhsDim});
|
|
||||||
auto ugt = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ugt,
|
|
||||||
lhsExtent, rhsExtent);
|
|
||||||
auto resultExtent =
|
|
||||||
rewriter.create<SelectOp>(op.getLoc(), ugt, lhsExtent, rhsExtent);
|
|
||||||
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
|
|
||||||
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
|
|
||||||
resultExtents.push_back(resultExtent);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Remove the return type once ODS is fixed to do proper inference.
|
|
||||||
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
|
|
||||||
op, shape::ShapeType::get(rewriter.getContext()), resultExtents);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerShapeToExtentTensorOp
|
|
||||||
: public OpConversionPattern<shape::ToExtentTensorOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
shape::ToExtentTensorOpAdaptor adaptor(operands);
|
|
||||||
if (adaptor.input().getType().isa<shape::ShapeType>()) {
|
|
||||||
// Convert by matching to a producing FromExtentsOp.
|
|
||||||
auto fromExtents = adaptor.input().getDefiningOp<shape::FromExtentsOp>();
|
|
||||||
if (!fromExtents) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "not a from_extents op");
|
|
||||||
}
|
|
||||||
rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op,
|
|
||||||
fromExtents.extents());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assume that it is already an extent tensor.
|
|
||||||
// TODO: Since these ops are all multi-type, there should be a utility
|
|
||||||
// for switching on the allowable types instead of just assuming that it
|
|
||||||
// is an extent tensor.
|
|
||||||
rewriter.replaceOp(op, adaptor.input());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class LowerShapeGetExtentOp : public OpConversionPattern<shape::GetExtentOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(shape::GetExtentOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
shape::GetExtentOp::Adaptor adaptor(operands);
|
|
||||||
rewriter.replaceOpWithNewOp<ExtractElementOp>(op, adaptor.shape(),
|
|
||||||
adaptor.dim());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// Now that we have lowered ranked shapes, which reifies the eager
|
|
||||||
// error-handling code, the tcp::ShapeObserveErrorOp's are no longer
|
|
||||||
// needed.
|
|
||||||
class EraseShapeObserveErrorOp
|
|
||||||
: public OpConversionPattern<tcp::ShapeObserveErrorOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(tcp::ShapeObserveErrorOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// Basic invariant of this pass:
|
|
||||||
// Every `shape.from_extents` op operating on an extent tensor
|
|
||||||
// (`tensor<?xindex>`) is replaced by corresponding standard ops and folded
|
|
||||||
// away (for the ranked case, it should be possible to eliminate these).
|
|
||||||
//
|
|
||||||
// We expect that previous passes have inserted a "root" set of
|
|
||||||
// shape::FromExtentsOp's that allow this process to get started.
|
|
||||||
//
|
|
||||||
// This is similar to the approach that is used in IREE. It is basically a
|
|
||||||
// combination of the ConvertShapeToShapex pass and the
|
|
||||||
// "ranked_dim(make_ranked_shape(x1, x2), N) -> xN" folding pattern.
|
|
||||||
// These patterns have to be "conversion patterns" since the `operands` argument
|
|
||||||
// gives access to the post-conversion operands from earlier ops.
|
|
||||||
//
|
|
||||||
// This pass depends heavily on ranked shapes, since only ranked shapes can
|
|
||||||
// be statically expanded to a fixed set of SSA extents.
|
|
||||||
//
|
|
||||||
// TODO: This approach doesn't naively work with control flow.
|
|
||||||
// In the presence of non-cyclic control flow, we can just generalize the
|
|
||||||
// `getDefiningOp<shape::FromExtentsOp>()` calls into something that will
|
|
||||||
// look through block arguments and rewrite "phi of shapes -> phi of extents".
|
|
||||||
// In the presence of cyclic control flow, we need to somehow resolve the
|
|
||||||
// ranks of use-def cycles ahead of time or optimistically assume that
|
|
||||||
// backedges will match the rank of forward edges, and somehow be robust
|
|
||||||
// when that assumption fails.
|
|
||||||
//
|
|
||||||
// TODO: Add in a fold of
|
|
||||||
// `extract_element(tensor_from_elements(x0, x1, ...), n) -> xn` to restore
|
|
||||||
// the above invariant without relying on a subsequent canonicalization
|
|
||||||
// step.
|
|
||||||
namespace {
|
|
||||||
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
||||||
registry.insert<npcomprt::NpcomprtDialect>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
auto func = getOperation();
|
|
||||||
auto *context = &getContext();
|
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
patterns.insert<LowerConstShapeOp>(context);
|
|
||||||
patterns.insert<LowerShapeBroadcastOp>(context);
|
|
||||||
patterns.insert<LowerShapeGetExtentOp>(context);
|
|
||||||
patterns.insert<LowerShapeToExtentTensorOp>(context);
|
|
||||||
patterns.insert<EraseShapeObserveErrorOp>(context);
|
|
||||||
ConversionTarget target(*context);
|
|
||||||
target.addIllegalOp<shape::ShapeOfOp>();
|
|
||||||
target.addIllegalOp<shape::BroadcastOp>();
|
|
||||||
target.addIllegalOp<shape::GetExtentOp>();
|
|
||||||
target.addLegalOp<shape::FromExtentsOp>();
|
|
||||||
target.addIllegalOp<shape::ToExtentTensorOp>();
|
|
||||||
target.addLegalOp<npcomprt::AbortIfOp>();
|
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
|
||||||
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
|
|
||||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Erase some stray shape ops from the program. They can't be
|
|
||||||
// deleted during conversion because they become unused only after
|
|
||||||
// subsequent patterns bypass them.
|
|
||||||
auto walkResult = func.walk([](Operation *op) {
|
|
||||||
if (!isa<shape::FromExtentsOp>(op))
|
|
||||||
return WalkResult::advance();
|
|
||||||
if (op->use_empty()) {
|
|
||||||
op->erase();
|
|
||||||
} else {
|
|
||||||
op->emitError("could not be eliminated");
|
|
||||||
return WalkResult::interrupt();
|
|
||||||
}
|
|
||||||
return WalkResult::advance();
|
|
||||||
});
|
|
||||||
if (walkResult.wasInterrupted())
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
|
||||||
mlir::NPCOMP::createLowerRankedShapesPass() {
|
|
||||||
return std::make_unique<LowerRankedShapes>();
|
|
||||||
}
|
|
|
@ -0,0 +1,189 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
#include "npcomp/E2E/E2E.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerCstrBroadcastableOp
|
||||||
|
: public OpRewritePattern<shape::CstrBroadcastableOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// A shape.cstr_* op should be the result of lowering a !shape.shape; it
|
||||||
|
// should not itself ever consume or produce a !shape.shape.
|
||||||
|
//
|
||||||
|
// There is no way to "sink" a !shape.shape type, because one cannot inspect
|
||||||
|
// if it is an error. The only way to use it safely is to lower the op that
|
||||||
|
// produced the value to a set of constraints and then use the witness to
|
||||||
|
// guard a shape.assuming.
|
||||||
|
//
|
||||||
|
// Consider for example what we do when lowering TCF to TCP: we need to do a
|
||||||
|
// shape calculation for the broadcasting. But we create the
|
||||||
|
// shape.cstr_broadcastable and use its witness to guard a `shape.assuming {
|
||||||
|
// ... shape.broadcast ...}`. There's never any need to create a
|
||||||
|
// !shape.shape.
|
||||||
|
//
|
||||||
|
// The use of !shape.shape should be restricted to contexts like
|
||||||
|
// declarations of shape transfer functions, with automatic utilities to
|
||||||
|
// lower !shape.shape types to corresponding constraints + shape.assuming +
|
||||||
|
// tensors. In this (npcomp e2e) lowering flow, we don't have any such
|
||||||
|
// "declarative shape transfer functions" or utilities to expand them to
|
||||||
|
// constraints. So !shape.shape should never exist in our IR.
|
||||||
|
//
|
||||||
|
// Historically, we used !shape.shape type for everything, and
|
||||||
|
// shape.to_extent_tensor would abort in case of an error. But that's not a
|
||||||
|
// useful semantics for lowering, since the error is defined to happen as
|
||||||
|
// part of the shape.to_extent_tensor op, which requires materializing an
|
||||||
|
// "is error" bit in the IR and carrying it around everywhere that the
|
||||||
|
// original !shape.shape value was being used. In practice, nobody respects
|
||||||
|
// that, which opens us up to miscompilations. That is, the lowering
|
||||||
|
// strategy is either "not emit errors at all" or "emit errors as part of
|
||||||
|
// lowering e.g. the shape.broadcast op itself" (which technically puts the
|
||||||
|
// errors in some random location in the IR that is not the
|
||||||
|
// shape.to_extent_tensor op). E.g. the following code would miscompile with
|
||||||
|
// either of those ways that these ops get lowered in practice:
|
||||||
|
// ```
|
||||||
|
// %shape = shape.broadcast %lhs, %rhs : !shape.shape
|
||||||
|
// if %cond:
|
||||||
|
// shape.to_extent_tensor(%shape)
|
||||||
|
// ```
|
||||||
|
// It's not possible to correctly compile this code without significant
|
||||||
|
// contortions (such as carrying an "is error" bit). And to boot, we
|
||||||
|
// shouldn't be getting into that situation in the first place! But the
|
||||||
|
// `shape.to_extent_tensor : !shape.shape -> tensor<?xindex>` abstraction
|
||||||
|
// opens up that possibility.
|
||||||
|
//
|
||||||
|
// shape.to_extent_tensor should not really be a thing, since it creates
|
||||||
|
// these ill-defined situations about where errors are observed. A
|
||||||
|
// !shape.shape type should only exist (for this compilation flow) as part
|
||||||
|
// of a utility, something like "I want to do this shape calculation on
|
||||||
|
// !shape.shape type, create IR that uses tensor<?xindex> and witnesses to
|
||||||
|
// implement it, on the assumption that the error can be
|
||||||
|
// observed anywhere inside the shape calculation".
|
||||||
|
//
|
||||||
|
// !shape.shape type would still be useful for lowerings that actually
|
||||||
|
// result in a runtime type that carries an "is error" bit inside it, though
|
||||||
|
// TBD if such use cases arise.
|
||||||
|
if (op.getType().isa<shape::ShapeType>() ||
|
||||||
|
op.lhs().getType().isa<shape::ShapeType>() ||
|
||||||
|
op.rhs().getType().isa<shape::ShapeType>()) {
|
||||||
|
return op.emitError() << "Error shapes should not exist at this point";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||||
|
|
||||||
|
// Find smaller and greater rank and extent tensor.
|
||||||
|
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
|
||||||
|
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
|
||||||
|
Value lhsSmaller =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
|
||||||
|
Type indexTy = rewriter.getIndexType();
|
||||||
|
Type extentTensorTy = op.lhs().getType();
|
||||||
|
auto ifOp = rewriter.create<scf::IfOp>(
|
||||||
|
loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
|
||||||
|
lhsSmaller,
|
||||||
|
[&](OpBuilder &b, Location loc) {
|
||||||
|
b.create<scf::YieldOp>(
|
||||||
|
loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
|
||||||
|
},
|
||||||
|
[&](OpBuilder &b, Location loc) {
|
||||||
|
b.create<scf::YieldOp>(
|
||||||
|
loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
|
||||||
|
});
|
||||||
|
Value lesserRank = ifOp.getResult(0);
|
||||||
|
Value lesserRankOperand = ifOp.getResult(1);
|
||||||
|
Value greaterRank = ifOp.getResult(2);
|
||||||
|
Value greaterRankOperand = ifOp.getResult(3);
|
||||||
|
|
||||||
|
Value rankDiff =
|
||||||
|
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
|
||||||
|
|
||||||
|
// Compare the shapes extent by extent, and emit errors for
|
||||||
|
// non-broadcast-compatible shapes.
|
||||||
|
// Two extents are broadcast-compatible if
|
||||||
|
// 1. they are both equal, or
|
||||||
|
// 2. at least one of them is 1.
|
||||||
|
|
||||||
|
rewriter.create<scf::ForOp>(
|
||||||
|
loc, rankDiff, greaterRank, one, llvm::None,
|
||||||
|
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
|
||||||
|
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
|
||||||
|
loc, greaterRankOperand, ValueRange{iv});
|
||||||
|
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
|
||||||
|
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
|
||||||
|
loc, lesserRankOperand, ValueRange{ivShifted});
|
||||||
|
|
||||||
|
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
|
||||||
|
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
|
||||||
|
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
|
||||||
|
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
|
||||||
|
Value extentsAgree =
|
||||||
|
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
|
||||||
|
lesserRankOperandExtent);
|
||||||
|
auto broadcastIsValid =
|
||||||
|
b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
|
||||||
|
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
|
||||||
|
lesserRankOperandExtentIsOne));
|
||||||
|
b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
|
||||||
|
b.create<scf::YieldOp>(loc);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Now that we have emitted all the assertions, the witness is trivially
|
||||||
|
// satisfied.
|
||||||
|
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// This pass eliminates shape constraints from the program.
|
||||||
|
//
|
||||||
|
// After this pass finishes, there are no !shape.witness types in the program,
|
||||||
|
// no shape.assuming, no shape.cstr_*.
|
||||||
|
//
|
||||||
|
// TODO: This should move to upstream ShapeToStandard conversions.
|
||||||
|
class LowerShapeConstraints
|
||||||
|
: public LowerShapeConstraintsBase<LowerShapeConstraints> {
|
||||||
|
void runOnOperation() {
|
||||||
|
auto func = getOperation();
|
||||||
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<LowerCstrBroadcastableOp>(context);
|
||||||
|
// Add in the canonicalization patterns for shape.assuming so that it gets
|
||||||
|
// inlined when its witness becomes a true constant witness.
|
||||||
|
shape::AssumingOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
|
||||||
|
if (failed(applyPatternsAndFoldGreedily(func, patterns)))
|
||||||
|
return signalPassFailure();
|
||||||
|
|
||||||
|
// TODO: Check that there are no remaining !shape.witness, shape.assuming,
|
||||||
|
// shape.cstr_* ops, etc.
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::NPCOMP::createLowerShapeConstraintsPass() {
|
||||||
|
return std::make_unique<LowerShapeConstraints>();
|
||||||
|
}
|
|
@ -1,406 +0,0 @@
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "PassDetail.h"
|
|
||||||
#include "npcomp/E2E/E2E.h"
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Pass/PassRegistry.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
|
||||||
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace mlir::NPCOMP;
|
|
||||||
|
|
||||||
static Value allocMemRefForTensor(OpBuilder &builder, Value tensor, Value shape,
|
|
||||||
Location loc) {
|
|
||||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
|
||||||
auto memrefType =
|
|
||||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
|
||||||
return builder.create<tcp::AllocMemRefOp>(loc, memrefType, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// LowerBroadcastTo
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// TODO: Lower to linalg.indexed_generic instead and let linalg do the expansion
|
|
||||||
// to loops?
|
|
||||||
namespace {
|
|
||||||
class LowerBroadcastToToLoopsPattern
|
|
||||||
: public OpRewritePattern<tcp::BroadcastToOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(tcp::BroadcastToOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto resultType = op.getType().cast<RankedTensorType>();
|
|
||||||
auto inputType = op.operand().getType().cast<RankedTensorType>();
|
|
||||||
Value resultMemref = rewriter.create<tcp::AllocMemRefOp>(
|
|
||||||
op.getLoc(),
|
|
||||||
MemRefType::get(resultType.getShape(), resultType.getElementType()),
|
|
||||||
op.shape());
|
|
||||||
Value inputMemref = allocMemRefForTensor(
|
|
||||||
rewriter, op.operand(),
|
|
||||||
rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.operand()),
|
|
||||||
op.getLoc());
|
|
||||||
rewriter.create<TensorStoreOp>(op.getLoc(), op.operand(), inputMemref);
|
|
||||||
SmallVector<Value, 6> outputExtents;
|
|
||||||
SmallVector<Value, 6> inputDimRequiresBroadcasting;
|
|
||||||
|
|
||||||
// TODO: handle output rank > input rank.
|
|
||||||
for (int i = 0, e = resultType.getRank(); i < e; i++) {
|
|
||||||
Value dimIndex = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
|
|
||||||
Value outputExtent = rewriter.create<shape::GetExtentOp>(
|
|
||||||
op.getLoc(), rewriter.getIndexType(), op.shape(), dimIndex);
|
|
||||||
outputExtents.push_back(outputExtent);
|
|
||||||
}
|
|
||||||
int rankDiff = resultType.getRank() - inputType.getRank();
|
|
||||||
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
|
||||||
// Calculate the relevant extents.
|
|
||||||
Value inputExtent = rewriter.create<DimOp>(op.getLoc(), op.operand(), i);
|
|
||||||
inputDimRequiresBroadcasting.push_back(
|
|
||||||
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, inputExtent,
|
|
||||||
outputExtents[rankDiff + i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
|
||||||
Value c0 = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
|
||||||
Value c1 = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
|
|
||||||
|
|
||||||
SmallVector<Value, 6> inductionVariables;
|
|
||||||
// Create the (perfectly nested) loops.
|
|
||||||
// Loop invariant: At the start of iteration `i`, the rewriter insertion
|
|
||||||
// point is inside `i` nested loops.
|
|
||||||
for (int i = 0, e = resultType.getRank(); i < e; i++) {
|
|
||||||
auto loop = rewriter.create<scf::ForOp>(
|
|
||||||
op.getLoc(), c0, outputExtents[i], c1, ValueRange({}));
|
|
||||||
Block *body = loop.getBody();
|
|
||||||
inductionVariables.push_back(body->getArgument(0));
|
|
||||||
// Leave the insertion point at the beginning of the body.
|
|
||||||
rewriter.setInsertionPointToStart(body);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the inner loop body.
|
|
||||||
// When reading from the input, clamp any indices for dimensions that are
|
|
||||||
// being broadcast.
|
|
||||||
SmallVector<Value, 6> inputIndices;
|
|
||||||
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
|
||||||
auto c0 = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
|
||||||
auto select = rewriter.create<SelectOp>(
|
|
||||||
op.getLoc(), inputDimRequiresBroadcasting[i], c0,
|
|
||||||
inductionVariables[rankDiff + i]);
|
|
||||||
inputIndices.push_back(select);
|
|
||||||
}
|
|
||||||
Value load =
|
|
||||||
rewriter.create<LoadOp>(op.getLoc(), inputMemref, inputIndices);
|
|
||||||
rewriter.create<StoreOp>(op.getLoc(), load, resultMemref,
|
|
||||||
inductionVariables);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<TensorLoadOp>(op, resultMemref);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// TODO: This should be layered in better somewhere.
|
|
||||||
// We currently only create DimOp's during LowerBroadcastToToLoopsPattern,
|
|
||||||
// so for now just stuff it in here.
|
|
||||||
namespace {
|
|
||||||
class LowerDimOpToShape : public OpRewritePattern<DimOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(DimOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// TODO: Remove this const pattern when lowering to shape.get_extent.
|
|
||||||
auto shape =
|
|
||||||
rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.memrefOrTensor());
|
|
||||||
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, rewriter.getIndexType(),
|
|
||||||
shape, op.index());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerBroadcastToToLoops
|
|
||||||
: public LowerBroadcastToToLoopsBase<LowerBroadcastToToLoops> {
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
||||||
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
auto func = getOperation();
|
|
||||||
MLIRContext *context = &getContext();
|
|
||||||
ConversionTarget target(*context);
|
|
||||||
target.addLegalDialect<shape::ShapeDialect>();
|
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
|
||||||
target.addLegalDialect<scf::SCFDialect>();
|
|
||||||
target.addLegalDialect<tcp::TCPDialect>();
|
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
|
||||||
patterns.insert<LowerBroadcastToToLoopsPattern>(context);
|
|
||||||
target.addIllegalOp<DimOp>();
|
|
||||||
patterns.insert<LowerDimOpToShape>(context);
|
|
||||||
|
|
||||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
|
||||||
mlir::NPCOMP::createLowerBroadcastToToLoopsPass() {
|
|
||||||
return std::make_unique<LowerBroadcastToToLoops>();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// LowerLinalgOnTensorToLinalgOnMemref
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerLinalgGenericTensorToMemRef
|
|
||||||
: public OpRewritePattern<linalg::GenericOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(linalg::GenericOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
|
|
||||||
// TODO: Replace this with more generic code operating on named
|
|
||||||
// structured ops too.
|
|
||||||
|
|
||||||
// Only handle generic ops where all operands and results are tensors.
|
|
||||||
if (!llvm::all_of(op.getOperandTypes(),
|
|
||||||
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "all operands must be tensors");
|
|
||||||
}
|
|
||||||
if (!llvm::all_of(op.getResultTypes(),
|
|
||||||
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "all results must be tensors");
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Loosen restrictions on indexing maps.
|
|
||||||
// This will require more principled handling of shape reification
|
|
||||||
// earlier in the compilation stack, as in general output shapes of a
|
|
||||||
// linalg.generic cannot be inferred easily.
|
|
||||||
// See:
|
|
||||||
// https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866
|
|
||||||
if (!llvm::all_of(op.indexing_maps(), [](Attribute map) {
|
|
||||||
return map.cast<AffineMapAttr>().getValue().isIdentity();
|
|
||||||
})) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "all indexing maps must be identity maps");
|
|
||||||
}
|
|
||||||
if (!llvm::all_of(op.iterator_types(), [](Attribute str) {
|
|
||||||
return str.cast<StringAttr>().getValue() ==
|
|
||||||
getParallelIteratorTypeName();
|
|
||||||
})) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "all iterator types must be 'parallel'");
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 6> memrefs;
|
|
||||||
SmallVector<Value, 6> resultMemrefs;
|
|
||||||
SmallVector<Value, 6> operandShapes;
|
|
||||||
for (auto tensor : op.getOperands()) {
|
|
||||||
auto shape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), tensor);
|
|
||||||
auto memref = allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
|
|
||||||
rewriter.create<TensorStoreOp>(op.getLoc(), tensor, memref);
|
|
||||||
memrefs.push_back(memref);
|
|
||||||
operandShapes.push_back(shape);
|
|
||||||
}
|
|
||||||
auto shapeType = shape::ShapeType::get(rewriter.getContext());
|
|
||||||
SmallVector<Type, 6> shapeTypes(op.getNumResults(), shapeType);
|
|
||||||
// TODO: We need more principled handling of output shapes.
|
|
||||||
// This assumes that all results have the same shape, which is justified
|
|
||||||
// by checks above, but we really need a better story here.
|
|
||||||
SmallVector<Value, 6> resultShapes(op.getNumResults(), operandShapes[0]);
|
|
||||||
for (auto t : llvm::zip(op.getResults(), resultShapes)) {
|
|
||||||
auto tensor = std::get<0>(t);
|
|
||||||
auto shape = std::get<1>(t);
|
|
||||||
auto memref = allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
|
|
||||||
memrefs.push_back(memref);
|
|
||||||
resultMemrefs.push_back(memref);
|
|
||||||
}
|
|
||||||
auto newGeneric = rewriter.create<linalg::GenericOp>(
|
|
||||||
op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs());
|
|
||||||
newGeneric.region().getBlocks().clear();
|
|
||||||
BlockAndValueMapping mapper;
|
|
||||||
op.region().cloneInto(&newGeneric.region(), mapper);
|
|
||||||
for (auto memref : resultMemrefs) {
|
|
||||||
newGeneric.region().front().addArgument(
|
|
||||||
memref.getType().cast<MemRefType>().getElementType());
|
|
||||||
}
|
|
||||||
auto newResultTensors =
|
|
||||||
llvm::to_vector<6>(llvm::map_range(resultMemrefs, [&](Value memref) {
|
|
||||||
return rewriter.create<TensorLoadOp>(op.getLoc(), memref).getResult();
|
|
||||||
}));
|
|
||||||
rewriter.replaceOp(op, newResultTensors);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerLinalgOnTensorToLinalgOnMemref
|
|
||||||
: public LowerLinalgOnTensorToLinalgOnMemrefBase<
|
|
||||||
LowerLinalgOnTensorToLinalgOnMemref> {
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
||||||
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
auto func = getOperation();
|
|
||||||
auto *context = &getContext();
|
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
ConversionTarget target(*context);
|
|
||||||
target.addLegalDialect<shape::ShapeDialect>();
|
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
|
||||||
target.addLegalDialect<linalg::LinalgDialect>();
|
|
||||||
target.addLegalOp<tcp::AllocMemRefOp>();
|
|
||||||
patterns.insert<LowerLinalgGenericTensorToMemRef>(context);
|
|
||||||
target.addDynamicallyLegalOp<linalg::GenericOp>([](linalg::GenericOp op) {
|
|
||||||
if (llvm::any_of(op.getOperandTypes(), [](Type type) {
|
|
||||||
return type.isa<RankedTensorType>();
|
|
||||||
})) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (llvm::any_of(op.getResultTypes(), [](Type type) {
|
|
||||||
return type.isa<RankedTensorType>();
|
|
||||||
})) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
|
||||||
mlir::NPCOMP::createLowerLinalgOnTensorToLinalgOnMemrefPass() {
|
|
||||||
return std::make_unique<LowerLinalgOnTensorToLinalgOnMemref>();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// LowerConstantTensorsToMemrefs
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// This class creates global ops for all tensor-valued constants in the program.
|
|
||||||
// It creates them with pretty names and makes sure that duplicate globals
|
|
||||||
// aren't created.
|
|
||||||
class GlobalCreator {
|
|
||||||
public:
|
|
||||||
explicit GlobalCreator(ModuleOp module);
|
|
||||||
tcp::GlobalOp getGlobalFor(Attribute attr) {
|
|
||||||
assert(globals.find(attr) != globals.end() && "unknown constant attr");
|
|
||||||
return globals[attr];
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
DenseMap<Attribute, tcp::GlobalOp> globals;
|
|
||||||
};
|
|
||||||
|
|
||||||
GlobalCreator::GlobalCreator(ModuleOp module) {
|
|
||||||
// Create a builder without an insertion point. We will insert using the
|
|
||||||
// symbol table to guarantee unique names.
|
|
||||||
OpBuilder globalBuilder(module.getContext());
|
|
||||||
SymbolTable symbolTable(module);
|
|
||||||
module.walk([&](ConstantOp op) {
|
|
||||||
// We only want tensor constants for now.
|
|
||||||
auto type = op.getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!type)
|
|
||||||
return;
|
|
||||||
// If we already have a global for this constant value, no need to do
|
|
||||||
// anything else.
|
|
||||||
auto it = globals.find(op.getValue());
|
|
||||||
if (it != globals.end())
|
|
||||||
return;
|
|
||||||
|
|
||||||
// Create a pretty name.
|
|
||||||
SmallString<64> buf;
|
|
||||||
llvm::raw_svector_ostream os(buf);
|
|
||||||
interleave(type.getShape(), os, "x");
|
|
||||||
os << "x" << type.getElementType();
|
|
||||||
|
|
||||||
auto global = globalBuilder.create<tcp::GlobalOp>(
|
|
||||||
op.getLoc(), (Twine("__constant_") + os.str()).str(),
|
|
||||||
op.getValue().cast<ElementsAttr>());
|
|
||||||
symbolTable.insert(global);
|
|
||||||
// The symbol table inserts at the end of the module, but globals are a bit
|
|
||||||
// nicer if they are at the beginning.
|
|
||||||
global.getOperation()->moveBefore(&module.front());
|
|
||||||
globals[op.getValue()] = global;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerConstantTensorsToMemrefs
|
|
||||||
: public LowerConstantTensorsToMemrefsBase<LowerConstantTensorsToMemrefs> {
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
||||||
registry.insert<tcp::TCPDialect>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
auto module = getOperation();
|
|
||||||
GlobalCreator globals(module);
|
|
||||||
|
|
||||||
// With the global traversal factored into GlobalCreator, this could in
|
|
||||||
// principle be done with a pattern.
|
|
||||||
module.walk([&](ConstantOp op) {
|
|
||||||
auto type = op.getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!type)
|
|
||||||
return;
|
|
||||||
auto global = globals.getGlobalFor(op.getValue());
|
|
||||||
OpBuilder builder(op);
|
|
||||||
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
|
|
||||||
auto memref = builder.create<tcp::GetGlobalMemrefOp>(
|
|
||||||
op.getLoc(), memrefType, global.getName());
|
|
||||||
Value tensor = builder.create<TensorLoadOp>(op.getLoc(), type, memref);
|
|
||||||
op.replaceAllUsesWith(tensor);
|
|
||||||
op.erase();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
|
||||||
mlir::NPCOMP::createLowerConstantTensorsToMemrefsPass() {
|
|
||||||
return std::make_unique<LowerConstantTensorsToMemrefs>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) {
|
|
||||||
// Lower to hybrid tensor/memref.
|
|
||||||
// The invariant of "hybrid tensor/memref" is that the core computation
|
|
||||||
// ops operate on memref, but we launder in and out of tensors in such a
|
|
||||||
// way that the original SSA tensor values remain and can be traced to
|
|
||||||
// their corresponding memrefs (via tensor_load/tensor_store) which are
|
|
||||||
// allocated with alloc_shape ops.
|
|
||||||
// Thus, shape.shape_of ops on the original tensors in the program can be
|
|
||||||
// resolved to the shapes in the alloc_memref calls.
|
|
||||||
pm.addPass(createLowerConstantTensorsToMemrefsPass());
|
|
||||||
pm.addPass(createLowerLinalgOnTensorToLinalgOnMemrefPass());
|
|
||||||
pm.addPass(createLowerBroadcastToToLoopsPass());
|
|
||||||
}
|
|
|
@ -150,6 +150,59 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
|
||||||
|
OpBuilder &builder, Location loc) {
|
||||||
|
// TODO: Deduplicate strings.
|
||||||
|
auto arrayTy = LLVMType::getArrayTy(LLVMType::getInt8Ty(module.getContext()),
|
||||||
|
msg.getValue().size());
|
||||||
|
OpBuilder::InsertionGuard guard(builder);
|
||||||
|
builder.setInsertionPointToStart(module.getBody());
|
||||||
|
// To get a unique symbol name, use a suffix derived from the current number
|
||||||
|
// of ops in the module.
|
||||||
|
// We can't use the SymbolTable's logic for this because the module
|
||||||
|
// transiently contains a `func` and `llvm.func` with the same name during
|
||||||
|
// conversion, preventing us from instantiating a SymbolTable.
|
||||||
|
std::string symbolName =
|
||||||
|
(Twine("__npcomp_string_") +
|
||||||
|
Twine(llvm::size(llvm::to_vector<6>(module.getOps<LLVM::GlobalOp>()))))
|
||||||
|
.str();
|
||||||
|
auto globalOp =
|
||||||
|
builder.create<LLVM::GlobalOp>(loc, arrayTy, /*isConstant=*/true,
|
||||||
|
LLVM::Linkage::Internal, symbolName, msg);
|
||||||
|
return globalOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class AbortIfOpCompilerRuntimeLowering
|
||||||
|
: public OpConversionPattern<npcomprt::AbortIfOp> {
|
||||||
|
public:
|
||||||
|
AbortIfOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
||||||
|
: OpConversionPattern<npcomprt::AbortIfOp>(backingFunc.getContext()),
|
||||||
|
backingFunc(backingFunc) {}
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(npcomprt::AbortIfOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
npcomprt::AbortIfOp::Adaptor adaptor(operands);
|
||||||
|
auto *context = op.getContext();
|
||||||
|
|
||||||
|
// Create the global string, take its address, and gep to get an `i8*`.
|
||||||
|
auto globalOp = createGlobalString(op.getParentOfType<ModuleOp>(),
|
||||||
|
op.msgAttr(), rewriter, op.getLoc());
|
||||||
|
auto msgArray = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), globalOp);
|
||||||
|
auto c0 = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
op.getLoc(), LLVMType::getIntNTy(context, 32),
|
||||||
|
rewriter.getI32IntegerAttr(0));
|
||||||
|
auto msg = rewriter.create<LLVM::GEPOp>(op.getLoc(),
|
||||||
|
LLVMType::getInt8PtrTy(context),
|
||||||
|
msgArray, ValueRange({c0, c0}));
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
||||||
|
op, backingFunc, ValueRange({adaptor.pred(), msg}));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
LLVM::LLVMFuncOp backingFunc;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Create the LLVM runtime function backing the npcomprt op with name `name`
|
// Create the LLVM runtime function backing the npcomprt op with name `name`
|
||||||
// and requiring `type`.
|
// and requiring `type`.
|
||||||
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type,
|
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type,
|
||||||
|
@ -168,24 +221,13 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
|
||||||
OpBuilder builder(module.getBodyRegion());
|
OpBuilder builder(module.getBodyRegion());
|
||||||
|
|
||||||
{
|
{
|
||||||
auto abortIfFuncTy = LLVMType::getFunctionTy(LLVMType::getVoidTy(context),
|
auto abortIfFuncTy = LLVMType::getFunctionTy(
|
||||||
{LLVMType::getInt1Ty(context)},
|
LLVMType::getVoidTy(context),
|
||||||
|
{LLVMType::getInt1Ty(context), LLVMType::getInt8PtrTy(context)},
|
||||||
/*isVarArg=*/false);
|
/*isVarArg=*/false);
|
||||||
LLVMFuncOp abortIfFunc = createCompilerRuntimeFuncDecl(
|
LLVMFuncOp abortIfFunc = createCompilerRuntimeFuncDecl(
|
||||||
"abort_if", abortIfFuncTy, builder, module.getLoc());
|
"abort_if", abortIfFuncTy, builder, module.getLoc());
|
||||||
patterns.insert<TrivialCompilerRuntimeLowering<npcomprt::AbortIfOp>>(
|
patterns.insert<AbortIfOpCompilerRuntimeLowering>(abortIfFunc);
|
||||||
abortIfFunc);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto getExtentFuncTy = LLVMType::getFunctionTy(
|
|
||||||
typeConverter.convertType(builder.getIndexType()).cast<LLVMType>(),
|
|
||||||
{LLVMType::getInt8PtrTy(context), LLVMType::getIntNTy(context, 32)},
|
|
||||||
/*isVarArg=*/false);
|
|
||||||
LLVMFuncOp getExtentFunc = createCompilerRuntimeFuncDecl(
|
|
||||||
"get_extent", getExtentFuncTy, builder, module.getLoc());
|
|
||||||
patterns.insert<TrivialCompilerRuntimeLowering<npcomprt::GetExtentOp>>(
|
|
||||||
getExtentFunc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto convertFunctionType = [&](FunctionType type) {
|
auto convertFunctionType = [&](FunctionType type) {
|
||||||
|
|
|
@ -9,8 +9,6 @@
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
#include "npcomp/E2E/E2E.h"
|
#include "npcomp/E2E/E2E.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/IR/Verifier.h"
|
#include "mlir/IR/Verifier.h"
|
||||||
|
@ -23,6 +21,13 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
// Get the type used to represent MemRefType `type` on ABI boundaries.
|
||||||
|
// For convenience we do a cast to MemRefType internally.
|
||||||
|
static Type getABIMemrefType(Type type) {
|
||||||
|
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(),
|
||||||
|
/*memorySpace=*/0);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Creating module metadata.
|
// Creating module metadata.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -30,10 +35,10 @@ using namespace mlir::NPCOMP;
|
||||||
// Returns true if the function signature can be expressed with the npcomprt
|
// Returns true if the function signature can be expressed with the npcomprt
|
||||||
// ABI.
|
// ABI.
|
||||||
static bool expressibleWithNpcomprtABI(FunctionType type) {
|
static bool expressibleWithNpcomprtABI(FunctionType type) {
|
||||||
// Currently, only tensor types can be exposed at npcomprt ABI boundaries.
|
// Currently, only memref types can be exposed at npcomprt ABI boundaries.
|
||||||
return llvm::all_of(
|
return llvm::all_of(
|
||||||
llvm::concat<const Type>(type.getInputs(), type.getResults()),
|
llvm::concat<const Type>(type.getInputs(), type.getResults()),
|
||||||
[](Type t) { return t.isa<TensorType>(); });
|
[](Type t) { return t.isa<MemRefType>(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult createModuleMetadata(ModuleOp module) {
|
static LogicalResult createModuleMetadata(ModuleOp module) {
|
||||||
|
@ -69,82 +74,6 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
|
||||||
// Dialect conversion.
|
// Dialect conversion.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerTensorStoreOp : public OpConversionPattern<TensorStoreOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(TensorStoreOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
TensorStoreOp::Adaptor adaptor(operands);
|
|
||||||
auto memrefType = op.memref().getType().cast<MemRefType>();
|
|
||||||
Value abiMemref = rewriter.create<npcomprt::ToMemrefOp>(
|
|
||||||
op.getLoc(),
|
|
||||||
UnrankedMemRefType::get(memrefType.getElementType(), /*memorySpace=*/0),
|
|
||||||
adaptor.tensor());
|
|
||||||
auto memref =
|
|
||||||
rewriter.create<MemRefCastOp>(op.getLoc(), abiMemref, memrefType);
|
|
||||||
rewriter.replaceOpWithNewOp<linalg::CopyOp>(op, memref, adaptor.memref());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
TensorLoadOp::Adaptor adaptor(operands);
|
|
||||||
auto abiMemref = rewriter.create<MemRefCastOp>(
|
|
||||||
op.getLoc(), adaptor.memref(),
|
|
||||||
UnrankedMemRefType::get(
|
|
||||||
adaptor.memref().getType().cast<MemRefType>().getElementType(),
|
|
||||||
/*memorySpace=*/0));
|
|
||||||
rewriter.replaceOpWithNewOp<npcomprt::FromMemrefOp>(
|
|
||||||
op, rewriter.getType<npcomprt::TensorType>(), abiMemref);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerShapeOfOp : public OpConversionPattern<shape::ShapeOfOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(shape::ShapeOfOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
shape::ShapeOfOp::Adaptor adaptor(operands);
|
|
||||||
// TODO: For now npcomp only supports ranked tensor types for its shape
|
|
||||||
// lowering, since we don't have a runtime shape struct and lower all shapes
|
|
||||||
// to individual SSA values.
|
|
||||||
auto tensorType = op.arg().getType().cast<RankedTensorType>();
|
|
||||||
SmallVector<Value, 6> extents;
|
|
||||||
for (int i = 0, e = tensorType.getRank(); i < e; i++) {
|
|
||||||
auto ci = rewriter.create<ConstantOp>(op.getLoc(),
|
|
||||||
rewriter.getI32IntegerAttr(i));
|
|
||||||
// TODO: Shouldn't the index type for the output be inferred since
|
|
||||||
// https://reviews.llvm.org/rG31f40f603d0c00b313397196124c5f39090badf0
|
|
||||||
// ?
|
|
||||||
extents.push_back(rewriter.create<npcomprt::GetExtentOp>(
|
|
||||||
op.getLoc(), rewriter.getIndexType(), adaptor.arg(), ci));
|
|
||||||
}
|
|
||||||
auto newShape = rewriter.create<shape::FromExtentsOp>(
|
|
||||||
op.getLoc(), rewriter.getType<shape::ShapeType>(), extents);
|
|
||||||
// TODO: Provide a builder that doesn't require the result type.
|
|
||||||
rewriter.replaceOpWithNewOp<shape::ToExtentTensorOp>(
|
|
||||||
op,
|
|
||||||
RankedTensorType::get({ShapedType::kDynamicSize},
|
|
||||||
rewriter.getIndexType()),
|
|
||||||
newShape);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class LowerGlobalOp : public OpConversionPattern<tcp::GlobalOp> {
|
class LowerGlobalOp : public OpConversionPattern<tcp::GlobalOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -167,10 +96,8 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(tcp::GetGlobalMemrefOp op, ArrayRef<Value> operands,
|
matchAndRewrite(tcp::GetGlobalMemrefOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto abiMemrefType = UnrankedMemRefType::get(
|
|
||||||
op.getType().cast<ShapedType>().getElementType(), /*memorySpace=*/0);
|
|
||||||
auto abiMemref = rewriter.create<npcomprt::GetGlobalOp>(
|
auto abiMemref = rewriter.create<npcomprt::GetGlobalOp>(
|
||||||
op.getLoc(), abiMemrefType, op.global());
|
op.getLoc(), getABIMemrefType(op.getType()), op.global());
|
||||||
// Cast back to the original type.
|
// Cast back to the original type.
|
||||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType());
|
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType());
|
||||||
return success();
|
return success();
|
||||||
|
@ -178,47 +105,126 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerAssertOp : public OpConversionPattern<AssertOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
AssertOp::Adaptor adaptor(operands);
|
||||||
|
// The npcomprt runtime function aborts if the argument is true, rather than
|
||||||
|
// when it is false as an `assert` does. So negate the predicate (by xor'ing
|
||||||
|
// with 1).
|
||||||
|
auto c1 = rewriter.create<ConstantOp>(
|
||||||
|
op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(),
|
||||||
|
APInt(/*numBits=*/1, /*val=*/1)));
|
||||||
|
Value assertFailed = rewriter.create<XOrOp>(op.getLoc(), adaptor.arg(), c1);
|
||||||
|
rewriter.replaceOpWithNewOp<npcomprt::AbortIfOp>(op, assertFailed,
|
||||||
|
op.msgAttr());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// At ABI bondaries, use !npcomprt.tensor instead of memref.
|
||||||
|
class FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(FuncOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
FunctionType type = op.getType();
|
||||||
|
|
||||||
|
TypeConverter::SignatureConversion entryConversion(type.getNumInputs());
|
||||||
|
if (failed(typeConverter->convertSignatureArgs(type.getInputs(),
|
||||||
|
entryConversion)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "could not convert inputs");
|
||||||
|
SmallVector<Type, 1> newResultTypes;
|
||||||
|
if (failed(typeConverter->convertTypes(type.getResults(), newResultTypes)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "could not convert outputs");
|
||||||
|
|
||||||
|
rewriter.updateRootInPlace(op, [&] {
|
||||||
|
// Update the function type.
|
||||||
|
op.setType(FunctionType::get(entryConversion.getConvertedTypes(),
|
||||||
|
newResultTypes, op.getContext()));
|
||||||
|
// Rewrite the entry block.
|
||||||
|
Block &oldEntry = op.getBody().front();
|
||||||
|
Block &newEntry =
|
||||||
|
*rewriter.applySignatureConversion(&op.getBody(), entryConversion);
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&newEntry);
|
||||||
|
BlockArgument newArg, oldArg;
|
||||||
|
for (auto newAndOldArg :
|
||||||
|
llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) {
|
||||||
|
std::tie(newArg, oldArg) = newAndOldArg;
|
||||||
|
auto abiMemref = rewriter.create<npcomprt::ToMemrefOp>(
|
||||||
|
op.getLoc(), getABIMemrefType(oldArg.getType()), newArg);
|
||||||
|
auto memref = rewriter.create<MemRefCastOp>(op.getLoc(), abiMemref,
|
||||||
|
oldArg.getType());
|
||||||
|
rewriter.replaceUsesOfBlockArgument(oldArg, memref);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// At the return ABI boundaries, convert to !npcomprt.tensor type.
|
||||||
|
// This pattern is needed to trigger the type conversion mechanics to do a
|
||||||
|
// target materialization.
|
||||||
|
class RewriteReturnOp : public OpConversionPattern<ReturnOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
op.getParentOfType<FuncOp>().dump();
|
||||||
|
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
static LogicalResult doDialectConversion(ModuleOp module) {
|
static LogicalResult doDialectConversion(ModuleOp module) {
|
||||||
auto *context = module.getContext();
|
auto *context = module.getContext();
|
||||||
|
|
||||||
TypeConverter converter;
|
TypeConverter typeConverter;
|
||||||
converter.addConversion([](TensorType type) {
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
typeConverter.addConversion([](MemRefType type) {
|
||||||
return npcomprt::TensorType::get(type.getContext());
|
return npcomprt::TensorType::get(type.getContext());
|
||||||
});
|
});
|
||||||
converter.addConversion([](npcomprt::TensorType type) { return type; });
|
typeConverter.addTargetMaterialization(
|
||||||
|
[](OpBuilder &builder, npcomprt::TensorType type, ValueRange inputs,
|
||||||
|
Location loc) -> Value {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto abiMemref = builder.create<MemRefCastOp>(
|
||||||
|
loc, inputs[0], getABIMemrefType(inputs[0].getType()));
|
||||||
|
return builder.create<npcomprt::FromMemrefOp>(loc, type, abiMemref);
|
||||||
|
});
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
|
target.addLegalDialect<npcomprt::NpcomprtDialect>();
|
||||||
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
|
|
||||||
populateFuncOpTypeConversionPattern(patterns, context, converter);
|
patterns.insert<FuncOpSignatureConversion>(typeConverter, context);
|
||||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp op) {
|
target.addDynamicallyLegalOp<FuncOp>(
|
||||||
return converter.isSignatureLegal(op.getType());
|
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
|
||||||
});
|
patterns.insert<RewriteReturnOp>(typeConverter, context);
|
||||||
|
target.addDynamicallyLegalOp<ReturnOp>(
|
||||||
patterns.insert<LowerTensorStoreOp>(context);
|
[&](ReturnOp op) { return typeConverter.isLegal(op); });
|
||||||
target.addIllegalOp<TensorStoreOp>();
|
|
||||||
target.addLegalOp<npcomprt::ToMemrefOp>();
|
|
||||||
target.addLegalOp<linalg::CopyOp>();
|
|
||||||
target.addLegalOp<MemRefCastOp>();
|
|
||||||
|
|
||||||
patterns.insert<LowerTensorLoadOp>(context);
|
|
||||||
target.addIllegalOp<TensorLoadOp>();
|
|
||||||
target.addLegalOp<npcomprt::FromMemrefOp>();
|
|
||||||
|
|
||||||
patterns.insert<LowerShapeOfOp>(context);
|
|
||||||
target.addIllegalOp<shape::ShapeOfOp>();
|
|
||||||
target.addLegalOp<ConstantOp>();
|
|
||||||
target.addLegalOp<shape::FromExtentsOp>();
|
|
||||||
target.addLegalOp<shape::ToExtentTensorOp>();
|
|
||||||
target.addLegalOp<npcomprt::GetExtentOp>();
|
|
||||||
|
|
||||||
patterns.insert<LowerGlobalOp>(context);
|
patterns.insert<LowerGlobalOp>(context);
|
||||||
target.addIllegalOp<tcp::GlobalOp>();
|
target.addIllegalOp<tcp::GlobalOp>();
|
||||||
target.addLegalOp<npcomprt::GlobalOp>();
|
|
||||||
|
|
||||||
patterns.insert<LowerGetGlobalMemrefOp>(context);
|
patterns.insert<LowerGetGlobalMemrefOp>(context);
|
||||||
target.addIllegalOp<tcp::GetGlobalMemrefOp>();
|
target.addIllegalOp<tcp::GetGlobalMemrefOp>();
|
||||||
target.addLegalOp<npcomprt::GetGlobalOp>();
|
|
||||||
|
patterns.insert<LowerAssertOp>(context);
|
||||||
|
target.addIllegalOp<AssertOp>();
|
||||||
|
|
||||||
return applyPartialConversion(module, target, patterns);
|
return applyPartialConversion(module, target, patterns);
|
||||||
}
|
}
|
||||||
|
@ -228,7 +234,7 @@ namespace {
|
||||||
// the npcomprt dialect.
|
// the npcomprt dialect.
|
||||||
class LowerToNpcomprtABI : public LowerToNpcomprtABIBase<LowerToNpcomprtABI> {
|
class LowerToNpcomprtABI : public LowerToNpcomprtABIBase<LowerToNpcomprtABI> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<linalg::LinalgDialect, npcomprt::NpcomprtDialect>();
|
registry.insert<npcomprt::NpcomprtDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "npcomp/E2E/E2E.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||||
|
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// LowerConstantTensorsToMemref
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// This class creates global ops for all tensor-valued constants in the program.
|
||||||
|
// It creates them with pretty names and makes sure that duplicate globals
|
||||||
|
// aren't created.
|
||||||
|
class GlobalCreator {
|
||||||
|
public:
|
||||||
|
explicit GlobalCreator(ModuleOp module);
|
||||||
|
tcp::GlobalOp getGlobalFor(Attribute attr) {
|
||||||
|
assert(globals.find(attr) != globals.end() && "unknown constant attr");
|
||||||
|
return globals[attr];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DenseMap<Attribute, tcp::GlobalOp> globals;
|
||||||
|
};
|
||||||
|
|
||||||
|
GlobalCreator::GlobalCreator(ModuleOp module) {
|
||||||
|
// Create a builder without an insertion point. We will insert using the
|
||||||
|
// symbol table to guarantee unique names.
|
||||||
|
OpBuilder globalBuilder(module.getContext());
|
||||||
|
SymbolTable symbolTable(module);
|
||||||
|
module.walk([&](ConstantOp op) {
|
||||||
|
// We only want tensor constants for now.
|
||||||
|
auto type = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!type)
|
||||||
|
return;
|
||||||
|
// If we already have a global for this constant value, no need to do
|
||||||
|
// anything else.
|
||||||
|
auto it = globals.find(op.getValue());
|
||||||
|
if (it != globals.end())
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Create a pretty name.
|
||||||
|
SmallString<64> buf;
|
||||||
|
llvm::raw_svector_ostream os(buf);
|
||||||
|
interleave(type.getShape(), os, "x");
|
||||||
|
os << "x" << type.getElementType();
|
||||||
|
|
||||||
|
auto global = globalBuilder.create<tcp::GlobalOp>(
|
||||||
|
op.getLoc(), (Twine("__constant_") + os.str()).str(),
|
||||||
|
op.getValue().cast<ElementsAttr>());
|
||||||
|
symbolTable.insert(global);
|
||||||
|
// The symbol table inserts at the end of the module, but globals are a bit
|
||||||
|
// nicer if they are at the beginning.
|
||||||
|
global.getOperation()->moveBefore(&module.front());
|
||||||
|
globals[op.getValue()] = global;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerConstantTensorsToMemref
|
||||||
|
: public LowerConstantTensorsToMemrefBase<LowerConstantTensorsToMemref> {
|
||||||
|
void runOnOperation() {
|
||||||
|
auto module = getOperation();
|
||||||
|
GlobalCreator globals(module);
|
||||||
|
|
||||||
|
// With the global traversal factored into GlobalCreator, this could in
|
||||||
|
// principle be done with a pattern.
|
||||||
|
module.walk([&](ConstantOp op) {
|
||||||
|
auto type = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!type)
|
||||||
|
return;
|
||||||
|
auto global = globals.getGlobalFor(op.getValue());
|
||||||
|
OpBuilder builder(op);
|
||||||
|
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
|
||||||
|
auto memref = builder.create<tcp::GetGlobalMemrefOp>(
|
||||||
|
op.getLoc(), memrefType, global.getName());
|
||||||
|
Value tensor =
|
||||||
|
builder.create<tcp::MemrefToTensorOp>(op.getLoc(), type, memref);
|
||||||
|
op.replaceAllUsesWith(tensor);
|
||||||
|
op.erase();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::NPCOMP::createLowerConstantTensorsToMemrefPass() {
|
||||||
|
return std::make_unique<LowerConstantTensorsToMemref>();
|
||||||
|
}
|
|
@ -0,0 +1,327 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "npcomp/E2E/E2E.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
|
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||||
|
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
static Value allocMemRefForTensor(OpBuilder &builder, Value tensor, Value shape,
|
||||||
|
Location loc) {
|
||||||
|
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||||
|
auto memrefType =
|
||||||
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
|
return builder.create<tcp::AllocMemRefOp>(loc, memrefType, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// TODO: Lower to a "buffer version" of tcp::BroadcastTo instead of directly to
|
||||||
|
// loops.
|
||||||
|
class LowerBroadcastToToLoopsPattern
|
||||||
|
: public OpConversionPattern<tcp::BroadcastToOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tcp::BroadcastToOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto resultType = op.getType().cast<RankedTensorType>();
|
||||||
|
auto inputType = op.operand().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
auto shapedResults = dyn_cast<tcp::ShapedResultsOp>(op.getParentOp());
|
||||||
|
if (!shapedResults)
|
||||||
|
return rewriter.notifyMatchFailure(op, "parent not tcp.shaped_results");
|
||||||
|
if (op.getOperation()->getResults() !=
|
||||||
|
shapedResults.getBody()->getTerminator()->getOperands())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only limited forms of tcp.shaped_results allowed");
|
||||||
|
auto resultShape = shapedResults.resultShapes()[0];
|
||||||
|
Value resultMemref =
|
||||||
|
allocMemRefForTensor(rewriter, op.result(), resultShape, op.getLoc());
|
||||||
|
Value inputMemref = operands[0];
|
||||||
|
|
||||||
|
SmallVector<Value, 6> outputExtents;
|
||||||
|
for (int i = 0, e = resultType.getRank(); i < e; i++) {
|
||||||
|
Value dimIndex = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
|
||||||
|
Value outputExtent = rewriter.create<shape::GetExtentOp>(
|
||||||
|
op.getLoc(), rewriter.getIndexType(), resultShape, dimIndex);
|
||||||
|
outputExtents.push_back(outputExtent);
|
||||||
|
}
|
||||||
|
int rankDiff = resultType.getRank() - inputType.getRank();
|
||||||
|
SmallVector<Value, 6> inputDimRequiresBroadcasting;
|
||||||
|
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
||||||
|
// Calculate the relevant extents.
|
||||||
|
Value inputExtent = rewriter.create<DimOp>(op.getLoc(), op.operand(), i);
|
||||||
|
inputDimRequiresBroadcasting.push_back(
|
||||||
|
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, inputExtent,
|
||||||
|
outputExtents[rankDiff + i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
Value c0 = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
||||||
|
Value c1 = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
|
||||||
|
|
||||||
|
SmallVector<Value, 6> inductionVariables;
|
||||||
|
// Create the (perfectly nested) loops.
|
||||||
|
// Loop invariant: At the start of iteration `i`, the rewriter insertion
|
||||||
|
// point is inside `i` nested loops.
|
||||||
|
for (int i = 0, e = resultType.getRank(); i < e; i++) {
|
||||||
|
auto loop = rewriter.create<scf::ForOp>(
|
||||||
|
op.getLoc(), c0, outputExtents[i], c1, ValueRange({}));
|
||||||
|
Block *body = loop.getBody();
|
||||||
|
inductionVariables.push_back(body->getArgument(0));
|
||||||
|
// Leave the insertion point at the beginning of the body.
|
||||||
|
rewriter.setInsertionPointToStart(body);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the inner loop body.
|
||||||
|
// When reading from the input, clamp any indices for dimensions that are
|
||||||
|
// being broadcast.
|
||||||
|
SmallVector<Value, 6> inputIndices;
|
||||||
|
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
||||||
|
auto c0 = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
||||||
|
auto select = rewriter.create<SelectOp>(
|
||||||
|
op.getLoc(), inputDimRequiresBroadcasting[i], c0,
|
||||||
|
inductionVariables[rankDiff + i]);
|
||||||
|
inputIndices.push_back(select);
|
||||||
|
}
|
||||||
|
Value load =
|
||||||
|
rewriter.create<LoadOp>(op.getLoc(), inputMemref, inputIndices);
|
||||||
|
rewriter.create<StoreOp>(op.getLoc(), load, resultMemref,
|
||||||
|
inductionVariables);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, resultMemref);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerLinalgGenericTensorToMemRef
|
||||||
|
: public OpConversionPattern<linalg::GenericOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
// TODO: Replace this with more generic code operating on named
|
||||||
|
// structured ops too.
|
||||||
|
|
||||||
|
// These checks mirror those in BypassShapes.
|
||||||
|
if (!llvm::all_of(op.getOperandTypes(),
|
||||||
|
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "all operands must be tensors");
|
||||||
|
}
|
||||||
|
if (!llvm::all_of(op.getResultTypes(),
|
||||||
|
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "all results must be tensors");
|
||||||
|
}
|
||||||
|
if (!llvm::all_of(op.indexing_maps(), [](Attribute map) {
|
||||||
|
return map.cast<AffineMapAttr>().getValue().isIdentity();
|
||||||
|
})) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "all indexing maps must be identity maps");
|
||||||
|
}
|
||||||
|
if (!llvm::all_of(op.iterator_types(), [](Attribute str) {
|
||||||
|
return str.cast<StringAttr>().getValue() ==
|
||||||
|
getParallelIteratorTypeName();
|
||||||
|
})) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "all iterator types must be 'parallel'");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 6> memrefs(operands.begin(), operands.end());
|
||||||
|
SmallVector<Value, 6> resultMemrefs;
|
||||||
|
SmallVector<Value, 6> operandShapes;
|
||||||
|
|
||||||
|
auto shapedResults = dyn_cast<tcp::ShapedResultsOp>(op.getParentOp());
|
||||||
|
if (!shapedResults)
|
||||||
|
return rewriter.notifyMatchFailure(op, "parent not tcp.shaped_results");
|
||||||
|
// TODO: What if there are multiple ops in the tcp.shaped_results region?
|
||||||
|
// The IREE solution is "they have to be fused and create no allocations
|
||||||
|
// ultimately". The non-IREE solution is to just not bypass shapes in the
|
||||||
|
// first place.
|
||||||
|
if (op.getResults() !=
|
||||||
|
shapedResults.getBody()->getTerminator()->getOperands())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only limited forms of tcp.shaped_results allowed");
|
||||||
|
|
||||||
|
for (auto t : llvm::zip(op.getResults(), shapedResults.resultShapes())) {
|
||||||
|
auto tensor = std::get<0>(t);
|
||||||
|
auto shape = std::get<1>(t);
|
||||||
|
auto memref = allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
|
||||||
|
memrefs.push_back(memref);
|
||||||
|
resultMemrefs.push_back(memref);
|
||||||
|
}
|
||||||
|
auto newGeneric = rewriter.create<linalg::GenericOp>(
|
||||||
|
op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs());
|
||||||
|
newGeneric.region().getBlocks().clear();
|
||||||
|
BlockAndValueMapping mapper;
|
||||||
|
op.region().cloneInto(&newGeneric.region(), mapper);
|
||||||
|
for (auto memref : resultMemrefs) {
|
||||||
|
newGeneric.region().front().addArgument(
|
||||||
|
memref.getType().cast<MemRefType>().getElementType());
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, resultMemrefs);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// TODO: Linalg and shape don't implement the inliner interface, which blocks us
|
||||||
|
// from using mlir::inlineRegion. Locally override it here.
|
||||||
|
class LocallyOverrideLegalityInlinerInterface : public InlinerInterface {
|
||||||
|
public:
|
||||||
|
using InlinerInterface::InlinerInterface;
|
||||||
|
bool isLegalToInline(Operation *op, Region *dest,
|
||||||
|
BlockAndValueMapping &valueMapping) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isLegalToInline(Region *dest, Region *src,
|
||||||
|
BlockAndValueMapping &valueMapping) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// This pass is responsible for lowering regions wrapped by
|
||||||
|
// tcp.shaped_results (which operate on tensors) to memrefs.
|
||||||
|
// This includes any ops potentially contained within them.
|
||||||
|
// This is somewhat analogous to IREE's backend compilation of a single dispatch
|
||||||
|
// region, except that for now, we only allow a single op in the
|
||||||
|
// tcp.shaped_results, and we don't have any notion of "backend" layered at all.
|
||||||
|
// Nor is it clear if we really want any of that here.
|
||||||
|
//
|
||||||
|
// The tcp.shaped_results ops provide precisely the information needed to
|
||||||
|
// allocate output buffers when converting to memref.
|
||||||
|
// For now, this process eliminates the original tcp.shaped_results op since we
|
||||||
|
// don't have any host/device distinction or other structure that would require
|
||||||
|
// retaining that sort of IR structure.
|
||||||
|
//
|
||||||
|
// TODO: Do "shape_of" resolution while still on tensors.
|
||||||
|
// Here we spew out tons of shape_of and rely on dim ops on descriptors to make
|
||||||
|
// it work. The key difference is that we need tcp.shaped_results (or its
|
||||||
|
// successor / something it gets lowered to) to not be IsolatedFromAbove, and
|
||||||
|
// explicitly capture all input tensors along with their shapes. That allows
|
||||||
|
// shape_of ops on inputs to be trivially resolved. Unfortunately, this opens up
|
||||||
|
// the whole "dispatch region formation" can of worms like exists in IREE --
|
||||||
|
// once you have multiple ops inside a "dispatch region", you need to somehow
|
||||||
|
// lower them without allocating intermediate buffers.
|
||||||
|
//
|
||||||
|
// TODO: Don't hardcode the lowering for every op in this one pass.
|
||||||
|
class LowerShapedResultsToMemref
|
||||||
|
: public LowerShapedResultsToMemrefBase<LowerShapedResultsToMemref> {
|
||||||
|
void runOnOperation() {
|
||||||
|
auto func = getOperation();
|
||||||
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
typeConverter.addConversion([](RankedTensorType type) -> Type {
|
||||||
|
return MemRefType::get(type.getShape(), type.getElementType());
|
||||||
|
});
|
||||||
|
|
||||||
|
typeConverter.addSourceMaterialization([](OpBuilder &builder,
|
||||||
|
RankedTensorType type,
|
||||||
|
ValueRange inputs, Location loc) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(inputs[0].getType().isa<MemRefType>());
|
||||||
|
return (Value)builder.create<tcp::MemrefToTensorOp>(loc, type, inputs[0]);
|
||||||
|
});
|
||||||
|
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
||||||
|
MemRefType type,
|
||||||
|
ValueRange inputs, Location loc) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(inputs[0].getType().isa<RankedTensorType>());
|
||||||
|
return (Value)builder.create<tcp::TensorToMemrefOp>(loc, type, inputs[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
|
// The shaped results ops themselves. They have to be legal since we delete
|
||||||
|
// them later after the conversion process.
|
||||||
|
target.addLegalOp<tcp::ShapedResultsOp>();
|
||||||
|
target.addLegalOp<tcp::YieldOp>();
|
||||||
|
// All lowering to buffers involves tcp.alloc_memref ops.
|
||||||
|
target.addLegalOp<tcp::AllocMemRefOp>();
|
||||||
|
// The casting ops are introduced by the type converter, so we should mark
|
||||||
|
// them legal.
|
||||||
|
target.addLegalOp<tcp::MemrefToTensorOp>();
|
||||||
|
target.addLegalOp<tcp::TensorToMemrefOp>();
|
||||||
|
|
||||||
|
patterns.insert<LowerLinalgGenericTensorToMemRef>(typeConverter, context);
|
||||||
|
target.addDynamicallyLegalOp<linalg::GenericOp>([](linalg::GenericOp op) {
|
||||||
|
if (llvm::any_of(op.getOperandTypes(), [](Type type) {
|
||||||
|
return type.isa<RankedTensorType>();
|
||||||
|
})) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (llvm::any_of(op.getResultTypes(), [](Type type) {
|
||||||
|
return type.isa<RankedTensorType>();
|
||||||
|
})) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
||||||
|
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||||
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
|
target.addLegalDialect<scf::SCFDialect>();
|
||||||
|
target.addLegalOp<shape::GetExtentOp>();
|
||||||
|
|
||||||
|
SmallVector<Operation *, 6> shapedResultsOps;
|
||||||
|
func.walk([&](tcp::ShapedResultsOp op) { shapedResultsOps.push_back(op); });
|
||||||
|
|
||||||
|
if (failed(applyFullConversion(shapedResultsOps, target, patterns)))
|
||||||
|
return signalPassFailure();
|
||||||
|
|
||||||
|
// Now inline the tcp.shaped_results ops.
|
||||||
|
// This can't be done as part of the conversion since conversion visits
|
||||||
|
// ops in preorder, and we need the tcp.shaped_results ops to be present
|
||||||
|
// so that inner ops can get their shape.
|
||||||
|
LocallyOverrideLegalityInlinerInterface interface(context);
|
||||||
|
for (Operation *shapedResultsOp : shapedResultsOps) {
|
||||||
|
auto op = cast<tcp::ShapedResultsOp>(shapedResultsOp);
|
||||||
|
if (failed(inlineRegion(interface, &op.body(), op, ValueRange({}),
|
||||||
|
op.getResults(), /*inlineLoc=*/llvm::None,
|
||||||
|
/*shouldCloneInlinedRegion=*/false))) {
|
||||||
|
op.emitError() << "could not inline body";
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
op.erase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::NPCOMP::createLowerShapedResultsToMemrefPass() {
|
||||||
|
return std::make_unique<LowerShapedResultsToMemref>();
|
||||||
|
}
|
|
@ -0,0 +1,143 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "npcomp/E2E/E2E.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerExtractElementOp : public OpConversionPattern<ExtractElementOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(ExtractElementOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
ExtractElementOp::Adaptor adaptor(operands);
|
||||||
|
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
|
||||||
|
adaptor.indices());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerTensorFromElementsOp
|
||||||
|
: public OpConversionPattern<TensorFromElementsOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(TensorFromElementsOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
int numberOfElements = op.elements().size();
|
||||||
|
auto resultType = MemRefType::get(
|
||||||
|
{numberOfElements}, op.getType().cast<TensorType>().getElementType());
|
||||||
|
Value result = rewriter.create<AllocOp>(op.getLoc(), resultType);
|
||||||
|
for (auto element : llvm::enumerate(op.elements())) {
|
||||||
|
Value index =
|
||||||
|
rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
|
||||||
|
rewriter.create<StoreOp>(op.getLoc(), element.value(), result, index);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, {result});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerTensorCastOp : public OpConversionPattern<TensorCastOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto resultType = typeConverter->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOp(op, operands[0]);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// TODO: Upstream this.
|
||||||
|
class LowerStdToMemref : public LowerStdToMemrefBase<LowerStdToMemref> {
|
||||||
|
void runOnOperation() {
|
||||||
|
auto func = getOperation();
|
||||||
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
typeConverter.addConversion([](RankedTensorType type) -> Type {
|
||||||
|
return MemRefType::get(type.getShape(), type.getElementType());
|
||||||
|
});
|
||||||
|
typeConverter.addSourceMaterialization([](OpBuilder &builder,
|
||||||
|
RankedTensorType type,
|
||||||
|
ValueRange inputs, Location loc) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(inputs[0].getType().isa<MemRefType>());
|
||||||
|
return (Value)builder.create<tcp::MemrefToTensorOp>(loc, type, inputs[0]);
|
||||||
|
});
|
||||||
|
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
||||||
|
MemRefType type,
|
||||||
|
ValueRange inputs, Location loc) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(inputs[0].getType().isa<RankedTensorType>());
|
||||||
|
return (Value)builder.create<tcp::TensorToMemrefOp>(loc, type, inputs[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
|
|
||||||
|
// The casting ops are introduced by the type converter, so they must be
|
||||||
|
// legal.
|
||||||
|
target.addLegalOp<tcp::MemrefToTensorOp>();
|
||||||
|
target.addLegalOp<tcp::TensorToMemrefOp>();
|
||||||
|
|
||||||
|
patterns.insert<LowerExtractElementOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<ExtractElementOp>();
|
||||||
|
patterns.insert<LowerTensorFromElementsOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<TensorFromElementsOp>();
|
||||||
|
patterns.insert<LowerTensorCastOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<TensorCastOp>();
|
||||||
|
patterns.insert<LowerTensorLoadOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<TensorLoadOp>();
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(func, target, patterns)))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::NPCOMP::createLowerStdToMemrefPass() {
|
||||||
|
return std::make_unique<LowerStdToMemref>();
|
||||||
|
}
|
|
@ -0,0 +1,167 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "npcomp/E2E/E2E.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Generic "update the types according to the type converter" patterns.
|
||||||
|
//
|
||||||
|
// TODO: These should be upstreamed. There's nothing specific to memref type
|
||||||
|
// conversion about them.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// This is a type conversion similar to CallOpSignatureConversion.
|
||||||
|
class LowerIfOpTypes : public OpConversionPattern<scf::IfOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(scf::IfOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<Type, 6> newResultTypes;
|
||||||
|
for (auto type : op.getResultTypes()) {
|
||||||
|
Type newType = typeConverter->convertType(type);
|
||||||
|
if (!newType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||||
|
newResultTypes.push_back(newType);
|
||||||
|
}
|
||||||
|
rewriter.updateRootInPlace(op, [&] {
|
||||||
|
for (auto t : llvm::zip(op.getResults(), newResultTypes))
|
||||||
|
std::get<0>(t).setType(std::get<1>(t));
|
||||||
|
});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// This is a type conversion similar to CallOpSignatureConversion.
|
||||||
|
class LowerSelectOpTypes : public OpConversionPattern<SelectOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
SelectOp::Adaptor adaptor(operands);
|
||||||
|
rewriter.updateRootInPlace(
|
||||||
|
op, [&] { op.getResult().setType(adaptor.true_value().getType()); });
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Further lowerings.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerTensorToMemrefOp
|
||||||
|
: public OpConversionPattern<tcp::TensorToMemrefOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tcp::TensorToMemrefOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
tcp::TensorToMemrefOp::Adaptor adaptor(operands);
|
||||||
|
rewriter.replaceOp(op, adaptor.tensor());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerMemrefToTensorOp
|
||||||
|
: public OpConversionPattern<tcp::MemrefToTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tcp::MemrefToTensorOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
tcp::MemrefToTensorOp::Adaptor adaptor(operands);
|
||||||
|
rewriter.replaceOp(op, op.memref());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// The pass.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerStructuralToMemref
|
||||||
|
: public LowerStructuralToMemrefBase<LowerStructuralToMemref> {
|
||||||
|
void runOnOperation() {
|
||||||
|
auto func = getOperation();
|
||||||
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
// TODO: move these common type conversions to somewhere common.
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
typeConverter.addConversion([](RankedTensorType type) -> Type {
|
||||||
|
return MemRefType::get(type.getShape(), type.getElementType());
|
||||||
|
});
|
||||||
|
|
||||||
|
typeConverter.addSourceMaterialization([](OpBuilder &builder,
|
||||||
|
RankedTensorType type,
|
||||||
|
ValueRange inputs, Location loc) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(inputs[0].getType().isa<MemRefType>());
|
||||||
|
return (Value)builder.create<tcp::MemrefToTensorOp>(loc, type, inputs[0]);
|
||||||
|
});
|
||||||
|
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
||||||
|
MemRefType type,
|
||||||
|
ValueRange inputs, Location loc) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(inputs[0].getType().isa<RankedTensorType>());
|
||||||
|
return (Value)builder.create<tcp::TensorToMemrefOp>(loc, type, inputs[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
|
// All ops whose results are not tensor types are legal.
|
||||||
|
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||||
|
return llvm::all_of(op->getResultTypes(),
|
||||||
|
[](Type type) { return !type.isa<TensorType>(); });
|
||||||
|
});
|
||||||
|
|
||||||
|
populateFuncOpTypeConversionPattern(patterns, context, typeConverter);
|
||||||
|
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp op) {
|
||||||
|
return typeConverter.isSignatureLegal(op.getType()) &&
|
||||||
|
typeConverter.isLegal(&op.getBody());
|
||||||
|
});
|
||||||
|
|
||||||
|
patterns.insert<LowerSelectOpTypes>(typeConverter, context);
|
||||||
|
patterns.insert<LowerIfOpTypes>(typeConverter, context);
|
||||||
|
patterns.insert<LowerTensorToMemrefOp>(typeConverter, context);
|
||||||
|
patterns.insert<LowerMemrefToTensorOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<tcp::TensorToMemrefOp>();
|
||||||
|
|
||||||
|
if (failed(applyFullConversion(func, target, patterns)))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::NPCOMP::createLowerStructuralToMemrefPass() {
|
||||||
|
return std::make_unique<LowerStructuralToMemref>();
|
||||||
|
}
|
|
@ -20,19 +20,13 @@
|
||||||
|
|
||||||
using namespace npcomprt;
|
using namespace npcomprt;
|
||||||
|
|
||||||
extern "C" void __npcomp_compiler_rt_abort_if(bool b) {
|
extern "C" void __npcomp_compiler_rt_abort_if(bool b, const char *msg) {
|
||||||
if (b) {
|
if (b) {
|
||||||
std::fprintf(stderr, "NPCOMP: aborting!\n");
|
std::fprintf(stderr, "NPCOMP: aborting: %s\n", msg);
|
||||||
std::exit(1);
|
std::exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" std::size_t __npcomp_compiler_rt_get_extent(Tensor *tensor,
|
|
||||||
std::int32_t dim) {
|
|
||||||
assert(dim < tensor->getRank() && "dim out of bounds!");
|
|
||||||
return tensor->getExtent(dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// These definitions are based on the ones in
|
// These definitions are based on the ones in
|
||||||
// `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in
|
// `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in
|
||||||
|
|
|
@ -1,9 +1,21 @@
|
||||||
// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
// CHECK-LABEL: func @f
|
// CHECK-LABEL: func @tcf_add(
|
||||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>,
|
||||||
// Just the lightest sanity check.
|
// CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: tcp.add
|
// CHECK: %[[LHSSHAPE:.*]] = shape.shape_of %[[LHS]]
|
||||||
|
// CHECK: %[[RHSSHAPE:.*]] = shape.shape_of %[[RHS]]
|
||||||
|
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[LHSSHAPE]], %[[RHSSHAPE]]
|
||||||
|
// CHECK: %[[RET:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||||
|
// CHECK: %[[RESULTSHAPE:.*]] = shape.broadcast %[[LHSSHAPE]], %[[RHSSHAPE]]
|
||||||
|
// CHECK: %[[LHSBCAST:.*]] = "tcp.broadcast_to"(%[[LHS]], %[[RESULTSHAPE]])
|
||||||
|
// CHECK: %[[RHSBCAST:.*]] = "tcp.broadcast_to"(%[[RHS]], %[[RESULTSHAPE]])
|
||||||
|
// CHECK: %[[ADD:.*]] = "tcp.add"(%[[LHSBCAST]], %[[RHSBCAST]])
|
||||||
|
// CHECK: shape.assuming_yield %[[ADD]] : tensor<?xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return %[[RET:.*]] : tensor<?xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @tcf_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
// RUN: npcomp-opt -canonicalize <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor_to_memref
|
||||||
|
func @tensor_to_memref_fold(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// CHECK-NEXT: return %arg0 : memref<?xf32>
|
||||||
|
%0 = tcp.memref_to_tensor %arg0 : memref<?xf32> -> tensor<?xf32>
|
||||||
|
%1 = tcp.tensor_to_memref %0 : tensor<?xf32> -> memref<?xf32>
|
||||||
|
return %1 : memref<?xf32>
|
||||||
|
}
|
|
@ -29,3 +29,14 @@ func @f() {
|
||||||
tcp.get_global_memref @g : memref<2xi8>
|
tcp.get_global_memref @g : memref<2xi8>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @g(%arg0: tensor<?x?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||||
|
// expected-error @+1 {{number of operands must equal number of results}}
|
||||||
|
%add = tcp.shaped_results %arg1, %arg1 {
|
||||||
|
%0 = "tcp.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
tcp.yield %0 : tensor<?x?xf32>
|
||||||
|
} : tensor<?xindex>, tensor<?xindex> -> tensor<?x?xf32>
|
||||||
|
return %add : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
// CHECK-LABEL: tcp.global @foo dense<0.0{{.*}}> : tensor<10xf32>
|
// CHECK-LABEL: tcp.global @foo dense<0.0{{.*}}> : tensor<10xf32>
|
||||||
tcp.global @foo dense<0.0> : tensor<10xf32>
|
tcp.global @foo dense<0.0> : tensor<10xf32>
|
||||||
|
@ -9,3 +9,18 @@ func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
|
||||||
%1 = tcp.get_global_memref @foo : memref<10xf32>
|
%1 = tcp.get_global_memref @foo : memref<10xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @g
|
||||||
|
// CHECK-NEXT: %[[RET:.*]] = tcp.shaped_results %arg1 {
|
||||||
|
// CHECK-NEXT: %[[VAL:.*]] =
|
||||||
|
// CHECK-NEXT: tcp.yield %[[VAL]] : tensor<?x?xf32>
|
||||||
|
// CHECK-NEXT: } : tensor<?xindex> -> tensor<?x?xf32>
|
||||||
|
// CHECK-NEXT: return %[[RET]] : tensor<?x?xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
func @g(%arg0: tensor<?x?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||||
|
%add = tcp.shaped_results %arg1 {
|
||||||
|
%0 = "tcp.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
tcp.yield %0 : tensor<?x?xf32>
|
||||||
|
} : tensor<?xindex> -> tensor<?x?xf32>
|
||||||
|
return %add : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
// RUN: npcomp-opt -bypass-shapes <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
#map0 = affine_map<(d0) -> (d0)>
|
||||||
|
// CHECK-LABEL: func @linalg_generic
|
||||||
|
func @linalg_generic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
// This is an elementwise linalg op, so output shape is equal to input shape.
|
||||||
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0
|
||||||
|
// CHECK: tcp.shaped_results %[[SHAPE]]
|
||||||
|
%0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %arg0, %arg1 {
|
||||||
|
^bb0(%arg2: f32, %arg3: f32):
|
||||||
|
%8 = addf %arg2, %arg3 : f32
|
||||||
|
linalg.yield %8 : f32
|
||||||
|
}: tensor<?xf32>, tensor<?xf32> -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tcp_broadcast_to
|
||||||
|
func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) {
|
||||||
|
// CHECK: %0 = tcp.shaped_results %arg1
|
||||||
|
%0 = "tcp.broadcast_to"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xindex>) -> tensor<?x?xf32>
|
||||||
|
return
|
||||||
|
}
|
|
@ -10,10 +10,3 @@ func @rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func @multiple_ops
|
|
||||||
func @multiple_ops(%arg0: tensor<f32>, %arg1: tensor<?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|
||||||
%0 = "tcf.add"(%arg1, %arg2) : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
|
||||||
%1 = "tcf.add"(%arg0, %0) : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
|
||||||
return %1 : tensor<?x?xf32>
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func @global_add
|
// CHECK-LABEL: func @global_add
|
||||||
func @global_add() -> tensor<2xf32> attributes {iree.module.export} {
|
func @global_add() -> tensor<2xf32> {
|
||||||
%cst = constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
%cst = constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||||
%cst_0 = constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
|
%cst_0 = constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
|
||||||
%0 = "tcf.add"(%cst, %cst_0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
%0 = "tcf.add"(%cst, %cst_0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -split-input-file -lower-constant-tensors-to-memrefs <%s | FileCheck %s
|
// RUN: npcomp-opt -split-input-file -lower-constant-tensors-to-memref <%s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: module {
|
// CHECK-LABEL: module {
|
||||||
// We check the debug name too since we put some effort into making that readable.
|
// We check the debug name too since we put some effort into making that readable.
|
||||||
|
@ -7,7 +7,7 @@
|
||||||
// CHECK: func @basic
|
// CHECK: func @basic
|
||||||
func @basic() -> tensor<3x4xf32> {
|
func @basic() -> tensor<3x4xf32> {
|
||||||
// CHECK: %[[MEMREF:.*]] = tcp.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
|
// CHECK: %[[MEMREF:.*]] = tcp.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
|
||||||
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]]
|
// CHECK: %[[TENSOR:.*]] = tcp.memref_to_tensor %[[MEMREF]]
|
||||||
%0 = constant dense<7.0> : tensor<3x4xf32>
|
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||||
// CHECK: return %[[TENSOR]]
|
// CHECK: return %[[TENSOR]]
|
||||||
return %0 : tensor<3x4xf32>
|
return %0 : tensor<3x4xf32>
|
|
@ -1,15 +0,0 @@
|
||||||
// RUN: npcomp-opt -lower-linalg-tensor-to-memref <%s | FileCheck %s --dump-input=fail
|
|
||||||
#map0 = affine_map<(d0) -> (d0)>
|
|
||||||
// CHECK-LABEL: func @f
|
|
||||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|
||||||
// CHECK-DAG: %[[LHS:.+]] = tcp.alloc_memref
|
|
||||||
// CHECK-DAG: %[[RHS:.+]] = tcp.alloc_memref
|
|
||||||
// CHECK-DAG: %[[DST:.+]] = tcp.alloc_memref
|
|
||||||
// CHECK: linalg.generic{{.*}} %[[LHS]], %[[RHS]], %[[DST]]
|
|
||||||
%0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %arg0, %arg1 {
|
|
||||||
^bb0(%arg2: f32, %arg3: f32):
|
|
||||||
%8 = addf %arg2, %arg3 : f32
|
|
||||||
linalg.yield %8 : f32
|
|
||||||
}: tensor<?xf32>, tensor<?xf32> -> tensor<?xf32>
|
|
||||||
return %0 : tensor<?xf32>
|
|
||||||
}
|
|
|
@ -1,50 +0,0 @@
|
||||||
// RUN: npcomp-opt -lower-ranked-shapes <%s -split-input-file -verify-diagnostics | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @broadcast_rank2_rank1
|
|
||||||
func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index, index) {
|
|
||||||
// CHECK-NOT: shape.broadcast
|
|
||||||
// CHECK-NOT: tcp.get_extent
|
|
||||||
// CHECK-NOT: shape.from_extents
|
|
||||||
%0 = shape.from_extents %arg0, %arg1
|
|
||||||
%1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<?xindex>
|
|
||||||
%2 = shape.from_extents %arg2
|
|
||||||
%3 = shape.to_extent_tensor %2 : !shape.shape -> tensor<?xindex>
|
|
||||||
%4 = "shape.broadcast"(%1, %3) : (tensor<?xindex>, tensor<?xindex>) -> !shape.shape
|
|
||||||
%5 = shape.to_extent_tensor %4 : !shape.shape -> tensor<?xindex>
|
|
||||||
%c0 = constant 0 : index
|
|
||||||
%c1 = constant 1 : index
|
|
||||||
%e0 = shape.get_extent %5, %c0 : tensor<?xindex>, index -> index
|
|
||||||
%e1 = shape.get_extent %5, %c1 : tensor<?xindex>, index -> index
|
|
||||||
return %e0, %e1 : index, index
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func @erase_stray_shape_ops
|
|
||||||
func @erase_stray_shape_ops(%arg0: index) {
|
|
||||||
// CHECK-NOT: tcp.shape_observe_error
|
|
||||||
// CHECK-NOT: shape.from_extents
|
|
||||||
%0 = shape.from_extents %arg0
|
|
||||||
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
func @cannot_erase_stray_shape_ops() -> !shape.shape {
|
|
||||||
// expected-error @+1 {{could not be eliminated}}
|
|
||||||
%0 = shape.from_extents
|
|
||||||
return %0 : !shape.shape
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// TODO: Remove this as it is now just testing shape and std ops.
|
|
||||||
// CHECK-LABEL: func @const_shape
|
|
||||||
func @const_shape() -> index {
|
|
||||||
// CHECK-NOT: shape.const_shape
|
|
||||||
%0 = shape.const_shape [] : tensor<?xindex>
|
|
||||||
%1 = shape.const_shape [7] : tensor<?xindex>
|
|
||||||
%2 = constant 0 : index
|
|
||||||
%3 = shape.get_extent %1, %2 : tensor<?xindex>, index -> index
|
|
||||||
// CHECK: %[[C7:.*]] = constant 7 : index
|
|
||||||
// CHECK: return %[[C7]]
|
|
||||||
return %3 : index
|
|
||||||
}
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
// RUN: npcomp-opt -lower-shape-constraints <%s | FileCheck %s
|
||||||
|
|
||||||
|
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
|
||||||
|
%witness = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
|
||||||
|
return %witness : !shape.witness
|
||||||
|
}
|
||||||
|
// There's not very much useful to check here other than pasting the output.
|
||||||
|
// CHECK-LABEL: func @cstr_broadcastable(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?xindex>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?xindex>) -> !shape.witness {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[VAL_3:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_4:.*]] = shape.const_witness true
|
||||||
|
// CHECK: %[[VAL_5:.*]] = dim %[[VAL_0]], %[[VAL_2]] : tensor<?xindex>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = dim %[[VAL_1]], %[[VAL_2]] : tensor<?xindex>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = cmpi "ule", %[[VAL_5]], %[[VAL_6]] : index
|
||||||
|
// CHECK: %[[VAL_8:.*]]:4 = scf.if %[[VAL_7]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
|
||||||
|
// CHECK: scf.yield %[[VAL_5]], %[[VAL_0]], %[[VAL_6]], %[[VAL_1]] : index, tensor<?xindex>, index, tensor<?xindex>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: scf.yield %[[VAL_6]], %[[VAL_1]], %[[VAL_5]], %[[VAL_0]] : index, tensor<?xindex>, index, tensor<?xindex>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_9:.*]] = subi %[[VAL_10:.*]]#2, %[[VAL_10]]#0 : index
|
||||||
|
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]]#2 step %[[VAL_3]] {
|
||||||
|
// CHECK: %[[VAL_12:.*]] = extract_element %[[VAL_10]]#3{{\[}}%[[VAL_11]]] : tensor<?xindex>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = subi %[[VAL_11]], %[[VAL_9]] : index
|
||||||
|
// CHECK: %[[VAL_14:.*]] = extract_element %[[VAL_10]]#1{{\[}}%[[VAL_13]]] : tensor<?xindex>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = cmpi "eq", %[[VAL_12]], %[[VAL_3]] : index
|
||||||
|
// CHECK: %[[VAL_16:.*]] = cmpi "eq", %[[VAL_14]], %[[VAL_3]] : index
|
||||||
|
// CHECK: %[[VAL_17:.*]] = cmpi "eq", %[[VAL_12]], %[[VAL_14]] : index
|
||||||
|
// CHECK: %[[VAL_18:.*]] = or %[[VAL_15]], %[[VAL_16]] : i1
|
||||||
|
// CHECK: %[[VAL_19:.*]] = or %[[VAL_17]], %[[VAL_18]] : i1
|
||||||
|
// CHECK: assert %[[VAL_19]], "invalid broadcast"
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return %[[VAL_4]] : !shape.witness
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// Check that `shape.assuming` is eliminated after we create the error handling code.
|
||||||
|
// CHECK-LABEL: func @assuming
|
||||||
|
func @assuming(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> tensor<2xf32> {
|
||||||
|
%witness = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
|
||||||
|
// CHECK-NOT: shape.assuming
|
||||||
|
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<2xf32>
|
||||||
|
%0 = shape.assuming %witness -> tensor<2xf32> {
|
||||||
|
%c = constant dense<0.0> : tensor<2xf32>
|
||||||
|
shape.assuming_yield %c : tensor<2xf32>
|
||||||
|
}
|
||||||
|
// CHECK: return %[[CST]]
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
// RUN: npcomp-opt -lower-shaped-results-to-memref <%s -split-input-file | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
#map0 = affine_map<(d0) -> (d0)>
|
||||||
|
// CHECK-LABEL: func @linalg_generic
|
||||||
|
func @linalg_generic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xindex>) -> tensor<?xf32> {
|
||||||
|
// CHECK: %[[LHS:.*]] = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
||||||
|
// CHECK: %[[RHS:.*]] = tcp.tensor_to_memref %arg1 : tensor<?xf32> -> memref<?xf32>
|
||||||
|
// CHECK: %[[DST:.*]] = tcp.alloc_memref %arg2 : memref<?xf32>
|
||||||
|
// CHECK: linalg.generic {{.*}} %[[LHS]], %[[RHS]], %[[DST]]
|
||||||
|
// CHECK-NOT: tcp.shaped_results
|
||||||
|
%0 = tcp.shaped_results %arg2 {
|
||||||
|
%0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %arg0, %arg1 {
|
||||||
|
^bb0(%arg3: f32, %arg4: f32):
|
||||||
|
%8 = addf %arg3, %arg4 : f32
|
||||||
|
linalg.yield %8 : f32
|
||||||
|
} : tensor<?xf32>, tensor<?xf32> -> tensor<?xf32>
|
||||||
|
tcp.yield %0 : tensor<?xf32>
|
||||||
|
} : tensor<?xindex> -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tcp_broadcast_to
|
||||||
|
func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||||
|
// Check for two nested loops, but don't look at more detail for now.
|
||||||
|
// TODO: This pass should not create loops. Instead it should create a
|
||||||
|
// buffer version of tcp.broadcast_to
|
||||||
|
// CHECK: scf.for
|
||||||
|
// CHECK: scf.for
|
||||||
|
// CHECK-NOT: tcp.shaped_results
|
||||||
|
%0 = tcp.shaped_results %arg1 {
|
||||||
|
%0 = "tcp.broadcast_to"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xindex>) -> tensor<?x?xf32>
|
||||||
|
tcp.yield %0 : tensor<?x?xf32>
|
||||||
|
} : tensor<?xindex> -> tensor<?x?xf32>
|
||||||
|
return %0 : tensor<?x?xf32>
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
// RUN: npcomp-opt -lower-std-to-memref <%s -split-input-file | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// If we also ran -lower-structural-to-memref, we could avoid all this casting
|
||||||
|
// stuff and make the output of the test cases cleaner, but we choose not to do
|
||||||
|
// that to make the test actually check what happens in practice.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @extract_element
|
||||||
|
// CHECK: %[[MEMREF:.*]] = tcp.tensor_to_memref %arg0
|
||||||
|
// CHECK: %[[RET:.*]] = load %[[MEMREF]][%arg1] : memref<?xf32>
|
||||||
|
// CHECK: return %[[RET]] : f32
|
||||||
|
func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
|
||||||
|
%0 = extract_element %arg0[%arg1] : tensor<?xf32>
|
||||||
|
return %0 : f32
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @tensor_from_elements(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: index,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: index) -> tensor<2xindex> {
|
||||||
|
// CHECK: %[[MEMREF:.*]] = alloc()
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: store %[[ARG0]], %[[MEMREF]][%[[C0]]]
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: store %[[ARG1]], %[[MEMREF]][%[[C1]]]
|
||||||
|
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[MEMREF]]
|
||||||
|
// CHECK: return %[[RET]] : tensor<2xindex>
|
||||||
|
func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
|
||||||
|
%0 = tensor_from_elements %arg0, %arg1 : tensor<2xindex>
|
||||||
|
return %0 : tensor<2xindex>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor_cast(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xindex>) -> tensor<2xindex> {
|
||||||
|
// CHECK: %[[MEMREF:.*]] = tcp.tensor_to_memref %[[ARG0]] : tensor<?xindex> -> memref<?xindex>
|
||||||
|
// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
|
||||||
|
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[CASTED]] : memref<2xindex> -> tensor<2xindex>
|
||||||
|
// CHECK: return %[[RET]] : tensor<2xindex>
|
||||||
|
func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
|
||||||
|
%0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
|
||||||
|
return %0 : tensor<2xindex>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor_load(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: memref<?xindex>) -> tensor<?xindex> {
|
||||||
|
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[ARG0]] : memref<?xindex> -> tensor<?xindex>
|
||||||
|
// CHECK: return %[[RET]] : tensor<?xindex>
|
||||||
|
func @tensor_load(%arg0: memref<?xindex>) -> tensor<?xindex> {
|
||||||
|
%0 = tensor_load %arg0 : memref<?xindex>
|
||||||
|
return %0 : tensor<?xindex>
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
// RUN: npcomp-opt -lower-structural-to-memref <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// Basic cases.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// CHECK-NEXT: return %arg0 : memref<?xf32>
|
||||||
|
func @identity(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
return %arg0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @bb_arg(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// CHECK-NEXT: br ^bb1(%arg0 : memref<?xf32>)
|
||||||
|
// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>):
|
||||||
|
// CHECK-NEXT: return %[[BBARG]] : memref<?xf32>
|
||||||
|
func @bb_arg(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
br ^bb1(%arg0: tensor<?xf32>)
|
||||||
|
^bb1(%bbarg: tensor<?xf32>):
|
||||||
|
return %bbarg : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @select(%arg0: i1, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// CHECK-NEXT: %[[RET:.*]] = select %arg0, %arg1, %arg2 : memref<?xf32>
|
||||||
|
// CHECK-NEXT: return %[[RET]] : memref<?xf32>
|
||||||
|
func @select(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = std.select %pred, %true_val, %false_val : tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @if(%arg0: i1, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// CHECK-NEXT: %[[RET:.*]] = scf.if %arg0 -> (memref<?xf32>) {
|
||||||
|
// CHECK-NEXT: scf.yield %arg1 : memref<?xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: scf.yield %arg2 : memref<?xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return %[[RET]] : memref<?xf32>
|
||||||
|
func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = scf.if %pred -> (tensor<?xf32>) {
|
||||||
|
scf.yield %true_val : tensor<?xf32>
|
||||||
|
} else {
|
||||||
|
scf.yield %false_val : tensor<?xf32>
|
||||||
|
}
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Test the interactions with materializations.
|
||||||
|
// Note: this pass never actually expects IR with memref argument types.
|
||||||
|
// We use memref-typed arguments purely for testing convenience.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @identity_materializations(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// CHECK-NEXT: return %arg0 : memref<?xf32>
|
||||||
|
func @identity_materializations(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
||||||
|
%1 = tcp.memref_to_tensor %0 : memref<?xf32> -> tensor<?xf32>
|
||||||
|
return %1 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @if_materializations(%arg0: i1, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// CHECK-NEXT: %[[RET:.*]] = scf.if %arg0 -> (memref<?xf32>) {
|
||||||
|
// CHECK-NEXT: scf.yield %arg1 : memref<?xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: scf.yield %arg2 : memref<?xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return %[[RET]] : memref<?xf32>
|
||||||
|
func @if_materializations(%pred: i1, %true_val_memref: memref<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%true_val = tcp.memref_to_tensor %true_val_memref : memref<?xf32> -> tensor<?xf32>
|
||||||
|
%0 = scf.if %pred -> (tensor<?xf32>) {
|
||||||
|
scf.yield %true_val : tensor<?xf32>
|
||||||
|
} else {
|
||||||
|
scf.yield %false_val : tensor<?xf32>
|
||||||
|
}
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @elide_memref_to_tensor(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// CHECK-NEXT: return %arg0 : memref<?xf32>
|
||||||
|
func @elide_memref_to_tensor(%arg0: memref<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = tcp.memref_to_tensor %arg0 : memref<?xf32> -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @elide_tensor_to_memref(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// CHECK-NEXT: return %arg0 : memref<?xf32>
|
||||||
|
func @elide_tensor_to_memref(%arg0: tensor<?xf32>) -> memref<?xf32> {
|
||||||
|
%0 = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
||||||
|
return %0 : memref<?xf32>
|
||||||
|
}
|
|
@ -1,8 +1,7 @@
|
||||||
// RUN: npcomp-opt -e2e-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt -e2e-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @malloc(!llvm.i64) -> !llvm.ptr<i8>
|
// CHECK-LABEL: llvm.func @malloc(!llvm.i64) -> !llvm.ptr<i8>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1)
|
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_extent(!llvm.ptr<i8>, !llvm.i32) -> !llvm.i64
|
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
|
|
|
@ -14,8 +14,7 @@
|
||||||
// CHECK: llvm.store %[[VAL_6]], %[[VAL_9]] : !llvm.ptr<ptr<i8>>
|
// CHECK: llvm.store %[[VAL_6]], %[[VAL_9]] : !llvm.ptr<ptr<i8>>
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1)
|
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_extent(!llvm.ptr<i8>, !llvm.i32) -> !llvm.i64
|
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
|
@ -112,8 +111,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
||||||
// CHECK: llvm.call @inputs1results0(%[[VAL_5]]) : (!llvm.ptr<i8>) -> ()
|
// CHECK: llvm.call @inputs1results0(%[[VAL_5]]) : (!llvm.ptr<i8>) -> ()
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1)
|
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_extent(!llvm.ptr<i8>, !llvm.i32) -> !llvm.i64
|
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
|
@ -213,32 +211,23 @@ func @inputs1results2(%arg0: !npcomprt.tensor) -> (!npcomprt.tensor, !npcomprt.t
|
||||||
|
|
||||||
// Test emission of compiler runtime functions.
|
// Test emission of compiler runtime functions.
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1)
|
// CHECK: llvm.mlir.global internal constant @[[STRSYM:.*]]("msg")
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_extent(!llvm.ptr<i8>, !llvm.i32) -> !llvm.i64
|
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @calls_abort_if(
|
// CHECK-LABEL: llvm.func @calls_abort_if(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.i1) {
|
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.i1) {
|
||||||
// CHECK: llvm.call @__npcomp_compiler_rt_abort_if(%[[VAL_0]]) : (!llvm.i1) -> ()
|
// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @[[STRSYM]] : !llvm.ptr<array<3 x i8>>
|
||||||
|
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
|
// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<3 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
|
||||||
|
// CHECK: llvm.call @__npcomp_compiler_rt_abort_if(%[[VAL_3:.*]], %[[VAL_2]]) : (!llvm.i1, !llvm.ptr<i8>) -> ()
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
// CHECK: }
|
|
||||||
func @calls_abort_if(%arg0: i1) {
|
|
||||||
npcomprt.abort_if %arg0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @calls_get_extent(
|
func @calls_abort_if(%arg0: i1) {
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.i64 {
|
npcomprt.abort_if %arg0, "msg"
|
||||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
return
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.call @__npcomp_compiler_rt_get_extent(%[[VAL_0]], %[[VAL_1]]) : (!llvm.ptr<i8>, !llvm.i32) -> !llvm.i64
|
|
||||||
// CHECK: llvm.return %[[VAL_2]] : !llvm.i64
|
|
||||||
// CHECK: }
|
|
||||||
func @calls_get_extent(%arg0: !npcomprt.tensor) -> index {
|
|
||||||
%c1 = constant 1 : i32
|
|
||||||
%0 = npcomprt.get_extent %arg0, %c1
|
|
||||||
return %0 : index
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @calls_to_memref(
|
// CHECK-LABEL: llvm.func @calls_to_memref(
|
||||||
|
|
|
@ -1,39 +1,74 @@
|
||||||
// RUN: npcomp-opt -lower-to-npcomprt-abi -split-input-file -verify-diagnostics <%s | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt -lower-to-npcomprt-abi -split-input-file -verify-diagnostics <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// Test module metadata.
|
||||||
|
|
||||||
// CHECK: npcomprt.module_metadata
|
// CHECK: npcomprt.module_metadata
|
||||||
// CHECK-NEXT: npcomprt.func_metadata {funcName = @identity, numInputs = 1 : i32, numOutputs = 1 : i32}
|
// CHECK-NEXT: npcomprt.func_metadata {funcName = @f_2inputs_0outputs, numInputs = 2 : i32, numOutputs = 0 : i32}
|
||||||
// CHECK-NEXT: npcomprt.func_metadata {funcName = @basic, numInputs = 1 : i32, numOutputs = 1 : i32}
|
// CHECK-NEXT: npcomprt.func_metadata {funcName = @f_1input_2outputs, numInputs = 1 : i32, numOutputs = 2 : i32}
|
||||||
|
|
||||||
|
// This function only exists to test its metadata above.
|
||||||
// CHECK-LABEL: func @identity(
|
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !npcomprt.tensor) -> !npcomprt.tensor {
|
return
|
||||||
// CHECK: return %[[VAL_0]] : !npcomprt.tensor
|
|
||||||
// CHECK: }
|
|
||||||
func @identity(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|
||||||
return %arg0 : tensor<?xf32>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @basic(
|
// This function only exists to test its metadata above.
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !npcomprt.tensor) -> !npcomprt.tensor {
|
func @f_1input_2outputs(%arg0: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
|
||||||
// CHECK: %[[VAL_1:.*]] = constant 0 : i32
|
return %arg0, %arg0 : memref<?xf32>, memref<?xf32>
|
||||||
// CHECK: %[[VAL_2:.*]] = npcomprt.get_extent %[[VAL_0]], %[[VAL_1]]
|
}
|
||||||
// CHECK: %[[VAL_3:.*]] = shape.from_extents %[[VAL_2]]
|
|
||||||
// CHECK: %[[VAL_4:.*]] = shape.to_extent_tensor %[[VAL_3]]
|
|
||||||
// CHECK: %[[VAL_5:.*]] = tcp.alloc_memref %[[VAL_4]] : memref<?xf32>
|
|
||||||
// CHECK: %[[VAL_6:.*]] = npcomprt.to_memref %[[VAL_0]] : memref<*xf32>
|
|
||||||
// CHECK: %[[VAL_7:.*]] = memref_cast %[[VAL_6]] : memref<*xf32> to memref<?xf32>
|
|
||||||
// CHECK: linalg.copy(%[[VAL_7]], %[[VAL_5]]) : memref<?xf32>, memref<?xf32>
|
|
||||||
// CHECK: %[[VAL_8:.*]] = memref_cast %[[VAL_5]] : memref<?xf32> to memref<*xf32>
|
|
||||||
// CHECK: %[[VAL_9:.*]] = npcomprt.from_memref %[[VAL_8]] : memref<*xf32>
|
|
||||||
// CHECK: return %[[VAL_9]] : !npcomprt.tensor
|
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
// -----
|
||||||
%shape = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
|
|
||||||
%memref = tcp.alloc_memref %shape : memref<?xf32>
|
// Test ABI conversions.
|
||||||
tensor_store %arg0, %memref : memref<?xf32>
|
|
||||||
%ret = tensor_load %memref : memref<?xf32>
|
// CHECK-LABEL: func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor
|
||||||
return %ret: tensor<?xf32>
|
func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// The argument materialization.
|
||||||
|
// In this test case, these go unused since, as described below, the new
|
||||||
|
// argument value is seen immediately by the return op for some reason.
|
||||||
|
// CHECK-NEXT: %[[INABIMEMREF:.*]] = npcomprt.to_memref %arg0 : memref<*xf32>
|
||||||
|
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
|
||||||
|
|
||||||
|
// TODO: Why do these target materializations not happen in this particular
|
||||||
|
// test?
|
||||||
|
// Somehow, the return op rewrite sees the new argument value immediately,
|
||||||
|
// rather than the result of replaceUsesOfBlockArgument from
|
||||||
|
// FuncOpSignatureConversion
|
||||||
|
// Cxxxx-NEXT: %[[OUTABIMEMREF:.*]] = memref_cast %[[MEMREF]] : memref<?xf32> to memref<*xf32>
|
||||||
|
// Cxxxx-NEXT: %[[RET:.*]] = npcomprt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
|
||||||
|
// Cxxxx-NEXT: return %[[RET]]
|
||||||
|
|
||||||
|
// CHECK-NEXT: return %arg0
|
||||||
|
return %arg0 : memref<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @use_of_arg(%arg0: !npcomprt.tensor)
|
||||||
|
func @use_of_arg(%arg0: memref<?xf32>) {
|
||||||
|
// CHECK-NEXT: %[[INABIMEMREF:.*]] = npcomprt.to_memref %arg0 : memref<*xf32>
|
||||||
|
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%0 = dim %arg0, %c0 : memref<?xf32>
|
||||||
|
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK-NEXT: dim %[[MEMREF]], %[[C0]] : memref<?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @multiple_blocks(%arg0: !npcomprt.tensor) -> !npcomprt.tensor
|
||||||
|
func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
// CHECK-NEXT: %[[INABIMEMREF:.*]] = npcomprt.to_memref %arg0 : memref<*xf32>
|
||||||
|
// CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
|
||||||
|
// CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref<?xf32>)
|
||||||
|
br ^bb1(%arg0: memref<?xf32>)
|
||||||
|
// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>):
|
||||||
|
^bb1(%bbarg: memref<?xf32>):
|
||||||
|
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref_cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
|
||||||
|
// CHECK-NEXT: %[[OUTABIMEMREF:.*]] = npcomprt.from_memref %[[OUTMEMREF]] : memref<*xf32>
|
||||||
|
// CHECK-NEXT: return %[[OUTABIMEMREF]] : !npcomprt.tensor
|
||||||
|
return %bbarg : memref<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -42,19 +77,20 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32>
|
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32>
|
||||||
tcp.global @g dense<7.0> : tensor<10xf32>
|
tcp.global @g dense<7.0> : tensor<10xf32>
|
||||||
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
|
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
|
||||||
func @gets_global() -> tensor<10xf32> {
|
func @gets_global() -> memref<10xf32> {
|
||||||
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32>
|
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32>
|
||||||
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
|
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
|
||||||
// CHECK: %[[RETMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
|
// CHECK: %[[OUTABIMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
|
||||||
// CHECK: %[[RET:.*]] = npcomprt.from_memref %[[RETMEMREF]] : memref<*xf32>
|
// CHECK: %[[RET:.*]] = npcomprt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
|
||||||
// CHECK: return %[[RET]] : !npcomprt.tensor
|
// CHECK: return %[[RET]] : !npcomprt.tensor
|
||||||
%0 = tcp.get_global_memref @g : memref<10xf32>
|
%0 = tcp.get_global_memref @g : memref<10xf32>
|
||||||
%1 = tensor_load %0 : memref<10xf32>
|
return %0 : memref<10xf32>
|
||||||
return %1 : tensor<10xf32>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Test diagnostics.
|
||||||
|
|
||||||
// expected-error @+1 {{func not expressible with npcomprt ABI}}
|
// expected-error @+1 {{func not expressible with npcomprt ABI}}
|
||||||
func @unhandled_abi_type_on_public_func(%arg0: i32) {
|
func @unhandled_abi_type_on_public_func(%arg0: i32) {
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
// RUN: npcomp-opt -resolve-shape-of-ops <%s -split-input-file -verify-diagnostics | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @basic
|
|
||||||
func @basic(%arg0: tensor<?xindex>) -> tensor<?xindex> {
|
|
||||||
%memref = tcp.alloc_memref %arg0 : memref<?xf32>
|
|
||||||
%tensor = tensor_load %memref : memref<?xf32>
|
|
||||||
%shape = "shape.shape_of"(%tensor) : (tensor<?xf32>) -> tensor<?xindex>
|
|
||||||
// CHECK: return %arg0
|
|
||||||
return %shape : tensor<?xindex>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @arg_unresolved_ok
|
|
||||||
func @arg_unresolved_ok(%arg0: tensor<?xf32>) -> tensor<?xindex> {
|
|
||||||
%0 = "shape.shape_of"(%arg0): (tensor<?xf32>) -> tensor<?xindex>
|
|
||||||
return %0 : tensor<?xindex>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @TODO_bb_arg_unresolved_not_ok
|
|
||||||
// TODO: This should emit a diagnostic, but doesn't. Why?
|
|
||||||
// addDynamicallyLegalOp isn't working as I expect.
|
|
||||||
func @TODO_bb_arg_unresolved_not_ok(%arg0: i1, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xindex> {
|
|
||||||
cond_br %arg0, ^bb1(%arg1: tensor<?xf32>), ^bb1(%arg2: tensor<?xf32>)
|
|
||||||
^bb1(%bbarg: tensor<?xf32>):
|
|
||||||
%0 = "shape.shape_of"(%bbarg): (tensor<?xf32>) -> tensor<?xindex>
|
|
||||||
return %0 : tensor<?xindex>
|
|
||||||
}
|
|
|
@ -1,27 +0,0 @@
|
||||||
// RUN: npcomp-opt -resolve-tensor-load-store-ops <%s | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @basic
|
|
||||||
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|
||||||
|
|
||||||
%shape = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> tensor<?xindex>
|
|
||||||
|
|
||||||
// CHECK: %[[SRCMEMREF:.+]] = tcp.alloc_memref
|
|
||||||
%src_memref = tcp.alloc_memref %shape : memref<?xf32>
|
|
||||||
// tensor_store of argument remains.
|
|
||||||
// CHECK: tensor_store %arg0, %[[SRCMEMREF]]
|
|
||||||
tensor_store %arg0, %src_memref : memref<?xf32>
|
|
||||||
%src = tensor_load %src_memref : memref<?xf32>
|
|
||||||
|
|
||||||
// CHECK: %[[DSTMEMREF:.+]] = tcp.alloc_memref
|
|
||||||
%dst_memref = tcp.alloc_memref %shape : memref<?xf32>
|
|
||||||
// tensor_store of internally created tensor is eliminated.
|
|
||||||
// CHECK-NOT: tensor_store
|
|
||||||
// CHECK: linalg.copy(%[[SRCMEMREF]], %[[DSTMEMREF]])
|
|
||||||
tensor_store %src, %dst_memref : memref<?xf32>
|
|
||||||
%ret = tensor_load %dst_memref : memref<?xf32>
|
|
||||||
|
|
||||||
// The tensor_load feeding into the return remains.
|
|
||||||
// %[[RET:.+]] = tensor_load %[[DSTMEMREF]]
|
|
||||||
// return %[[RET]]
|
|
||||||
return %ret : tensor<?xf32>
|
|
||||||
}
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
// RUN: not npcomp-run-mlir %s \
|
||||||
|
// RUN: -invoke invalid_broadcast \
|
||||||
|
// RUN: -arg-value="dense<[1.0, 2.0]> : tensor<2xf32>" \
|
||||||
|
// RUN: -arg-value="dense<[3.0, 4.0, 5.0]> : tensor<3xf32>" \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: NPCOMP: aborting: invalid broadcast
|
||||||
|
func @invalid_broadcast(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
|
@ -1,12 +1,12 @@
|
||||||
// RUN: npcomp-run-mlir %s \
|
// RUN: npcomp-run-mlir %s \
|
||||||
// RUN: -invoke multi_output \
|
// RUN: -invoke multi_output \
|
||||||
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
|
// RUN: -arg-value="dense<1.0> : tensor<1xf32>" \
|
||||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
// RUN: | FileCheck %s
|
// RUN: | FileCheck %s
|
||||||
|
|
||||||
// CHECK: output #0: dense<2.000000e+00> : tensor<f32>
|
// CHECK: output #0: dense<2.000000e+00> : tensor<1xf32>
|
||||||
// CHECK: output #1: dense<2.000000e+00> : tensor<f32>
|
// CHECK: output #1: dense<2.000000e+00> : tensor<1xf32>
|
||||||
func @multi_output(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
func @multi_output(%arg0: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
|
||||||
%0 = "tcf.add"(%arg0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
%0 = "tcf.add"(%arg0, %arg0) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0, %0 : tensor<f32>, tensor<f32>
|
return %0, %0 : tensor<?xf32>, tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue