[RefBackend] Split out RefBackend (refback) dialect from TCP.

This is the first in a patch series that is refactoring the
constellation of things variously called or associated with "E2E",
"RefE2E", "npcomprt", and "TCP" into a more cleanly layered result.

Concretely, this first patch fixes the fact that TCP was basically
acting like a dumping ground needed by the reference backend. This
splits it out, which is fairly mechanical, but touches a lot of lines of
code (basically replacing `tcp` with `refback` and `TCP` with
`RefBackend).

Now, the RefBackend dialect is that dumping ground, which
is slighly better, as it starts allowing TCP to become a nice clean
middle layer that is not related per se to the reference backend.

The previous name RefE2E or "reference e2e flow" was super confusing.
Now that we are seeing more clearly where the "backend" distinction
lies, the [RefBackend] commit tag is born :)
pull/69/head
Sean Silva 2020-10-06 15:44:18 -07:00
parent 3ccc2214a7
commit 5017430dc7
36 changed files with 577 additions and 397 deletions

View File

@ -2,6 +2,7 @@ add_subdirectory(ATen)
add_subdirectory(Basicpy)
add_subdirectory(Npcomprt)
add_subdirectory(Numpy)
add_subdirectory(RefBackend)
add_subdirectory(TCF)
add_subdirectory(TCP)
add_subdirectory(Torch)

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1 @@
add_mlir_dialect(RefBackendOps refback)

View File

@ -0,0 +1,23 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef REFBACKEND_BASE
#define REFBACKEND_BASE
include "mlir/IR/OpBase.td"
def RefBackend_Dialect : Dialect {
let name = "refback";
let cppNamespace = "::mlir::NPCOMP::refback";
let description = [{
Ops used by the reference backend as part of its lowering.
}];
}
#endif // REFBACKEND_BASE

View File

@ -0,0 +1,16 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDDIALECT_H
#define NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDDIALECT_H
#include "mlir/IR/Dialect.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendOpsDialect.h.inc"
#endif // NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDDIALECT_H

View File

@ -0,0 +1,22 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDOPS_H
#define NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDOPS_H
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#define GET_OP_CLASSES
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h.inc"
#endif // NPCOMP_DIALECT_REFBACKEND_IR_REFBACKENDOPS_H

View File

@ -0,0 +1,147 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef REFBACKEND_OPS
#define REFBACKEND_OPS
include "npcomp/Dialect/RefBackend/IR/RefBackendBase.td"
include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
class RefBackend_Op<string mnemonic, list<OpTrait> traits = []>
: Op<RefBackend_Dialect, mnemonic, traits> {
}
def RefBackend_GlobalOp : RefBackend_Op<"global", [Symbol]> {
let summary = "Represents a global variable";
let description = [{
Represents a global variable.
Currently, only constant tensors are supported, and they are not
considered to be exported.
}];
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
let results = (outs);
let printer = [{ return ::print$cppClass(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
//===----------------------------------------------------------------------===//
// Ops related to tensor->memref conversion.
//===----------------------------------------------------------------------===//
// TODO: These ops probably belong in a "RefBackend on memrefs" dialect analogous
// to `lmhlo`
// TODO: Use TypesMatchWith to verify this better.
def RefBackend_TensorToMemrefOp : RefBackend_Op<"tensor_to_memref", [NoSideEffect]> {
let summary = "Converts a tensor to a memref";
let description = [{
This op is used to materialize conversions to allow incremental lowering of
tensors to memrefs.
}];
let arguments = (ins AnyRankedTensor:$tensor);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "attr-dict $tensor `:` type($tensor) `->` type($memref)";
let hasFolder = 1;
}
// TODO: Use TypesMatchWith to verify this better.
def RefBackend_MemrefToTensorOp : RefBackend_Op<"memref_to_tensor", [NoSideEffect]> {
let summary = "Converts a memref to a tensor";
let description = [{
This op is used to materialize conversions to allow incremental lowering of
tensors to memrefs.
}];
let arguments = (ins AnyMemRef:$memref);
let results = (outs AnyRankedTensor:$tensor);
let assemblyFormat = "attr-dict $memref `:` type($memref) `->` type($tensor)";
}
def RefBackend_AllocMemRefOp : RefBackend_Op<"alloc_memref", []> {
let summary = "Allocates a memref of the given shape.";
let description = [{
Allocates a memref of the given shape.
This op is a convenience for creating a bunch of
shape.get_extent + std.alloc ops.
}];
let arguments = (ins Shape_ExtentTensorType:$shape);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "$shape attr-dict `:` type($memref)";
}
def RefBackend_GetGlobalMemrefOp : RefBackend_Op<"get_global_memref"> {
let summary = "Obtain a memref pointing at the given global";
let description = [{
Obtain a memref pointing at the given global.
}];
let arguments = (ins FlatSymbolRefAttr:$global);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "$global attr-dict `:` type($memref)";
let verifier = "return ::verify$cppClass(*this);";
}
//===----------------------------------------------------------------------===//
// Ops related to shapes.
//===----------------------------------------------------------------------===//
// TODO: These belong in a shape-related dialect.
def RefBackend_ShapedResultsOp : RefBackend_Op<"shaped_results", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">,
RecursiveSideEffects,
NoRegionArguments
]> {
let summary = "Result shape annotation";
let description = [{
Represents a computation whose outputs have a precomputed shape.
The i-th result has the shape described by the i-th operand.
This op is not isolated from above, so if the region needs any inputs,
they can simply be captured. Hence, this op is a
"this tensor has this shape" annotation with a slightly different set of
tradeoffs than the so-called "tie shape" kinds of operations.
In particular, this region-based formulation has the opportunity to
capture structural invariants.
Example:
```mlir
// sincos is an elementwise operation, so it doesn't change the shape.
%x = ...
%xShape = ...
%sin, %cos = refback.shaped_results %xShape, %xShape {
%sin, cos = "some.sincos"(%x)
: tensor<?xf32> -> (tensor<?xf32>, tensor<?xf32>)
refback.yield %sin, %cos : tensor<?xf32>, tensor<?xf32>
}
```
}];
let arguments = (ins
Variadic<Shape_ExtentTensorType>:$resultShapes
);
let results = (outs Variadic<AnyTensor>:$results);
let regions = (region SizedRegion<1>:$body);
let printer = [{ return ::print$cppClass(p, *this); }];
let verifier = [{ return ::verify$cppClass(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def RefBackend_YieldOp : RefBackend_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
ParentOneOf<["ShapedResultsOp"]>]> {
let summary = "Yield-like terminator for RefBackend dialect";
let description = "See scf.yield";
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
#endif // REFBACKEND_OPS

View File

@ -105,128 +105,4 @@ It is undefined behavior if such a broadcast is not legal.
let assemblyFormat = "$operand `,` $shape attr-dict `:` functional-type(operands, results)";
}
def TCP_GlobalOp : TCP_Op<"global", [Symbol]> {
let summary = "Represents a global variable";
let description = [{
Represents a global variable.
Currently, only constant tensors are supported, and they are not
considered to be exported.
}];
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
let results = (outs);
let printer = [{ return ::print$cppClass(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
//===----------------------------------------------------------------------===//
// Ops related to tensor->memref conversion.
//===----------------------------------------------------------------------===//
// TODO: These ops probably belong in a "TCP on memrefs" dialect analogous
// to `lmhlo`
// TODO: Use TypesMatchWith to verify this better.
def TCP_TensorToMemrefOp : TCP_Op<"tensor_to_memref", [NoSideEffect]> {
let summary = "Converts a tensor to a memref";
let description = [{
This op is used to materialize conversions to allow incremental lowering of
tensors to memrefs.
}];
let arguments = (ins AnyRankedTensor:$tensor);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "attr-dict $tensor `:` type($tensor) `->` type($memref)";
let hasFolder = 1;
}
// TODO: Use TypesMatchWith to verify this better.
def TCP_MemrefToTensorOp : TCP_Op<"memref_to_tensor", [NoSideEffect]> {
let summary = "Converts a memref to a tensor";
let description = [{
This op is used to materialize conversions to allow incremental lowering of
tensors to memrefs.
}];
let arguments = (ins AnyMemRef:$memref);
let results = (outs AnyRankedTensor:$tensor);
let assemblyFormat = "attr-dict $memref `:` type($memref) `->` type($tensor)";
}
def TCP_AllocMemRefOp : TCP_Op<"alloc_memref", []> {
let summary = "Allocates a memref of the given shape.";
let description = [{
Allocates a memref of the given shape.
This op is a convenience for creating a bunch of
shape.get_extent + std.alloc ops.
}];
let arguments = (ins Shape_ExtentTensorType:$shape);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "$shape attr-dict `:` type($memref)";
}
def TCP_GetGlobalMemrefOp : TCP_Op<"get_global_memref"> {
let summary = "Obtain a memref pointing at the given global";
let description = [{
Obtain a memref pointing at the given global.
}];
let arguments = (ins FlatSymbolRefAttr:$global);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "$global attr-dict `:` type($memref)";
let verifier = "return ::verify$cppClass(*this);";
}
//===----------------------------------------------------------------------===//
// Ops related to shapes.
//===----------------------------------------------------------------------===//
// TODO: These belong in a shape-related dialect.
def TCP_ShapedResultsOp : TCP_Op<"shaped_results", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">,
RecursiveSideEffects,
NoRegionArguments
]> {
let summary = "Result shape annotation";
let description = [{
Represents a computation whose outputs have a precomputed shape.
The i-th result has the shape described by the i-th operand.
This op is not isolated from above, so if the region needs any inputs,
they can simply be captured. Hence, this op is a
"this tensor has this shape" annotation with a slightly different set of
tradeoffs than the so-called "tie shape" kinds of operations.
In particular, this region-based formulation has the opportunity to
capture structural invariants.
Example:
```mlir
// sincos is an elementwise operation, so it doesn't change the shape.
%x = ...
%xShape = ...
%sin, %cos = tcp.shaped_results %xShape, %xShape {
%sin, cos = "some.sincos"(%x)
: tensor<?xf32> -> (tensor<?xf32>, tensor<?xf32>)
tcp.yield %sin, %cos : tensor<?xf32>, tensor<?xf32>
}
```
}];
let arguments = (ins
Variadic<Shape_ExtentTensorType>:$resultShapes
);
let results = (outs Variadic<AnyTensor>:$results);
let regions = (region SizedRegion<1>:$body);
let printer = [{ return ::print$cppClass(p, *this); }];
let verifier = [{ return ::verify$cppClass(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def TCP_YieldOp : TCP_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
ParentOneOf<["ShapedResultsOp"]>]> {
let summary = "Yield-like terminator for TCP dialect";
let description = "See scf.yield";
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
#endif // TCP_OPS

View File

@ -17,7 +17,7 @@ def BypassShapes : Pass<"bypass-shapes", "FuncOp"> {
}
def LowerShapedResultsToMemref : Pass<"lower-shaped-results-to-memref", "FuncOp"> {
let summary = "Lower tcp.shaped_results regions";
let summary = "Lower refback.shaped_results regions";
let constructor = "mlir::NPCOMP::createLowerShapedResultsToMemrefPass()";
}
@ -30,7 +30,7 @@ def LowerConstantTensorsToMemref :
Pass<"lower-constant-tensors-to-memref", "ModuleOp"> {
let summary = "Lower std.constant of tensor type to memref";
let description = [{
This must be a module pass since it involves creating tcp.global ops.
This must be a module pass since it involves creating refback.global ops.
}];
let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefPass()";
}

View File

@ -41,6 +41,7 @@ add_mlir_library(NPCOMPInitAll
PUBLIC
# Local depends
NPCOMPE2E
NPCOMPRefBackendDialect
NPCOMPTCP
NPCOMPTCF
NPCOMPTorchDialect

View File

@ -2,6 +2,7 @@ add_subdirectory(ATen)
add_subdirectory(Basicpy)
add_subdirectory(Npcomprt)
add_subdirectory(Numpy)
add_subdirectory(RefBackend)
add_subdirectory(TCF)
add_subdirectory(TCP)
add_subdirectory(Torch)

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,19 @@
add_mlir_dialect_library(NPCOMPRefBackendDialect
RefBackendDialect.cpp
RefBackendOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/RefBackend
DEPENDS
MLIRRefBackendOpsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
MLIRSideEffectInterfaces
MLIRShape
)

View File

@ -0,0 +1,50 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
#include "mlir/Transforms/InliningUtils.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
using namespace mlir;
using namespace mlir::NPCOMP::refback;
//===----------------------------------------------------------------------===//
// RefBackendDialect Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct RefBackendInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src,
BlockAndValueMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *,
BlockAndValueMapping &) const final {
return true;
}
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
auto retValOp = dyn_cast<YieldOp>(op);
if (!retValOp)
return;
for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
}
}
};
} // end anonymous namespace
void RefBackendDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.cpp.inc"
>();
addInterfaces<RefBackendInlinerInterface>();
}

View File

@ -0,0 +1,126 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::refback;
//===----------------------------------------------------------------------===//
// TensorToMemrefOp
//===----------------------------------------------------------------------===//
OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute> operands) {
if (auto memrefToTensor = tensor().getDefiningOp<refback::MemrefToTensorOp>())
return memrefToTensor.memref();
return nullptr;
}
//===----------------------------------------------------------------------===//
// ShapedResultsOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyShapedResultsOp(ShapedResultsOp op) {
if (op.getNumOperands() != op.getNumResults())
return op.emitError() << "number of operands must equal number of results";
if (op.getNumOperands() == 0)
return op.emitError() << "must have at least one operand/result";
return RegionBranchOpInterface::verifyTypes(op);
}
static void printShapedResultsOp(OpAsmPrinter &p, ShapedResultsOp &op) {
p << "refback.shaped_results ";
p.printOptionalAttrDictWithKeyword(op.getAttrs());
p.printOperands(op.getOperands());
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p << " : ";
interleaveComma(op.getOperandTypes(), p);
p << " -> ";
interleaveComma(op.getResultTypes(), p);
}
static ParseResult parseShapedResultsOp(OpAsmParser &parser,
OperationState &result) {
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
SmallVector<OpAsmParser::OperandType, 6> operands;
if (parser.parseOperandList(operands))
return failure();
auto *body = result.addRegion();
if (parser.parseRegion(*body, llvm::None, llvm::None))
return failure();
SmallVector<Type, 6> inputTypes;
if (parser.parseColonTypeList(inputTypes))
return failure();
if (parser.resolveOperands(operands, inputTypes, parser.getNameLoc(),
result.operands))
return failure();
if (parser.parseArrowTypeList(result.types))
return failure();
return success();
}
void ShapedResultsOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
if (index.hasValue())
regions.push_back(RegionSuccessor(getResults()));
else
regions.push_back(RegionSuccessor(&body()));
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
p << "refback.global ";
p.printSymbolName(op.sym_name());
p << ' ';
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
/*elidedAttrs=*/{"sym_name", "value"});
p.printAttribute(op.valueAttr());
}
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
Attribute valueAttr;
if (parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// GetGlobalMemrefOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) {
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
if (!global)
return op.emitError() << "must reference a valid symbol";
auto globalType = global.value().getType();
auto resultType = op.getType().cast<ShapedType>();
if (failed(
verifyCompatibleShape(globalType.getShape(), resultType.getShape())))
return op.emitError() << "inconsistent with shape of global";
if (globalType.getElementType() != resultType.getElementType())
return op.emitError() << "inconsistent with element type of global";
return success();
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.cpp.inc"

View File

@ -28,16 +28,6 @@ struct TCPInlinerInterface : public DialectInlinerInterface {
BlockAndValueMapping &) const final {
return true;
}
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
auto retValOp = dyn_cast<YieldOp>(op);
if (!retValOp)
return;
for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
}
}
};
} // end anonymous namespace

View File

@ -15,112 +15,5 @@ using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::tcp;
//===----------------------------------------------------------------------===//
// TensorToMemrefOp
//===----------------------------------------------------------------------===//
OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute> operands) {
if (auto memrefToTensor = tensor().getDefiningOp<tcp::MemrefToTensorOp>())
return memrefToTensor.memref();
return nullptr;
}
//===----------------------------------------------------------------------===//
// ShapedResultsOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyShapedResultsOp(ShapedResultsOp op) {
if (op.getNumOperands() != op.getNumResults())
return op.emitError() << "number of operands must equal number of results";
if (op.getNumOperands() == 0)
return op.emitError() << "must have at least one operand/result";
return RegionBranchOpInterface::verifyTypes(op);
}
static void printShapedResultsOp(OpAsmPrinter &p, ShapedResultsOp &op) {
p << "tcp.shaped_results ";
p.printOptionalAttrDictWithKeyword(op.getAttrs());
p.printOperands(op.getOperands());
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p << " : ";
interleaveComma(op.getOperandTypes(), p);
p << " -> ";
interleaveComma(op.getResultTypes(), p);
}
static ParseResult parseShapedResultsOp(OpAsmParser &parser,
OperationState &result) {
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
SmallVector<OpAsmParser::OperandType, 6> operands;
if (parser.parseOperandList(operands))
return failure();
auto *body = result.addRegion();
if (parser.parseRegion(*body, llvm::None, llvm::None))
return failure();
SmallVector<Type, 6> inputTypes;
if (parser.parseColonTypeList(inputTypes))
return failure();
if (parser.resolveOperands(operands, inputTypes, parser.getNameLoc(),
result.operands))
return failure();
if (parser.parseArrowTypeList(result.types))
return failure();
return success();
}
void ShapedResultsOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
if (index.hasValue())
regions.push_back(RegionSuccessor(getResults()));
else
regions.push_back(RegionSuccessor(&body()));
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
p << "tcp.global ";
p.printSymbolName(op.sym_name());
p << ' ';
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
/*elidedAttrs=*/{"sym_name", "value"});
p.printAttribute(op.valueAttr());
}
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
Attribute valueAttr;
if (parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// GetGlobalMemrefOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) {
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
if (!global)
return op.emitError() << "must reference a valid symbol";
auto globalType = global.value().getType();
auto resultType = op.getType().cast<ShapedType>();
if (failed(
verifyCompatibleShape(globalType.getShape(), resultType.getShape())))
return op.emitError() << "inconsistent with shape of global";
if (globalType.getElementType() != resultType.getElementType())
return op.emitError() << "inconsistent with element type of global";
return success();
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/TCP/IR/TCPOps.cpp.inc"

View File

@ -10,6 +10,8 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
#include "npcomp/E2E/E2E.h"
@ -43,11 +45,11 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
namespace {
// TODO: There is a coupling between this pass and LowerShapedResults.
// Any op that is wrapped in tcp.shaped_results here needs to be known how to be
// lowered by LowerShapedResults.
// Any op that is wrapped in refback.shaped_results here needs to be known how
// to be lowered by LowerShapedResults.
class BypassShapes : public BypassShapesBase<BypassShapes> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect>();
registry.insert<shape::ShapeDialect, refback::RefBackendDialect>();
}
void runOnOperation() override {
@ -57,16 +59,16 @@ class BypassShapes : public BypassShapesBase<BypassShapes> {
SmallVector<Value, 6> resultShapes = bypassResultShapes(op);
if (resultShapes.empty())
return;
// We have result shapes, so wrap this op in a tcp.shaped_results op.
// We have result shapes, so wrap this op in a refback.shaped_results op.
OpBuilder builder(&op);
auto shapedResults = builder.create<tcp::ShapedResultsOp>(
auto shapedResults = builder.create<refback::ShapedResultsOp>(
op.getLoc(), op.getResultTypes(), resultShapes);
op.replaceAllUsesWith(shapedResults);
// Move the op into the body and yield the results.
Block *body = builder.createBlock(&shapedResults.body());
op.moveBefore(body, body->end());
builder.create<tcp::YieldOp>(op.getLoc(), op.getResults());
builder.create<refback::YieldOp>(op.getLoc(), op.getResults());
});
}
};

View File

@ -27,6 +27,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
@ -55,10 +56,10 @@ void mlir::NPCOMP::registerE2EPasses() {
//===----------------------------------------------------------------------===//
namespace {
class LowerAllocMemRefOp : public OpRewritePattern<tcp::AllocMemRefOp> {
class LowerAllocMemRefOp : public OpRewritePattern<refback::AllocMemRefOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tcp::AllocMemRefOp op,
LogicalResult matchAndRewrite(refback::AllocMemRefOp op,
PatternRewriter &rewriter) const override {
auto memrefType = op.getType().cast<MemRefType>();
auto shape = op.getOperand();
@ -91,7 +92,7 @@ class LowerAllocMemRefOps
OwningRewritePatternList patterns;
patterns.insert<LowerAllocMemRefOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<tcp::AllocMemRefOp>();
target.addIllegalOp<refback::AllocMemRefOp>();
target.addLegalOp<shape::GetExtentOp>();
target.addLegalOp<AllocOp>();
target.addLegalOp<ConstantOp>();
@ -177,7 +178,7 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
pm.addPass(createConvertTCFToTCPPass());
// For operations with a shape transfer function, explicitly bypass their
// shape computations with tcp.shaped_results ops.
// shape computations with refback.shaped_results ops.
//
// Right now, our lowering flow depends heavily on descriptors, so technically
// we don't need to bypass shapes -- we can just splat out the shape
@ -220,9 +221,9 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
// rather than a single mega dialect conversion pass.
//
// This means that intermediate steps have source/target materializations
// (tcp.memref_to_tensor / tcp.tensor_to_memref) in the IR.
// (refback.memref_to_tensor / refback.tensor_to_memref) in the IR.
// Lower ops enclosed in tcp.shaped_results regions.
// Lower ops enclosed in refback.shaped_results regions.
// For now, this is covering the "tensor compute" ops like tcp.add /
// tcp.broadcast_to (the former being handled via a special subset of
// linalg.generic) -- we only handle those two, so having an isolated pass
@ -230,10 +231,10 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
// more pluggable. The exact interface for this pluggability depends on
// what design we want to settle on for bypassing shape computations.
pm.addPass(createLowerShapedResultsToMemrefPass());
// Lower tensor-valued constants to tcp.global.
// Lower tensor-valued constants to refback.global.
pm.addPass(createLowerConstantTensorsToMemrefPass());
// tcp::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument. We
// need to resolve this to std.alloc which takes individual extents.
// refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument.
// We need to resolve this to std.alloc which takes individual extents.
pm.addPass(createLowerAllocMemRefOpsPass());
// Lower shape ops to std.
// TODO: This should in principle be moved before tensor->memref conversion.

View File

@ -16,7 +16,7 @@
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
@ -75,11 +75,11 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
//===----------------------------------------------------------------------===//
namespace {
class LowerGlobalOp : public OpConversionPattern<tcp::GlobalOp> {
class LowerGlobalOp : public OpConversionPattern<refback::GlobalOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::GlobalOp op, ArrayRef<Value> operands,
matchAndRewrite(refback::GlobalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<npcomprt::GlobalOp>(op, op.sym_name(),
op.value());
@ -90,11 +90,11 @@ public:
namespace {
class LowerGetGlobalMemrefOp
: public OpConversionPattern<tcp::GetGlobalMemrefOp> {
: public OpConversionPattern<refback::GetGlobalMemrefOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::GetGlobalMemrefOp op, ArrayRef<Value> operands,
matchAndRewrite(refback::GetGlobalMemrefOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto abiMemref = rewriter.create<npcomprt::GetGlobalOp>(
op.getLoc(), getABIMemrefType(op.getType()), op.global());
@ -217,10 +217,10 @@ static LogicalResult doDialectConversion(ModuleOp module) {
[&](ReturnOp op) { return typeConverter.isLegal(op); });
patterns.insert<LowerGlobalOp>(context);
target.addIllegalOp<tcp::GlobalOp>();
target.addIllegalOp<refback::GlobalOp>();
patterns.insert<LowerGetGlobalMemrefOp>(context);
target.addIllegalOp<tcp::GetGlobalMemrefOp>();
target.addIllegalOp<refback::GetGlobalMemrefOp>();
patterns.insert<LowerAssertOp>(context);
target.addIllegalOp<AssertOp>();

View File

@ -17,9 +17,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
@ -35,13 +34,13 @@ namespace {
class GlobalCreator {
public:
explicit GlobalCreator(ModuleOp module);
tcp::GlobalOp getGlobalFor(Attribute attr) {
refback::GlobalOp getGlobalFor(Attribute attr) {
assert(globals.find(attr) != globals.end() && "unknown constant attr");
return globals[attr];
}
private:
DenseMap<Attribute, tcp::GlobalOp> globals;
DenseMap<Attribute, refback::GlobalOp> globals;
};
GlobalCreator::GlobalCreator(ModuleOp module) {
@ -66,7 +65,7 @@ GlobalCreator::GlobalCreator(ModuleOp module) {
interleave(type.getShape(), os, "x");
os << "x" << type.getElementType();
auto global = globalBuilder.create<tcp::GlobalOp>(
auto global = globalBuilder.create<refback::GlobalOp>(
op.getLoc(), (Twine("__constant_") + os.str()).str(),
op.getValue().cast<ElementsAttr>());
symbolTable.insert(global);
@ -82,7 +81,7 @@ namespace {
class LowerConstantTensorsToMemref
: public LowerConstantTensorsToMemrefBase<LowerConstantTensorsToMemref> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tcp::TCPDialect>();
registry.insert<refback::RefBackendDialect>();
}
void runOnOperation() override {
@ -98,10 +97,10 @@ class LowerConstantTensorsToMemref
auto global = globals.getGlobalFor(op.getValue());
OpBuilder builder(op);
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
auto memref = builder.create<tcp::GetGlobalMemrefOp>(
auto memref = builder.create<refback::GetGlobalMemrefOp>(
op.getLoc(), memrefType, global.getName());
Value tensor =
builder.create<tcp::MemrefToTensorOp>(op.getLoc(), type, memref);
builder.create<refback::MemrefToTensorOp>(op.getLoc(), type, memref);
op.replaceAllUsesWith(tensor);
op.erase();
});

View File

@ -19,6 +19,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/InliningUtils.h"
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
@ -30,13 +31,13 @@ allocateResults(Operation *op, ConversionPatternRewriter &rewriter,
Location loc,
SmallVectorImpl<Value> *resultShapesOut = nullptr) {
// TODO: This is really fragile. Can we have a better story?
auto shapedResults = dyn_cast<tcp::ShapedResultsOp>(op->getParentOp());
auto shapedResults = dyn_cast<refback::ShapedResultsOp>(op->getParentOp());
if (!shapedResults)
return rewriter.notifyMatchFailure(op, "parent not tcp.shaped_results");
return rewriter.notifyMatchFailure(op, "parent not refback.shaped_results");
if (op->getResults() !=
shapedResults.getBody()->getTerminator()->getOperands())
return rewriter.notifyMatchFailure(
op, "only limited forms of tcp.shaped_results allowed");
op, "only limited forms of refback.shaped_results allowed");
auto resultShapes = shapedResults.resultShapes();
SmallVector<Value, 6> results;
for (auto t : llvm::zip(op->getResults(), resultShapes)) {
@ -46,7 +47,7 @@ allocateResults(Operation *op, ConversionPatternRewriter &rewriter,
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
auto memref =
rewriter.create<tcp::AllocMemRefOp>(loc, memrefType, resultShape);
rewriter.create<refback::AllocMemRefOp>(loc, memrefType, resultShape);
results.push_back(memref);
}
if (resultShapesOut)
@ -251,22 +252,22 @@ public:
namespace {
// This pass is responsible for lowering regions wrapped by
// tcp.shaped_results (which operate on tensors) to memrefs.
// refback.shaped_results (which operate on tensors) to memrefs.
// This includes any ops potentially contained within them.
// This is somewhat analogous to IREE's backend compilation of a single dispatch
// region, except that for now, we only allow a single op in the
// tcp.shaped_results, and we don't have any notion of "backend" layered at all.
// Nor is it clear if we really want any of that here.
// refback.shaped_results, and we don't have any notion of "backend" layered at
// all. Nor is it clear if we really want any of that here.
//
// The tcp.shaped_results ops provide precisely the information needed to
// The refback.shaped_results ops provide precisely the information needed to
// allocate output buffers when converting to memref.
// For now, this process eliminates the original tcp.shaped_results op since we
// don't have any host/device distinction or other structure that would require
// retaining that sort of IR structure.
// For now, this process eliminates the original refback.shaped_results op since
// we don't have any host/device distinction or other structure that would
// require retaining that sort of IR structure.
//
// TODO: Do "shape_of" resolution while still on tensors.
// Here we spew out tons of shape_of and rely on dim ops on descriptors to make
// it work. The key difference is that we need tcp.shaped_results (or its
// it work. The key difference is that we need refback.shaped_results (or its
// successor / something it gets lowered to) to not be IsolatedFromAbove, and
// explicitly capture all input tensors along with their shapes. That allows
// shape_of ops on inputs to be trivially resolved. Unfortunately, this opens up
@ -300,14 +301,16 @@ class LowerShapedResultsToMemref
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<MemRefType>());
return (Value)builder.create<tcp::MemrefToTensorOp>(loc, type, inputs[0]);
return (Value)builder.create<refback::MemrefToTensorOp>(loc, type,
inputs[0]);
});
typeConverter.addTargetMaterialization([](OpBuilder &builder,
MemRefType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<RankedTensorType>());
return (Value)builder.create<tcp::TensorToMemrefOp>(loc, type, inputs[0]);
return (Value)builder.create<refback::TensorToMemrefOp>(loc, type,
inputs[0]);
});
OwningRewritePatternList patterns;
@ -316,14 +319,14 @@ class LowerShapedResultsToMemref
// The shaped results ops themselves. They have to be legal since we delete
// them later after the conversion process.
target.addLegalOp<tcp::ShapedResultsOp>();
target.addLegalOp<tcp::YieldOp>();
// All lowering to buffers involves tcp.alloc_memref ops.
target.addLegalOp<tcp::AllocMemRefOp>();
target.addLegalOp<refback::ShapedResultsOp>();
target.addLegalOp<refback::YieldOp>();
// All lowering to buffers involves refback.alloc_memref ops.
target.addLegalOp<refback::AllocMemRefOp>();
// The casting ops are introduced by the type converter, so we should mark
// them legal.
target.addLegalOp<tcp::MemrefToTensorOp>();
target.addLegalOp<tcp::TensorToMemrefOp>();
target.addLegalOp<refback::MemrefToTensorOp>();
target.addLegalOp<refback::TensorToMemrefOp>();
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
target.addIllegalOp<tcp::BroadcastToOp>();
@ -341,18 +344,19 @@ class LowerShapedResultsToMemref
target.addLegalOp<shape::GetExtentOp>();
SmallVector<Operation *, 6> shapedResultsOps;
func.walk([&](tcp::ShapedResultsOp op) { shapedResultsOps.push_back(op); });
func.walk(
[&](refback::ShapedResultsOp op) { shapedResultsOps.push_back(op); });
if (failed(applyFullConversion(shapedResultsOps, target, patterns)))
return signalPassFailure();
// Now inline the tcp.shaped_results ops.
// Now inline the refback.shaped_results ops.
// This can't be done as part of the conversion since conversion visits
// ops in preorder, and we need the tcp.shaped_results ops to be present
// ops in preorder, and we need the refback.shaped_results ops to be present
// so that inner ops can get their shape.
LocallyOverrideLegalityInlinerInterface interface(context);
for (Operation *shapedResultsOp : shapedResultsOps) {
auto op = cast<tcp::ShapedResultsOp>(shapedResultsOp);
auto op = cast<refback::ShapedResultsOp>(shapedResultsOp);
if (failed(inlineRegion(interface, &op.body(), op, ValueRange({}),
op.getResults(), /*inlineLoc=*/llvm::None,
/*shouldCloneInlinedRegion=*/false))) {

View File

@ -13,8 +13,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
@ -88,7 +88,7 @@ namespace {
// TODO: Upstream this.
class LowerStdToMemref : public LowerStdToMemrefBase<LowerStdToMemref> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tcp::TCPDialect>();
registry.insert<refback::RefBackendDialect>();
}
void runOnOperation() override {
@ -105,14 +105,16 @@ class LowerStdToMemref : public LowerStdToMemrefBase<LowerStdToMemref> {
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<MemRefType>());
return (Value)builder.create<tcp::MemrefToTensorOp>(loc, type, inputs[0]);
return (Value)builder.create<refback::MemrefToTensorOp>(loc, type,
inputs[0]);
});
typeConverter.addTargetMaterialization([](OpBuilder &builder,
MemRefType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<RankedTensorType>());
return (Value)builder.create<tcp::TensorToMemrefOp>(loc, type, inputs[0]);
return (Value)builder.create<refback::TensorToMemrefOp>(loc, type,
inputs[0]);
});
OwningRewritePatternList patterns;
@ -123,8 +125,8 @@ class LowerStdToMemref : public LowerStdToMemrefBase<LowerStdToMemref> {
// The casting ops are introduced by the type converter, so they must be
// legal.
target.addLegalOp<tcp::MemrefToTensorOp>();
target.addLegalOp<tcp::TensorToMemrefOp>();
target.addLegalOp<refback::MemrefToTensorOp>();
target.addLegalOp<refback::TensorToMemrefOp>();
patterns.insert<LowerExtractElementOp>(typeConverter, context);
target.addIllegalOp<ExtractElementOp>();

View File

@ -14,7 +14,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
@ -99,13 +99,13 @@ public:
namespace {
class LowerTensorToMemrefOp
: public OpConversionPattern<tcp::TensorToMemrefOp> {
: public OpConversionPattern<refback::TensorToMemrefOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::TensorToMemrefOp op, ArrayRef<Value> operands,
matchAndRewrite(refback::TensorToMemrefOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tcp::TensorToMemrefOp::Adaptor adaptor(operands);
refback::TensorToMemrefOp::Adaptor adaptor(operands);
rewriter.replaceOp(op, adaptor.tensor());
return success();
}
@ -114,13 +114,13 @@ public:
namespace {
class LowerMemrefToTensorOp
: public OpConversionPattern<tcp::MemrefToTensorOp> {
: public OpConversionPattern<refback::MemrefToTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::MemrefToTensorOp op, ArrayRef<Value> operands,
matchAndRewrite(refback::MemrefToTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tcp::MemrefToTensorOp::Adaptor adaptor(operands);
refback::MemrefToTensorOp::Adaptor adaptor(operands);
rewriter.replaceOp(op, op.memref());
return success();
}
@ -150,14 +150,16 @@ class LowerStructuralToMemref
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<MemRefType>());
return (Value)builder.create<tcp::MemrefToTensorOp>(loc, type, inputs[0]);
return (Value)builder.create<refback::MemrefToTensorOp>(loc, type,
inputs[0]);
});
typeConverter.addTargetMaterialization([](OpBuilder &builder,
MemRefType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<RankedTensorType>());
return (Value)builder.create<tcp::TensorToMemrefOp>(loc, type, inputs[0]);
return (Value)builder.create<refback::TensorToMemrefOp>(loc, type,
inputs[0]);
});
OwningRewritePatternList patterns;
@ -181,7 +183,7 @@ class LowerStructuralToMemref
patterns.insert<LowerForOpTypes>(typeConverter, context);
patterns.insert<LowerTensorToMemrefOp>(typeConverter, context);
patterns.insert<LowerMemrefToTensorOp>(typeConverter, context);
target.addIllegalOp<tcp::TensorToMemrefOp>();
target.addIllegalOp<refback::TensorToMemrefOp>();
if (failed(applyFullConversion(func, target, patterns)))
return signalPassFailure();

View File

@ -15,6 +15,7 @@
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
#include "npcomp/Dialect/TCF/IR/TCFDialect.h"
#include "npcomp/Dialect/TCF/Transforms/Passes.h"
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
@ -73,6 +74,7 @@ void mlir::NPCOMP::registerAllDialects(mlir::DialectRegistry &registry) {
Basicpy::BasicpyDialect,
Numpy::NumpyDialect,
npcomprt::NpcomprtDialect,
refback::RefBackendDialect,
tcf::TCFDialect,
tcp::TCPDialect,
mlir::NPCOMP::Torch::TorchDialect>();

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @tensor_to_memref
func @tensor_to_memref_fold(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-NEXT: return %arg0 : memref<?xf32>
%0 = tcp.memref_to_tensor %arg0 : memref<?xf32> -> tensor<?xf32>
%1 = tcp.tensor_to_memref %0 : tensor<?xf32> -> memref<?xf32>
%0 = refback.memref_to_tensor %arg0 : memref<?xf32> -> tensor<?xf32>
%1 = refback.tensor_to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}

View File

@ -2,31 +2,31 @@
// -----
tcp.global @g dense<0.0> : tensor<2xf32>
refback.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{must reference a valid symbol}}
tcp.get_global_memref @nonexistent_symbol : memref<3xf32>
refback.get_global_memref @nonexistent_symbol : memref<3xf32>
return
}
// -----
tcp.global @g dense<0.0> : tensor<2xf32>
refback.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{inconsistent with shape of global}}
tcp.get_global_memref @g : memref<3xf32>
refback.get_global_memref @g : memref<3xf32>
return
}
// -----
tcp.global @g dense<0.0> : tensor<2xf32>
refback.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{inconsistent with element type of global}}
tcp.get_global_memref @g : memref<2xi8>
refback.get_global_memref @g : memref<2xi8>
return
}
@ -34,9 +34,9 @@ func @f() {
func @g(%arg0: tensor<?x?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// expected-error @+1 {{number of operands must equal number of results}}
%add = tcp.shaped_results %arg1, %arg1 {
%add = refback.shaped_results %arg1, %arg1 {
%0 = tcp.add %arg0, %arg0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
tcp.yield %0 : tensor<?x?xf32>
refback.yield %0 : tensor<?x?xf32>
} : tensor<?xindex>, tensor<?xindex> -> tensor<?x?xf32>
return %add : tensor<?x?xf32>
}

View File

@ -0,0 +1,26 @@
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
// CHECK-LABEL: refback.global @foo dense<0.0{{.*}}> : tensor<10xf32>
refback.global @foo dense<0.0> : tensor<10xf32>
// CHECK-LABEL: func @global
func @global() {
// CHECK: refback.get_global_memref @foo : memref<10xf32>
%0 = refback.get_global_memref @foo : memref<10xf32>
return
}
// CHECK-LABEL: func @shaped_results
// CHECK-NEXT: %[[RET:.*]] = refback.shaped_results %arg1 {
// CHECK-NEXT: %[[VAL:.*]] =
// CHECK-NEXT: refback.yield %[[VAL]] : tensor<?x?xf32>
// CHECK-NEXT: } : tensor<?xindex> -> tensor<?x?xf32>
// CHECK-NEXT: return %[[RET]] : tensor<?x?xf32>
// CHECK-NEXT: }
func @shaped_results(%arg0: tensor<?x?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
%add = refback.shaped_results %arg1 {
%0 = tcp.add %arg0, %arg0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
refback.yield %0 : tensor<?x?xf32>
} : tensor<?xindex> -> tensor<?x?xf32>
return %add : tensor<?x?xf32>
}

View File

@ -1,14 +1,4 @@
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
// CHECK-LABEL: tcp.global @foo dense<0.0{{.*}}> : tensor<10xf32>
tcp.global @foo dense<0.0> : tensor<10xf32>
// CHECK-LABEL: func @global
func @global() {
// CHECK: tcp.get_global_memref @foo : memref<10xf32>
%0 = tcp.get_global_memref @foo : memref<10xf32>
return
}
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
// CHECK-LABEL: func @binary_elementwise
func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
@ -27,18 +17,3 @@ func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
%0 = tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: func @shaped_results
// CHECK-NEXT: %[[RET:.*]] = tcp.shaped_results %arg1 {
// CHECK-NEXT: %[[VAL:.*]] =
// CHECK-NEXT: tcp.yield %[[VAL]] : tensor<?x?xf32>
// CHECK-NEXT: } : tensor<?xindex> -> tensor<?x?xf32>
// CHECK-NEXT: return %[[RET]] : tensor<?x?xf32>
// CHECK-NEXT: }
func @shaped_results(%arg0: tensor<?x?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
%add = tcp.shaped_results %arg1 {
%0 = tcp.add %arg0, %arg0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
tcp.yield %0 : tensor<?x?xf32>
} : tensor<?xindex> -> tensor<?x?xf32>
return %add : tensor<?x?xf32>
}

View File

@ -2,7 +2,7 @@
// CHECK-LABEL: func @tcp_broadcast_to
func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) {
// CHECK: %0 = tcp.shaped_results %arg1
// CHECK: %0 = refback.shaped_results %arg1
%0 = tcp.broadcast_to %arg0, %arg1 : (tensor<?xf32>, tensor<?xindex>) -> tensor<?x?xf32>
return
}
@ -11,7 +11,7 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) {
// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[LHSSHAPE:.*]] = shape.shape_of %[[LHS]]
// CHECK: %[[RET:.*]] = tcp.shaped_results %[[LHSSHAPE]]
// CHECK: %[[RET:.*]] = refback.shaped_results %[[LHSSHAPE]]
// CHECK: return %[[RET:.*]] : tensor<?xf32>
// CHECK: }
func @tcp_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
@ -24,7 +24,7 @@ func @tcp_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[LHSSHAPE:.*]] = shape.shape_of %[[LHS]]
// CHECK: %[[RET:.*]] = tcp.shaped_results %[[LHSSHAPE]]
// CHECK: %[[RET:.*]] = refback.shaped_results %[[LHSSHAPE]]
// CHECK: return %[[RET:.*]] : tensor<?xf32>
// CHECK: }
func @tcp_max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
@ -40,7 +40,7 @@ func @tcp_max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[RHSROWS:.*]] = dim %[[RHS]], %[[C1]]
// CHECK: %[[RESULTSHAPE:.*]] = tensor_from_elements %[[LHSCOLS]], %[[RHSROWS]]
// CHECK: %[[RET:.*]] = tcp.shaped_results %[[RESULTSHAPE]] {
// CHECK: %[[RET:.*]] = refback.shaped_results %[[RESULTSHAPE]] {
// CHECK: return %[[RET:.*]] : tensor<?x?xf32>
// CHECK: }
func @tcp_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {

View File

@ -5,7 +5,7 @@ func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> {
// CHECK: %[[I:.*]] = constant 0 : index
// CHECK: %[[E:.*]] = shape.get_extent %arg0, %[[I]]
// CHECK: alloc(%[[E]])
%0 = tcp.alloc_memref %arg0 : memref<?xf32>
%0 = refback.alloc_memref %arg0 : memref<?xf32>
return %0 : memref<?xf32>
}
@ -14,7 +14,7 @@ func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> {
func @all_static(%arg0: tensor<?xindex>) -> memref<3x4x5xf32> {
// CHECK-NOT: shape.get_extent
// CHECK: alloc()
%0 = tcp.alloc_memref %arg0 : memref<3x4x5xf32>
%0 = refback.alloc_memref %arg0 : memref<3x4x5xf32>
return %0 : memref<3x4x5xf32>
}
@ -26,6 +26,6 @@ func @some_static(%arg0: tensor<?xindex>) -> memref<3x?x5x?x7xf32> {
// CHECK-DAG: %[[I3:.*]] = constant 3 : index
// CHECK-DAG: %[[E3:.*]] = shape.get_extent %arg0, %[[I3]]
// CHECK: alloc(%[[E1]], %[[E3]])
%0 = tcp.alloc_memref %arg0 : memref<3x?x5x?x7xf32>
%0 = refback.alloc_memref %arg0 : memref<3x?x5x?x7xf32>
return %0 : memref<3x?x5x?x7xf32>
}

View File

@ -3,11 +3,11 @@
// CHECK-LABEL: module {
// We check the debug name too since we put some effort into making that readable.
// The name isn't load-bearing though.
// CHECK: tcp.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32>
// CHECK: refback.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32>
// CHECK: func @basic
func @basic() -> tensor<3x4xf32> {
// CHECK: %[[MEMREF:.*]] = tcp.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
// CHECK: %[[TENSOR:.*]] = tcp.memref_to_tensor %[[MEMREF]]
// CHECK: %[[MEMREF:.*]] = refback.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
// CHECK: %[[TENSOR:.*]] = refback.memref_to_tensor %[[MEMREF]]
%0 = constant dense<7.0> : tensor<3x4xf32>
// CHECK: return %[[TENSOR]]
return %0 : tensor<3x4xf32>
@ -20,8 +20,8 @@ func @basic() -> tensor<3x4xf32> {
// CHECK-LABEL: module {
// Only one global is created.
// CHECK: tcp.global
// CHECK-NOT: tcp.global
// CHECK: refback.global
// CHECK-NOT: refback.global
func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
%0 = constant dense<7.0> : tensor<3x4xf32>
%1 = constant dense<7.0> : tensor<3x4xf32>
@ -36,9 +36,9 @@ func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
// CHECK-LABEL: module {
// Two globals are created.
// CHECK: tcp.global
// CHECK: tcp.global
// CHECK-NOT: tcp.global
// CHECK: refback.global
// CHECK: refback.global
// CHECK-NOT: refback.global
func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
%0 = constant dense<7.0> : tensor<3x4xf32>
%1 = constant dense<8.0> : tensor<3x4xf32>
@ -52,7 +52,7 @@ func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
// CHECK-LABEL: module {
// We don't convert non-tensor globals.
// CHECK-NOT: tcp.global
// CHECK-NOT: refback.global
func @non_tensor() {
%0 = constant 7 : i32
return

View File

@ -7,10 +7,10 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?
// buffer version of tcp.broadcast_to
// CHECK: scf.for
// CHECK: scf.for
// CHECK-NOT: tcp.shaped_results
%0 = tcp.shaped_results %arg1 {
// CHECK-NOT: refback.shaped_results
%0 = refback.shaped_results %arg1 {
%0 = tcp.broadcast_to %arg0, %arg1 : (tensor<?xf32>, tensor<?xindex>) -> tensor<?x?xf32>
tcp.yield %0 : tensor<?x?xf32>
refback.yield %0 : tensor<?x?xf32>
} : tensor<?xindex> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@ -20,22 +20,22 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?
// CHECK-SAME: %arg0: tensor<?xf32>,
// CHECK-SAME: %arg1: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[LHSSHAPE:.*]] = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
// CHECK: %[[LHS:.*]] = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
// CHECK: %[[RHS:.*]] = tcp.tensor_to_memref %arg1 : tensor<?xf32> -> memref<?xf32>
// CHECK: %[[RESULT:.*]] = tcp.alloc_memref %[[LHSSHAPE]] : memref<?xf32>
// CHECK: %[[LHS:.*]] = refback.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
// CHECK: %[[RHS:.*]] = refback.tensor_to_memref %arg1 : tensor<?xf32> -> memref<?xf32>
// CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[LHSSHAPE]] : memref<?xf32>
// CHECK: linalg.generic {{.*}} ins(%[[LHS]], %[[RHS]] {{.*}}) outs(%[[RESULT]] {{.*}}) {
// CHECK: ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
// CHECK: %[[VAL_9:.*]] = addf %[[VAL_6]], %[[VAL_7]] : f32
// CHECK: linalg.yield %[[VAL_9]] : f32
// CHECK: }
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[RESULT]] : memref<?xf32> -> tensor<?xf32>
// CHECK: %[[RET:.*]] = refback.memref_to_tensor %[[RESULT]] : memref<?xf32> -> tensor<?xf32>
// CHECK: return %[[RET]] : tensor<?xf32>
// CHECK: }
func @tcp_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
%1 = tcp.shaped_results %0 {
%1 = refback.shaped_results %0 {
%2 = tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
tcp.yield %2 : tensor<?xf32>
refback.yield %2 : tensor<?xf32>
} : tensor<?xindex> -> tensor<?xf32>
return %1 : tensor<?xf32>
}
@ -49,9 +49,9 @@ func @tcp_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: linalg.yield %[[MAX]] : f32
func @tcp_max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
%1 = tcp.shaped_results %0 {
%1 = refback.shaped_results %0 {
%2 = tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
tcp.yield %2 : tensor<?xf32>
refback.yield %2 : tensor<?xf32>
} : tensor<?xindex> -> tensor<?xf32>
return %1 : tensor<?xf32>
}
@ -62,19 +62,19 @@ func @tcp_max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-SAME: %arg0: tensor<?x?xf32>,
// CHECK-SAME: %arg1: tensor<?x?xf32>,
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: %[[LHS:.*]] = tcp.tensor_to_memref %arg0 : tensor<?x?xf32> -> memref<?x?xf32>
// CHECK: %[[RHS:.*]] = tcp.tensor_to_memref %arg1 : tensor<?x?xf32> -> memref<?x?xf32>
// CHECK: %[[RESULT:.*]] = tcp.alloc_memref %[[SHAPE]] : memref<?x?xf32>
// CHECK: %[[LHS:.*]] = refback.tensor_to_memref %arg0 : tensor<?x?xf32> -> memref<?x?xf32>
// CHECK: %[[RHS:.*]] = refback.tensor_to_memref %arg1 : tensor<?x?xf32> -> memref<?x?xf32>
// CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[SHAPE]] : memref<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32
// CHECK: linalg.fill(%2, %[[C0]]) : memref<?x?xf32>, f32
// CHECK: linalg.matmul ins(%[[LHS]], %[[RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[RESULT]] : memref<?x?xf32>)
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[RESULT]] : memref<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[RET:.*]] = refback.memref_to_tensor %[[RESULT]] : memref<?x?xf32> -> tensor<?x?xf32>
// CHECK: return %[[RET]] : tensor<?x?xf32>
// CHECK: }
func @tcp_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %shape: tensor<?xindex>) -> tensor<?x?xf32> {
%0 = tcp.shaped_results %shape {
%0 = refback.shaped_results %shape {
%matmul = tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
tcp.yield %matmul : tensor<?x?xf32>
refback.yield %matmul : tensor<?x?xf32>
} : tensor<?xindex> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

View File

@ -5,7 +5,7 @@
// that to make the test actually check what happens in practice.
// CHECK-LABEL: func @extract_element
// CHECK: %[[MEMREF:.*]] = tcp.tensor_to_memref %arg0
// CHECK: %[[MEMREF:.*]] = refback.tensor_to_memref %arg0
// CHECK: %[[RET:.*]] = load %[[MEMREF]][%arg1] : memref<?xf32>
// CHECK: return %[[RET]] : f32
func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
@ -20,7 +20,7 @@ func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
// CHECK: store %[[ARG0]], %[[MEMREF]][%[[C0]]]
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: store %[[ARG1]], %[[MEMREF]][%[[C1]]]
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[MEMREF]]
// CHECK: %[[RET:.*]] = refback.memref_to_tensor %[[MEMREF]]
// CHECK: return %[[RET]] : tensor<2xindex>
func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
%0 = tensor_from_elements %arg0, %arg1 : tensor<2xindex>
@ -30,9 +30,9 @@ func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
// CHECK-LABEL: func @tensor_cast(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xindex>) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = tcp.tensor_to_memref %[[ARG0]] : tensor<?xindex> -> memref<?xindex>
// CHECK: %[[MEMREF:.*]] = refback.tensor_to_memref %[[ARG0]] : tensor<?xindex> -> memref<?xindex>
// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[CASTED]] : memref<2xindex> -> tensor<2xindex>
// CHECK: %[[RET:.*]] = refback.memref_to_tensor %[[CASTED]] : memref<2xindex> -> tensor<2xindex>
// CHECK: return %[[RET]] : tensor<2xindex>
func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
%0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
@ -41,10 +41,9 @@ func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
// CHECK-LABEL: func @tensor_load(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xindex>) -> tensor<?xindex> {
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[ARG0]] : memref<?xindex> -> tensor<?xindex>
// CHECK: %[[RET:.*]] = refback.memref_to_tensor %[[ARG0]] : memref<?xindex> -> tensor<?xindex>
// CHECK: return %[[RET]] : tensor<?xindex>
func @tensor_load(%arg0: memref<?xindex>) -> tensor<?xindex> {
%0 = tensor_load %arg0 : memref<?xindex>
return %0 : tensor<?xindex>
}

View File

@ -63,8 +63,8 @@ func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f3
// CHECK-LABEL: func @identity_materializations(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-NEXT: return %arg0 : memref<?xf32>
func @identity_materializations(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
%1 = tcp.memref_to_tensor %0 : memref<?xf32> -> tensor<?xf32>
%0 = refback.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
%1 = refback.memref_to_tensor %0 : memref<?xf32> -> tensor<?xf32>
return %1 : tensor<?xf32>
}
@ -76,7 +76,7 @@ func @identity_materializations(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[RET]] : memref<?xf32>
func @if_materializations(%pred: i1, %true_val_memref: memref<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
%true_val = tcp.memref_to_tensor %true_val_memref : memref<?xf32> -> tensor<?xf32>
%true_val = refback.memref_to_tensor %true_val_memref : memref<?xf32> -> tensor<?xf32>
%0 = scf.if %pred -> (tensor<?xf32>) {
scf.yield %true_val : tensor<?xf32>
} else {
@ -88,13 +88,13 @@ func @if_materializations(%pred: i1, %true_val_memref: memref<?xf32>, %false_val
// CHECK-LABEL: func @elide_memref_to_tensor(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-NEXT: return %arg0 : memref<?xf32>
func @elide_memref_to_tensor(%arg0: memref<?xf32>) -> tensor<?xf32> {
%0 = tcp.memref_to_tensor %arg0 : memref<?xf32> -> tensor<?xf32>
%0 = refback.memref_to_tensor %arg0 : memref<?xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @elide_tensor_to_memref(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-NEXT: return %arg0 : memref<?xf32>
func @elide_tensor_to_memref(%arg0: tensor<?xf32>) -> memref<?xf32> {
%0 = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
%0 = refback.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
return %0 : memref<?xf32>
}

View File

@ -75,7 +75,7 @@ func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32>
tcp.global @g dense<7.0> : tensor<10xf32>
refback.global @g dense<7.0> : tensor<10xf32>
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
func @gets_global() -> memref<10xf32> {
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32>
@ -83,7 +83,7 @@ func @gets_global() -> memref<10xf32> {
// CHECK: %[[OUTABIMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
// CHECK: %[[RET:.*]] = npcomprt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
// CHECK: return %[[RET]] : !npcomprt.tensor
%0 = tcp.get_global_memref @g : memref<10xf32>
%0 = refback.get_global_memref @g : memref<10xf32>
return %0 : memref<10xf32>
}