mirror of https://github.com/llvm/torch-mlir
Add aten::len.t, aten::size, and aten::gt.int primitive ops
Also add some canonicalizations that finally reduce ResNet down to a single block.pull/213/head
parent
ec6d06aa86
commit
122cae2ee3
|
@ -97,6 +97,11 @@ def generate_ops(g: "OpGenerator"):
|
||||||
has_folder=True)
|
has_folder=True)
|
||||||
g.ordinary_primitive_op("aten::dim(Tensor)", "DimOp", "dim")
|
g.ordinary_primitive_op("aten::dim(Tensor)", "DimOp", "dim")
|
||||||
g.ordinary_primitive_op("aten::ne(int,int)", "NeIntOp", "ne.int")
|
g.ordinary_primitive_op("aten::ne(int,int)", "NeIntOp", "ne.int")
|
||||||
|
g.ordinary_primitive_op("aten::size(Tensor)", "SizeOp", "size",
|
||||||
|
has_canonicalizer=True)
|
||||||
|
g.ordinary_primitive_op("aten::len(t[])", "LenTOp", "len.t",
|
||||||
|
has_canonicalizer=True)
|
||||||
|
g.ordinary_primitive_op("aten::gt(int,int)", "GtIntOp", "gt.int")
|
||||||
|
|
||||||
# Convolution ops. Note that these are special in PyTorch and the importer,
|
# Convolution ops. Note that these are special in PyTorch and the importer,
|
||||||
# and we model them after the signatures of the convolution_overrideable
|
# and we model them after the signatures of the convolution_overrideable
|
||||||
|
@ -314,6 +319,7 @@ class OpGenerator:
|
||||||
"int[]": "AnyTorchIntListType",
|
"int[]": "AnyTorchIntListType",
|
||||||
"bool": "AnyTorchBoolType",
|
"bool": "AnyTorchBoolType",
|
||||||
"bool[]": "AnyTorchBoolListType",
|
"bool[]": "AnyTorchBoolListType",
|
||||||
|
"t[]": "Basicpy_ListType",
|
||||||
"t1": "AnyTorchType",
|
"t1": "AnyTorchType",
|
||||||
"t2": "AnyTorchType",
|
"t2": "AnyTorchType",
|
||||||
},
|
},
|
||||||
|
@ -505,7 +511,8 @@ class InflightOpDef:
|
||||||
override_return_types: Sequence[str] = None,
|
override_return_types: Sequence[str] = None,
|
||||||
drop_arg_indices: Sequence[int] = (),
|
drop_arg_indices: Sequence[int] = (),
|
||||||
drop_return_indices: Sequence[int] = (),
|
drop_return_indices: Sequence[int] = (),
|
||||||
has_folder: bool = False):
|
has_folder: bool = False,
|
||||||
|
has_canonicalizer: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.g = g
|
self.g = g
|
||||||
self.kernel_sig = kernel_sig
|
self.kernel_sig = kernel_sig
|
||||||
|
@ -520,6 +527,7 @@ class InflightOpDef:
|
||||||
self.drop_arg_indices = drop_arg_indices
|
self.drop_arg_indices = drop_arg_indices
|
||||||
self.drop_return_indices = drop_return_indices
|
self.drop_return_indices = drop_return_indices
|
||||||
self.has_folder = has_folder
|
self.has_folder = has_folder
|
||||||
|
self.has_canonicalizer = has_canonicalizer
|
||||||
self.reg_record = g.get_reg_record(self.kernel_sig)
|
self.reg_record = g.get_reg_record(self.kernel_sig)
|
||||||
self._emitted = False
|
self._emitted = False
|
||||||
self._traceback = traceback.extract_stack()[0:-2]
|
self._traceback = traceback.extract_stack()[0:-2]
|
||||||
|
@ -603,7 +611,8 @@ class InflightOpDef:
|
||||||
ods_ins=self.ods_ins,
|
ods_ins=self.ods_ins,
|
||||||
ods_outs=self.ods_outs,
|
ods_outs=self.ods_outs,
|
||||||
traits=self.traits,
|
traits=self.traits,
|
||||||
has_folder=self.has_folder)
|
has_folder=self.has_folder,
|
||||||
|
has_canonicalizer=self.has_canonicalizer)
|
||||||
self.g.impl_emitter.emit_kernel_methods(
|
self.g.impl_emitter.emit_kernel_methods(
|
||||||
self.ods_name,
|
self.ods_name,
|
||||||
self.reg_record,
|
self.reg_record,
|
||||||
|
@ -664,7 +673,8 @@ class OdsEmitter(EmitterBase):
|
||||||
ods_outs: List[Tuple[str, str]],
|
ods_outs: List[Tuple[str, str]],
|
||||||
traits: Sequence[str] = (),
|
traits: Sequence[str] = (),
|
||||||
summary: Optional[str] = None,
|
summary: Optional[str] = None,
|
||||||
has_folder: bool = False):
|
has_folder: bool = False,
|
||||||
|
has_canonicalizer: bool = False):
|
||||||
# Def first-line.
|
# Def first-line.
|
||||||
full_traits = list(traits)
|
full_traits = list(traits)
|
||||||
full_traits.append(
|
full_traits.append(
|
||||||
|
@ -695,6 +705,8 @@ class OdsEmitter(EmitterBase):
|
||||||
|
|
||||||
if has_folder:
|
if has_folder:
|
||||||
self.print("let hasFolder = 1;")
|
self.print("let hasFolder = 1;")
|
||||||
|
if has_canonicalizer:
|
||||||
|
self.print("let hasCanonicalizer = 1;")
|
||||||
|
|
||||||
# Def last-line.
|
# Def last-line.
|
||||||
self.print("}\n")
|
self.print("}\n")
|
||||||
|
|
|
@ -961,6 +961,60 @@ const Torch::BuildKernelMetadata &NeIntOp::getTorchBuildKernelMetadata() {
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata SizeOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &SizeOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::size";
|
||||||
|
m.addArgTypes({"Tensor"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor});
|
||||||
|
m.addReturnTypes({"int[]"});
|
||||||
|
m.addReturnConversions({KVC::kNone});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata LenTOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &LenTOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::len";
|
||||||
|
m.addArgTypes({"t[]"});
|
||||||
|
m.addArgConversions({KVC::kNone});
|
||||||
|
m.addReturnTypes({"int"});
|
||||||
|
m.addReturnConversions({KVC::kNone});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata GtIntOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &GtIntOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::gt";
|
||||||
|
m.addArgTypes({"int", "int"});
|
||||||
|
m.addArgConversions({KVC::kNone, KVC::kNone});
|
||||||
|
m.addReturnTypes({"bool"});
|
||||||
|
m.addReturnConversions({KVC::kNone});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// NN ops
|
// NN ops
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
|
@ -509,6 +509,39 @@ def aten_NeIntOp: aten_Op<"ne.int", [NoSideEffect, DeclareOpInterfaceMethods<Tor
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def aten_SizeOp: aten_Op<"size", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||||
|
let summary = "Recognized op for kernel aten::size";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchIntListType
|
||||||
|
);
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def aten_LenTOp: aten_Op<"len.t", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||||
|
let summary = "Recognized op for kernel aten::len";
|
||||||
|
let arguments = (ins
|
||||||
|
Basicpy_ListType:$a
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchIntType
|
||||||
|
);
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def aten_GtIntOp: aten_Op<"gt.int", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||||
|
let summary = "Recognized op for kernel aten::gt";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchIntType:$a,
|
||||||
|
AnyTorchIntType:$b
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchBoolType
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// NN ops
|
// NN ops
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
|
@ -170,22 +170,6 @@ def aten_ReluUnderOp: aten_Op<"relu_", [NoSideEffect, StatisticsOpInterface]>,
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def aten_SizeOp: aten_Op<"size", [NoSideEffect, StatisticsOpInterface]>,
|
|
||||||
Results<(outs AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$self,
|
|
||||||
AnyScalar:$dim
|
|
||||||
);
|
|
||||||
let summary = "aten size operator";
|
|
||||||
let description = [{
|
|
||||||
SizeOp
|
|
||||||
aten size operator
|
|
||||||
}];
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, uint64_t> getStatistics();
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def aten_SqueezeOp: aten_Op<"squeeze", [NoSideEffect, StatisticsOpInterface]>,
|
def aten_SqueezeOp: aten_Op<"squeeze", [NoSideEffect, StatisticsOpInterface]>,
|
||||||
Results<(outs AnyTensor)> {
|
Results<(outs AnyTensor)> {
|
||||||
let arguments = (
|
let arguments = (
|
||||||
|
|
|
@ -40,6 +40,13 @@ LogicalResult convertNeIntOp(aten::NeIntOp op, PatternRewriter &rewriter) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult convertGtIntOp(aten::GtIntOp op, PatternRewriter &rewriter) {
|
||||||
|
auto i1 = rewriter.create<CmpIOp>(op->getLoc(), CmpIPredicate::sgt,
|
||||||
|
op->getOperand(0), op->getOperand(1));
|
||||||
|
rewriter.replaceOpWithNewOp<Basicpy::BoolCastOp>(op, op.getType(), i1);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// The pass
|
// The pass
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -60,6 +67,7 @@ public:
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
patterns.add(convertDimOp);
|
patterns.add(convertDimOp);
|
||||||
patterns.add(convertNeIntOp);
|
patterns.add(convertNeIntOp);
|
||||||
|
patterns.add(convertGtIntOp);
|
||||||
return std::move(patterns);
|
return std::move(patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -117,6 +117,7 @@ void ATenDialect::initialize() {
|
||||||
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
||||||
>();
|
>();
|
||||||
getContext()->getOrLoadDialect("torch");
|
getContext()->getOrLoadDialect("torch");
|
||||||
|
getContext()->getOrLoadDialect("std");
|
||||||
}
|
}
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc"
|
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc"
|
||||||
|
|
|
@ -522,14 +522,6 @@ std::map<std::string, uint64_t> SumOp::getStatistics() {
|
||||||
return toReturn;
|
return toReturn;
|
||||||
}
|
}
|
||||||
|
|
||||||
// size op can be zero overhead
|
|
||||||
std::map<std::string, uint64_t> SizeOp::getStatistics() {
|
|
||||||
std::map<std::string, uint64_t> toReturn;
|
|
||||||
toReturn["reads"] = toReturn["operand:0:activation_in"] = 0;
|
|
||||||
toReturn["writes"] = toReturn["result:0:activation_out"] = 0;
|
|
||||||
return toReturn;
|
|
||||||
}
|
|
||||||
|
|
||||||
// squeeze can be zero overhead
|
// squeeze can be zero overhead
|
||||||
std::map<std::string, uint64_t> SqueezeOp::getStatistics() {
|
std::map<std::string, uint64_t> SqueezeOp::getStatistics() {
|
||||||
std::map<std::string, uint64_t> toReturn;
|
std::map<std::string, uint64_t> toReturn;
|
||||||
|
|
|
@ -8,8 +8,11 @@
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
|
||||||
|
@ -34,6 +37,43 @@ OpFoldResult IsOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// LenTOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void LenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](LenTOp op, PatternRewriter &rewriter) {
|
||||||
|
auto buildList = op.getOperand().getDefiningOp<Basicpy::BuildListOp>();
|
||||||
|
if (!buildList)
|
||||||
|
return rewriter.notifyMatchFailure(op, "operand not basicpy.build_list");
|
||||||
|
rewriter.replaceOpWithNewOp<::mlir::ConstantOp>(
|
||||||
|
op, rewriter.getI64IntegerAttr(buildList.getNumOperands()));
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SizeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void SizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](SizeOp op, PatternRewriter &rewriter) {
|
||||||
|
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!type)
|
||||||
|
return rewriter.notifyMatchFailure(op, "not a ranked tensor");
|
||||||
|
SmallVector<Value> listElements;
|
||||||
|
for (int64_t size : type.getShape()) {
|
||||||
|
listElements.push_back(rewriter.create<::mlir::ConstantOp>(
|
||||||
|
op->getLoc(), rewriter.getI64IntegerAttr(size)));
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<Basicpy::BuildListOp>(
|
||||||
|
op, Basicpy::ListType::get(rewriter.getContext()), listElements);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
||||||
|
|
||||||
|
|
|
@ -20,3 +20,14 @@ func @aten.ne.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType {
|
||||||
%0 = "aten.ne.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType
|
%0 = "aten.ne.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType
|
||||||
return %0 : !basicpy.BoolType
|
return %0 : !basicpy.BoolType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @aten.gt.int(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: i64,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: i64) -> !basicpy.BoolType {
|
||||||
|
// CHECK: %[[I1:.*]] = cmpi sgt, %[[ARG0]], %[[ARG1]] : i64
|
||||||
|
// CHECK: %[[BASICPY_BOOL:.*]] = basicpy.bool_cast %[[I1]] : i1 -> !basicpy.BoolType
|
||||||
|
// CHECK: return %[[BASICPY_BOOL]] : !basicpy.BoolType
|
||||||
|
func @aten.gt.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType {
|
||||||
|
%0 = "aten.gt.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType
|
||||||
|
return %0 : !basicpy.BoolType
|
||||||
|
}
|
||||||
|
|
|
@ -7,3 +7,21 @@ func @aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !basicp
|
||||||
%0 = "aten.__is__"(%arg0, %arg1) : (!basicpy.ListType, !basicpy.NoneType) -> !basicpy.BoolType
|
%0 = "aten.__is__"(%arg0, %arg1) : (!basicpy.ListType, !basicpy.NoneType) -> !basicpy.BoolType
|
||||||
return %0 : !basicpy.BoolType
|
return %0 : !basicpy.BoolType
|
||||||
}
|
}
|
||||||
|
// CHECK-LABEL: func @aten.size(
|
||||||
|
// CHECK: %[[CM1:.*]] = constant -1 : i64
|
||||||
|
// CHECK: %[[C3:.*]] = constant 3 : i64
|
||||||
|
// CHECK: %[[RET:.*]] = basicpy.build_list %[[CM1]], %[[C3]] : (i64, i64) -> !basicpy.ListType
|
||||||
|
// CHECK: return %[[RET]] : !basicpy.ListType
|
||||||
|
func @aten.size(%arg0: tensor<?x3xf32>) -> !basicpy.ListType {
|
||||||
|
%0 = "aten.size"(%arg0) : (tensor<?x3xf32>) -> !basicpy.ListType
|
||||||
|
return %0 : !basicpy.ListType
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @aten.len.t(
|
||||||
|
// CHECK: %[[LENGTH:.*]] = constant 2 : i64
|
||||||
|
// CHECK: return %[[LENGTH]] : i64
|
||||||
|
func @aten.len.t(%arg0: i64) -> i64 {
|
||||||
|
%0 = basicpy.build_list %arg0, %arg0 : (i64, i64) -> !basicpy.ListType
|
||||||
|
%1 = "aten.len.t"(%0) : (!basicpy.ListType) -> i64
|
||||||
|
return %1 : i64
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue