Add flatten op recognition + shape refinement.

This op has complex aliasing semantics, so it is kept mutable for now.

With this, we reduce ResNet18 to a single BB with all aten operators
having rank + dtype:
https://gist.github.com/silvasean/2fcb1c6e4d4ae27461204a43ae9c5031
pull/214/head
Sean Silva 2021-04-30 11:16:14 -07:00
parent 122cae2ee3
commit 3d08c83580
9 changed files with 222 additions and 40 deletions

View File

@ -164,6 +164,10 @@ def generate_ops(g: "OpGenerator"):
"aten::nll_loss2d_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)", "aten::nll_loss2d_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
"NllLoss2dBackwardOp", "nll_loss2d_backward") "NllLoss2dBackwardOp", "nll_loss2d_backward")
g.print_banner("Mutable/view-like ops")
g.ordinary_mutable_op("aten::flatten(Tensor,int,int)",
"FlattenOp",
"flatten")
# One-off in-place ops (note that many in-place arithmetic ops are handled # One-off in-place ops (note that many in-place arithmetic ops are handled
# as a transformation from their immutable forms). # as a transformation from their immutable forms).
g.ordinary_inplace_op("aten::copy_(Tensor,Tensor,bool)", g.ordinary_inplace_op("aten::copy_(Tensor,Tensor,bool)",
@ -298,6 +302,36 @@ class OpGenerator:
) )
opdef.emit() opdef.emit()
def ordinary_mutable_op(self,
kernel_sig: str,
ods_name: str,
op_name: str,
traits: Sequence[str] = (),
**kwargs):
""""An ordinary mutable-tensor based op."""
opdef = self.define_op(
kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
traits=list(traits),
**kwargs)
opdef.transforms(
type_transforms={
"Tensor": "AnyTorchMutableTensor",
"Tensor?": "AnyTorchOptionalMutableTensor",
"int": "AnyTorchIntType",
"int[]": "AnyTorchIntListType",
"bool": "AnyTorchBoolType",
"bool[]": "AnyTorchBoolListType",
"float": "AnyFloat",
},
flag_transforms={
"Tensor": ["kMutableTensor"],
"Tensor?": ["kMutableTensor"],
},
)
opdef.emit()
def ordinary_primitive_op(self, def ordinary_primitive_op(self,
kernel_sig: str, kernel_sig: str,
ods_name: str, ods_name: str,

View File

@ -39,23 +39,6 @@ def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>,
} }
def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyType:$arg0,
AnyType:$arg1,
AnyType:$arg2
);
let summary = "Flatten operator";
let description = [{
Flatten operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_TypeCastOp : aten_Op<"type_cast", [NoSideEffect]>, def aten_TypeCastOp : aten_Op<"type_cast", [NoSideEffect]>,
Results<(outs AnyType)> { Results<(outs AnyType)> {
let summary = "TypeCast operator"; let summary = "TypeCast operator";

View File

@ -1291,6 +1291,28 @@ const Torch::BuildKernelMetadata &NllLoss2dBackwardOp::getTorchBuildKernelMetada
return metadata; return metadata;
} }
// -----------------------------------------------------------------------------
// Mutable/view-like ops
// -----------------------------------------------------------------------------
Torch::KernelMetadata FlattenOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &FlattenOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::flatten";
m.addArgTypes({"Tensor", "int", "int"});
m.addArgConversions({KVC::kMutableTensor, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kMutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() { Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }

View File

@ -759,6 +759,22 @@ def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, Decl
); );
} }
// -----------------------------------------------------------------------------
// Mutable/view-like ops
// -----------------------------------------------------------------------------
def aten_FlattenOp: aten_Op<"flatten", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
let summary = "Recognized op for kernel aten::flatten";
let arguments = (ins
AnyTorchMutableTensor:$self,
AnyTorchIntType:$start_dim,
AnyTorchIntType:$end_dim
);
let results = (outs
AnyTorchMutableTensor
);
}
def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> { def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
let summary = "Recognized op for kernel aten::copy_"; let summary = "Recognized op for kernel aten::copy_";
let arguments = (ins let arguments = (ins

View File

@ -157,14 +157,6 @@ std::map<std::string, uint64_t> ExpandOp::getStatistics() {
return toReturn; return toReturn;
} }
// flatten can be zero overhead
std::map<std::string, uint64_t> FlattenOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
toReturn["reads"] = toReturn["operand:0:activation_in"] = 0;
toReturn["writes"] = toReturn["result:0:activation_out"] = 0;
return toReturn;
}
std::map<std::string, uint64_t> GatherOp::getStatistics() { std::map<std::string, uint64_t> GatherOp::getStatistics() {
std::map<std::string, uint64_t> toReturn; std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented // FIXME: unimplemented

View File

@ -85,6 +85,15 @@ convertTorchArgType(StringRef sourceTorchType, StringRef targetTorchType,
return None; return None;
} }
if (flag & KVC::kMutableTensor) {
if (!isTorchTensorType(sourceTorchType) ||
!isTorchTensorType(targetTorchType))
return None;
// If the type is already mutable, passthrough.
if (sourceMlirType.isa<Numpy::NdArrayType>())
return TypeConversion{sourceMlirType, nullptr};
}
// TODO: Special case promotions and conversions. // TODO: Special case promotions and conversions.
return None; return None;
} }
@ -145,6 +154,20 @@ convertTorchReturnType(StringRef sourceTorchType, StringRef targetTorchType,
} }
} }
if (flag & KVC::kMutableTensor) {
if (!isTorchTensorType(sourceTorchType) ||
!isTorchTensorType(targetTorchType)) {
LLVM_DEBUG(llvm::dbgs()
<< " * Source or target not a Tensor type\n");
return None;
}
// If the type is already mutable, passthrough.
if (sourceMlirType.isa<Numpy::NdArrayType>()) {
LLVM_DEBUG(llvm::dbgs() << " * Source is already mutable\n");
return TypeConversion{sourceMlirType, nullptr};
}
}
LLVM_DEBUG(llvm::dbgs() << " * Return type conversion fallthrough\n"); LLVM_DEBUG(llvm::dbgs() << " * Return type conversion fallthrough\n");
return None; return None;
} }

