diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py index b9dd66b9c..c8a484644 100644 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py @@ -97,6 +97,11 @@ def generate_ops(g: "OpGenerator"): has_folder=True) g.ordinary_primitive_op("aten::dim(Tensor)", "DimOp", "dim") g.ordinary_primitive_op("aten::ne(int,int)", "NeIntOp", "ne.int") + g.ordinary_primitive_op("aten::size(Tensor)", "SizeOp", "size", + has_canonicalizer=True) + g.ordinary_primitive_op("aten::len(t[])", "LenTOp", "len.t", + has_canonicalizer=True) + g.ordinary_primitive_op("aten::gt(int,int)", "GtIntOp", "gt.int") # Convolution ops. Note that these are special in PyTorch and the importer, # and we model them after the signatures of the convolution_overrideable @@ -314,6 +319,7 @@ class OpGenerator: "int[]": "AnyTorchIntListType", "bool": "AnyTorchBoolType", "bool[]": "AnyTorchBoolListType", + "t[]": "Basicpy_ListType", "t1": "AnyTorchType", "t2": "AnyTorchType", }, @@ -505,7 +511,8 @@ class InflightOpDef: override_return_types: Sequence[str] = None, drop_arg_indices: Sequence[int] = (), drop_return_indices: Sequence[int] = (), - has_folder: bool = False): + has_folder: bool = False, + has_canonicalizer: bool = False): super().__init__() self.g = g self.kernel_sig = kernel_sig @@ -520,6 +527,7 @@ class InflightOpDef: self.drop_arg_indices = drop_arg_indices self.drop_return_indices = drop_return_indices self.has_folder = has_folder + self.has_canonicalizer = has_canonicalizer self.reg_record = g.get_reg_record(self.kernel_sig) self._emitted = False self._traceback = traceback.extract_stack()[0:-2] @@ -603,7 +611,8 @@ class InflightOpDef: ods_ins=self.ods_ins, ods_outs=self.ods_outs, traits=self.traits, - has_folder=self.has_folder) + has_folder=self.has_folder, + has_canonicalizer=self.has_canonicalizer) self.g.impl_emitter.emit_kernel_methods( self.ods_name, self.reg_record, @@ -664,7 +673,8 @@ class OdsEmitter(EmitterBase): ods_outs: List[Tuple[str, str]], traits: Sequence[str] = (), summary: Optional[str] = None, - has_folder: bool = False): + has_folder: bool = False, + has_canonicalizer: bool = False): # Def first-line. full_traits = list(traits) full_traits.append( @@ -695,6 +705,8 @@ class OdsEmitter(EmitterBase): if has_folder: self.print("let hasFolder = 1;") + if has_canonicalizer: + self.print("let hasCanonicalizer = 1;") # Def last-line. self.print("}\n") diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc index b3cb7f668..4c15ae66b 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc @@ -961,6 +961,60 @@ const Torch::BuildKernelMetadata &NeIntOp::getTorchBuildKernelMetadata() { return metadata; } +Torch::KernelMetadata SizeOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &SizeOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::size"; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"int[]"}); + m.addReturnConversions({KVC::kNone}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata LenTOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &LenTOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::len"; + m.addArgTypes({"t[]"}); + m.addArgConversions({KVC::kNone}); + m.addReturnTypes({"int"}); + m.addReturnConversions({KVC::kNone}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata GtIntOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &GtIntOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::gt"; + m.addArgTypes({"int", "int"}); + m.addArgConversions({KVC::kNone, KVC::kNone}); + m.addReturnTypes({"bool"}); + m.addReturnConversions({KVC::kNone}); + return m; + })(); + return metadata; +} + // ----------------------------------------------------------------------------- // NN ops // ----------------------------------------------------------------------------- diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td index fd7efca2c..b930a1fa1 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td @@ -509,6 +509,39 @@ def aten_NeIntOp: aten_Op<"ne.int", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AllowsTypeRefinement]> { + let summary = "Recognized op for kernel aten::size"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchIntListType + ); + let hasCanonicalizer = 1; +} + +def aten_LenTOp: aten_Op<"len.t", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AllowsTypeRefinement]> { + let summary = "Recognized op for kernel aten::len"; + let arguments = (ins + Basicpy_ListType:$a + ); + let results = (outs + AnyTorchIntType + ); + let hasCanonicalizer = 1; +} + +def aten_GtIntOp: aten_Op<"gt.int", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AllowsTypeRefinement]> { + let summary = "Recognized op for kernel aten::gt"; + let arguments = (ins + AnyTorchIntType:$a, + AnyTorchIntType:$b + ); + let results = (outs + AnyTorchBoolType + ); +} + // ----------------------------------------------------------------------------- // NN ops // ----------------------------------------------------------------------------- diff --git a/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td b/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td index 3bc548121..4ca4d000d 100644 --- a/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td @@ -170,22 +170,6 @@ def aten_ReluUnderOp: aten_Op<"relu_", [NoSideEffect, StatisticsOpInterface]>, }]; } -def aten_SizeOp: aten_Op<"size", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyScalar:$dim - ); - let summary = "aten size operator"; - let description = [{ - SizeOp - aten size operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - def aten_SqueezeOp: aten_Op<"squeeze", [NoSideEffect, StatisticsOpInterface]>, Results<(outs AnyTensor)> { let arguments = ( diff --git a/lib/Conversion/ATenToStd/ATenToStd.cpp b/lib/Conversion/ATenToStd/ATenToStd.cpp index 0c41ec6ac..4e82ebef4 100644 --- a/lib/Conversion/ATenToStd/ATenToStd.cpp +++ b/lib/Conversion/ATenToStd/ATenToStd.cpp @@ -40,6 +40,13 @@ LogicalResult convertNeIntOp(aten::NeIntOp op, PatternRewriter &rewriter) { return success(); } +LogicalResult convertGtIntOp(aten::GtIntOp op, PatternRewriter &rewriter) { + auto i1 = rewriter.create(op->getLoc(), CmpIPredicate::sgt, + op->getOperand(0), op->getOperand(1)); + rewriter.replaceOpWithNewOp(op, op.getType(), i1); + return success(); +} + // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- @@ -60,6 +67,7 @@ public: RewritePatternSet patterns(context); patterns.add(convertDimOp); patterns.add(convertNeIntOp); + patterns.add(convertGtIntOp); return std::move(patterns); } }; diff --git a/lib/Dialect/ATen/IR/ATenDialect.cpp b/lib/Dialect/ATen/IR/ATenDialect.cpp index b922d72d8..09c3f4a85 100644 --- a/lib/Dialect/ATen/IR/ATenDialect.cpp +++ b/lib/Dialect/ATen/IR/ATenDialect.cpp @@ -117,6 +117,7 @@ void ATenDialect::initialize() { #include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc" >(); getContext()->getOrLoadDialect("torch"); + getContext()->getOrLoadDialect("std"); } #include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc" diff --git a/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp b/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp index 42f83fbb3..f9a2a81d5 100644 --- a/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp +++ b/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp @@ -522,14 +522,6 @@ std::map SumOp::getStatistics() { return toReturn; } -// size op can be zero overhead -std::map SizeOp::getStatistics() { - std::map toReturn; - toReturn["reads"] = toReturn["operand:0:activation_in"] = 0; - toReturn["writes"] = toReturn["result:0:activation_out"] = 0; - return toReturn; -} - // squeeze can be zero overhead std::map SqueezeOp::getStatistics() { std::map toReturn; diff --git a/lib/Dialect/ATen/IR/ATenOps.cpp b/lib/Dialect/ATen/IR/ATenOps.cpp index 59e793c99..88f210b8f 100644 --- a/lib/Dialect/ATen/IR/ATenOps.cpp +++ b/lib/Dialect/ATen/IR/ATenOps.cpp @@ -8,8 +8,11 @@ #include "npcomp/Dialect/ATen/IR/ATenDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Torch/IR/TorchTypes.h" @@ -34,6 +37,43 @@ OpFoldResult IsOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// LenTOp +//===----------------------------------------------------------------------===// + +void LenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](LenTOp op, PatternRewriter &rewriter) { + auto buildList = op.getOperand().getDefiningOp(); + if (!buildList) + return rewriter.notifyMatchFailure(op, "operand not basicpy.build_list"); + rewriter.replaceOpWithNewOp<::mlir::ConstantOp>( + op, rewriter.getI64IntegerAttr(buildList.getNumOperands())); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// SizeOp +//===----------------------------------------------------------------------===// + +void SizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](SizeOp op, PatternRewriter &rewriter) { + auto type = op.getOperand().getType().dyn_cast(); + if (!type) + return rewriter.notifyMatchFailure(op, "not a ranked tensor"); + SmallVector listElements; + for (int64_t size : type.getShape()) { + listElements.push_back(rewriter.create<::mlir::ConstantOp>( + op->getLoc(), rewriter.getI64IntegerAttr(size))); + } + rewriter.replaceOpWithNewOp( + op, Basicpy::ListType::get(rewriter.getContext()), listElements); + return success(); + }); +} + #define GET_OP_CLASSES #include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc" diff --git a/test/Conversion/ATenToStd/basic.mlir b/test/Conversion/ATenToStd/basic.mlir index 298b39ece..9dabd8b45 100644 --- a/test/Conversion/ATenToStd/basic.mlir +++ b/test/Conversion/ATenToStd/basic.mlir @@ -20,3 +20,14 @@ func @aten.ne.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType { %0 = "aten.ne.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType return %0 : !basicpy.BoolType } + +// CHECK-LABEL: func @aten.gt.int( +// CHECK-SAME: %[[ARG0:.*]]: i64, +// CHECK-SAME: %[[ARG1:.*]]: i64) -> !basicpy.BoolType { +// CHECK: %[[I1:.*]] = cmpi sgt, %[[ARG0]], %[[ARG1]] : i64 +// CHECK: %[[BASICPY_BOOL:.*]] = basicpy.bool_cast %[[I1]] : i1 -> !basicpy.BoolType +// CHECK: return %[[BASICPY_BOOL]] : !basicpy.BoolType +func @aten.gt.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType { + %0 = "aten.gt.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType + return %0 : !basicpy.BoolType +} diff --git a/test/Dialect/ATen/canonicalize.mlir b/test/Dialect/ATen/canonicalize.mlir index 5c24ef065..bb1cdf6a4 100644 --- a/test/Dialect/ATen/canonicalize.mlir +++ b/test/Dialect/ATen/canonicalize.mlir @@ -7,3 +7,21 @@ func @aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !basicp %0 = "aten.__is__"(%arg0, %arg1) : (!basicpy.ListType, !basicpy.NoneType) -> !basicpy.BoolType return %0 : !basicpy.BoolType } +// CHECK-LABEL: func @aten.size( +// CHECK: %[[CM1:.*]] = constant -1 : i64 +// CHECK: %[[C3:.*]] = constant 3 : i64 +// CHECK: %[[RET:.*]] = basicpy.build_list %[[CM1]], %[[C3]] : (i64, i64) -> !basicpy.ListType +// CHECK: return %[[RET]] : !basicpy.ListType +func @aten.size(%arg0: tensor) -> !basicpy.ListType { + %0 = "aten.size"(%arg0) : (tensor) -> !basicpy.ListType + return %0 : !basicpy.ListType +} + +// CHECK-LABEL: func @aten.len.t( +// CHECK: %[[LENGTH:.*]] = constant 2 : i64 +// CHECK: return %[[LENGTH]] : i64 +func @aten.len.t(%arg0: i64) -> i64 { + %0 = basicpy.build_list %arg0, %arg0 : (i64, i64) -> !basicpy.ListType + %1 = "aten.len.t"(%0) : (!basicpy.ListType) -> i64 + return %1 : i64 +}