mirror of https://github.com/llvm/torch-mlir
Add flatten op recognition + shape refinement.
This op has complex aliasing semantics, so it is kept mutable for now. With this, we reduce ResNet18 to a single BB with all aten operators having rank + dtype: https://gist.github.com/silvasean/2fcb1c6e4d4ae27461204a43ae9c5031pull/214/head
parent
122cae2ee3
commit
3d08c83580
|
@ -164,6 +164,10 @@ def generate_ops(g: "OpGenerator"):
|
|||
"aten::nll_loss2d_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
|
||||
"NllLoss2dBackwardOp", "nll_loss2d_backward")
|
||||
|
||||
g.print_banner("Mutable/view-like ops")
|
||||
g.ordinary_mutable_op("aten::flatten(Tensor,int,int)",
|
||||
"FlattenOp",
|
||||
"flatten")
|
||||
# One-off in-place ops (note that many in-place arithmetic ops are handled
|
||||
# as a transformation from their immutable forms).
|
||||
g.ordinary_inplace_op("aten::copy_(Tensor,Tensor,bool)",
|
||||
|
@ -298,6 +302,36 @@ class OpGenerator:
|
|||
)
|
||||
opdef.emit()
|
||||
|
||||
def ordinary_mutable_op(self,
|
||||
kernel_sig: str,
|
||||
ods_name: str,
|
||||
op_name: str,
|
||||
traits: Sequence[str] = (),
|
||||
**kwargs):
|
||||
""""An ordinary mutable-tensor based op."""
|
||||
opdef = self.define_op(
|
||||
kernel_sig=kernel_sig,
|
||||
ods_name=ods_name,
|
||||
op_name=op_name,
|
||||
traits=list(traits),
|
||||
**kwargs)
|
||||
opdef.transforms(
|
||||
type_transforms={
|
||||
"Tensor": "AnyTorchMutableTensor",
|
||||
"Tensor?": "AnyTorchOptionalMutableTensor",
|
||||
"int": "AnyTorchIntType",
|
||||
"int[]": "AnyTorchIntListType",
|
||||
"bool": "AnyTorchBoolType",
|
||||
"bool[]": "AnyTorchBoolListType",
|
||||
"float": "AnyFloat",
|
||||
},
|
||||
flag_transforms={
|
||||
"Tensor": ["kMutableTensor"],
|
||||
"Tensor?": ["kMutableTensor"],
|
||||
},
|
||||
)
|
||||
opdef.emit()
|
||||
|
||||
def ordinary_primitive_op(self,
|
||||
kernel_sig: str,
|
||||
ods_name: str,
|
||||
|
|
|
@ -39,23 +39,6 @@ def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>,
|
|||
|
||||
}
|
||||
|
||||
def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyType:$arg0,
|
||||
AnyType:$arg1,
|
||||
AnyType:$arg2
|
||||
);
|
||||
|
||||
let summary = "Flatten operator";
|
||||
let description = [{
|
||||
Flatten operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_TypeCastOp : aten_Op<"type_cast", [NoSideEffect]>,
|
||||
Results<(outs AnyType)> {
|
||||
let summary = "TypeCast operator";
|
||||
|
|
|
@ -1291,6 +1291,28 @@ const Torch::BuildKernelMetadata &NllLoss2dBackwardOp::getTorchBuildKernelMetada
|
|||
return metadata;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Mutable/view-like ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
Torch::KernelMetadata FlattenOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &FlattenOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::flatten";
|
||||
m.addArgTypes({"Tensor", "int", "int"});
|
||||
m.addArgConversions({KVC::kMutableTensor, KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kMutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
|
|
@ -759,6 +759,22 @@ def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, Decl
|
|||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Mutable/view-like ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
def aten_FlattenOp: aten_Op<"flatten", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||
let summary = "Recognized op for kernel aten::flatten";
|
||||
let arguments = (ins
|
||||
AnyTorchMutableTensor:$self,
|
||||
AnyTorchIntType:$start_dim,
|
||||
AnyTorchIntType:$end_dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchMutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||
let summary = "Recognized op for kernel aten::copy_";
|
||||
let arguments = (ins
|
||||
|
|
|
@ -157,14 +157,6 @@ std::map<std::string, uint64_t> ExpandOp::getStatistics() {
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
// flatten can be zero overhead
|
||||
std::map<std::string, uint64_t> FlattenOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
toReturn["reads"] = toReturn["operand:0:activation_in"] = 0;
|
||||
toReturn["writes"] = toReturn["result:0:activation_out"] = 0;
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
std::map<std::string, uint64_t> GatherOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
// FIXME: unimplemented
|
||||
|
|
|
@ -85,6 +85,15 @@ convertTorchArgType(StringRef sourceTorchType, StringRef targetTorchType,
|
|||
return None;
|
||||
}
|
||||
|
||||
if (flag & KVC::kMutableTensor) {
|
||||
if (!isTorchTensorType(sourceTorchType) ||
|
||||
!isTorchTensorType(targetTorchType))
|
||||
return None;
|
||||
// If the type is already mutable, passthrough.
|
||||
if (sourceMlirType.isa<Numpy::NdArrayType>())
|
||||
return TypeConversion{sourceMlirType, nullptr};
|
||||
}
|
||||
|
||||
// TODO: Special case promotions and conversions.
|
||||
return None;
|
||||
}
|
||||
|
@ -145,6 +154,20 @@ convertTorchReturnType(StringRef sourceTorchType, StringRef targetTorchType,
|
|||
}
|
||||
}
|
||||
|
||||
if (flag & KVC::kMutableTensor) {
|
||||
if (!isTorchTensorType(sourceTorchType) ||
|
||||
!isTorchTensorType(targetTorchType)) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " * Source or target not a Tensor type\n");
|
||||
return None;
|
||||
}
|
||||
// If the type is already mutable, passthrough.
|
||||
if (sourceMlirType.isa<Numpy::NdArrayType>()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Source is already mutable\n");
|
||||
return TypeConversion{sourceMlirType, nullptr};
|
||||
}
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Return type conversion fallthrough\n");
|
||||
return None;
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||
|
@ -76,6 +77,9 @@ struct ValueKnowledge {
|
|||
result.elementType = tensorType.getElementType();
|
||||
return result;
|
||||
}
|
||||
if (auto ndArrayType = type.dyn_cast<Numpy::NdArrayType>()) {
|
||||
return getKnowledgeFromType(ndArrayType.toTensorType());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -262,6 +266,39 @@ public:
|
|||
knowledge.elementType =
|
||||
joinElementTypes(lhs.elementType, rhs.elementType);
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (auto flatten = dyn_cast<aten::FlattenOp>(op)) {
|
||||
APInt startDimAP, endDimAP;
|
||||
auto operand = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
knowledge.elementType = operand.elementType;
|
||||
if (operand.hasRank && operand.sizes.size() == 0) {
|
||||
// Rank 0 is special and flattens to rank 1.
|
||||
knowledge.hasRank = true;
|
||||
knowledge.sizes.push_back(kUnknownSize);
|
||||
} else if (operand.hasRank &&
|
||||
matchPattern(flatten.start_dim(),
|
||||
m_ConstantInt(&startDimAP)) &&
|
||||
matchPattern(flatten.end_dim(), m_ConstantInt(&endDimAP))) {
|
||||
int64_t inputRank = operand.sizes.size();
|
||||
int64_t startDim = startDimAP.getSExtValue();
|
||||
int64_t endDim = endDimAP.getSExtValue();
|
||||
if (startDim < 0)
|
||||
startDim += inputRank;
|
||||
if (endDim < 0)
|
||||
endDim += inputRank;
|
||||
// Careful: dimension numbers might be out of bounds.
|
||||
if (0 <= startDim && startDim <= (inputRank - 1) && 0 <= endDim &&
|
||||
endDim <= (inputRank - 1) && startDim <= endDim) {
|
||||
knowledge.hasRank = true;
|
||||
for (auto i = 0; i < startDim; i++)
|
||||
knowledge.sizes.push_back(operand.sizes[i]);
|
||||
knowledge.sizes.push_back(kUnknownSize);
|
||||
for (auto i = endDim + 1; i < inputRank; i++)
|
||||
knowledge.sizes.push_back(operand.sizes[i]);
|
||||
}
|
||||
}
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
// Otherwise, this is an unknown operation. Just mark all results as having
|
||||
// reached a pessimistic fixpoint.
|
||||
|
@ -287,13 +324,29 @@ getTensorTypeFromKnowledge(MLIRContext *context,
|
|||
return RankedTensorType::get(value.sizes, value.elementType);
|
||||
}
|
||||
|
||||
// Get the most refined Numpy::NdArrayType compatible with ValueKnowledge.
|
||||
static Numpy::NdArrayType
|
||||
getNdArrayTypeFromKnowledge(MLIRContext *context,
|
||||
LatticeElement<ValueKnowledge> *knowledge) {
|
||||
if (!knowledge)
|
||||
return Numpy::NdArrayType::get(Numpy::AnyDtypeType::get(context));
|
||||
|
||||
const ValueKnowledge &value = knowledge->getValue();
|
||||
if (!value.hasRank)
|
||||
return Numpy::NdArrayType::get(value.elementType);
|
||||
return Numpy::NdArrayType::get(value.elementType,
|
||||
llvm::makeArrayRef(value.sizes));
|
||||
}
|
||||
|
||||
// Get a the most refined type compatible with ValueKnowledge, or null if that
|
||||
// is not possible.
|
||||
static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
|
||||
if (v.getType().isa<TensorType>())
|
||||
return getTensorTypeFromKnowledge(v.getContext(),
|
||||
analyzer.lookupLatticeElement(v));
|
||||
// TODO: Support !numpy.ndarray type.
|
||||
if (v.getType().isa<Numpy::NdArrayType>())
|
||||
return getNdArrayTypeFromKnowledge(v.getContext(),
|
||||
analyzer.lookupLatticeElement(v));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -308,15 +361,28 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
|||
// Type is same as existing one? Nothing to do.
|
||||
if (refinedType == originalType)
|
||||
continue;
|
||||
// If the type is a TensorType, then we have numpy.tensor_static_info_cast
|
||||
// to add/remove the static information. We make sure to always embed the
|
||||
// static information in the IR, and insert the minimal number of casts
|
||||
// needed to do so.
|
||||
// TODO: This logic should generalize easily to other types. We just
|
||||
// need to know which op allows us to add static information and which op
|
||||
// allows us to remove static information (in this case, one op allows
|
||||
// both).
|
||||
// If we have an op that allows adding/removing static information from
|
||||
// this type, then we can rewrite. We make sure to always embed the static
|
||||
// information in the IR, and insert the minimal number of casts needed to
|
||||
// do so.
|
||||
// TODO: For some types, we will need 2 ops here: one to add static
|
||||
// information, and the other to remove static information.
|
||||
// (for example, torch.unchecked_cast / torch.derefine for torch.optional
|
||||
// types).
|
||||
std::function<Value(Location, Type, Value)> createStaticInfoCast;
|
||||
OpBuilder b(op->getBlock(), std::next(op->getIterator()));
|
||||
if (originalType.isa<TensorType>()) {
|
||||
createStaticInfoCast = [&](Location loc, Type newType,
|
||||
Value v) -> Value {
|
||||
return b.create<Numpy::TensorStaticInfoCastOp>(loc, newType, v);
|
||||
};
|
||||
} else if (originalType.isa<Numpy::NdArrayType>()) {
|
||||
createStaticInfoCast = [&](Location loc, Type newType,
|
||||
Value v) -> Value {
|
||||
return b.create<Numpy::StaticInfoCastOp>(loc, newType, v);
|
||||
};
|
||||
}
|
||||
if (createStaticInfoCast) {
|
||||
// Save off the original uses to avoid iterator invalidation issues
|
||||
// or other unexpected behavior since we are creating new ops here that
|
||||
// use the value.
|
||||
|
@ -325,14 +391,13 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
|||
OpBuilder b(op->getBlock(), std::next(op->getIterator()));
|
||||
Value newTypedValue;
|
||||
// Always make sure that the new static information is reflected in the
|
||||
// IR, either by updating the type in place, or inserting a
|
||||
// numpy.tensor_static_info_cast op.
|
||||
// IR, either by updating the type in place, or inserting a static info
|
||||
// cast.
|
||||
if (allowsTypeRefinement(op)) {
|
||||
newTypedValue = v;
|
||||
v.setType(refinedType);
|
||||
} else {
|
||||
newTypedValue = b.create<Numpy::TensorStaticInfoCastOp>(
|
||||
op->getLoc(), refinedType, v);
|
||||
newTypedValue = createStaticInfoCast(op->getLoc(), refinedType, v);
|
||||
}
|
||||
|
||||
Value oldTypedValue;
|
||||
|
@ -345,8 +410,8 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
|||
// If needed, create a value of the original type to appease users
|
||||
// that cannot accept the new type.
|
||||
if (!oldTypedValue) {
|
||||
oldTypedValue = b.create<Numpy::TensorStaticInfoCastOp>(
|
||||
op->getLoc(), originalType, newTypedValue);
|
||||
oldTypedValue =
|
||||
createStaticInfoCast(op->getLoc(), originalType, newTypedValue);
|
||||
}
|
||||
use->set(oldTypedValue);
|
||||
}
|
||||
|
|
|
@ -147,3 +147,18 @@ func @inplace_variant(%arg0: !numpy.ndarray<[2,2]:f32>, %arg1: !numpy.ndarray<[2
|
|||
// CHECK: return %[[LHS_OUT]], %[[LHS_OUT]] : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
||||
return %0, %arg0 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @mutable_tensor(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
// CHECK: %[[CM1:.*]] = constant -1 : i64
|
||||
// CHECK: %[[C1:.*]] = constant 1 : i64
|
||||
// CHECK: %[[RET:.*]] = "aten.flatten"(%[[ARG]], %[[C1]], %[[CM1]]) : (!numpy.ndarray<*:!numpy.any_dtype>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
func @mutable_tensor(%arg0: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
%c-1_i64 = constant -1 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = torch.kernel_call "aten::flatten" %arg0, %c1_i64, %c-1_i64 : (!numpy.ndarray<*:!numpy.any_dtype>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
}
|
||||
|
|
|
@ -109,6 +109,38 @@ func @f(%arg0: tensor<?x?x?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
|||
|
||||
// -----
|
||||
|
||||
// Also test cast insertion for array types.
|
||||
// CHECK-LABEL: func @flatten_all(
|
||||
// CHECK: %[[FLATTENED:.*]] = "aten.flatten"{{.*}}-> !numpy.ndarray<[?]:f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.static_info_cast %[[FLATTENED]] : !numpy.ndarray<[?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]]
|
||||
func @flatten_all(%arg0: !numpy.ndarray<[3,2,?,5]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
%end = constant -1 : i64
|
||||
%start = constant 0 : i64
|
||||
%0 = "aten.flatten"(%arg0, %start, %end) : (!numpy.ndarray<[3,2,?,5]:f32>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @flatten_some(
|
||||
// CHECK: "aten.flatten"{{.*}}-> !numpy.ndarray<[3,?,5]:f32>
|
||||
func @flatten_some(%arg0: !numpy.ndarray<[3,2,?,5]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
%end = constant -2 : i64
|
||||
%start = constant 1 : i64
|
||||
%0 = "aten.flatten"(%arg0, %start, %end) : (!numpy.ndarray<[3,2,?,5]:f32>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @flatten_rank0(
|
||||
// CHECK: "aten.flatten"{{.*}}-> !numpy.ndarray<[?]:f32>
|
||||
func @flatten_rank0(%arg0: !numpy.ndarray<[]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
%end = constant -1 : i64
|
||||
%start = constant 0 : i64
|
||||
%0 = "aten.flatten"(%arg0, %start, %end) : (!numpy.ndarray<[]:f32>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: tensor<?x3xf32>) {
|
||||
%c1_i64 = constant 1 : i64
|
||||
|
|
Loading…
Reference in New Issue