View File

@ -13,6 +13,7 @@
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/ATen/IR/ATenDialect.h" #include "npcomp/Dialect/ATen/IR/ATenDialect.h"
@ -76,6 +77,9 @@ struct ValueKnowledge {
result.elementType = tensorType.getElementType(); result.elementType = tensorType.getElementType();
return result; return result;
} }
if (auto ndArrayType = type.dyn_cast<Numpy::NdArrayType>()) {
return getKnowledgeFromType(ndArrayType.toTensorType());
}
return result; return result;
} }
@ -262,6 +266,39 @@ public:
knowledge.elementType = knowledge.elementType =
joinElementTypes(lhs.elementType, rhs.elementType); joinElementTypes(lhs.elementType, rhs.elementType);
return getLatticeElement(op->getResult(0)).join(knowledge); return getLatticeElement(op->getResult(0)).join(knowledge);
} else if (auto flatten = dyn_cast<aten::FlattenOp>(op)) {
APInt startDimAP, endDimAP;
auto operand = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getPessimisticValueState(op->getContext());
knowledge.elementType = operand.elementType;
if (operand.hasRank && operand.sizes.size() == 0) {
// Rank 0 is special and flattens to rank 1.
knowledge.hasRank = true;
knowledge.sizes.push_back(kUnknownSize);
} else if (operand.hasRank &&
matchPattern(flatten.start_dim(),
m_ConstantInt(&startDimAP)) &&
matchPattern(flatten.end_dim(), m_ConstantInt(&endDimAP))) {
int64_t inputRank = operand.sizes.size();
int64_t startDim = startDimAP.getSExtValue();
int64_t endDim = endDimAP.getSExtValue();
if (startDim < 0)
startDim += inputRank;
if (endDim < 0)
endDim += inputRank;
// Careful: dimension numbers might be out of bounds.
if (0 <= startDim && startDim <= (inputRank - 1) && 0 <= endDim &&
endDim <= (inputRank - 1) && startDim <= endDim) {
knowledge.hasRank = true;
for (auto i = 0; i < startDim; i++)
knowledge.sizes.push_back(operand.sizes[i]);
knowledge.sizes.push_back(kUnknownSize);
for (auto i = endDim + 1; i < inputRank; i++)
knowledge.sizes.push_back(operand.sizes[i]);
}
}
return getLatticeElement(op->getResult(0)).join(knowledge);
} }
// Otherwise, this is an unknown operation. Just mark all results as having // Otherwise, this is an unknown operation. Just mark all results as having
// reached a pessimistic fixpoint. // reached a pessimistic fixpoint.
@ -287,13 +324,29 @@ getTensorTypeFromKnowledge(MLIRContext *context,
return RankedTensorType::get(value.sizes, value.elementType); return RankedTensorType::get(value.sizes, value.elementType);
} }
// Get the most refined Numpy::NdArrayType compatible with ValueKnowledge.
static Numpy::NdArrayType
getNdArrayTypeFromKnowledge(MLIRContext *context,
LatticeElement<ValueKnowledge> *knowledge) {
if (!knowledge)
return Numpy::NdArrayType::get(Numpy::AnyDtypeType::get(context));
const ValueKnowledge &value = knowledge->getValue();
if (!value.hasRank)
return Numpy::NdArrayType::get(value.elementType);
return Numpy::NdArrayType::get(value.elementType,
llvm::makeArrayRef(value.sizes));
}
// Get a the most refined type compatible with ValueKnowledge, or null if that // Get a the most refined type compatible with ValueKnowledge, or null if that
// is not possible. // is not possible.
static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) { static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
if (v.getType().isa<TensorType>()) if (v.getType().isa<TensorType>())
return getTensorTypeFromKnowledge(v.getContext(), return getTensorTypeFromKnowledge(v.getContext(),
analyzer.lookupLatticeElement(v)); analyzer.lookupLatticeElement(v));
// TODO: Support !numpy.ndarray type. if (v.getType().isa<Numpy::NdArrayType>())
return getNdArrayTypeFromKnowledge(v.getContext(),
analyzer.lookupLatticeElement(v));
return nullptr; return nullptr;
} }
@ -308,15 +361,28 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
// Type is same as existing one? Nothing to do. // Type is same as existing one? Nothing to do.
if (refinedType == originalType) if (refinedType == originalType)
continue; continue;
// If the type is a TensorType, then we have numpy.tensor_static_info_cast // If we have an op that allows adding/removing static information from
// to add/remove the static information. We make sure to always embed the // this type, then we can rewrite. We make sure to always embed the static
// static information in the IR, and insert the minimal number of casts // information in the IR, and insert the minimal number of casts needed to
// needed to do so. // do so.
// TODO: This logic should generalize easily to other types. We just // TODO: For some types, we will need 2 ops here: one to add static
// need to know which op allows us to add static information and which op // information, and the other to remove static information.
// allows us to remove static information (in this case, one op allows // (for example, torch.unchecked_cast / torch.derefine for torch.optional
// both). // types).
std::function<Value(Location, Type, Value)> createStaticInfoCast;
OpBuilder b(op->getBlock(), std::next(op->getIterator()));
if (originalType.isa<TensorType>()) { if (originalType.isa<TensorType>()) {
createStaticInfoCast = [&](Location loc, Type newType,
Value v) -> Value {
return b.create<Numpy::TensorStaticInfoCastOp>(loc, newType, v);
};
} else if (originalType.isa<Numpy::NdArrayType>()) {
createStaticInfoCast = [&](Location loc, Type newType,
Value v) -> Value {
return b.create<Numpy::StaticInfoCastOp>(loc, newType, v);
};
}
if (createStaticInfoCast) {
// Save off the original uses to avoid iterator invalidation issues // Save off the original uses to avoid iterator invalidation issues
// or other unexpected behavior since we are creating new ops here that // or other unexpected behavior since we are creating new ops here that
// use the value. // use the value.
@ -325,14 +391,13 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
OpBuilder b(op->getBlock(), std::next(op->getIterator())); OpBuilder b(op->getBlock(), std::next(op->getIterator()));
Value newTypedValue; Value newTypedValue;
// Always make sure that the new static information is reflected in the // Always make sure that the new static information is reflected in the
// IR, either by updating the type in place, or inserting a // IR, either by updating the type in place, or inserting a static info
// numpy.tensor_static_info_cast op. // cast.
if (allowsTypeRefinement(op)) { if (allowsTypeRefinement(op)) {
newTypedValue = v; newTypedValue = v;
v.setType(refinedType); v.setType(refinedType);
} else { } else {
newTypedValue = b.create<Numpy::TensorStaticInfoCastOp>( newTypedValue = createStaticInfoCast(op->getLoc(), refinedType, v);
op->getLoc(), refinedType, v);
} }
Value oldTypedValue; Value oldTypedValue;
@ -345,8 +410,8 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
// If needed, create a value of the original type to appease users // If needed, create a value of the original type to appease users
// that cannot accept the new type. // that cannot accept the new type.
if (!oldTypedValue) { if (!oldTypedValue) {
oldTypedValue = b.create<Numpy::TensorStaticInfoCastOp>( oldTypedValue =
op->getLoc(), originalType, newTypedValue); createStaticInfoCast(op->getLoc(), originalType, newTypedValue);
} }
use->set(oldTypedValue); use->set(oldTypedValue);
} }

