mirror of https://github.com/llvm/torch-mlir
Initial TCF/TCP E2E seed.
Very much WIP. This is enough to get tcf.add down to approximately the "linalg.generic on buffers" level of abstraction. (but there are nuances)pull/1/head
parent
f394e12d86
commit
e29aef855b
|
@ -1 +1,3 @@
|
|||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(E2E)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(NPCOMPConversionPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc ConversionPasses ./)
|
|
@ -0,0 +1,32 @@
|
|||
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_CONVERSION_PASSES
|
||||
#define NPCOMP_CONVERSION_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TCFToTCP
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
|
||||
let summary = "Convert TCF to TCP";
|
||||
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TCPToLinalg
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertTCPToLinalg : Pass<"convert-tcp-to-linalg", "ModuleOp"> {
|
||||
let summary = "Convert TCP to Linalg";
|
||||
let constructor = "mlir::NPCOMP::createConvertTCPToLinalgPass()";
|
||||
}
|
||||
|
||||
#endif // NPCOMP_CONVERSION_PASSES
|
|
@ -0,0 +1,21 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_CONVERSION_TCFTOTCP_CONVERTTCFTOTCP_H
|
||||
#define NPCOMP_CONVERSION_TCFTOTCP_CONVERTTCFTOTCP_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToTCPPass();
|
||||
}
|
||||
}
|
||||
|
||||
#endif // NPCOMP_CONVERSION_TCFTOTCP_CONVERTTCFTOTCP_H
|
|
@ -0,0 +1,22 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_CONVERSION_TCPTOLINALG_CONVERTTCPTOLINALG_H
|
||||
#define NPCOMP_CONVERSION_TCPTOLINALG_CONVERTTCPTOLINALG_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCPToLinalgPass();
|
||||
}
|
||||
}
|
||||
|
||||
#endif // NPCOMP_CONVERSION_TCPTOLINALG_CONVERTTCPTOLINALG_H
|
|
@ -1,2 +1,4 @@
|
|||
add_subdirectory(Basicpy)
|
||||
add_subdirectory(Numpy)
|
||||
add_subdirectory(TCF)
|
||||
add_subdirectory(TCP)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(IR)
|
|
@ -0,0 +1 @@
|
|||
add_mlir_dialect(TCFOps tcf)
|
|
@ -0,0 +1,42 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TCF_BASE
|
||||
#define TCF_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def TCF_Dialect : Dialect {
|
||||
let name = "tcf";
|
||||
let cppNamespace = "::mlir::NPCOMP::tcf";
|
||||
let description = [{
|
||||
The `tcf` dialect is a key facilitator for ingesting into the MLIR ecosystem
|
||||
dynamic frontend languages with a "tensor" primitive type.
|
||||
|
||||
Some of its key features are:
|
||||
- Ops that safely report errors, such as mismatching sizes for a matrix
|
||||
multiplication.
|
||||
- Parameters controlling op behavior are dynamic operands, such as
|
||||
convolution window sizes.
|
||||
- Support for a rank-dynamic programming model.
|
||||
- Support for implicit broadcasting, following the industry-standard numpy
|
||||
broadcasting rules.
|
||||
|
||||
These features make this dialect interoperate well with highly-dynamic
|
||||
programming models as are common in many frontends.
|
||||
|
||||
This dialect is optimized for compiler analysis and transformation, especially
|
||||
lowering to lower levels of abstraction in the compiler.
|
||||
Tensor programs, as represented in this dialect, are not necessarily represented
|
||||
in the most efficient way for op-by-op execution.
|
||||
The goal is that most frontend ops are representable in a small, but
|
||||
not-necessarily-just-one set of ops from this dialect.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // #ifndef TCF_BASE
|
|
@ -0,0 +1,24 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TCF_IR_TCFDIALECT_H
|
||||
#define NPCOMP_DIALECT_TCF_IR_TCFDIALECT_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace tcf {
|
||||
|
||||
#include "npcomp/Dialect/TCF/IR/TCFOpsDialect.h.inc"
|
||||
|
||||
} // namespace tcf
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TCF_IR_TCFDIALECT_H
|
|
@ -0,0 +1,27 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TCF_IR_TCFOPS_H
|
||||
#define NPCOMP_DIALECT_TCF_IR_TCFOPS_H
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace tcf {
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/TCF/IR/TCFOps.h.inc"
|
||||
|
||||
} // namespace tcf
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TCF_IR_TCFOPS_H
|
|
@ -0,0 +1,71 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TCF_OPS
|
||||
#define TCF_OPS
|
||||
|
||||
include "npcomp/Dialect/TCF/IR/TCFBase.td"
|
||||
|
||||
class TCF_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<TCF_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
// TODO: investigate effects framework for defining error semantics
|
||||
// TODO: define in a general way across the dialect what "encounters an error" means.
|
||||
|
||||
// TODO: verify same dtype?
|
||||
// TODO: what are the allowable dtypes?
|
||||
def TCF_AddOp : TCF_Op<"add"> {
|
||||
let summary = "Add two tensors.";
|
||||
let description = [{
|
||||
Add two tensors.
|
||||
}];
|
||||
let arguments = (ins AnyTensor:$lhs, AnyTensor:$rhs);
|
||||
let results = (outs AnyTensor:$result);
|
||||
}
|
||||
|
||||
def TCF_BatchMatmulOp : TCF_Op<"batch_matmul"> {
|
||||
let summary = "Performs a batch of matrix multiplications.";
|
||||
let description = [{
|
||||
This op, in its simplest case, performs a matrix multiplication between the two operands.
|
||||
Let the input shapes of the operands have shape:
|
||||
- `lhs`: `[BLHS..., LHSROWS, LHSCOLS]`
|
||||
- `rhs`: `[BRHS..., RHSROWS, RHSCOLS]`
|
||||
Then `result` will have shape `[broadcast(BLHS, BRHS),LHSROWS,RHSCOLS]`.
|
||||
|
||||
This op encounters an error if `LHSCOLS != RHSROWS` or if
|
||||
`broadcast(BLHS, BRHS)` is not possible.
|
||||
|
||||
}];
|
||||
let arguments = (ins AnyTensor:$lhs, AnyTensor:$rhs);
|
||||
let results = (outs AnyTensor:$result);
|
||||
}
|
||||
|
||||
// TODO: represent more general convolutions (via more parameters and also more ops)
|
||||
// torch.nn.functional has a good summary of frontend needs: https://pytorch.org/docs/stable/nn.functional.html#conv2d
|
||||
// TODO: describe error conditions
|
||||
def TCF_Conv2DOp : TCF_Op<"conv_2d"> {
|
||||
let summary = "Perform a 2D convolution.";
|
||||
let description = [{
|
||||
This op performs a 2D convolution in the sense typical in deep learning
|
||||
contexts.
|
||||
|
||||
The inputs have the following rank structure:
|
||||
- `input`: `[BATCH, Zin, IN0, IN1]`
|
||||
- `kernel`: `[Zout, Zin, K0, K1]`
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyTensor:$input,
|
||||
AnyTensor:$kernel
|
||||
);
|
||||
let results = (outs
|
||||
AnyTensor:$result
|
||||
);
|
||||
}
|
||||
|
||||
#endif // #ifndef TCF_OPS
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(IR)
|
|
@ -0,0 +1 @@
|
|||
add_mlir_dialect(TCPOps tcp)
|
|
@ -0,0 +1,40 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TCP_BASE
|
||||
#define TCP_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def TCP_Dialect : Dialect {
|
||||
let name = "tcp";
|
||||
let cppNamespace = "::mlir::NPCOMP::tcp";
|
||||
let description = [{
|
||||
The `tcp` dialect is the gateway to MLIR's code generation infrastructure.
|
||||
It is also a great place to do algebraic transformations making use of
|
||||
semantically-charged named ops.
|
||||
|
||||
Features:
|
||||
- Requires ranked tensors (except for a handful of special ops).
|
||||
- No implicit broadcasting.
|
||||
- Performance-critical parameters like convolution window sizes are represented
|
||||
with attributes.
|
||||
- Attention to detail modeling ops that are logically "pure" but have
|
||||
preconditions.
|
||||
|
||||
Together these features allow a relatively large class of "common-sense"
|
||||
optimizations to be done with only modestly complex considerations.
|
||||
// TODO: consider having these ops take a "witness" argument
|
||||
// that makes them truly NoSideEffect?
|
||||
// Or have a totally pure "tcp.island" op?
|
||||
// Figure it out when doing the tcf to tcp lowering.
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif // TCP_BASE
|
|
@ -0,0 +1,24 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TCP_IR_TCPDIALECT_H
|
||||
#define NPCOMP_DIALECT_TCP_IR_TCPDIALECT_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace tcp {
|
||||
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOpsDialect.h.inc"
|
||||
|
||||
} // namespace tcp
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TCP_IR_TCPDIALECT_H
|
|
@ -0,0 +1,29 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TCP_IR_TCPOPS_H
|
||||
#define NPCOMP_DIALECT_TCP_IR_TCPOPS_H
|
||||
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Interfaces/SideEffects.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace tcp {
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h.inc"
|
||||
|
||||
} // namespace tcp
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TCP_IR_TCPOPS_H
|
|
@ -0,0 +1,123 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TCP_OPS
|
||||
#define TCP_OPS
|
||||
|
||||
include "npcomp/Dialect/TCP/IR/TCPBase.td"
|
||||
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
|
||||
class TCP_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<TCP_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
// TODO: Do islands belong inside this dialect?
|
||||
// It almost feels like they should be in an `error` dialect (containing
|
||||
// the witness type as well).
|
||||
// There would be no "shape.abort_if_error" because the aborting happens
|
||||
// inside the witness ops, with the island operating as a witness sink.
|
||||
def TCP_IslandOp : TCP_Op<"island"> {
|
||||
let summary = "Island of no-side-effect ops.";
|
||||
let description = [{
|
||||
Most ops in the `tcp` dialect have complex preconditions on their tensor
|
||||
arguments (usually their shapes) so that they can be a good starting point
|
||||
for code generation compilation flows.
|
||||
We want an efficient way to understand which ops are related to which
|
||||
preconditions without cluttering the TCP ops themselves.
|
||||
To do this, we have this `tcp.island` op which takes as operands
|
||||
witness values establishing use-def edges between precondition assertions
|
||||
and this island op, and then restrict most other `tcp` ops to reside inside
|
||||
these islands. This makes code motion to rearrange `tcp` ops simpler
|
||||
by having the witness use-def edges, without needing for every `tcp` op
|
||||
to have extra operands.
|
||||
// TODO: Still unclear if this is really that useful. This mainly affects the
|
||||
// ability to hoist tcp ops. It should always be safe to sink TCP ops.
|
||||
}];
|
||||
let arguments = (ins Variadic<NoneType>:$witnesses);
|
||||
let regions = (region AnyRegion:$body);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||
// TODO: verify return types match internal tcp.yield return ValueRange's.
|
||||
}
|
||||
|
||||
def TCP_YieldOp
|
||||
: TCP_Op<"yield", [NoSideEffect, HasParent<"IslandOp">, Terminator]> {
|
||||
let summary = "yield-like terminator for tcp.island op.";
|
||||
let description = [{
|
||||
Returns control and a variadic list of values to the parent tcp.island op.
|
||||
}];
|
||||
let arguments = (ins Variadic<AnyRankedTensor>:$operands);
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
// TODO: clarify allowed tensor element types.
|
||||
// TODO: HasParent is too restrictive? can't have an island with loop.for with
|
||||
// further ops inside it?
|
||||
def TCP_AddOp
|
||||
: TCP_Op<"add", []> {
|
||||
let summary = "Adds two tensors.";
|
||||
let description = [{
|
||||
Adds two tensors.
|
||||
}];
|
||||
let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
}
|
||||
|
||||
def TCP_BroadcastToOp : TCP_Op<"broadcast_to"> {
|
||||
let summary = "Broadcasts an operand to a given shape.";
|
||||
let description = [{
|
||||
Broadcasts `operand` to the shape `shape`.
|
||||
|
||||
It is undefined behavior if such a broadcast is not legal.
|
||||
}];
|
||||
let arguments = (ins AnyRankedTensor:$operand, Shape_ShapeType:$shape);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
}
|
||||
|
||||
// TODO: This probably doesn't belong in the tcp dialect.
|
||||
def TCP_AllocMemRefOp : TCP_Op<"alloc_memref", []> {
|
||||
let summary = "Allocates a memref of the given shape.";
|
||||
let description = [{
|
||||
Allocates a memref of the given shape.
|
||||
}];
|
||||
let arguments = (ins Shape_ShapeType:$shape);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
}
|
||||
|
||||
// TODO: Change to a more principled witness-based error handling mechanism.
|
||||
// This op probably doesn't need to exist eventually.
|
||||
def TCP_AbortIfErrorOp : TCP_Op<"abort_if_error",
|
||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Aborts the program if the argument is an error shape.";
|
||||
let description = [{
|
||||
Aborts the program if its `shape` argument is an error shape.
|
||||
TODO: can we do something better designed here then just abort?
|
||||
|
||||
Returns `none`, which can be used as a witness value to establish a use-def
|
||||
relationship between this op and an op that requires the precondition
|
||||
established by this op.
|
||||
}];
|
||||
let arguments = (ins Shape_ShapeType:$shape);
|
||||
let results = (outs NoneType:$result);
|
||||
}
|
||||
|
||||
// TODO: This probably belongs in the shape dialect.
|
||||
def TCP_GetExtentOp : TCP_Op<"get_extent",
|
||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Gets the specified extent from a shape.";
|
||||
let description = [{
|
||||
Gets the specified extent from a shape.
|
||||
|
||||
This op has undefined behavior if the shape is an error.
|
||||
}];
|
||||
let arguments = (ins Shape_ShapeType:$shape, I64Attr:$dim);
|
||||
let results = (outs Index:$extent);
|
||||
}
|
||||
|
||||
#endif // TCP_OPS
|
|
@ -0,0 +1,6 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(NPCOMPE2EPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc E2EPasses ./)
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_E2E_E2E_H
|
||||
#define NPCOMP_E2E_E2E_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLowerBroadcastToToLoopsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createLowerLinalgOnTensorToLinalgOnMemrefPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createResolveShapeOfOpsPass();
|
||||
|
||||
void createLowerToHybridTensorMemRefPipeline(OpPassManager &pm);
|
||||
|
||||
// The main pipeline that encapsulates the full E2E lowering.
|
||||
void createE2ELoweringPipeline(OpPassManager &pm);
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_E2E_E2E_H
|
|
@ -0,0 +1,31 @@
|
|||
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_E2E_PASSES
|
||||
#define NPCOMP_E2E_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def LowerLinalgOnTensorToLinalgOnMemref :
|
||||
Pass<"lower-linalg-tensor-to-memref", "FuncOp"> {
|
||||
let summary = "Lowers linalg on tensors to linalg on memrefs";
|
||||
let constructor = "mlir::NPCOMP::createLowerLinalgOnTensorToLinalgOnMemrefPass()";
|
||||
}
|
||||
|
||||
def LowerBroadcastToToLoops :
|
||||
Pass<"lower-broadcast-to-to-loops", "FuncOp"> {
|
||||
let summary = "Lower tcp::BroadcastTo to loops.";
|
||||
let constructor = "mlir::NPCOMP::createLowerBroadcastToToLoopsPass()";
|
||||
}
|
||||
|
||||
def ResolveShapeOfOps : Pass<"resolve-shape-of-ops", "FuncOp"> {
|
||||
let summary = "Resolve shape.shape_of ops to other shapes.";
|
||||
let constructor = "mlir::NPCOMP::createResolveShapeOfOpsPass()";
|
||||
}
|
||||
|
||||
#endif // NPCOMP_E2E_PASSES
|
|
@ -1 +1,3 @@
|
|||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(E2E)
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
add_subdirectory(TCFToTCP)
|
||||
add_subdirectory(TCPToLinalg)
|
|
@ -0,0 +1,23 @@
|
|||
//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_CONVERSION_PASSDETAIL_H
|
||||
#define NPCOMP_CONVERSION_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "npcomp/Conversion/Passes.h.inc"
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // NPCOMP_CONVERSION_PASSDETAIL_H
|
|
@ -0,0 +1,18 @@
|
|||
add_mlir_conversion_library(NPCOMPTCFToTCP
|
||||
TCFToTCP.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TCFToTCP
|
||||
|
||||
DEPENDS
|
||||
NPCOMPConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
MLIRShape
|
||||
)
|
|
@ -0,0 +1,84 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
namespace {
|
||||
class ConvertAdd : public OpRewritePattern<tcf::AddOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(tcf::AddOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto lhsType = op.lhs().getType().dyn_cast<RankedTensorType>();
|
||||
auto rhsType = op.rhs().getType().dyn_cast<RankedTensorType>();
|
||||
if (!lhsType || !rhsType) {
|
||||
return rewriter.notifyMatchFailure(op, "requires ranked tensors");
|
||||
}
|
||||
Value lhsShape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.lhs());
|
||||
Value rhsShape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.rhs());
|
||||
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
|
||||
op.getLoc(), lhsShape, rhsShape, /*error=*/nullptr);
|
||||
Value witness =
|
||||
rewriter.create<tcp::AbortIfErrorOp>(op.getLoc(), broadcastedShape);
|
||||
tcp::IslandOp island =
|
||||
rewriter.create<tcp::IslandOp>(op.getLoc(), op.getType(), witness);
|
||||
Region &body = island.body();
|
||||
Block *bodyBlock = new Block;
|
||||
body.push_back(bodyBlock);
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(bodyBlock, bodyBlock->begin());
|
||||
// TODO: It's annoying to do the dynamic broadcast above then
|
||||
// do the static transfer function here. Would be nice if they could
|
||||
// somehow be unified.
|
||||
SmallVector<int64_t, 6> broadcastedStaticShape;
|
||||
OpTrait::util::getBroadcastedShape(lhsType.getShape(), rhsType.getShape(),
|
||||
broadcastedStaticShape);
|
||||
auto resultType =
|
||||
RankedTensorType::get(broadcastedStaticShape, lhsType.getElementType());
|
||||
Value lhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
|
||||
op.getLoc(), resultType, op.lhs(), broadcastedShape);
|
||||
Value rhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
|
||||
op.getLoc(), resultType, op.rhs(), broadcastedShape);
|
||||
Value add = rewriter.create<tcp::AddOp>(op.getLoc(), op.getType(),
|
||||
lhsBroadcasted, rhsBroadcasted);
|
||||
rewriter.create<tcp::YieldOp>(op.getLoc(), add);
|
||||
|
||||
rewriter.replaceOp(op, island.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertTCFToTCP : public ConvertTCFToTCPBase<ConvertTCFToTCP> {
|
||||
public:
|
||||
void runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConvertAdd>(context);
|
||||
(void)applyPatternsAndFoldGreedily(module, patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::createConvertTCFToTCPPass() {
|
||||
return std::make_unique<ConvertTCFToTCP>();
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
add_mlir_conversion_library(NPCOMPTCPToLinalg
|
||||
TCPToLinalg.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TCPToLinalg
|
||||
|
||||
DEPENDS
|
||||
NPCOMPConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
MLIRShape
|
||||
)
|
|
@ -0,0 +1,84 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace NPCOMP;
|
||||
|
||||
namespace {
|
||||
class ConvertAdd : public OpRewritePattern<tcp::AddOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(tcp::AddOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
size_t rank = op.getType().cast<RankedTensorType>().getRank();
|
||||
SmallVector<StringRef, 6> iterators(rank, getParallelIteratorTypeName());
|
||||
SmallVector<AffineMap, 3> accesses(/*args in + args out*/ 3,
|
||||
rewriter.getMultiDimIdentityMap(rank));
|
||||
auto genericOp = rewriter.create<linalg::GenericOp>(
|
||||
op.getLoc(), llvm::makeArrayRef({op.getType()}),
|
||||
ValueRange({op.lhs(), op.rhs()}),
|
||||
/*args_in=*/rewriter.getI64IntegerAttr(2),
|
||||
/*args_out=*/rewriter.getI64IntegerAttr(1),
|
||||
/*indexing_maps=*/rewriter.getAffineMapArrayAttr(accesses),
|
||||
/*iterator_types=*/rewriter.getStrArrayAttr(iterators), /*doc=*/nullptr,
|
||||
/*library_call=*/nullptr);
|
||||
|
||||
Region ®ion = genericOp.region();
|
||||
Block *block = rewriter.createBlock(®ion, region.begin());
|
||||
for (auto operandType : op.getOperandTypes()) {
|
||||
block->addArgument(operandType.cast<RankedTensorType>().getElementType());
|
||||
}
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
Value bodyValue = rewriter.create<AddFOp>(
|
||||
op.getLoc(), block->getArgument(0), block->getArgument(1));
|
||||
rewriter.create<linalg::YieldOp>(op.getLoc(), bodyValue);
|
||||
|
||||
rewriter.replaceOp(op, genericOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertTCPToLinalg : public ConvertTCPToLinalgBase<ConvertTCPToLinalg> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
ConversionTarget target(*context);
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
patterns.insert<ConvertAdd>(context);
|
||||
target.addIllegalOp<tcp::AddOp>();
|
||||
|
||||
target.addLegalDialect<linalg::LinalgDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
|
||||
if (failed(applyPartialConversion(module, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::createConvertTCPToLinalgPass() {
|
||||
return std::make_unique<ConvertTCPToLinalg>();
|
||||
}
|
|
@ -1,2 +1,4 @@
|
|||
add_subdirectory(Basicpy)
|
||||
add_subdirectory(Numpy)
|
||||
add_subdirectory(TCF)
|
||||
add_subdirectory(TCP)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(IR)
|
|
@ -0,0 +1,17 @@
|
|||
add_mlir_dialect_library(NPCOMPTCF
|
||||
TCFDialect.cpp
|
||||
TCFOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/TCF
|
||||
|
||||
DEPENDS
|
||||
MLIRTCFOpsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
)
|
|
@ -0,0 +1,21 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/TCF/IR/TCFDialect.h"
|
||||
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::tcf;
|
||||
|
||||
TCFDialect::TCFDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/TCF/IR/TCFOps.cpp.inc"
|
||||
>();
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::tcf;
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace tcf {
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/TCF/IR/TCFOps.cpp.inc"
|
||||
} // namespace tcf
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(IR)
|
|
@ -0,0 +1,18 @@
|
|||
add_mlir_dialect_library(NPCOMPTCP
|
||||
TCPDialect.cpp
|
||||
TCPOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/TCP
|
||||
|
||||
DEPENDS
|
||||
MLIRTCPOpsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
MLIRSideEffects
|
||||
)
|
|
@ -0,0 +1,21 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::tcp;
|
||||
|
||||
TCPDialect::TCPDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.cpp.inc"
|
||||
>();
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::tcp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AbortIfErrorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AbortIfErrorOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(NoneType::get(context));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetExtentOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GetExtentOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(IndexType::get(context));
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace tcp {
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.cpp.inc"
|
||||
} // namespace tcp
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
|
@ -0,0 +1,18 @@
|
|||
add_mlir_library(NPCOMPE2E
|
||||
E2E.cpp
|
||||
LowerToHybridTensorMemRef.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SRC_DIR}/include/npcomp/E2E
|
||||
|
||||
DEPENDS
|
||||
NPCOMPE2EPassIncGen
|
||||
MLIRLinalgOps
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLinalgOps
|
||||
)
|
|
@ -0,0 +1,150 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the base file for our "end-to-end" npcomp lowering pipeline.
|
||||
// At the moment, the first "end" is TCF ops and the second "end" is `llvm`
|
||||
// dialect suitable for jitting.
|
||||
//
|
||||
// This is still work-in-progress and not even working end-to-end for the
|
||||
// most trivial examples, see TODO's in createE2ELoweringPipeline for the
|
||||
// status.
|
||||
//
|
||||
// As a pragmatic matter, I generally tend to drop random passes and stuff
|
||||
// inside this top-level file and then shard it out to separate files once
|
||||
// a clear organizing principle arises (to avoid premature organizing).
|
||||
//
|
||||
// Once we have end-to-end functionality working, we will throw
|
||||
// increasingly complex programs and augment this pass pipeline, likely
|
||||
// introducing better structure and more clear principles.
|
||||
//
|
||||
// I wish I had a clear view of how this pipeline should perfectly layer
|
||||
// ahead of time, but unfortunately I don't since it crosses half a dozen
|
||||
// abstraction levels / dialects, some of which have no precedent that I'm
|
||||
// aware of (dynamic-shape-aware, error-aware TCF -> TCP) or very little
|
||||
// (tensor -> memref/buffer with dynamic shapes, shape -> SSA values for
|
||||
// ranked shape extents).
|
||||
//
|
||||
// Right now there's lots of stuff in this pipeline that is limited to
|
||||
// special cases where I have an idea of how to elaborate it to the general
|
||||
// case. The priority is getting and end-to-end flow working that we can
|
||||
// grow out organically to a curriculum of more complex cases, elaborating
|
||||
// on the design principles and layering as necessitated by the curriculum.
|
||||
//
|
||||
// This should be fun :)
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/E2E/E2E.h"
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
class ResolveShapeOfOpViaAllocMemRefOp : public OpRewritePattern<shape::ShapeOfOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (auto tensorLoad = llvm::dyn_cast_or_null<TensorLoadOp>(
|
||||
op.getOperand().getDefiningOp())) {
|
||||
if (auto allocMemRef = llvm::dyn_cast_or_null<tcp::AllocMemRefOp>(
|
||||
tensorLoad.getOperand().getDefiningOp())) {
|
||||
rewriter.replaceOp(op, allocMemRef.getOperand());
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
class ResolveShapeOfOps : public ResolveShapeOfOpsBase<ResolveShapeOfOps> {
|
||||
void runOnOperation() {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ResolveShapeOfOpViaAllocMemRefOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
//target.addIllegalOp<shape::ShapeOfOp>();
|
||||
target.addDynamicallyLegalOp<shape::ShapeOfOp>(
|
||||
[](shape::ShapeOfOp shapeOf) {
|
||||
// Only shape.shape_of on arguments to the entry block are legal at
|
||||
// this point. They are assumed to be resolved eventually via
|
||||
// the lowering of the tensor argument to some ABI that has the
|
||||
// relevant information available. But this is ABI dependent.
|
||||
// TODO: Convince myself that we never need to deal with general
|
||||
// block operands, or implement general handling of block
|
||||
// operands (need to add new bb operands of !shape.shape type).
|
||||
if (auto blockArg = shapeOf.getOperand().dyn_cast<BlockArgument>()) {
|
||||
Block *block = blockArg.getOwner();
|
||||
if (&block->getParent()->front() == block) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createResolveShapeOfOpsPass() {
|
||||
return std::make_unique<ResolveShapeOfOps>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// createE2ELoweringPipeline
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
|
||||
// Input IR is TCF ops.
|
||||
|
||||
// Convert to TCP.
|
||||
pm.addPass(createConvertTCFToTCPPass());
|
||||
// Convert tcp ops to Linalg where possible.
|
||||
pm.addPass(createConvertTCPToLinalgPass());
|
||||
|
||||
// TODO: legalize `dim` to shape.shape_of + tcp.get_extent
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Tensor to buffer (memref) conversion.
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Lower to hybrid tensor/memref
|
||||
createLowerToHybridTensorMemRefPipeline(pm);
|
||||
|
||||
// At this point, every tensor in the program is the result of a
|
||||
// `tensor_load` of an `alloc_memref` op (or is an argument). Therefore,
|
||||
// every shape_of can be resolved by looking at the corresponding
|
||||
// alloc_memref of the tensor.
|
||||
pm.addPass(createResolveShapeOfOpsPass());
|
||||
|
||||
|
||||
// TODO:
|
||||
// forward tensor_load/tensor_store (which leaves all tensors with no
|
||||
// uses)
|
||||
// lower linalg to loops: mlir::createConvertLinalgToLoopsPass()
|
||||
// lower shape stuff to rshape?
|
||||
// lower rshape to SSA values?
|
||||
// Convert all of it to LLVM?
|
||||
}
|
|
@ -0,0 +1,290 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/E2E/E2E.h"
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
static Value allocMemRefForTensor(OpBuilder &builder, Value tensor, Value shape,
|
||||
Location loc) {
|
||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||
auto memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
return builder.create<tcp::AllocMemRefOp>(loc, memrefType, shape);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LowerBroadcastTo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: Lower to linalg.indexed_generic instead and let linalg do the expansion
|
||||
// to loops?
|
||||
class LowerBroadcastToToLoopsPattern
|
||||
: public OpRewritePattern<tcp::BroadcastToOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(tcp::BroadcastToOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto resultType = op.getType().cast<RankedTensorType>();
|
||||
auto inputType = op.operand().getType().cast<RankedTensorType>();
|
||||
Value resultMemref = rewriter.create<tcp::AllocMemRefOp>(
|
||||
op.getLoc(),
|
||||
MemRefType::get(resultType.getShape(), resultType.getElementType()),
|
||||
op.shape());
|
||||
Value inputMemref = allocMemRefForTensor(
|
||||
rewriter, op.operand(),
|
||||
rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.operand()),
|
||||
op.getLoc());
|
||||
rewriter.create<TensorStoreOp>(op.getLoc(), op.operand(), inputMemref);
|
||||
SmallVector<Value, 6> outputExtents;
|
||||
SmallVector<Value, 6> inputDimRequiresBroadcasting;
|
||||
|
||||
// TODO: handle output rank > input rank.
|
||||
for (int i = 0, e = resultType.getRank(); i < e; i++) {
|
||||
|
||||
Value outputExtent = rewriter.create<tcp::GetExtentOp>(
|
||||
op.getLoc(), op.shape(), rewriter.getI64IntegerAttr(i));
|
||||
outputExtents.push_back(outputExtent);
|
||||
|
||||
}
|
||||
int rankDiff = resultType.getRank() - inputType.getRank();
|
||||
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
||||
// Calculate the relevant extents.
|
||||
Value inputExtent = rewriter.create<DimOp>(op.getLoc(), op.operand(), i);
|
||||
inputDimRequiresBroadcasting.push_back(
|
||||
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, inputExtent,
|
||||
outputExtents[rankDiff + i]));
|
||||
}
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Value c0 = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
||||
Value c1 = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
|
||||
|
||||
SmallVector<Value, 6> inductionVariables;
|
||||
// Create the (perfectly nested) loops.
|
||||
// Loop invariant: At the start of iteration `i`, the rewriter insertion
|
||||
// point is inside `i` nested loops.
|
||||
for (int i = 0, e = resultType.getRank(); i < e; i++) {
|
||||
auto loop = rewriter.create<loop::ForOp>(
|
||||
op.getLoc(), c0, outputExtents[i], c1, ValueRange({}));
|
||||
Block *body = loop.getBody();
|
||||
inductionVariables.push_back(body->getArgument(0));
|
||||
// Leave the insertion point at the beginning of the body.
|
||||
rewriter.setInsertionPointToStart(body);
|
||||
}
|
||||
|
||||
// Create the inner loop body.
|
||||
// When reading from the input, clamp any indices for dimensions that are
|
||||
// being broadcast.
|
||||
SmallVector<Value, 6> inputIndices;
|
||||
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
||||
auto c0 = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
||||
auto select = rewriter.create<SelectOp>(
|
||||
op.getLoc(), inputDimRequiresBroadcasting[i], c0,
|
||||
inductionVariables[rankDiff + i]);
|
||||
inputIndices.push_back(select);
|
||||
}
|
||||
Value load =
|
||||
rewriter.create<LoadOp>(op.getLoc(), inputMemref, inputIndices);
|
||||
rewriter.create<StoreOp>(op.getLoc(), load, resultMemref, inductionVariables);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<TensorLoadOp>(op, resultMemref);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
class LowerBroadcastToToLoops
|
||||
: public LowerBroadcastToToLoopsBase<LowerBroadcastToToLoops> {
|
||||
void runOnOperation() {
|
||||
auto func = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<shape::ShapeDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<loop::LoopOpsDialect>();
|
||||
target.addLegalDialect<tcp::TCPDialect>();
|
||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
|
||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||
patterns.insert<LowerBroadcastToToLoopsPattern>(context);
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createLowerBroadcastToToLoopsPass() {
|
||||
return std::make_unique<LowerBroadcastToToLoops>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LowerLinalgOnTensorToLinalgOnMemref
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class LowerLinalgGenericTensorToMemRef : public OpRewritePattern<linalg::GenericOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(linalg::GenericOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
// TODO: Replace this with more generic code operating on named
|
||||
// structured ops too.
|
||||
|
||||
// Only handle generic ops where all operands and results are tensors.
|
||||
if (!llvm::all_of(op.getOperandTypes(), [](Type type) {
|
||||
return type.isa<RankedTensorType>();
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(op, "all operands must be tensors");
|
||||
}
|
||||
if (!llvm::all_of(op.getResultTypes(), [](Type type) {
|
||||
return type.isa<RankedTensorType>();
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(op, "all results must be tensors");
|
||||
}
|
||||
|
||||
// TODO: Loosen restrictions on indexing maps.
|
||||
// This will require more principled handling of shape reification
|
||||
// earlier in the compilation stack, as in general output shapes of a
|
||||
// linalg.generic cannot be inferred easily.
|
||||
// See:
|
||||
// https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866
|
||||
if (!llvm::all_of(op.indexing_maps(), [](Attribute map) {
|
||||
return map.cast<AffineMapAttr>().getValue().isIdentity();
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all indexing maps must be identity maps");
|
||||
}
|
||||
if (!llvm::all_of(op.iterator_types(), [](Attribute str) {
|
||||
return str.cast<StringAttr>().getValue() ==
|
||||
getParallelIteratorTypeName();
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all iterator types must be 'parallel'");
|
||||
}
|
||||
|
||||
SmallVector<Value, 6> memrefs;
|
||||
SmallVector<Value, 6> resultMemrefs;
|
||||
SmallVector<Value, 6> operandShapes;
|
||||
for (auto tensor : op.getOperands()) {
|
||||
auto shape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), tensor);
|
||||
auto memref =
|
||||
allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
|
||||
rewriter.create<TensorStoreOp>(op.getLoc(), tensor, memref);
|
||||
memrefs.push_back(memref);
|
||||
operandShapes.push_back(shape);
|
||||
}
|
||||
auto shapeType = shape::ShapeType::get(rewriter.getContext());
|
||||
SmallVector<Type, 6> shapeTypes(op.getNumResults(), shapeType);
|
||||
// TODO: We need more principled handling of output shapes.
|
||||
// This assumes that all results have the same shape, which is justified
|
||||
// by checks above, but we really need a better story here.
|
||||
SmallVector<Value, 6> resultShapes(op.getNumResults(), operandShapes[0]);
|
||||
for (auto t : llvm::zip(op.getResults(), resultShapes)) {
|
||||
auto tensor = std::get<0>(t);
|
||||
auto shape = std::get<1>(t);
|
||||
auto memref =
|
||||
allocMemRefForTensor(rewriter, tensor, shape, op.getLoc());
|
||||
memrefs.push_back(memref);
|
||||
resultMemrefs.push_back(memref);
|
||||
}
|
||||
auto newGeneric = rewriter.create<linalg::GenericOp>(
|
||||
op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs());
|
||||
newGeneric.region().getBlocks().clear();
|
||||
BlockAndValueMapping mapper;
|
||||
op.region().cloneInto(&newGeneric.region(), mapper);
|
||||
for (auto memref : resultMemrefs) {
|
||||
newGeneric.region().front().addArgument(
|
||||
memref.getType().cast<MemRefType>().getElementType());
|
||||
}
|
||||
auto newResultTensors =
|
||||
llvm::to_vector<6>(llvm::map_range(resultMemrefs, [&](Value memref) {
|
||||
return rewriter.create<TensorLoadOp>(op.getLoc(), memref)
|
||||
.getResult();
|
||||
}));
|
||||
rewriter.replaceOp(op, newResultTensors);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
namespace {
|
||||
class LowerLinalgOnTensorToLinalgOnMemref
|
||||
: public LowerLinalgOnTensorToLinalgOnMemrefBase<
|
||||
LowerLinalgOnTensorToLinalgOnMemref> {
|
||||
void runOnOperation() {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<shape::ShapeDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<linalg::LinalgDialect>();
|
||||
target.addLegalOp<tcp::AllocMemRefOp>();
|
||||
patterns.insert<LowerLinalgGenericTensorToMemRef>(context);
|
||||
target.addDynamicallyLegalOp<linalg::GenericOp>([](linalg::GenericOp op) {
|
||||
if (llvm::any_of(op.getOperandTypes(), [](Type type) {
|
||||
return type.isa<RankedTensorType>();
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
if (llvm::any_of(op.getResultTypes(), [](Type type) {
|
||||
return type.isa<RankedTensorType>();
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createLowerLinalgOnTensorToLinalgOnMemrefPass() {
|
||||
return std::make_unique<LowerLinalgOnTensorToLinalgOnMemref>();
|
||||
}
|
||||
|
||||
void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) {
|
||||
// Lower to hybrid tensor/memref.
|
||||
// The invariant of "hybrid tensor/memref" is that the core computation
|
||||
// ops operate on memref, but we launder in and out of tensors in such a
|
||||
// way that the original SSA tensor values remain and can be traced to
|
||||
// their corresponding memrefs (via tensor_load/tensor_store) which are
|
||||
// allocated with alloc_shape ops.
|
||||
// Thus, shape.shape_of ops on the original tensors in the program can be
|
||||
// resolved to the shapes in the alloc_memref calls.
|
||||
pm.addPass(createLowerLinalgOnTensorToLinalgOnMemrefPass());
|
||||
pm.addPass(createLowerBroadcastToToLoopsPass());
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
//===- PassDetail.h - E2E Pass class details --------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef E2E_PASSDETAIL_H
|
||||
#define E2E_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "npcomp/E2E/Passes.h.inc"
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // E2E_PASSDETAIL_H
|
|
@ -0,0 +1,9 @@
|
|||
// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// Just the lightest sanity check.
|
||||
// CHECK: tcp.add
|
||||
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
// RUN: npcomp-opt <%s -convert-tcp-to-linalg | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: linalg.generic
|
||||
%0 = "tcp.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail
|
||||
|
||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
|
||||
// CHECK: "tcf.add"
|
||||
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail
|
||||
|
||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
|
||||
// CHECK: tcp.add
|
||||
%result = "tcp.island"() ({
|
||||
%0 = "tcp.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
"tcp.yield"(%0) : (tensor<?xf32>) -> ()
|
||||
}) : () -> tensor<?xf32>
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
// RUN: npcomp-opt <%s -pass-pipeline=e2e-lowering-pipeline | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @rank1
|
||||
func @rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @rank2
|
||||
func @rank2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHxCK-LABEL: func @rank1and2
|
||||
func @rank1and2(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// RUN: npcomp-opt -lower-linalg-tensor-to-memref <%s | FileCheck %s --dump-input=fail
|
||||
#map0 = affine_map<(d0) -> (d0)>
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK-DAG: %[[LHS:.+]] = "tcp.alloc_memref"
|
||||
// CHECK-DAG: %[[RHS:.+]] = "tcp.alloc_memref"
|
||||
// CHECK-DAG: %[[DST:.+]] = "tcp.alloc_memref"
|
||||
// CHECK: linalg.generic{{.*}} %[[LHS]], %[[RHS]], %[[DST]]
|
||||
%0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %arg0, %arg1 {
|
||||
^bb0(%arg2: f32, %arg3: f32):
|
||||
%8 = addf %arg2, %arg3 : f32
|
||||
linalg.yield %8 : f32
|
||||
}: tensor<?xf32>, tensor<?xf32> -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
// RUN: npcomp-opt <%s -pass-pipeline=e2e-lowering-pipeline | FileCheck %s --dump-input=fail
|
||||
|
||||
// This is the simplest case, which is easy to stare at for debugging
|
||||
// purposes.
|
||||
|
||||
// CHECK-LABEL: func @rank1
|
||||
func @rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
// RUN: npcomp-opt -resolve-shape-of-ops <%s -split-input-file -verify-diagnostics | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @basic
|
||||
func @basic(%arg0: !shape.shape) -> !shape.shape {
|
||||
%memref = "tcp.alloc_memref"(%arg0) : (!shape.shape) -> memref<?xf32>
|
||||
%tensor = tensor_load %memref : memref<?xf32>
|
||||
%shape = "shape.shape_of"(%tensor) : (tensor<?xf32>) -> !shape.shape
|
||||
// CHECK: return %arg0
|
||||
return %shape : !shape.shape
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @arg_unresolved_ok
|
||||
func @arg_unresolved_ok(%arg0: tensor<?xf32>) -> !shape.shape {
|
||||
%0 = "shape.shape_of"(%arg0): (tensor<?xf32>) -> !shape.shape
|
||||
return %0 : !shape.shape
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @TODO_bb_arg_unresolved_not_ok
|
||||
// TODO: This should emit a diagnostic, but doesn't. Why?
|
||||
// addDynamicallyLegalOp isn't working as I expect.
|
||||
func @TODO_bb_arg_unresolved_not_ok(%arg0: i1, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> !shape.shape {
|
||||
cond_br %arg0, ^bb1(%arg1: tensor<?xf32>), ^bb1(%arg2: tensor<?xf32>)
|
||||
^bb1(%bbarg: tensor<?xf32>):
|
||||
%0 = "shape.shape_of"(%bbarg): (tensor<?xf32>) -> !shape.shape
|
||||
return %0 : !shape.shape
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
// RUN: npcomp-opt -lower-to-hybrid-tensor-memref-pipeline <%s | FileCheck %s --dump-input=fail
|
||||
|
||||
#map0 = affine_map<(d0) -> (d0)>
|
||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
|
||||
%1 = "shape.shape_of"(%arg1) : (tensor<?xf32>) -> !shape.shape
|
||||
%2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%3 = "shape.abort_if_error"(%2) : (!shape.shape) -> none
|
||||
%4 = "tcp.island"(%3) ( {
|
||||
%5 = "tcp.broadcast_to"(%arg0, %2) : (tensor<?xf32>, !shape.shape) -> tensor<?xf32>
|
||||
%6 = "tcp.broadcast_to"(%arg1, %2) : (tensor<?xf32>, !shape.shape) -> tensor<?xf32>
|
||||
// CHECK: alloc_memref
|
||||
%7 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %5, %6 {
|
||||
^bb0(%arg2: f32, %arg3: f32): // no predecessors
|
||||
%8 = addf %arg2, %arg3 : f32
|
||||
linalg.yield %8 : f32
|
||||
}: tensor<?xf32>, tensor<?xf32> -> tensor<?xf32>
|
||||
"tcp.yield"(%7) : (tensor<?xf32>) -> ()
|
||||
}) : (none) -> tensor<?xf32>
|
||||
return %4 : tensor<?xf32>
|
||||
}
|
||||
|
|
@ -4,6 +4,9 @@ set(LIBS
|
|||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
MLIROptLib
|
||||
NPCOMPE2E
|
||||
NPCOMPTCP
|
||||
NPCOMPTCF
|
||||
NPCOMPNumpyDialect
|
||||
NPCOMPBasicpyDialect
|
||||
)
|
||||
|
|
|
@ -21,6 +21,12 @@
|
|||
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/TCF/IR/TCFDialect.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
|
||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
||||
#include "npcomp/E2E/E2E.h"
|
||||
|
||||
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
|
@ -58,12 +64,26 @@ static llvm::cl::opt<bool>
|
|||
llvm::cl::init(false));
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// TODO: Move all npcomp registration to a common helper.
|
||||
mlir::registerAllDialects();
|
||||
mlir::registerAllPasses();
|
||||
|
||||
mlir::registerDialect<mlir::NPCOMP::Basicpy::BasicpyDialect>();
|
||||
mlir::registerDialect<mlir::NPCOMP::Numpy::NumpyDialect>();
|
||||
// TODO: Register standalone passes here.
|
||||
mlir::registerDialect<mlir::NPCOMP::tcf::TCFDialect>();
|
||||
mlir::registerDialect<mlir::NPCOMP::tcp::TCPDialect>();
|
||||
|
||||
using mlir::Pass; // The .inc files reference this unqualified.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/E2E/Passes.h.inc"
|
||||
mlir::PassPipelineRegistration<>("e2e-lowering-pipeline", "E2E lowering pipeline.",
|
||||
mlir::NPCOMP::createE2ELoweringPipeline);
|
||||
mlir::PassPipelineRegistration<>(
|
||||
"lower-to-hybrid-tensor-memref-pipeline",
|
||||
"Pipeline lowering to hybrid tensor/memref.",
|
||||
mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline);
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/Conversion/Passes.h.inc"
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
|
|
Loading…
Reference in New Issue