diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py index 22e6d870c..b9dd66b9c 100644 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py @@ -106,6 +106,18 @@ def generate_ops(g: "OpGenerator"): g.ordinary_immutable_op("aten::linear(Tensor,Tensor,Tensor?)", "LinearOp", "linear") + g.ordinary_immutable_op( + "aten::batch_norm(Tensor,Tensor?,Tensor?,Tensor?,Tensor?,bool,float,float,bool)", + "BatchNormOp", + "batch_norm") + g.ordinary_immutable_op( + "aten::max_pool2d(Tensor,int[],int[],int[],int[],bool)", + "MaxPool2dOp", + "max_pool2d") + g.ordinary_immutable_op( + "aten::adaptive_avg_pool2d(Tensor,int[])", + "AdaptiveAvgPool2dOp", + "adaptive_avg_pool2d") g.ordinary_immutable_op( "aten::convolution_overrideable(Tensor,Tensor,Tensor?,int[],int[],int[],bool,int[],int)", "ConvolutionOp", @@ -272,6 +284,7 @@ class OpGenerator: "int[]": "AnyTorchIntListType", "bool": "AnyTorchBoolType", "bool[]": "AnyTorchBoolListType", + "float": "AnyFloat", }, flag_transforms={ "Tensor": ["kImmutableTensor"], @@ -363,11 +376,13 @@ class OpGenerator: These take and return a tensor and typically have an out and inplace variant (they may not but we generate patterns to match anyway). """ + kernel_name = kernel_sig.partition("(")[0] opdef = self.define_op( kernel_sig=kernel_sig, ods_name=ods_name, op_name=op_name, promote_trailing_out_tensor=promote_trailing_out_tensor, + inplace_variant_kernel_name=kernel_name + "_", traits=list(traits) + ["NoSideEffect"], **kwargs) opdef.arg_transforms( diff --git a/include/npcomp/Dialect/ATen/IR/ATenOps.td b/include/npcomp/Dialect/ATen/IR/ATenOps.td index 080892532..f07a1bc9d 100644 --- a/include/npcomp/Dialect/ATen/IR/ATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/ATenOps.td @@ -28,29 +28,6 @@ class aten_Op traits = [StatisticsOpInterface]> : include "npcomp/Dialect/ATen/IR/GeneratedATenOps.td" include "npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td" -def aten_BatchNormOp: aten_Op<"batch_norm", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor:$output, AnyTensor:$save_mean, AnyTensor:$save_invstd)> { - let arguments = ( - ins AnyType:$arg0, - AnyType:$arg1, - AnyType:$arg2, - AnyType:$arg3, - AnyType:$arg4, - AnyType:$arg5, - AnyType:$arg6, - AnyType:$arg7, - AnyType:$arg8 - ); - - let summary = "BatchNorm operator"; - let description = [{ - BatchNorm operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - // We have list constants, which come out of pytorch. Represent them using // our own constant-like type, which gets lowered to std_ConstantOp later. def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>, @@ -79,26 +56,6 @@ def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>, }]; } -def aten_MaxPool2dOp: aten_Op<"max_pool2d", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyType:$arg0, - AnyType:$arg1, - AnyType:$arg2, - AnyType:$arg3, - AnyType:$arg4, - AnyType:$arg5 - ); - - let summary = "MaxPool2d operator"; - let description = [{ - MaxPool2d operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - def aten_TypeCastOp : aten_Op<"type_cast", [NoSideEffect]>, Results<(outs AnyType)> { let summary = "TypeCast operator"; diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc index 2215e050d..b3cb7f668 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc @@ -213,6 +213,7 @@ const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::abs"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::abs_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -232,6 +233,7 @@ const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::acos"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::acos_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -251,6 +253,7 @@ const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::angle"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::angle_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -270,6 +273,7 @@ const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::asin"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::asin_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -289,6 +293,7 @@ const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::atan"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::atan_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -308,6 +313,7 @@ const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::ceil"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::ceil_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -327,6 +333,7 @@ const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::conj"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::conj_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -346,6 +353,7 @@ const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::cos"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::cos_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -365,6 +373,7 @@ const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::cosh"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::cosh_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -384,6 +393,7 @@ const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::digamma"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::digamma_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -403,6 +413,7 @@ const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::erf"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::erf_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -422,6 +433,7 @@ const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::erfc"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::erfc_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -441,6 +453,7 @@ const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::erfinv"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::erfinv_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -460,6 +473,7 @@ const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::exp"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::exp_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -479,6 +493,7 @@ const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::expm1"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::expm1_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -498,6 +513,7 @@ const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::floor"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::floor_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -517,6 +533,7 @@ const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::frac"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::frac_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -536,6 +553,7 @@ const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::lgamma"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::lgamma_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -555,6 +573,7 @@ const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::log"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::log_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -574,6 +593,7 @@ const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::log10"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::log10_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -593,6 +613,7 @@ const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::log1p"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::log1p_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -612,6 +633,7 @@ const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::log2"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::log2_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -631,6 +653,7 @@ const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::neg"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::neg_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -650,6 +673,7 @@ const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::relu"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::relu_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -669,6 +693,7 @@ const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::reciprocal"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::reciprocal_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -688,6 +713,7 @@ const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::round"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::round_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -707,6 +733,7 @@ const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::rsqrt"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::rsqrt_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -726,6 +753,7 @@ const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::sigmoid"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::sigmoid_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -745,6 +773,7 @@ const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::sign"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::sign_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -764,6 +793,7 @@ const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::sin"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::sin_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -783,6 +813,7 @@ const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::sinh"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::sinh_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -802,6 +833,7 @@ const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::sqrt"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::sqrt_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -821,6 +853,7 @@ const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::tan"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::tan_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -840,6 +873,7 @@ const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::tanh"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::tanh_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -859,6 +893,7 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() { Torch::BuildKernelMetadata m; m.kernelName = "aten::trunc"; m.promoteTrailingOutTensor = true; + m.inplaceVariantKernelName = "aten::trunc_"; m.addArgTypes({"Tensor"}); m.addArgConversions({KVC::kImmutableTensor}); m.addReturnTypes({"Tensor"}); @@ -949,6 +984,63 @@ const Torch::BuildKernelMetadata &LinearOp::getTorchBuildKernelMetadata() { return metadata; } +Torch::KernelMetadata BatchNormOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &BatchNormOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::batch_norm"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor?", "Tensor?", "Tensor?", "Tensor?", "bool", "float", "float", "bool"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata MaxPool2dOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &MaxPool2dOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::max_pool2d"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "int[]", "int[]", "int[]", "int[]", "bool"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata AdaptiveAvgPool2dOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &AdaptiveAvgPool2dOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::adaptive_avg_pool2d"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "int[]"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kNone}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + Torch::KernelMetadata ConvolutionOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td index 925d4c64e..fd7efca2c 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td @@ -525,6 +525,50 @@ def aten_LinearOp: aten_Op<"linear", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AllowsTypeRefinement]> { + let summary = "Recognized op for kernel aten::batch_norm"; + let arguments = (ins + AnyTorchImmutableTensor:$input, + AnyTorchOptionalImmutableTensor:$weight, + AnyTorchOptionalImmutableTensor:$bias, + AnyTorchOptionalImmutableTensor:$running_mean, + AnyTorchOptionalImmutableTensor:$running_var, + AnyTorchBoolType:$training, + AnyFloat:$momentum, + AnyFloat:$eps, + AnyTorchBoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + +def aten_MaxPool2dOp: aten_Op<"max_pool2d", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AllowsTypeRefinement]> { + let summary = "Recognized op for kernel aten::max_pool2d"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchIntListType:$kernel_size, + AnyTorchIntListType:$stride, + AnyTorchIntListType:$padding, + AnyTorchIntListType:$dilation, + AnyTorchBoolType:$ceil_mode + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + +def aten_AdaptiveAvgPool2dOp: aten_Op<"adaptive_avg_pool2d", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AllowsTypeRefinement]> { + let summary = "Recognized op for kernel aten::adaptive_avg_pool2d"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchIntListType:$output_size + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AllowsTypeRefinement]> { let summary = "Recognized op for kernel aten::convolution_overrideable"; let arguments = (ins @@ -691,3 +735,4 @@ def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyType:$output_size - ); - let summary = "aten _adaptive_avg_pool2d operator"; - let description = [{ - AdaptiveAvgPool2dOp - aten _adaptive_avg_pool2d operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - def aten_AdaptiveAvgPool2dBackwardOp: aten_Op<"_adaptive_avg_pool2d_backward", [NoSideEffect, StatisticsOpInterface]>, Results<(outs AnyTensor)> { let arguments = ( diff --git a/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp b/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp index 0d9c54122..42f83fbb3 100644 --- a/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp +++ b/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp @@ -40,13 +40,6 @@ namespace mlir { namespace NPCOMP { namespace aten { -std::map AdaptiveAvgPool2dOp::getStatistics() { - std::map toReturn; - // FIXME: unimplemented - toReturn["reads"] = -1; - toReturn["writes"] = -1; - return toReturn; -} std::map AdaptiveAvgPool2dBackwardOp::getStatistics() { std::map toReturn; // FIXME: unimplemented @@ -130,46 +123,6 @@ std::map AsStridedOp::getStatistics() { return toReturn; } -// batch_norm -std::map BatchNormOp::getStatistics() { - - std::map toReturn; - - TensorType resultTy = getResult(0).getType().cast(); - uint64_t op_volume = getTensorVolume(resultTy); - uint64_t weight_volume = getTensorVolume(getOperand(1).getType()); - uint64_t bias_volume = getTensorVolume(getOperand(2).getType()); - toReturn["operand:0:activation_in"] = op_volume; - toReturn["result:0:activation_out"] = op_volume; - toReturn["operand:1:parameters_in:weight"] = weight_volume; - toReturn["operand:2:parameters_in:bias"] = bias_volume; - - // Now for the arithmetic. Assume variance is calculated as sum of squares - uint64_t ifm_depth = resultTy.getShape()[1]; - - toReturn["ops:+"] = op_volume; // Add up for mean - toReturn["ops:*"] = op_volume; // Square for variance - toReturn["ops:+"] += op_volume; // Add up squares for variance - - toReturn["ops:*"] += ifm_depth; // Calc channel means - toReturn["ops:-"] += ifm_depth; // Calc channel vars - toReturn["ops:*"] += ifm_depth; // Calc channel vars - - toReturn["ops:sqrt"] = ifm_depth; // Convert to SD - toReturn["ops:/"] = ifm_depth; // Get the reciprocal - - toReturn["ops:+"] += op_volume; // Subtract mean off each pixel - toReturn["ops:*"] += op_volume; // Multiply by 1/SD for each pixel - - toReturn["ops:+"] += op_volume; // Bias - toReturn["ops:*"] += op_volume; // Scale - - toReturn["reads"] = op_volume + weight_volume + bias_volume; - toReturn["writes"] = op_volume; - - return toReturn; -} - // div_ std::map DivUnderOp::getStatistics() { @@ -266,33 +219,6 @@ std::map HardtanhBackwardOp::getStatistics() { return toReturn; } -// max_pool2d -std::map MaxPool2dOp::getStatistics() { - - std::map toReturn; - - TensorType resultTy = getResult().getType().cast(); - TensorType inputType = getOperand(0).getType().cast(); - - uint64_t ofm_volume = getTensorVolume(resultTy); - toReturn["result:0:activation_out"] = ofm_volume; - - uint64_t ifm_volume = getTensorVolume(inputType); - toReturn["input:0:activation_in"] = ifm_volume; - - // To find the number of compares, we need the filter extent - - std::vector kernel_size = unpackListConstant(getOperand(1)); - - uint64_t aperture = kernel_size[0] * kernel_size[1]; - toReturn["ops:>"] = ofm_volume * (aperture - 1); - - toReturn["reads"] = ifm_volume; - toReturn["writes"] = ofm_volume; - - return toReturn; -} - // max_pool2d_with_indices std::map MaxPool2dWithIndicesOp::getStatistics() { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 33645db87..4b3441f0b 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -162,7 +162,9 @@ public: ChangeResult visitOperation(Operation *op, ArrayRef *> operands) final { - if (isa(op)) { + if (isa(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } if (isa(op)) { @@ -214,6 +216,52 @@ public: joinElementTypes(operands[1]->getValue().elementType, operands[2]->getValue().elementType)); return getLatticeElement(op->getResult(0)).join(knowledge); + } else if (isa(op)) { + auto knowledge = + ValueKnowledge::getPessimisticValueState(op->getContext()); + knowledge.hasRank = true; + knowledge.sizes.resize(4, kUnknownSize); + // Running some experiments in PyTorch, the bias doesn't seem to + // contribute to the final element type. + knowledge.elementType = + joinElementTypes(operands[0]->getValue().elementType, + operands[1]->getValue().elementType); + return getLatticeElement(op->getResult(0)).join(knowledge); + } else if (isa(op)) { + auto knowledge = + ValueKnowledge::getPessimisticValueState(op->getContext()); + knowledge.hasRank = true; + knowledge.sizes.resize(4, kUnknownSize); + knowledge.elementType = operands[0]->getValue().elementType; + return getLatticeElement(op->getResult(0)).join(knowledge); + } else if (isa(op)) { + auto input = operands[0]->getValue(); + auto knowledge = + ValueKnowledge::getPessimisticValueState(op->getContext()); + if (input.hasRank) { + knowledge.hasRank = true; + knowledge.sizes.resize(input.sizes.size(), kUnknownSize); + } + knowledge.elementType = input.elementType; + return getLatticeElement(op->getResult(0)).join(knowledge); + } else if (isa(op)) { + // This is a general binary broadcasting shape transfer function. + // We currently don't track "size 1" in our lattice, but we might want to. + // We could make this more precise as well. But again, as with the other + // shape transfer functions, handling the statically-invalid case is + // tricky, so we defer that until we need it. + auto lhs = operands[0]->getValue(); + auto rhs = operands[1]->getValue(); + auto knowledge = + ValueKnowledge::getPessimisticValueState(op->getContext()); + if (lhs.hasRank && rhs.hasRank) { + knowledge.hasRank = true; + knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()), + kUnknownSize); + } + knowledge.elementType = + joinElementTypes(lhs.elementType, rhs.elementType); + return getLatticeElement(op->getResult(0)).join(knowledge); } // Otherwise, this is an unknown operation. Just mark all results as having // reached a pessimistic fixpoint. diff --git a/test/Dialect/ATen/aten_batchnorm.mlir b/test/Dialect/ATen/aten_batchnorm.mlir deleted file mode 100644 index bf5eaa853..000000000 --- a/test/Dialect/ATen/aten_batchnorm.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s -// CHECK-LABEL: "L0-batch_norm-0": { -// CHECK-NEXT: "activation_in": 103320, -// CHECK-NEXT: "activation_out": 103320, -// CHECK-NEXT: "ops:*": 310206, -// CHECK-NEXT: "ops:+": 413280, -// CHECK-NEXT: "ops:-": 123, -// CHECK-NEXT: "ops:/": 123, -// CHECK-NEXT: "ops:sqrt": 123, -// CHECK-NEXT: "parameters_in": 246, -// CHECK-NEXT: "reads": 103566, -// CHECK-NEXT: "writes": 103320 - -module { - func @graph(%arg0: tensor<42x123x4x5xf32>, %arg1: tensor<123xf32>, %arg2: tensor<123xf32>, %arg3: tensor<123xf32>, %arg4: tensor<123xf32>, %arg5: tensor) -> tensor<42x123x4x5xf32> { - %0 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %1 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32 - %2 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32 - %3 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1 - %4:3 = "aten.batch_norm"(%arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2, %3) : (tensor<42x123x4x5xf32>, tensor<123xf32>, tensor -<123xf32>, tensor<123xf32>, tensor<123xf32>, i1, f32, f32, i1) -> (tensor<42x123x4x5xf32>, tensor<123xf32>, tensor<123xf32>) - return %4#0 : tensor<42x123x4x5xf32> - } -} diff --git a/test/Dialect/ATen/aten_maxpool2d.mlir b/test/Dialect/ATen/aten_maxpool2d.mlir deleted file mode 100644 index f5b85956a..000000000 --- a/test/Dialect/ATen/aten_maxpool2d.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s -// CHECK-LABEL: "L0-max_pool2d-0": { -// CHECK-NEXT: "activation_in": 8192, -// CHECK-NEXT: "activation_out": 2048, -// CHECK-NEXT: "ops:>": 16384, -// CHECK-NEXT: "reads": 8192, -// CHECK-NEXT: "writes": 2048 - -module { - func @graph(%arg0: tensor<1x32x16x16xf32>) -> tensor<1x32x8x8xf32> { - %0 = "aten.constant"() {type = "List[i32]", value = dense<3> : vector<2xi64>} : () -> !aten.list - %1 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi64>} : () -> !aten.list - %2 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list - %3 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list - %4 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %5 = "aten.max_pool2d"(%arg0, %0, %1, %2, %3, %4) : (tensor<1x32x16x16xf32>, !aten.list, !aten.list, !aten.list, !aten.list, i1) -> tensor<1x32x8x8xf32> - "std.return"(%5) : (tensor<1x32x8x8xf32>) -> () - } -} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 23b014bb1..2c5c4ecbb 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -51,6 +51,76 @@ func @f(%arg0: tensor, %arg1: tensor<5x3xf32>, %arg2: tensor<5xf32>) -> // ----- +// CHECK-LABEL: func @f +// CHECK: %[[CONV2D:.*]] = "aten.conv2d"{{.*}} -> tensor +// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[CONV2D]] : tensor to tensor<*x!numpy.any_dtype> +// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype> +func @f(%arg0:tensor<*x!numpy.any_dtype>, %arg1:tensor<*x!numpy.any_dtype>, %arg2:tensor<*x!numpy.any_dtype>) ->tensor<*x!numpy.any_dtype> { + %c0_i64 = constant 0 : i64 + %c1_i64 = constant 1 : i64 + %0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType + %2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) ->tensor<*x!numpy.any_dtype> + return %3 :tensor<*x!numpy.any_dtype> +} + +// CHECK-LABEL: func @g +// CHECK: %[[CONV2D:.*]] = "aten.conv2d"{{.*}} -> tensor +// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[CONV2D]] : tensor to tensor<*x!numpy.any_dtype> +// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype> +func @g(%arg0:tensor<*xf32>, %arg1:tensor<*xf32>, %arg2:tensor<*xf32>) ->tensor<*x!numpy.any_dtype> { + %c0_i64 = constant 0 : i64 + %c1_i64 = constant 1 : i64 + %0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType + %2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) ->tensor<*x!numpy.any_dtype> + return %3 :tensor<*x!numpy.any_dtype> +} + +// ----- + +// CHECK-LABEL: func @f +func @f(%arg0: tensor) -> tensor<*x!numpy.any_dtype> { + %c1_i64 = constant 1 : i64 + %c3_i64 = constant 3 : i64 + %c2_i64 = constant 2 : i64 + %bool_false = basicpy.bool_constant false + %21 = basicpy.build_list %c3_i64, %c3_i64 : (i64, i64) -> !basicpy.ListType + %22 = basicpy.build_list %c2_i64, %c2_i64 : (i64, i64) -> !basicpy.ListType + %23 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %24 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + // CHECK: "aten.max_pool2d"{{.*}} -> tensor + %27 = "aten.max_pool2d"(%arg0, %21, %22, %23, %24, %bool_false) : (tensor, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType) -> tensor<*x!numpy.any_dtype> + return %27 : tensor<*x!numpy.any_dtype> +} + +// ----- + +// CHECK-LABEL: func @f +func @f(%arg0: tensor) -> tensor<*x!numpy.any_dtype> { + %c1_i64 = constant 1 : i64 + %0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + // CHECK: "aten.adaptive_avg_pool2d"{{.*}} -> tensor + %1 = "aten.adaptive_avg_pool2d"(%arg0, %0) : (tensor, !basicpy.ListType) -> tensor<*x!numpy.any_dtype> + return %1 : tensor<*x!numpy.any_dtype> +} + +// ----- + +// CHECK-LABEL: func @f +func @f(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: tensor) { + %c1_i64 = constant 1 : i64 + // CHECK: "aten.add"{{.*}} -> tensor + %0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<*x!numpy.any_dtype> + // CHECK: "aten.add"{{.*}} -> tensor + %1 = "aten.add"(%arg0, %arg2, %c1_i64) : (tensor<4x6x3xf32>, tensor, i64) -> tensor<*x!numpy.any_dtype> + return +} + +// ----- + // CHECK-LABEL: func @f func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> { // Check propagation through multiple ops.