View File

@ -147,3 +147,18 @@ func @inplace_variant(%arg0: !numpy.ndarray<[2,2]:f32>, %arg1: !numpy.ndarray<[2
// CHECK: return %[[LHS_OUT]], %[[LHS_OUT]] : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32> // CHECK: return %[[LHS_OUT]], %[[LHS_OUT]] : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
return %0, %arg0 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32> return %0, %arg0 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
} }
// -----
// CHECK-LABEL: func @mutable_tensor(
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
// CHECK: %[[CM1:.*]] = constant -1 : i64
// CHECK: %[[C1:.*]] = constant 1 : i64
// CHECK: %[[RET:.*]] = "aten.flatten"(%[[ARG]], %[[C1]], %[[CM1]]) : (!numpy.ndarray<*:!numpy.any_dtype>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
// CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
func @mutable_tensor(%arg0: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
%c-1_i64 = constant -1 : i64
%c1_i64 = constant 1 : i64
%0 = torch.kernel_call "aten::flatten" %arg0, %c1_i64, %c-1_i64 : (!numpy.ndarray<*:!numpy.any_dtype>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
}

View File

@ -109,6 +109,38 @@ func @f(%arg0: tensor<?x?x?x?xf32>) -> tensor<*x!numpy.any_dtype> {
// ----- // -----
// Also test cast insertion for array types.
// CHECK-LABEL: func @flatten_all(
// CHECK: %[[FLATTENED:.*]] = "aten.flatten"{{.*}}-> !numpy.ndarray<[?]:f32>
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.static_info_cast %[[FLATTENED]] : !numpy.ndarray<[?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
// CHECK: return %[[SHAPE_ERASED]]
func @flatten_all(%arg0: !numpy.ndarray<[3,2,?,5]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
%end = constant -1 : i64
%start = constant 0 : i64
%0 = "aten.flatten"(%arg0, %start, %end) : (!numpy.ndarray<[3,2,?,5]:f32>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
}
// CHECK-LABEL: func @flatten_some(
// CHECK: "aten.flatten"{{.*}}-> !numpy.ndarray<[3,?,5]:f32>
func @flatten_some(%arg0: !numpy.ndarray<[3,2,?,5]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
%end = constant -2 : i64
%start = constant 1 : i64
%0 = "aten.flatten"(%arg0, %start, %end) : (!numpy.ndarray<[3,2,?,5]:f32>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
}
// CHECK-LABEL: func @flatten_rank0(
// CHECK: "aten.flatten"{{.*}}-> !numpy.ndarray<[?]:f32>
func @flatten_rank0(%arg0: !numpy.ndarray<[]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
%end = constant -1 : i64
%start = constant 0 : i64
%0 = "aten.flatten"(%arg0, %start, %end) : (!numpy.ndarray<[]:f32>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
}
// -----
// CHECK-LABEL: func @f // CHECK-LABEL: func @f
func @f(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: tensor<?x3xf32>) { func @f(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: tensor<?x3xf32>) {
%c1_i64 = constant 1 : i64 %c1_i64 = constant 1 : i64