mirror of https://github.com/llvm/torch-mlir
Add aten::len.t, aten::size, and aten::gt.int primitive ops
Also add some canonicalizations that finally reduce ResNet down to a single block.pull/213/head
parent
ec6d06aa86
commit
122cae2ee3
|
@ -97,6 +97,11 @@ def generate_ops(g: "OpGenerator"):
|
|||
has_folder=True)
|
||||
g.ordinary_primitive_op("aten::dim(Tensor)", "DimOp", "dim")
|
||||
g.ordinary_primitive_op("aten::ne(int,int)", "NeIntOp", "ne.int")
|
||||
g.ordinary_primitive_op("aten::size(Tensor)", "SizeOp", "size",
|
||||
has_canonicalizer=True)
|
||||
g.ordinary_primitive_op("aten::len(t[])", "LenTOp", "len.t",
|
||||
has_canonicalizer=True)
|
||||
g.ordinary_primitive_op("aten::gt(int,int)", "GtIntOp", "gt.int")
|
||||
|
||||
# Convolution ops. Note that these are special in PyTorch and the importer,
|
||||
# and we model them after the signatures of the convolution_overrideable
|
||||
|
@ -314,6 +319,7 @@ class OpGenerator:
|
|||
"int[]": "AnyTorchIntListType",
|
||||
"bool": "AnyTorchBoolType",
|
||||
"bool[]": "AnyTorchBoolListType",
|
||||
"t[]": "Basicpy_ListType",
|
||||
"t1": "AnyTorchType",
|
||||
"t2": "AnyTorchType",
|
||||
},
|
||||
|
@ -505,7 +511,8 @@ class InflightOpDef:
|
|||
override_return_types: Sequence[str] = None,
|
||||
drop_arg_indices: Sequence[int] = (),
|
||||
drop_return_indices: Sequence[int] = (),
|
||||
has_folder: bool = False):
|
||||
has_folder: bool = False,
|
||||
has_canonicalizer: bool = False):
|
||||
super().__init__()
|
||||
self.g = g
|
||||
self.kernel_sig = kernel_sig
|
||||
|
@ -520,6 +527,7 @@ class InflightOpDef:
|
|||
self.drop_arg_indices = drop_arg_indices
|
||||
self.drop_return_indices = drop_return_indices
|
||||
self.has_folder = has_folder
|
||||
self.has_canonicalizer = has_canonicalizer
|
||||
self.reg_record = g.get_reg_record(self.kernel_sig)
|
||||
self._emitted = False
|
||||
self._traceback = traceback.extract_stack()[0:-2]
|
||||
|
@ -603,7 +611,8 @@ class InflightOpDef:
|
|||
ods_ins=self.ods_ins,
|
||||
ods_outs=self.ods_outs,
|
||||
traits=self.traits,
|
||||
has_folder=self.has_folder)
|
||||
has_folder=self.has_folder,
|
||||
has_canonicalizer=self.has_canonicalizer)
|
||||
self.g.impl_emitter.emit_kernel_methods(
|
||||
self.ods_name,
|
||||
self.reg_record,
|
||||
|
@ -664,7 +673,8 @@ class OdsEmitter(EmitterBase):
|
|||
ods_outs: List[Tuple[str, str]],
|
||||
traits: Sequence[str] = (),
|
||||
summary: Optional[str] = None,
|
||||
has_folder: bool = False):
|
||||
has_folder: bool = False,
|
||||
has_canonicalizer: bool = False):
|
||||
# Def first-line.
|
||||
full_traits = list(traits)
|
||||
full_traits.append(
|
||||
|
@ -695,6 +705,8 @@ class OdsEmitter(EmitterBase):
|
|||
|
||||
if has_folder:
|
||||
self.print("let hasFolder = 1;")
|
||||
if has_canonicalizer:
|
||||
self.print("let hasCanonicalizer = 1;")
|
||||
|
||||
# Def last-line.
|
||||
self.print("}\n")
|
||||
|
|
|
@ -961,6 +961,60 @@ const Torch::BuildKernelMetadata &NeIntOp::getTorchBuildKernelMetadata() {
|
|||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata SizeOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &SizeOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::size";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"int[]"});
|
||||
m.addReturnConversions({KVC::kNone});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata LenTOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &LenTOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::len";
|
||||
m.addArgTypes({"t[]"});
|
||||
m.addArgConversions({KVC::kNone});
|
||||
m.addReturnTypes({"int"});
|
||||
m.addReturnConversions({KVC::kNone});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata GtIntOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &GtIntOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::gt";
|
||||
m.addArgTypes({"int", "int"});
|
||||
m.addArgConversions({KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"bool"});
|
||||
m.addReturnConversions({KVC::kNone});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// NN ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -509,6 +509,39 @@ def aten_NeIntOp: aten_Op<"ne.int", [NoSideEffect, DeclareOpInterfaceMethods<Tor
|
|||
);
|
||||
}
|
||||
|
||||
def aten_SizeOp: aten_Op<"size", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||
let summary = "Recognized op for kernel aten::size";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchIntListType
|
||||
);
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def aten_LenTOp: aten_Op<"len.t", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||
let summary = "Recognized op for kernel aten::len";
|
||||
let arguments = (ins
|
||||
Basicpy_ListType:$a
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchIntType
|
||||
);
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def aten_GtIntOp: aten_Op<"gt.int", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||
let summary = "Recognized op for kernel aten::gt";
|
||||
let arguments = (ins
|
||||
AnyTorchIntType:$a,
|
||||
AnyTorchIntType:$b
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchBoolType
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// NN ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -170,22 +170,6 @@ def aten_ReluUnderOp: aten_Op<"relu_", [NoSideEffect, StatisticsOpInterface]>,
|
|||
}];
|
||||
}
|
||||
|
||||
def aten_SizeOp: aten_Op<"size", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyScalar:$dim
|
||||
);
|
||||
let summary = "aten size operator";
|
||||
let description = [{
|
||||
SizeOp
|
||||
aten size operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_SqueezeOp: aten_Op<"squeeze", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
|
|
|
@ -40,6 +40,13 @@ LogicalResult convertNeIntOp(aten::NeIntOp op, PatternRewriter &rewriter) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult convertGtIntOp(aten::GtIntOp op, PatternRewriter &rewriter) {
|
||||
auto i1 = rewriter.create<CmpIOp>(op->getLoc(), CmpIPredicate::sgt,
|
||||
op->getOperand(0), op->getOperand(1));
|
||||
rewriter.replaceOpWithNewOp<Basicpy::BoolCastOp>(op, op.getType(), i1);
|
||||
return success();
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -60,6 +67,7 @@ public:
|
|||
RewritePatternSet patterns(context);
|
||||
patterns.add(convertDimOp);
|
||||
patterns.add(convertNeIntOp);
|
||||
patterns.add(convertGtIntOp);
|
||||
return std::move(patterns);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -117,6 +117,7 @@ void ATenDialect::initialize() {
|
|||
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
||||
>();
|
||||
getContext()->getOrLoadDialect("torch");
|
||||
getContext()->getOrLoadDialect("std");
|
||||
}
|
||||
|
||||
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc"
|
||||
|
|
|
@ -522,14 +522,6 @@ std::map<std::string, uint64_t> SumOp::getStatistics() {
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
// size op can be zero overhead
|
||||
std::map<std::string, uint64_t> SizeOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
toReturn["reads"] = toReturn["operand:0:activation_in"] = 0;
|
||||
toReturn["writes"] = toReturn["result:0:activation_out"] = 0;
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
// squeeze can be zero overhead
|
||||
std::map<std::string, uint64_t> SqueezeOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
|
|
|
@ -8,8 +8,11 @@
|
|||
|
||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
|
@ -34,6 +37,43 @@ OpFoldResult IsOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LenTOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](LenTOp op, PatternRewriter &rewriter) {
|
||||
auto buildList = op.getOperand().getDefiningOp<Basicpy::BuildListOp>();
|
||||
if (!buildList)
|
||||
return rewriter.notifyMatchFailure(op, "operand not basicpy.build_list");
|
||||
rewriter.replaceOpWithNewOp<::mlir::ConstantOp>(
|
||||
op, rewriter.getI64IntegerAttr(buildList.getNumOperands()));
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](SizeOp op, PatternRewriter &rewriter) {
|
||||
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return rewriter.notifyMatchFailure(op, "not a ranked tensor");
|
||||
SmallVector<Value> listElements;
|
||||
for (int64_t size : type.getShape()) {
|
||||
listElements.push_back(rewriter.create<::mlir::ConstantOp>(
|
||||
op->getLoc(), rewriter.getI64IntegerAttr(size)));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Basicpy::BuildListOp>(
|
||||
op, Basicpy::ListType::get(rewriter.getContext()), listElements);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
||||
|
||||
|
|
|
@ -20,3 +20,14 @@ func @aten.ne.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType {
|
|||
%0 = "aten.ne.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType
|
||||
return %0 : !basicpy.BoolType
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @aten.gt.int(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i64,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i64) -> !basicpy.BoolType {
|
||||
// CHECK: %[[I1:.*]] = cmpi sgt, %[[ARG0]], %[[ARG1]] : i64
|
||||
// CHECK: %[[BASICPY_BOOL:.*]] = basicpy.bool_cast %[[I1]] : i1 -> !basicpy.BoolType
|
||||
// CHECK: return %[[BASICPY_BOOL]] : !basicpy.BoolType
|
||||
func @aten.gt.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType {
|
||||
%0 = "aten.gt.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType
|
||||
return %0 : !basicpy.BoolType
|
||||
}
|
||||
|
|
|
@ -7,3 +7,21 @@ func @aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !basicp
|
|||
%0 = "aten.__is__"(%arg0, %arg1) : (!basicpy.ListType, !basicpy.NoneType) -> !basicpy.BoolType
|
||||
return %0 : !basicpy.BoolType
|
||||
}
|
||||
// CHECK-LABEL: func @aten.size(
|
||||
// CHECK: %[[CM1:.*]] = constant -1 : i64
|
||||
// CHECK: %[[C3:.*]] = constant 3 : i64
|
||||
// CHECK: %[[RET:.*]] = basicpy.build_list %[[CM1]], %[[C3]] : (i64, i64) -> !basicpy.ListType
|
||||
// CHECK: return %[[RET]] : !basicpy.ListType
|
||||
func @aten.size(%arg0: tensor<?x3xf32>) -> !basicpy.ListType {
|
||||
%0 = "aten.size"(%arg0) : (tensor<?x3xf32>) -> !basicpy.ListType
|
||||
return %0 : !basicpy.ListType
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @aten.len.t(
|
||||
// CHECK: %[[LENGTH:.*]] = constant 2 : i64
|
||||
// CHECK: return %[[LENGTH]] : i64
|
||||
func @aten.len.t(%arg0: i64) -> i64 {
|
||||
%0 = basicpy.build_list %arg0, %arg0 : (i64, i64) -> !basicpy.ListType
|
||||
%1 = "aten.len.t"(%0) : (!basicpy.ListType) -> i64
|
||||
return %1 : i64
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue