Add aten::len.t, aten::size, and aten::gt.int primitive ops

Also add some canonicalizations that finally reduce ResNet down to a
single block.
pull/213/head
Sean Silva 2021-04-29 15:38:11 -07:00
parent ec6d06aa86
commit 122cae2ee3
10 changed files with 180 additions and 27 deletions

View File

@ -97,6 +97,11 @@ def generate_ops(g: "OpGenerator"):
has_folder=True)
g.ordinary_primitive_op("aten::dim(Tensor)", "DimOp", "dim")
g.ordinary_primitive_op("aten::ne(int,int)", "NeIntOp", "ne.int")
g.ordinary_primitive_op("aten::size(Tensor)", "SizeOp", "size",
has_canonicalizer=True)
g.ordinary_primitive_op("aten::len(t[])", "LenTOp", "len.t",
has_canonicalizer=True)
g.ordinary_primitive_op("aten::gt(int,int)", "GtIntOp", "gt.int")
# Convolution ops. Note that these are special in PyTorch and the importer,
# and we model them after the signatures of the convolution_overrideable
@ -314,6 +319,7 @@ class OpGenerator:
"int[]": "AnyTorchIntListType",
"bool": "AnyTorchBoolType",
"bool[]": "AnyTorchBoolListType",
"t[]": "Basicpy_ListType",
"t1": "AnyTorchType",
"t2": "AnyTorchType",
},
@ -505,7 +511,8 @@ class InflightOpDef:
override_return_types: Sequence[str] = None,
drop_arg_indices: Sequence[int] = (),
drop_return_indices: Sequence[int] = (),
has_folder: bool = False):
has_folder: bool = False,
has_canonicalizer: bool = False):
super().__init__()
self.g = g
self.kernel_sig = kernel_sig
@ -520,6 +527,7 @@ class InflightOpDef:
self.drop_arg_indices = drop_arg_indices
self.drop_return_indices = drop_return_indices
self.has_folder = has_folder
self.has_canonicalizer = has_canonicalizer
self.reg_record = g.get_reg_record(self.kernel_sig)
self._emitted = False
self._traceback = traceback.extract_stack()[0:-2]
@ -603,7 +611,8 @@ class InflightOpDef:
ods_ins=self.ods_ins,
ods_outs=self.ods_outs,
traits=self.traits,
has_folder=self.has_folder)
has_folder=self.has_folder,
has_canonicalizer=self.has_canonicalizer)
self.g.impl_emitter.emit_kernel_methods(
self.ods_name,
self.reg_record,
@ -664,7 +673,8 @@ class OdsEmitter(EmitterBase):
ods_outs: List[Tuple[str, str]],
traits: Sequence[str] = (),
summary: Optional[str] = None,
has_folder: bool = False):
has_folder: bool = False,
has_canonicalizer: bool = False):
# Def first-line.
full_traits = list(traits)
full_traits.append(
@ -695,6 +705,8 @@ class OdsEmitter(EmitterBase):
if has_folder:
self.print("let hasFolder = 1;")
if has_canonicalizer:
self.print("let hasCanonicalizer = 1;")
# Def last-line.
self.print("}\n")

View File

@ -961,6 +961,60 @@ const Torch::BuildKernelMetadata &NeIntOp::getTorchBuildKernelMetadata() {
return metadata;
}
Torch::KernelMetadata SizeOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &SizeOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::size";
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"int[]"});
m.addReturnConversions({KVC::kNone});
return m;
})();
return metadata;
}
Torch::KernelMetadata LenTOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &LenTOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::len";
m.addArgTypes({"t[]"});
m.addArgConversions({KVC::kNone});
m.addReturnTypes({"int"});
m.addReturnConversions({KVC::kNone});
return m;
})();
return metadata;
}
Torch::KernelMetadata GtIntOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &GtIntOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::gt";
m.addArgTypes({"int", "int"});
m.addArgConversions({KVC::kNone, KVC::kNone});
m.addReturnTypes({"bool"});
m.addReturnConversions({KVC::kNone});
return m;
})();
return metadata;
}
// -----------------------------------------------------------------------------
// NN ops
// -----------------------------------------------------------------------------

View File

@ -509,6 +509,39 @@ def aten_NeIntOp: aten_Op<"ne.int", [NoSideEffect, DeclareOpInterfaceMethods<Tor
);
}
def aten_SizeOp: aten_Op<"size", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
let summary = "Recognized op for kernel aten::size";
let arguments = (ins
AnyTorchImmutableTensor:$self
);
let results = (outs
AnyTorchIntListType
);
let hasCanonicalizer = 1;
}
def aten_LenTOp: aten_Op<"len.t", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
let summary = "Recognized op for kernel aten::len";
let arguments = (ins
Basicpy_ListType:$a
);
let results = (outs
AnyTorchIntType
);
let hasCanonicalizer = 1;
}
def aten_GtIntOp: aten_Op<"gt.int", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
let summary = "Recognized op for kernel aten::gt";
let arguments = (ins
AnyTorchIntType:$a,
AnyTorchIntType:$b
);
let results = (outs
AnyTorchBoolType
);
}
// -----------------------------------------------------------------------------
// NN ops
// -----------------------------------------------------------------------------

View File

@ -170,22 +170,6 @@ def aten_ReluUnderOp: aten_Op<"relu_", [NoSideEffect, StatisticsOpInterface]>,
}];
}
def aten_SizeOp: aten_Op<"size", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyScalar:$dim
);
let summary = "aten size operator";
let description = [{
SizeOp
aten size operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_SqueezeOp: aten_Op<"squeeze", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (

View File

@ -40,6 +40,13 @@ LogicalResult convertNeIntOp(aten::NeIntOp op, PatternRewriter &rewriter) {
return success();
}
LogicalResult convertGtIntOp(aten::GtIntOp op, PatternRewriter &rewriter) {
auto i1 = rewriter.create<CmpIOp>(op->getLoc(), CmpIPredicate::sgt,
op->getOperand(0), op->getOperand(1));
rewriter.replaceOpWithNewOp<Basicpy::BoolCastOp>(op, op.getType(), i1);
return success();
}
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -60,6 +67,7 @@ public:
RewritePatternSet patterns(context);
patterns.add(convertDimOp);
patterns.add(convertNeIntOp);
patterns.add(convertGtIntOp);
return std::move(patterns);
}
};

View File

@ -117,6 +117,7 @@ void ATenDialect::initialize() {
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
>();
getContext()->getOrLoadDialect("torch");
getContext()->getOrLoadDialect("std");
}
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc"

View File

@ -522,14 +522,6 @@ std::map<std::string, uint64_t> SumOp::getStatistics() {
return toReturn;
}
// size op can be zero overhead
std::map<std::string, uint64_t> SizeOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
toReturn["reads"] = toReturn["operand:0:activation_in"] = 0;
toReturn["writes"] = toReturn["result:0:activation_out"] = 0;
return toReturn;
}
// squeeze can be zero overhead
std::map<std::string, uint64_t> SqueezeOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;

View File

@ -8,8 +8,11 @@
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
@ -34,6 +37,43 @@ OpFoldResult IsOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// LenTOp
//===----------------------------------------------------------------------===//
void LenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](LenTOp op, PatternRewriter &rewriter) {
auto buildList = op.getOperand().getDefiningOp<Basicpy::BuildListOp>();
if (!buildList)
return rewriter.notifyMatchFailure(op, "operand not basicpy.build_list");
rewriter.replaceOpWithNewOp<::mlir::ConstantOp>(
op, rewriter.getI64IntegerAttr(buildList.getNumOperands()));
return success();
});
}
//===----------------------------------------------------------------------===//
// SizeOp
//===----------------------------------------------------------------------===//
void SizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](SizeOp op, PatternRewriter &rewriter) {
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
if (!type)
return rewriter.notifyMatchFailure(op, "not a ranked tensor");
SmallVector<Value> listElements;
for (int64_t size : type.getShape()) {
listElements.push_back(rewriter.create<::mlir::ConstantOp>(
op->getLoc(), rewriter.getI64IntegerAttr(size)));
}
rewriter.replaceOpWithNewOp<Basicpy::BuildListOp>(
op, Basicpy::ListType::get(rewriter.getContext()), listElements);
return success();
});
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"

View File

@ -20,3 +20,14 @@ func @aten.ne.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType {
%0 = "aten.ne.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType
return %0 : !basicpy.BoolType
}
// CHECK-LABEL: func @aten.gt.int(
// CHECK-SAME: %[[ARG0:.*]]: i64,
// CHECK-SAME: %[[ARG1:.*]]: i64) -> !basicpy.BoolType {
// CHECK: %[[I1:.*]] = cmpi sgt, %[[ARG0]], %[[ARG1]] : i64
// CHECK: %[[BASICPY_BOOL:.*]] = basicpy.bool_cast %[[I1]] : i1 -> !basicpy.BoolType
// CHECK: return %[[BASICPY_BOOL]] : !basicpy.BoolType
func @aten.gt.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType {
%0 = "aten.gt.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType
return %0 : !basicpy.BoolType
}

View File

@ -7,3 +7,21 @@ func @aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !basicp
%0 = "aten.__is__"(%arg0, %arg1) : (!basicpy.ListType, !basicpy.NoneType) -> !basicpy.BoolType
return %0 : !basicpy.BoolType
}
// CHECK-LABEL: func @aten.size(
// CHECK: %[[CM1:.*]] = constant -1 : i64
// CHECK: %[[C3:.*]] = constant 3 : i64
// CHECK: %[[RET:.*]] = basicpy.build_list %[[CM1]], %[[C3]] : (i64, i64) -> !basicpy.ListType
// CHECK: return %[[RET]] : !basicpy.ListType
func @aten.size(%arg0: tensor<?x3xf32>) -> !basicpy.ListType {
%0 = "aten.size"(%arg0) : (tensor<?x3xf32>) -> !basicpy.ListType
return %0 : !basicpy.ListType
}
// CHECK-LABEL: func @aten.len.t(
// CHECK: %[[LENGTH:.*]] = constant 2 : i64
// CHECK: return %[[LENGTH]] : i64
func @aten.len.t(%arg0: i64) -> i64 {
%0 = basicpy.build_list %arg0, %arg0 : (i64, i64) -> !basicpy.ListType
%1 = "aten.len.t"(%0) : (!basicpy.ListType) -> i64
return %1 : i64
}