mirror of https://github.com/llvm/torch-mlir
[RefBackend] Split out RefBackend (refback) dialect from TCP.
This is the first in a patch series that is refactoring the constellation of things variously called or associated with "E2E", "RefE2E", "npcomprt", and "TCP" into a more cleanly layered result. Concretely, this first patch fixes the fact that TCP was basically acting like a dumping ground needed by the reference backend. This splits it out, which is fairly mechanical, but touches a lot of lines of code (basically replacing `tcp` with `refback` and `TCP` with `RefBackend). Now, the RefBackend dialect is that dumping ground, which is slighly better, as it starts allowing TCP to become a nice clean middle layer that is not related per se to the reference backend. The previous name RefE2E or "reference e2e flow" was super confusing. Now that we are seeing more clearly where the "backend" distinction lies, the [RefBackend] commit tag is born :)pull/69/head
parent
3ccc2214a7
commit
5017430dc7
|
@ -2,6 +2,7 @@ add_subdirectory(ATen)
|
|||
add_subdirectory(Basicpy)
|
||||
add_subdirectory(Npcomprt)
|
||||
add_subdirectory(Numpy)
|
||||
add_subdirectory(RefBackend)
|
||||
add_subdirectory(TCF)
|
||||
add_subdirectory(TCP)
|
||||
add_subdirectory(Torch)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(IR)
|
|
@ -0,0 +1 @@
|
|||
add_mlir_dialect(RefBackendOps refback)
|
|
@ -0,0 +1,23 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef REFBACKEND_BASE
|
||||
#define REFBACKEND_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def RefBackend_Dialect : Dialect {
|
||||
let name = "refback";
|
||||
let cppNamespace = "::mlir::NPCOMP::refback";
|
||||
let description = [{
|
||||
Ops used by the reference backend as part of its lowering.
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif // REFBACKEND_BASE
|
|
@ -0,0 +1,16 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDDIALECT_H
|
||||
#define NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDDIALECT_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOpsDialect.h.inc"
|
||||
|
||||
#endif // NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDDIALECT_H
|
|
@ -0,0 +1,22 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDOPS_H
|
||||
#define NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDOPS_H
|
||||
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h.inc"
|
||||
|
||||
#endif // NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDOPS_H
|
|
@ -0,0 +1,147 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef REFBACKEND_OPS
|
||||
#define REFBACKEND_OPS
|
||||
|
||||
include "npcomp/Dialect/RefBackend/IR/RefBackendBase.td"
|
||||
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
class RefBackend_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<RefBackend_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
def RefBackend_GlobalOp : RefBackend_Op<"global", [Symbol]> {
|
||||
let summary = "Represents a global variable";
|
||||
let description = [{
|
||||
Represents a global variable.
|
||||
|
||||
Currently, only constant tensors are supported, and they are not
|
||||
considered to be exported.
|
||||
}];
|
||||
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
|
||||
let results = (outs);
|
||||
|
||||
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops related to tensor->memref conversion.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: These ops probably belong in a "RefBackend on memrefs" dialect analogous
|
||||
// to `lmhlo`
|
||||
|
||||
// TODO: Use TypesMatchWith to verify this better.
|
||||
def RefBackend_TensorToMemrefOp : RefBackend_Op<"tensor_to_memref", [NoSideEffect]> {
|
||||
let summary = "Converts a tensor to a memref";
|
||||
let description = [{
|
||||
This op is used to materialize conversions to allow incremental lowering of
|
||||
tensors to memrefs.
|
||||
}];
|
||||
let arguments = (ins AnyRankedTensor:$tensor);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
let assemblyFormat = "attr-dict $tensor `:` type($tensor) `->` type($memref)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// TODO: Use TypesMatchWith to verify this better.
|
||||
def RefBackend_MemrefToTensorOp : RefBackend_Op<"memref_to_tensor", [NoSideEffect]> {
|
||||
let summary = "Converts a memref to a tensor";
|
||||
let description = [{
|
||||
This op is used to materialize conversions to allow incremental lowering of
|
||||
tensors to memrefs.
|
||||
}];
|
||||
let arguments = (ins AnyMemRef:$memref);
|
||||
let results = (outs AnyRankedTensor:$tensor);
|
||||
let assemblyFormat = "attr-dict $memref `:` type($memref) `->` type($tensor)";
|
||||
}
|
||||
|
||||
def RefBackend_AllocMemRefOp : RefBackend_Op<"alloc_memref", []> {
|
||||
let summary = "Allocates a memref of the given shape.";
|
||||
let description = [{
|
||||
Allocates a memref of the given shape.
|
||||
|
||||
This op is a convenience for creating a bunch of
|
||||
shape.get_extent + std.alloc ops.
|
||||
}];
|
||||
let arguments = (ins Shape_ExtentTensorType:$shape);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
let assemblyFormat = "$shape attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
def RefBackend_GetGlobalMemrefOp : RefBackend_Op<"get_global_memref"> {
|
||||
let summary = "Obtain a memref pointing at the given global";
|
||||
let description = [{
|
||||
Obtain a memref pointing at the given global.
|
||||
}];
|
||||
let arguments = (ins FlatSymbolRefAttr:$global);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
||||
let verifier = "return ::verify$cppClass(*this);";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops related to shapes.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: These belong in a shape-related dialect.
|
||||
|
||||
def RefBackend_ShapedResultsOp : RefBackend_Op<"shaped_results", [
|
||||
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
RecursiveSideEffects,
|
||||
NoRegionArguments
|
||||
]> {
|
||||
let summary = "Result shape annotation";
|
||||
let description = [{
|
||||
Represents a computation whose outputs have a precomputed shape.
|
||||
The i-th result has the shape described by the i-th operand.
|
||||
|
||||
This op is not isolated from above, so if the region needs any inputs,
|
||||
they can simply be captured. Hence, this op is a
|
||||
"this tensor has this shape" annotation with a slightly different set of
|
||||
tradeoffs than the so-called "tie shape" kinds of operations.
|
||||
In particular, this region-based formulation has the opportunity to
|
||||
capture structural invariants.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// sincos is an elementwise operation, so it doesn't change the shape.
|
||||
%x = ...
|
||||
%xShape = ...
|
||||
%sin, %cos = refback.shaped_results %xShape, %xShape {
|
||||
%sin, cos = "some.sincos"(%x)
|
||||
: tensor<?xf32> -> (tensor<?xf32>, tensor<?xf32>)
|
||||
refback.yield %sin, %cos : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
```
|
||||
}];
|
||||
let arguments = (ins
|
||||
Variadic<Shape_ExtentTensorType>:$resultShapes
|
||||
);
|
||||
let results = (outs Variadic<AnyTensor>:$results);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||
let verifier = [{ return ::verify$cppClass(*this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def RefBackend_YieldOp : RefBackend_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
|
||||
ParentOneOf<["ShapedResultsOp"]>]> {
|
||||
let summary = "Yield-like terminator for RefBackend dialect";
|
||||
let description = "See scf.yield";
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
}
|
||||
|
||||
#endif // REFBACKEND_OPS
|
|
@ -105,128 +105,4 @@ It is undefined behavior if such a broadcast is not legal.
|
|||
let assemblyFormat = "$operand `,` $shape attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TCP_GlobalOp : TCP_Op<"global", [Symbol]> {
|
||||
let summary = "Represents a global variable";
|
||||
let description = [{
|
||||
Represents a global variable.
|
||||
|
||||
Currently, only constant tensors are supported, and they are not
|
||||
considered to be exported.
|
||||
}];
|
||||
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
|
||||
let results = (outs);
|
||||
|
||||
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops related to tensor->memref conversion.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: These ops probably belong in a "TCP on memrefs" dialect analogous
|
||||
// to `lmhlo`
|
||||
|
||||
// TODO: Use TypesMatchWith to verify this better.
|
||||
def TCP_TensorToMemrefOp : TCP_Op<"tensor_to_memref", [NoSideEffect]> {
|
||||
let summary = "Converts a tensor to a memref";
|
||||
let description = [{
|
||||
This op is used to materialize conversions to allow incremental lowering of
|
||||
tensors to memrefs.
|
||||
}];
|
||||
let arguments = (ins AnyRankedTensor:$tensor);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
let assemblyFormat = "attr-dict $tensor `:` type($tensor) `->` type($memref)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// TODO: Use TypesMatchWith to verify this better.
|
||||
def TCP_MemrefToTensorOp : TCP_Op<"memref_to_tensor", [NoSideEffect]> {
|
||||
let summary = "Converts a memref to a tensor";
|
||||
let description = [{
|
||||
This op is used to materialize conversions to allow incremental lowering of
|
||||
tensors to memrefs.
|
||||
}];
|
||||
let arguments = (ins AnyMemRef:$memref);
|
||||
let results = (outs AnyRankedTensor:$tensor);
|
||||
let assemblyFormat = "attr-dict $memref `:` type($memref) `->` type($tensor)";
|
||||
}
|
||||
|
||||
def TCP_AllocMemRefOp : TCP_Op<"alloc_memref", []> {
|
||||
let summary = "Allocates a memref of the given shape.";
|
||||
let description = [{
|
||||
Allocates a memref of the given shape.
|
||||
|
||||
This op is a convenience for creating a bunch of
|
||||
shape.get_extent + std.alloc ops.
|
||||
}];
|
||||
let arguments = (ins Shape_ExtentTensorType:$shape);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
let assemblyFormat = "$shape attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
def TCP_GetGlobalMemrefOp : TCP_Op<"get_global_memref"> {
|
||||
let summary = "Obtain a memref pointing at the given global";
|
||||
let description = [{
|
||||
Obtain a memref pointing at the given global.
|
||||
}];
|
||||
let arguments = (ins FlatSymbolRefAttr:$global);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
||||
let verifier = "return ::verify$cppClass(*this);";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops related to shapes.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: These belong in a shape-related dialect.
|
||||
|
||||
def TCP_ShapedResultsOp : TCP_Op<"shaped_results", [
|
||||
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
RecursiveSideEffects,
|
||||
NoRegionArguments
|
||||
]> {
|
||||
let summary = "Result shape annotation";
|
||||
let description = [{
|
||||
Represents a computation whose outputs have a precomputed shape.
|
||||
The i-th result has the shape described by the i-th operand.
|
||||
|
||||
This op is not isolated from above, so if the region needs any inputs,
|
||||
they can simply be captured. Hence, this op is a
|
||||
"this tensor has this shape" annotation with a slightly different set of
|
||||
tradeoffs than the so-called "tie shape" kinds of operations.
|
||||
In particular, this region-based formulation has the opportunity to
|
||||
capture structural invariants.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// sincos is an elementwise operation, so it doesn't change the shape.
|
||||
%x = ...
|
||||
%xShape = ...
|
||||
%sin, %cos = tcp.shaped_results %xShape, %xShape {
|
||||
%sin, cos = "some.sincos"(%x)
|
||||
: tensor<?xf32> -> (tensor<?xf32>, tensor<?xf32>)
|
||||
tcp.yield %sin, %cos : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
```
|
||||
}];
|
||||
let arguments = (ins
|
||||
Variadic<Shape_ExtentTensorType>:$resultShapes
|
||||
);
|
||||
let results = (outs Variadic<AnyTensor>:$results);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||
let verifier = [{ return ::verify$cppClass(*this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def TCP_YieldOp : TCP_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
|
||||
ParentOneOf<["ShapedResultsOp"]>]> {
|
||||
let summary = "Yield-like terminator for TCP dialect";
|
||||
let description = "See scf.yield";
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
}
|
||||
|
||||
#endif // TCP_OPS
|
||||
|
|
|
@ -17,7 +17,7 @@ def BypassShapes : Pass<"bypass-shapes", "FuncOp"> {
|
|||
}
|
||||
|
||||
def LowerShapedResultsToMemref : Pass<"lower-shaped-results-to-memref", "FuncOp"> {
|
||||
let summary = "Lower tcp.shaped_results regions";
|
||||
let summary = "Lower refback.shaped_results regions";
|
||||
let constructor = "mlir::NPCOMP::createLowerShapedResultsToMemrefPass()";
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,7 @@ def LowerConstantTensorsToMemref :
|
|||
Pass<"lower-constant-tensors-to-memref", "ModuleOp"> {
|
||||
let summary = "Lower std.constant of tensor type to memref";
|
||||
let description = [{
|
||||
This must be a module pass since it involves creating tcp.global ops.
|
||||
This must be a module pass since it involves creating refback.global ops.
|
||||
}];
|
||||
let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefPass()";
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ add_mlir_library(NPCOMPInitAll
|
|||
PUBLIC
|
||||
# Local depends
|
||||
NPCOMPE2E
|
||||
NPCOMPRefBackendDialect
|
||||
NPCOMPTCP
|
||||
NPCOMPTCF
|
||||
NPCOMPTorchDialect
|
||||
|
|
|
@ -2,6 +2,7 @@ add_subdirectory(ATen)
|
|||
add_subdirectory(Basicpy)
|
||||
add_subdirectory(Npcomprt)
|
||||
add_subdirectory(Numpy)
|
||||
add_subdirectory(RefBackend)
|
||||
add_subdirectory(TCF)
|
||||
add_subdirectory(TCP)
|
||||
add_subdirectory(Torch)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(IR)
|
|
@ -0,0 +1,19 @@
|
|||
add_mlir_dialect_library(NPCOMPRefBackendDialect
|
||||
RefBackendDialect.cpp
|
||||
RefBackendOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/RefBackend
|
||||
|
||||
DEPENDS
|
||||
MLIRRefBackendOpsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRShape
|
||||
)
|
|
@ -0,0 +1,50 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::refback;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RefBackendDialect Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct RefBackendInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
bool isLegalToInline(Region *dest, Region *src,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Operation *, Region *,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
void handleTerminator(Operation *op,
|
||||
ArrayRef<Value> valuesToRepl) const final {
|
||||
auto retValOp = dyn_cast<YieldOp>(op);
|
||||
if (!retValOp)
|
||||
return;
|
||||
|
||||
for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
|
||||
std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void RefBackendDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.cpp.inc"
|
||||
>();
|
||||
addInterfaces<RefBackendInlinerInterface>();
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::refback;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorToMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto memrefToTensor = tensor().getDefiningOp<refback::MemrefToTensorOp>())
|
||||
return memrefToTensor.memref();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapedResultsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyShapedResultsOp(ShapedResultsOp op) {
|
||||
if (op.getNumOperands() != op.getNumResults())
|
||||
return op.emitError() << "number of operands must equal number of results";
|
||||
if (op.getNumOperands() == 0)
|
||||
return op.emitError() << "must have at least one operand/result";
|
||||
return RegionBranchOpInterface::verifyTypes(op);
|
||||
}
|
||||
|
||||
static void printShapedResultsOp(OpAsmPrinter &p, ShapedResultsOp &op) {
|
||||
p << "refback.shaped_results ";
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
||||
p.printOperands(op.getOperands());
|
||||
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
|
||||
p << " : ";
|
||||
interleaveComma(op.getOperandTypes(), p);
|
||||
p << " -> ";
|
||||
interleaveComma(op.getResultTypes(), p);
|
||||
}
|
||||
|
||||
static ParseResult parseShapedResultsOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
return failure();
|
||||
SmallVector<OpAsmParser::OperandType, 6> operands;
|
||||
if (parser.parseOperandList(operands))
|
||||
return failure();
|
||||
auto *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, llvm::None, llvm::None))
|
||||
return failure();
|
||||
SmallVector<Type, 6> inputTypes;
|
||||
if (parser.parseColonTypeList(inputTypes))
|
||||
return failure();
|
||||
if (parser.resolveOperands(operands, inputTypes, parser.getNameLoc(),
|
||||
result.operands))
|
||||
return failure();
|
||||
if (parser.parseArrowTypeList(result.types))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void ShapedResultsOp::getSuccessorRegions(
|
||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (index.hasValue())
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
else
|
||||
regions.push_back(RegionSuccessor(&body()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
|
||||
p << "refback.global ";
|
||||
p.printSymbolName(op.sym_name());
|
||||
p << ' ';
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||
/*elidedAttrs=*/{"sym_name", "value"});
|
||||
p.printAttribute(op.valueAttr());
|
||||
}
|
||||
|
||||
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
|
||||
StringAttr nameAttr;
|
||||
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
|
||||
result.attributes))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
return failure();
|
||||
Attribute valueAttr;
|
||||
if (parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetGlobalMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) {
|
||||
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
|
||||
if (!global)
|
||||
return op.emitError() << "must reference a valid symbol";
|
||||
auto globalType = global.value().getType();
|
||||
auto resultType = op.getType().cast<ShapedType>();
|
||||
if (failed(
|
||||
verifyCompatibleShape(globalType.getShape(), resultType.getShape())))
|
||||
return op.emitError() << "inconsistent with shape of global";
|
||||
if (globalType.getElementType() != resultType.getElementType())
|
||||
return op.emitError() << "inconsistent with element type of global";
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.cpp.inc"
|
|
@ -28,16 +28,6 @@ struct TCPInlinerInterface : public DialectInlinerInterface {
|
|||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
void handleTerminator(Operation *op,
|
||||
ArrayRef<Value> valuesToRepl) const final {
|
||||
auto retValOp = dyn_cast<YieldOp>(op);
|
||||
if (!retValOp)
|
||||
return;
|
||||
|
||||
for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
|
||||
std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
|
|
@ -15,112 +15,5 @@ using namespace mlir;
|
|||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::tcp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorToMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto memrefToTensor = tensor().getDefiningOp<tcp::MemrefToTensorOp>())
|
||||
return memrefToTensor.memref();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapedResultsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyShapedResultsOp(ShapedResultsOp op) {
|
||||
if (op.getNumOperands() != op.getNumResults())
|
||||
return op.emitError() << "number of operands must equal number of results";
|
||||
if (op.getNumOperands() == 0)
|
||||
return op.emitError() << "must have at least one operand/result";
|
||||
return RegionBranchOpInterface::verifyTypes(op);
|
||||
}
|
||||
|
||||
static void printShapedResultsOp(OpAsmPrinter &p, ShapedResultsOp &op) {
|
||||
p << "tcp.shaped_results ";
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
||||
p.printOperands(op.getOperands());
|
||||
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
|
||||
p << " : ";
|
||||
interleaveComma(op.getOperandTypes(), p);
|
||||
p << " -> ";
|
||||
interleaveComma(op.getResultTypes(), p);
|
||||
}
|
||||
|
||||
static ParseResult parseShapedResultsOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
return failure();
|
||||
SmallVector<OpAsmParser::OperandType, 6> operands;
|
||||
if (parser.parseOperandList(operands))
|
||||
return failure();
|
||||
auto *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, llvm::None, llvm::None))
|
||||
return failure();
|
||||
SmallVector<Type, 6> inputTypes;
|
||||
if (parser.parseColonTypeList(inputTypes))
|
||||
return failure();
|
||||
if (parser.resolveOperands(operands, inputTypes, parser.getNameLoc(),
|
||||
result.operands))
|
||||
return failure();
|
||||
if (parser.parseArrowTypeList(result.types))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void ShapedResultsOp::getSuccessorRegions(
|
||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (index.hasValue())
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
else
|
||||
regions.push_back(RegionSuccessor(&body()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
|
||||
p << "tcp.global ";
|
||||
p.printSymbolName(op.sym_name());
|
||||
p << ' ';
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||
/*elidedAttrs=*/{"sym_name", "value"});
|
||||
p.printAttribute(op.valueAttr());
|
||||
}
|
||||
|
||||
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
|
||||
StringAttr nameAttr;
|
||||
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
|
||||
result.attributes))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
return failure();
|
||||
Attribute valueAttr;
|
||||
if (parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetGlobalMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) {
|
||||
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
|
||||
if (!global)
|
||||
return op.emitError() << "must reference a valid symbol";
|
||||
auto globalType = global.value().getType();
|
||||
auto resultType = op.getType().cast<ShapedType>();
|
||||
if (failed(
|
||||
verifyCompatibleShape(globalType.getShape(), resultType.getShape())))
|
||||
return op.emitError() << "inconsistent with shape of global";
|
||||
if (globalType.getElementType() != resultType.getElementType())
|
||||
return op.emitError() << "inconsistent with element type of global";
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.cpp.inc"
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
#include "npcomp/E2E/E2E.h"
|
||||
|
||||
|
@ -43,11 +45,11 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
|||
|
||||
namespace {
|
||||
// TODO: There is a coupling between this pass and LowerShapedResults.
|
||||
// Any op that is wrapped in tcp.shaped_results here needs to be known how to be
|
||||
// lowered by LowerShapedResults.
|
||||
// Any op that is wrapped in refback.shaped_results here needs to be known how
|
||||
// to be lowered by LowerShapedResults.
|
||||
class BypassShapes : public BypassShapesBase<BypassShapes> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<shape::ShapeDialect>();
|
||||
registry.insert<shape::ShapeDialect, refback::RefBackendDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
|
@ -57,16 +59,16 @@ class BypassShapes : public BypassShapesBase<BypassShapes> {
|
|||
SmallVector<Value, 6> resultShapes = bypassResultShapes(op);
|
||||
if (resultShapes.empty())
|
||||
return;
|
||||
// We have result shapes, so wrap this op in a tcp.shaped_results op.
|
||||
// We have result shapes, so wrap this op in a refback.shaped_results op.
|
||||
OpBuilder builder(&op);
|
||||
auto shapedResults = builder.create<tcp::ShapedResultsOp>(
|
||||
auto shapedResults = builder.create<refback::ShapedResultsOp>(
|
||||
op.getLoc(), op.getResultTypes(), resultShapes);
|
||||
op.replaceAllUsesWith(shapedResults);
|
||||
|
||||
// Move the op into the body and yield the results.
|
||||
Block *body = builder.createBlock(&shapedResults.body());
|
||||
op.moveBefore(body, body->end());
|
||||
builder.create<tcp::YieldOp>(op.getLoc(), op.getResults());
|
||||
builder.create<refback::YieldOp>(op.getLoc(), op.getResults());
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
|
||||
|
@ -55,10 +56,10 @@ void mlir::NPCOMP::registerE2EPasses() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class LowerAllocMemRefOp : public OpRewritePattern<tcp::AllocMemRefOp> {
|
||||
class LowerAllocMemRefOp : public OpRewritePattern<refback::AllocMemRefOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(tcp::AllocMemRefOp op,
|
||||
LogicalResult matchAndRewrite(refback::AllocMemRefOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto memrefType = op.getType().cast<MemRefType>();
|
||||
auto shape = op.getOperand();
|
||||
|
@ -91,7 +92,7 @@ class LowerAllocMemRefOps
|
|||
OwningRewritePatternList patterns;
|
||||
patterns.insert<LowerAllocMemRefOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<tcp::AllocMemRefOp>();
|
||||
target.addIllegalOp<refback::AllocMemRefOp>();
|
||||
target.addLegalOp<shape::GetExtentOp>();
|
||||
target.addLegalOp<AllocOp>();
|
||||
target.addLegalOp<ConstantOp>();
|
||||
|
@ -177,7 +178,7 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
|
|||
pm.addPass(createConvertTCFToTCPPass());
|
||||
|
||||
// For operations with a shape transfer function, explicitly bypass their
|
||||
// shape computations with tcp.shaped_results ops.
|
||||
// shape computations with refback.shaped_results ops.
|
||||
//
|
||||
// Right now, our lowering flow depends heavily on descriptors, so technically
|
||||
// we don't need to bypass shapes -- we can just splat out the shape
|
||||
|
@ -220,9 +221,9 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
|
|||
// rather than a single mega dialect conversion pass.
|
||||
//
|
||||
// This means that intermediate steps have source/target materializations
|
||||
// (tcp.memref_to_tensor / tcp.tensor_to_memref) in the IR.
|
||||
// (refback.memref_to_tensor / refback.tensor_to_memref) in the IR.
|
||||
|
||||
// Lower ops enclosed in tcp.shaped_results regions.
|
||||
// Lower ops enclosed in refback.shaped_results regions.
|
||||
// For now, this is covering the "tensor compute" ops like tcp.add /
|
||||
// tcp.broadcast_to (the former being handled via a special subset of
|
||||
// linalg.generic) -- we only handle those two, so having an isolated pass
|
||||
|
@ -230,10 +231,10 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
|
|||
// more pluggable. The exact interface for this pluggability depends on
|
||||
// what design we want to settle on for bypassing shape computations.
|
||||
pm.addPass(createLowerShapedResultsToMemrefPass());
|
||||
// Lower tensor-valued constants to tcp.global.
|
||||
// Lower tensor-valued constants to refback.global.
|
||||
pm.addPass(createLowerConstantTensorsToMemrefPass());
|
||||
// tcp::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument. We
|
||||
// need to resolve this to std.alloc which takes individual extents.
|
||||
// refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument.
|
||||
// We need to resolve this to std.alloc which takes individual extents.
|
||||
pm.addPass(createLowerAllocMemRefOpsPass());
|
||||
// Lower shape ops to std.
|
||||
// TODO: This should in principle be moved before tensor->memref conversion.
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
@ -75,11 +75,11 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class LowerGlobalOp : public OpConversionPattern<tcp::GlobalOp> {
|
||||
class LowerGlobalOp : public OpConversionPattern<refback::GlobalOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tcp::GlobalOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(refback::GlobalOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<npcomprt::GlobalOp>(op, op.sym_name(),
|
||||
op.value());
|
||||
|
@ -90,11 +90,11 @@ public:
|
|||
|
||||
namespace {
|
||||
class LowerGetGlobalMemrefOp
|
||||
: public OpConversionPattern<tcp::GetGlobalMemrefOp> {
|
||||
: public OpConversionPattern<refback::GetGlobalMemrefOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tcp::GetGlobalMemrefOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(refback::GetGlobalMemrefOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto abiMemref = rewriter.create<npcomprt::GetGlobalOp>(
|
||||
op.getLoc(), getABIMemrefType(op.getType()), op.global());
|
||||
|
@ -217,10 +217,10 @@ static LogicalResult doDialectConversion(ModuleOp module) {
|
|||
[&](ReturnOp op) { return typeConverter.isLegal(op); });
|
||||
|
||||
patterns.insert<LowerGlobalOp>(context);
|
||||
target.addIllegalOp<tcp::GlobalOp>();
|
||||
target.addIllegalOp<refback::GlobalOp>();
|
||||
|
||||
patterns.insert<LowerGetGlobalMemrefOp>(context);
|
||||
target.addIllegalOp<tcp::GetGlobalMemrefOp>();
|
||||
target.addIllegalOp<refback::GetGlobalMemrefOp>();
|
||||
|
||||
patterns.insert<LowerAssertOp>(context);
|
||||
target.addIllegalOp<AssertOp>();
|
||||
|
|
|
@ -17,9 +17,8 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
@ -35,13 +34,13 @@ namespace {
|
|||
class GlobalCreator {
|
||||
public:
|
||||
explicit GlobalCreator(ModuleOp module);
|
||||
tcp::GlobalOp getGlobalFor(Attribute attr) {
|
||||
refback::GlobalOp getGlobalFor(Attribute attr) {
|
||||
assert(globals.find(attr) != globals.end() && "unknown constant attr");
|
||||
return globals[attr];
|
||||
}
|
||||
|
||||
private:
|
||||
DenseMap<Attribute, tcp::GlobalOp> globals;
|
||||
DenseMap<Attribute, refback::GlobalOp> globals;
|
||||
};
|
||||
|
||||
GlobalCreator::GlobalCreator(ModuleOp module) {
|
||||
|
@ -66,7 +65,7 @@ GlobalCreator::GlobalCreator(ModuleOp module) {
|
|||
interleave(type.getShape(), os, "x");
|
||||
os << "x" << type.getElementType();
|
||||
|
||||
auto global = globalBuilder.create<tcp::GlobalOp>(
|
||||
auto global = globalBuilder.create<refback::GlobalOp>(
|
||||
op.getLoc(), (Twine("__constant_") + os.str()).str(),
|
||||
op.getValue().cast<ElementsAttr>());
|
||||
symbolTable.insert(global);
|
||||
|
@ -82,7 +81,7 @@ namespace {
|
|||
class LowerConstantTensorsToMemref
|
||||
: public LowerConstantTensorsToMemrefBase<LowerConstantTensorsToMemref> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<tcp::TCPDialect>();
|
||||
registry.insert<refback::RefBackendDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
|
@ -98,10 +97,10 @@ class LowerConstantTensorsToMemref
|
|||
auto global = globals.getGlobalFor(op.getValue());
|
||||
OpBuilder builder(op);
|
||||
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
|
||||
auto memref = builder.create<tcp::GetGlobalMemrefOp>(
|
||||
auto memref = builder.create<refback::GetGlobalMemrefOp>(
|
||||
op.getLoc(), memrefType, global.getName());
|
||||
Value tensor =
|
||||
builder.create<tcp::MemrefToTensorOp>(op.getLoc(), type, memref);
|
||||
builder.create<refback::MemrefToTensorOp>(op.getLoc(), type, memref);
|
||||
op.replaceAllUsesWith(tensor);
|
||||
op.erase();
|
||||
});
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
|
||||
|
@ -30,13 +31,13 @@ allocateResults(Operation *op, ConversionPatternRewriter &rewriter,
|
|||
Location loc,
|
||||
SmallVectorImpl<Value> *resultShapesOut = nullptr) {
|
||||
// TODO: This is really fragile. Can we have a better story?
|
||||
auto shapedResults = dyn_cast<tcp::ShapedResultsOp>(op->getParentOp());
|
||||
auto shapedResults = dyn_cast<refback::ShapedResultsOp>(op->getParentOp());
|
||||
if (!shapedResults)
|
||||
return rewriter.notifyMatchFailure(op, "parent not tcp.shaped_results");
|
||||
return rewriter.notifyMatchFailure(op, "parent not refback.shaped_results");
|
||||
if (op->getResults() !=
|
||||
shapedResults.getBody()->getTerminator()->getOperands())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only limited forms of tcp.shaped_results allowed");
|
||||
op, "only limited forms of refback.shaped_results allowed");
|
||||
auto resultShapes = shapedResults.resultShapes();
|
||||
SmallVector<Value, 6> results;
|
||||
for (auto t : llvm::zip(op->getResults(), resultShapes)) {
|
||||
|
@ -46,7 +47,7 @@ allocateResults(Operation *op, ConversionPatternRewriter &rewriter,
|
|||
auto memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
auto memref =
|
||||
rewriter.create<tcp::AllocMemRefOp>(loc, memrefType, resultShape);
|
||||
rewriter.create<refback::AllocMemRefOp>(loc, memrefType, resultShape);
|
||||
results.push_back(memref);
|
||||
}
|
||||
if (resultShapesOut)
|
||||
|
@ -251,22 +252,22 @@ public:
|
|||
|
||||
namespace {
|
||||
// This pass is responsible for lowering regions wrapped by
|
||||
// tcp.shaped_results (which operate on tensors) to memrefs.
|
||||
// refback.shaped_results (which operate on tensors) to memrefs.
|
||||
// This includes any ops potentially contained within them.
|
||||
// This is somewhat analogous to IREE's backend compilation of a single dispatch
|
||||
// region, except that for now, we only allow a single op in the
|
||||
// tcp.shaped_results, and we don't have any notion of "backend" layered at all.
|
||||
// Nor is it clear if we really want any of that here.
|
||||
// refback.shaped_results, and we don't have any notion of "backend" layered at
|
||||
// all. Nor is it clear if we really want any of that here.
|
||||
//
|
||||
// The tcp.shaped_results ops provide precisely the information needed to
|
||||
// The refback.shaped_results ops provide precisely the information needed to
|
||||
// allocate output buffers when converting to memref.
|
||||
// For now, this process eliminates the original tcp.shaped_results op since we
|
||||
// don't have any host/device distinction or other structure that would require
|
||||
// retaining that sort of IR structure.
|
||||
// For now, this process eliminates the original refback.shaped_results op since
|
||||
// we don't have any host/device distinction or other structure that would
|
||||
// require retaining that sort of IR structure.
|
||||
//
|
||||
// TODO: Do "shape_of" resolution while still on tensors.
|
||||
// Here we spew out tons of shape_of and rely on dim ops on descriptors to make
|
||||
// it work. The key difference is that we need tcp.shaped_results (or its
|
||||
// it work. The key difference is that we need refback.shaped_results (or its
|
||||
// successor / something it gets lowered to) to not be IsolatedFromAbove, and
|
||||
// explicitly capture all input tensors along with their shapes. That allows
|
||||
// shape_of ops on inputs to be trivially resolved. Unfortunately, this opens up
|
||||
|
@ -300,14 +301,16 @@ class LowerShapedResultsToMemref
|
|||
ValueRange inputs, Location loc) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<MemRefType>());
|
||||
return (Value)builder.create<tcp::MemrefToTensorOp>(loc, type, inputs[0]);
|
||||
return (Value)builder.create<refback::MemrefToTensorOp>(loc, type,
|
||||
inputs[0]);
|
||||
});
|
||||
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
||||
MemRefType type,
|
||||
ValueRange inputs, Location loc) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<RankedTensorType>());
|
||||
return (Value)builder.create<tcp::TensorToMemrefOp>(loc, type, inputs[0]);
|
||||
return (Value)builder.create<refback::TensorToMemrefOp>(loc, type,
|
||||
inputs[0]);
|
||||
});
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -316,14 +319,14 @@ class LowerShapedResultsToMemref
|
|||
|
||||
// The shaped results ops themselves. They have to be legal since we delete
|
||||
// them later after the conversion process.
|
||||
target.addLegalOp<tcp::ShapedResultsOp>();
|
||||
target.addLegalOp<tcp::YieldOp>();
|
||||
// All lowering to buffers involves tcp.alloc_memref ops.
|
||||
target.addLegalOp<tcp::AllocMemRefOp>();
|
||||
target.addLegalOp<refback::ShapedResultsOp>();
|
||||
target.addLegalOp<refback::YieldOp>();
|
||||
// All lowering to buffers involves refback.alloc_memref ops.
|
||||
target.addLegalOp<refback::AllocMemRefOp>();
|
||||
// The casting ops are introduced by the type converter, so we should mark
|
||||
// them legal.
|
||||
target.addLegalOp<tcp::MemrefToTensorOp>();
|
||||
target.addLegalOp<tcp::TensorToMemrefOp>();
|
||||
target.addLegalOp<refback::MemrefToTensorOp>();
|
||||
target.addLegalOp<refback::TensorToMemrefOp>();
|
||||
|
||||
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||
|
@ -341,18 +344,19 @@ class LowerShapedResultsToMemref
|
|||
target.addLegalOp<shape::GetExtentOp>();
|
||||
|
||||
SmallVector<Operation *, 6> shapedResultsOps;
|
||||
func.walk([&](tcp::ShapedResultsOp op) { shapedResultsOps.push_back(op); });
|
||||
func.walk(
|
||||
[&](refback::ShapedResultsOp op) { shapedResultsOps.push_back(op); });
|
||||
|
||||
if (failed(applyFullConversion(shapedResultsOps, target, patterns)))
|
||||
return signalPassFailure();
|
||||
|
||||
// Now inline the tcp.shaped_results ops.
|
||||
// Now inline the refback.shaped_results ops.
|
||||
// This can't be done as part of the conversion since conversion visits
|
||||
// ops in preorder, and we need the tcp.shaped_results ops to be present
|
||||
// ops in preorder, and we need the refback.shaped_results ops to be present
|
||||
// so that inner ops can get their shape.
|
||||
LocallyOverrideLegalityInlinerInterface interface(context);
|
||||
for (Operation *shapedResultsOp : shapedResultsOps) {
|
||||
auto op = cast<tcp::ShapedResultsOp>(shapedResultsOp);
|
||||
auto op = cast<refback::ShapedResultsOp>(shapedResultsOp);
|
||||
if (failed(inlineRegion(interface, &op.body(), op, ValueRange({}),
|
||||
op.getResults(), /*inlineLoc=*/llvm::None,
|
||||
/*shouldCloneInlinedRegion=*/false))) {
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
@ -88,7 +88,7 @@ namespace {
|
|||
// TODO: Upstream this.
|
||||
class LowerStdToMemref : public LowerStdToMemrefBase<LowerStdToMemref> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<tcp::TCPDialect>();
|
||||
registry.insert<refback::RefBackendDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
|
@ -105,14 +105,16 @@ class LowerStdToMemref : public LowerStdToMemrefBase<LowerStdToMemref> {
|
|||
ValueRange inputs, Location loc) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<MemRefType>());
|
||||
return (Value)builder.create<tcp::MemrefToTensorOp>(loc, type, inputs[0]);
|
||||
return (Value)builder.create<refback::MemrefToTensorOp>(loc, type,
|
||||
inputs[0]);
|
||||
});
|
||||
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
||||
MemRefType type,
|
||||
ValueRange inputs, Location loc) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<RankedTensorType>());
|
||||
return (Value)builder.create<tcp::TensorToMemrefOp>(loc, type, inputs[0]);
|
||||
return (Value)builder.create<refback::TensorToMemrefOp>(loc, type,
|
||||
inputs[0]);
|
||||
});
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -123,8 +125,8 @@ class LowerStdToMemref : public LowerStdToMemrefBase<LowerStdToMemref> {
|
|||
|
||||
// The casting ops are introduced by the type converter, so they must be
|
||||
// legal.
|
||||
target.addLegalOp<tcp::MemrefToTensorOp>();
|
||||
target.addLegalOp<tcp::TensorToMemrefOp>();
|
||||
target.addLegalOp<refback::MemrefToTensorOp>();
|
||||
target.addLegalOp<refback::TensorToMemrefOp>();
|
||||
|
||||
patterns.insert<LowerExtractElementOp>(typeConverter, context);
|
||||
target.addIllegalOp<ExtractElementOp>();
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
@ -99,13 +99,13 @@ public:
|
|||
|
||||
namespace {
|
||||
class LowerTensorToMemrefOp
|
||||
: public OpConversionPattern<tcp::TensorToMemrefOp> {
|
||||
: public OpConversionPattern<refback::TensorToMemrefOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tcp::TensorToMemrefOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(refback::TensorToMemrefOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
tcp::TensorToMemrefOp::Adaptor adaptor(operands);
|
||||
refback::TensorToMemrefOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOp(op, adaptor.tensor());
|
||||
return success();
|
||||
}
|
||||
|
@ -114,13 +114,13 @@ public:
|
|||
|
||||
namespace {
|
||||
class LowerMemrefToTensorOp
|
||||
: public OpConversionPattern<tcp::MemrefToTensorOp> {
|
||||
: public OpConversionPattern<refback::MemrefToTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tcp::MemrefToTensorOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(refback::MemrefToTensorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
tcp::MemrefToTensorOp::Adaptor adaptor(operands);
|
||||
refback::MemrefToTensorOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOp(op, op.memref());
|
||||
return success();
|
||||
}
|
||||
|
@ -150,14 +150,16 @@ class LowerStructuralToMemref
|
|||
ValueRange inputs, Location loc) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<MemRefType>());
|
||||
return (Value)builder.create<tcp::MemrefToTensorOp>(loc, type, inputs[0]);
|
||||
return (Value)builder.create<refback::MemrefToTensorOp>(loc, type,
|
||||
inputs[0]);
|
||||
});
|
||||
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
||||
MemRefType type,
|
||||
ValueRange inputs, Location loc) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<RankedTensorType>());
|
||||
return (Value)builder.create<tcp::TensorToMemrefOp>(loc, type, inputs[0]);
|
||||
return (Value)builder.create<refback::TensorToMemrefOp>(loc, type,
|
||||
inputs[0]);
|
||||
});
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -181,7 +183,7 @@ class LowerStructuralToMemref
|
|||
patterns.insert<LowerForOpTypes>(typeConverter, context);
|
||||
patterns.insert<LowerTensorToMemrefOp>(typeConverter, context);
|
||||
patterns.insert<LowerMemrefToTensorOp>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::TensorToMemrefOp>();
|
||||
target.addIllegalOp<refback::TensorToMemrefOp>();
|
||||
|
||||
if (failed(applyFullConversion(func, target, patterns)))
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
|
||||
#include "npcomp/Dialect/TCF/IR/TCFDialect.h"
|
||||
#include "npcomp/Dialect/TCF/Transforms/Passes.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
|
@ -73,6 +74,7 @@ void mlir::NPCOMP::registerAllDialects(mlir::DialectRegistry ®istry) {
|
|||
Basicpy::BasicpyDialect,
|
||||
Numpy::NumpyDialect,
|
||||
npcomprt::NpcomprtDialect,
|
||||
refback::RefBackendDialect,
|
||||
tcf::TCFDialect,
|
||||
tcp::TCPDialect,
|
||||
mlir::NPCOMP::Torch::TorchDialect>();
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// CHECK-LABEL: func @tensor_to_memref
|
||||
func @tensor_to_memref_fold(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||
// CHECK-NEXT: return %arg0 : memref<?xf32>
|
||||
%0 = tcp.memref_to_tensor %arg0 : memref<?xf32> -> tensor<?xf32>
|
||||
%1 = tcp.tensor_to_memref %0 : tensor<?xf32> -> memref<?xf32>
|
||||
%0 = refback.memref_to_tensor %arg0 : memref<?xf32> -> tensor<?xf32>
|
||||
%1 = refback.tensor_to_memref %0 : tensor<?xf32> -> memref<?xf32>
|
||||
return %1 : memref<?xf32>
|
||||
}
|
|
@ -2,31 +2,31 @@
|
|||
|
||||
// -----
|
||||
|
||||
tcp.global @g dense<0.0> : tensor<2xf32>
|
||||
refback.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{must reference a valid symbol}}
|
||||
tcp.get_global_memref @nonexistent_symbol : memref<3xf32>
|
||||
refback.get_global_memref @nonexistent_symbol : memref<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
tcp.global @g dense<0.0> : tensor<2xf32>
|
||||
refback.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{inconsistent with shape of global}}
|
||||
tcp.get_global_memref @g : memref<3xf32>
|
||||
refback.get_global_memref @g : memref<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
tcp.global @g dense<0.0> : tensor<2xf32>
|
||||
refback.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{inconsistent with element type of global}}
|
||||
tcp.get_global_memref @g : memref<2xi8>
|
||||
refback.get_global_memref @g : memref<2xi8>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -34,9 +34,9 @@ func @f() {
|
|||
|
||||
func @g(%arg0: tensor<?x?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{number of operands must equal number of results}}
|
||||
%add = tcp.shaped_results %arg1, %arg1 {
|
||||
%add = refback.shaped_results %arg1, %arg1 {
|
||||
%0 = tcp.add %arg0, %arg0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
tcp.yield %0 : tensor<?x?xf32>
|
||||
refback.yield %0 : tensor<?x?xf32>
|
||||
} : tensor<?xindex>, tensor<?xindex> -> tensor<?x?xf32>
|
||||
return %add : tensor<?x?xf32>
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: refback.global @foo dense<0.0{{.*}}> : tensor<10xf32>
|
||||
refback.global @foo dense<0.0> : tensor<10xf32>
|
||||
|
||||
// CHECK-LABEL: func @global
|
||||
func @global() {
|
||||
// CHECK: refback.get_global_memref @foo : memref<10xf32>
|
||||
%0 = refback.get_global_memref @foo : memref<10xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shaped_results
|
||||
// CHECK-NEXT: %[[RET:.*]] = refback.shaped_results %arg1 {
|
||||
// CHECK-NEXT: %[[VAL:.*]] =
|
||||
// CHECK-NEXT: refback.yield %[[VAL]] : tensor<?x?xf32>
|
||||
// CHECK-NEXT: } : tensor<?xindex> -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: return %[[RET]] : tensor<?x?xf32>
|
||||
// CHECK-NEXT: }
|
||||
func @shaped_results(%arg0: tensor<?x?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||
%add = refback.shaped_results %arg1 {
|
||||
%0 = tcp.add %arg0, %arg0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
refback.yield %0 : tensor<?x?xf32>
|
||||
} : tensor<?xindex> -> tensor<?x?xf32>
|
||||
return %add : tensor<?x?xf32>
|
||||
}
|
|
@ -1,14 +1,4 @@
|
|||
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: tcp.global @foo dense<0.0{{.*}}> : tensor<10xf32>
|
||||
tcp.global @foo dense<0.0> : tensor<10xf32>
|
||||
|
||||
// CHECK-LABEL: func @global
|
||||
func @global() {
|
||||
// CHECK: tcp.get_global_memref @foo : memref<10xf32>
|
||||
%0 = tcp.get_global_memref @foo : memref<10xf32>
|
||||
return
|
||||
}
|
||||
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @binary_elementwise
|
||||
func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
|
||||
|
@ -27,18 +17,3 @@ func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
|
|||
%0 = tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shaped_results
|
||||
// CHECK-NEXT: %[[RET:.*]] = tcp.shaped_results %arg1 {
|
||||
// CHECK-NEXT: %[[VAL:.*]] =
|
||||
// CHECK-NEXT: tcp.yield %[[VAL]] : tensor<?x?xf32>
|
||||
// CHECK-NEXT: } : tensor<?xindex> -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: return %[[RET]] : tensor<?x?xf32>
|
||||
// CHECK-NEXT: }
|
||||
func @shaped_results(%arg0: tensor<?x?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||
%add = tcp.shaped_results %arg1 {
|
||||
%0 = tcp.add %arg0, %arg0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
tcp.yield %0 : tensor<?x?xf32>
|
||||
} : tensor<?xindex> -> tensor<?x?xf32>
|
||||
return %add : tensor<?x?xf32>
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
// CHECK-LABEL: func @tcp_broadcast_to
|
||||
func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) {
|
||||
// CHECK: %0 = tcp.shaped_results %arg1
|
||||
// CHECK: %0 = refback.shaped_results %arg1
|
||||
%0 = tcp.broadcast_to %arg0, %arg1 : (tensor<?xf32>, tensor<?xindex>) -> tensor<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) {
|
|||
// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: %[[LHSSHAPE:.*]] = shape.shape_of %[[LHS]]
|
||||
// CHECK: %[[RET:.*]] = tcp.shaped_results %[[LHSSHAPE]]
|
||||
// CHECK: %[[RET:.*]] = refback.shaped_results %[[LHSSHAPE]]
|
||||
// CHECK: return %[[RET:.*]] : tensor<?xf32>
|
||||
// CHECK: }
|
||||
func @tcp_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
|
@ -24,7 +24,7 @@ func @tcp_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|||
// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: %[[LHSSHAPE:.*]] = shape.shape_of %[[LHS]]
|
||||
// CHECK: %[[RET:.*]] = tcp.shaped_results %[[LHSSHAPE]]
|
||||
// CHECK: %[[RET:.*]] = refback.shaped_results %[[LHSSHAPE]]
|
||||
// CHECK: return %[[RET:.*]] : tensor<?xf32>
|
||||
// CHECK: }
|
||||
func @tcp_max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
|
@ -40,7 +40,7 @@ func @tcp_max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[RHSROWS:.*]] = dim %[[RHS]], %[[C1]]
|
||||
// CHECK: %[[RESULTSHAPE:.*]] = tensor_from_elements %[[LHSCOLS]], %[[RHSROWS]]
|
||||
// CHECK: %[[RET:.*]] = tcp.shaped_results %[[RESULTSHAPE]] {
|
||||
// CHECK: %[[RET:.*]] = refback.shaped_results %[[RESULTSHAPE]] {
|
||||
// CHECK: return %[[RET:.*]] : tensor<?x?xf32>
|
||||
// CHECK: }
|
||||
func @tcp_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
|
|
|
@ -5,7 +5,7 @@ func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> {
|
|||
// CHECK: %[[I:.*]] = constant 0 : index
|
||||
// CHECK: %[[E:.*]] = shape.get_extent %arg0, %[[I]]
|
||||
// CHECK: alloc(%[[E]])
|
||||
%0 = tcp.alloc_memref %arg0 : memref<?xf32>
|
||||
%0 = refback.alloc_memref %arg0 : memref<?xf32>
|
||||
return %0 : memref<?xf32>
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,7 @@ func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> {
|
|||
func @all_static(%arg0: tensor<?xindex>) -> memref<3x4x5xf32> {
|
||||
// CHECK-NOT: shape.get_extent
|
||||
// CHECK: alloc()
|
||||
%0 = tcp.alloc_memref %arg0 : memref<3x4x5xf32>
|
||||
%0 = refback.alloc_memref %arg0 : memref<3x4x5xf32>
|
||||
return %0 : memref<3x4x5xf32>
|
||||
}
|
||||
|
||||
|
@ -26,6 +26,6 @@ func @some_static(%arg0: tensor<?xindex>) -> memref<3x?x5x?x7xf32> {
|
|||
// CHECK-DAG: %[[I3:.*]] = constant 3 : index
|
||||
// CHECK-DAG: %[[E3:.*]] = shape.get_extent %arg0, %[[I3]]
|
||||
// CHECK: alloc(%[[E1]], %[[E3]])
|
||||
%0 = tcp.alloc_memref %arg0 : memref<3x?x5x?x7xf32>
|
||||
%0 = refback.alloc_memref %arg0 : memref<3x?x5x?x7xf32>
|
||||
return %0 : memref<3x?x5x?x7xf32>
|
||||
}
|
||||
|
|
|
@ -3,11 +3,11 @@
|
|||
// CHECK-LABEL: module {
|
||||
// We check the debug name too since we put some effort into making that readable.
|
||||
// The name isn't load-bearing though.
|
||||
// CHECK: tcp.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32>
|
||||
// CHECK: refback.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32>
|
||||
// CHECK: func @basic
|
||||
func @basic() -> tensor<3x4xf32> {
|
||||
// CHECK: %[[MEMREF:.*]] = tcp.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
|
||||
// CHECK: %[[TENSOR:.*]] = tcp.memref_to_tensor %[[MEMREF]]
|
||||
// CHECK: %[[MEMREF:.*]] = refback.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
|
||||
// CHECK: %[[TENSOR:.*]] = refback.memref_to_tensor %[[MEMREF]]
|
||||
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||
// CHECK: return %[[TENSOR]]
|
||||
return %0 : tensor<3x4xf32>
|
||||
|
@ -20,8 +20,8 @@ func @basic() -> tensor<3x4xf32> {
|
|||
// CHECK-LABEL: module {
|
||||
|
||||
// Only one global is created.
|
||||
// CHECK: tcp.global
|
||||
// CHECK-NOT: tcp.global
|
||||
// CHECK: refback.global
|
||||
// CHECK-NOT: refback.global
|
||||
func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
|
||||
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||
%1 = constant dense<7.0> : tensor<3x4xf32>
|
||||
|
@ -36,9 +36,9 @@ func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
|
|||
// CHECK-LABEL: module {
|
||||
|
||||
// Two globals are created.
|
||||
// CHECK: tcp.global
|
||||
// CHECK: tcp.global
|
||||
// CHECK-NOT: tcp.global
|
||||
// CHECK: refback.global
|
||||
// CHECK: refback.global
|
||||
// CHECK-NOT: refback.global
|
||||
func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
|
||||
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||
%1 = constant dense<8.0> : tensor<3x4xf32>
|
||||
|
@ -52,7 +52,7 @@ func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
|
|||
|
||||
// CHECK-LABEL: module {
|
||||
// We don't convert non-tensor globals.
|
||||
// CHECK-NOT: tcp.global
|
||||
// CHECK-NOT: refback.global
|
||||
func @non_tensor() {
|
||||
%0 = constant 7 : i32
|
||||
return
|
||||
|
|
|
@ -7,10 +7,10 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?
|
|||
// buffer version of tcp.broadcast_to
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: tcp.shaped_results
|
||||
%0 = tcp.shaped_results %arg1 {
|
||||
// CHECK-NOT: refback.shaped_results
|
||||
%0 = refback.shaped_results %arg1 {
|
||||
%0 = tcp.broadcast_to %arg0, %arg1 : (tensor<?xf32>, tensor<?xindex>) -> tensor<?x?xf32>
|
||||
tcp.yield %0 : tensor<?x?xf32>
|
||||
refback.yield %0 : tensor<?x?xf32>
|
||||
} : tensor<?xindex> -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
@ -20,22 +20,22 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?
|
|||
// CHECK-SAME: %arg0: tensor<?xf32>,
|
||||
// CHECK-SAME: %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: %[[LHSSHAPE:.*]] = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
|
||||
// CHECK: %[[LHS:.*]] = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
||||
// CHECK: %[[RHS:.*]] = tcp.tensor_to_memref %arg1 : tensor<?xf32> -> memref<?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = tcp.alloc_memref %[[LHSSHAPE]] : memref<?xf32>
|
||||
// CHECK: %[[LHS:.*]] = refback.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
||||
// CHECK: %[[RHS:.*]] = refback.tensor_to_memref %arg1 : tensor<?xf32> -> memref<?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[LHSSHAPE]] : memref<?xf32>
|
||||
// CHECK: linalg.generic {{.*}} ins(%[[LHS]], %[[RHS]] {{.*}}) outs(%[[RESULT]] {{.*}}) {
|
||||
// CHECK: ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
|
||||
// CHECK: %[[VAL_9:.*]] = addf %[[VAL_6]], %[[VAL_7]] : f32
|
||||
// CHECK: linalg.yield %[[VAL_9]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[RESULT]] : memref<?xf32> -> tensor<?xf32>
|
||||
// CHECK: %[[RET:.*]] = refback.memref_to_tensor %[[RESULT]] : memref<?xf32> -> tensor<?xf32>
|
||||
// CHECK: return %[[RET]] : tensor<?xf32>
|
||||
// CHECK: }
|
||||
func @tcp_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
|
||||
%1 = tcp.shaped_results %0 {
|
||||
%1 = refback.shaped_results %0 {
|
||||
%2 = tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
tcp.yield %2 : tensor<?xf32>
|
||||
refback.yield %2 : tensor<?xf32>
|
||||
} : tensor<?xindex> -> tensor<?xf32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
@ -49,9 +49,9 @@ func @tcp_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|||
// CHECK: linalg.yield %[[MAX]] : f32
|
||||
func @tcp_max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
|
||||
%1 = tcp.shaped_results %0 {
|
||||
%1 = refback.shaped_results %0 {
|
||||
%2 = tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
tcp.yield %2 : tensor<?xf32>
|
||||
refback.yield %2 : tensor<?xf32>
|
||||
} : tensor<?xindex> -> tensor<?xf32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
@ -62,19 +62,19 @@ func @tcp_max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|||
// CHECK-SAME: %arg0: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %arg1: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||
// CHECK: %[[LHS:.*]] = tcp.tensor_to_memref %arg0 : tensor<?x?xf32> -> memref<?x?xf32>
|
||||
// CHECK: %[[RHS:.*]] = tcp.tensor_to_memref %arg1 : tensor<?x?xf32> -> memref<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = tcp.alloc_memref %[[SHAPE]] : memref<?x?xf32>
|
||||
// CHECK: %[[LHS:.*]] = refback.tensor_to_memref %arg0 : tensor<?x?xf32> -> memref<?x?xf32>
|
||||
// CHECK: %[[RHS:.*]] = refback.tensor_to_memref %arg1 : tensor<?x?xf32> -> memref<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[SHAPE]] : memref<?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: linalg.fill(%2, %[[C0]]) : memref<?x?xf32>, f32
|
||||
// CHECK: linalg.matmul ins(%[[LHS]], %[[RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[RESULT]] : memref<?x?xf32>)
|
||||
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[RESULT]] : memref<?x?xf32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]] = refback.memref_to_tensor %[[RESULT]] : memref<?x?xf32> -> tensor<?x?xf32>
|
||||
// CHECK: return %[[RET]] : tensor<?x?xf32>
|
||||
// CHECK: }
|
||||
func @tcp_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %shape: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||
%0 = tcp.shaped_results %shape {
|
||||
%0 = refback.shaped_results %shape {
|
||||
%matmul = tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
tcp.yield %matmul : tensor<?x?xf32>
|
||||
refback.yield %matmul : tensor<?x?xf32>
|
||||
} : tensor<?xindex> -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// that to make the test actually check what happens in practice.
|
||||
|
||||
// CHECK-LABEL: func @extract_element
|
||||
// CHECK: %[[MEMREF:.*]] = tcp.tensor_to_memref %arg0
|
||||
// CHECK: %[[MEMREF:.*]] = refback.tensor_to_memref %arg0
|
||||
// CHECK: %[[RET:.*]] = load %[[MEMREF]][%arg1] : memref<?xf32>
|
||||
// CHECK: return %[[RET]] : f32
|
||||
func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
|
||||
|
@ -20,7 +20,7 @@ func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
|
|||
// CHECK: store %[[ARG0]], %[[MEMREF]][%[[C0]]]
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: store %[[ARG1]], %[[MEMREF]][%[[C1]]]
|
||||
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[MEMREF]]
|
||||
// CHECK: %[[RET:.*]] = refback.memref_to_tensor %[[MEMREF]]
|
||||
// CHECK: return %[[RET]] : tensor<2xindex>
|
||||
func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
|
||||
%0 = tensor_from_elements %arg0, %arg1 : tensor<2xindex>
|
||||
|
@ -30,9 +30,9 @@ func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
|
|||
|
||||
// CHECK-LABEL: func @tensor_cast(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xindex>) -> tensor<2xindex> {
|
||||
// CHECK: %[[MEMREF:.*]] = tcp.tensor_to_memref %[[ARG0]] : tensor<?xindex> -> memref<?xindex>
|
||||
// CHECK: %[[MEMREF:.*]] = refback.tensor_to_memref %[[ARG0]] : tensor<?xindex> -> memref<?xindex>
|
||||
// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
|
||||
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[CASTED]] : memref<2xindex> -> tensor<2xindex>
|
||||
// CHECK: %[[RET:.*]] = refback.memref_to_tensor %[[CASTED]] : memref<2xindex> -> tensor<2xindex>
|
||||
// CHECK: return %[[RET]] : tensor<2xindex>
|
||||
func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
|
||||
%0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
|
||||
|
@ -41,10 +41,9 @@ func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
|
|||
|
||||
// CHECK-LABEL: func @tensor_load(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<?xindex>) -> tensor<?xindex> {
|
||||
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[ARG0]] : memref<?xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[RET:.*]] = refback.memref_to_tensor %[[ARG0]] : memref<?xindex> -> tensor<?xindex>
|
||||
// CHECK: return %[[RET]] : tensor<?xindex>
|
||||
func @tensor_load(%arg0: memref<?xindex>) -> tensor<?xindex> {
|
||||
%0 = tensor_load %arg0 : memref<?xindex>
|
||||
return %0 : tensor<?xindex>
|
||||
}
|
||||
|
||||
|
|
|
@ -63,8 +63,8 @@ func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f3
|
|||
// CHECK-LABEL: func @identity_materializations(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||
// CHECK-NEXT: return %arg0 : memref<?xf32>
|
||||
func @identity_materializations(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
||||
%1 = tcp.memref_to_tensor %0 : memref<?xf32> -> tensor<?xf32>
|
||||
%0 = refback.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
||||
%1 = refback.memref_to_tensor %0 : memref<?xf32> -> tensor<?xf32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -76,7 +76,7 @@ func @identity_materializations(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[RET]] : memref<?xf32>
|
||||
func @if_materializations(%pred: i1, %true_val_memref: memref<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%true_val = tcp.memref_to_tensor %true_val_memref : memref<?xf32> -> tensor<?xf32>
|
||||
%true_val = refback.memref_to_tensor %true_val_memref : memref<?xf32> -> tensor<?xf32>
|
||||
%0 = scf.if %pred -> (tensor<?xf32>) {
|
||||
scf.yield %true_val : tensor<?xf32>
|
||||
} else {
|
||||
|
@ -88,13 +88,13 @@ func @if_materializations(%pred: i1, %true_val_memref: memref<?xf32>, %false_val
|
|||
// CHECK-LABEL: func @elide_memref_to_tensor(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||
// CHECK-NEXT: return %arg0 : memref<?xf32>
|
||||
func @elide_memref_to_tensor(%arg0: memref<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcp.memref_to_tensor %arg0 : memref<?xf32> -> tensor<?xf32>
|
||||
%0 = refback.memref_to_tensor %arg0 : memref<?xf32> -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @elide_tensor_to_memref(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||
// CHECK-NEXT: return %arg0 : memref<?xf32>
|
||||
func @elide_tensor_to_memref(%arg0: tensor<?xf32>) -> memref<?xf32> {
|
||||
%0 = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
||||
%0 = refback.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
||||
return %0 : memref<?xf32>
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
|
|||
|
||||
|
||||
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32>
|
||||
tcp.global @g dense<7.0> : tensor<10xf32>
|
||||
refback.global @g dense<7.0> : tensor<10xf32>
|
||||
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
|
||||
func @gets_global() -> memref<10xf32> {
|
||||
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32>
|
||||
|
@ -83,7 +83,7 @@ func @gets_global() -> memref<10xf32> {
|
|||
// CHECK: %[[OUTABIMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
|
||||
// CHECK: %[[RET:.*]] = npcomprt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
|
||||
// CHECK: return %[[RET]] : !npcomprt.tensor
|
||||
%0 = tcp.get_global_memref @g : memref<10xf32>
|
||||
%0 = refback.get_global_memref @g : memref<10xf32>
|
||||
return %0 : memref<10xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue