Add some more ResNet ops.

- aten::relu_, aten::max_pool2d, aten::adaptive_avg_pool2d, aten::batch_norm, aten::conv2d

No aten-to-linalg conversion for the latter ones, as they are fairly
substantial. At this point, I'm trying to get shape inference and stuff
working for them and the IR cleaned up.
pull/213/head
Sean Silva 2021-04-28 15:29:02 -07:00
parent 9257457d8a
commit ec6d06aa86
10 changed files with 271 additions and 177 deletions

View File

@ -106,6 +106,18 @@ def generate_ops(g: "OpGenerator"):
g.ordinary_immutable_op("aten::linear(Tensor,Tensor,Tensor?)", g.ordinary_immutable_op("aten::linear(Tensor,Tensor,Tensor?)",
"LinearOp", "LinearOp",
"linear") "linear")
g.ordinary_immutable_op(
"aten::batch_norm(Tensor,Tensor?,Tensor?,Tensor?,Tensor?,bool,float,float,bool)",
"BatchNormOp",
"batch_norm")
g.ordinary_immutable_op(
"aten::max_pool2d(Tensor,int[],int[],int[],int[],bool)",
"MaxPool2dOp",
"max_pool2d")
g.ordinary_immutable_op(
"aten::adaptive_avg_pool2d(Tensor,int[])",
"AdaptiveAvgPool2dOp",
"adaptive_avg_pool2d")
g.ordinary_immutable_op( g.ordinary_immutable_op(
"aten::convolution_overrideable(Tensor,Tensor,Tensor?,int[],int[],int[],bool,int[],int)", "aten::convolution_overrideable(Tensor,Tensor,Tensor?,int[],int[],int[],bool,int[],int)",
"ConvolutionOp", "ConvolutionOp",
@ -272,6 +284,7 @@ class OpGenerator:
"int[]": "AnyTorchIntListType", "int[]": "AnyTorchIntListType",
"bool": "AnyTorchBoolType", "bool": "AnyTorchBoolType",
"bool[]": "AnyTorchBoolListType", "bool[]": "AnyTorchBoolListType",
"float": "AnyFloat",
}, },
flag_transforms={ flag_transforms={
"Tensor": ["kImmutableTensor"], "Tensor": ["kImmutableTensor"],
@ -363,11 +376,13 @@ class OpGenerator:
These take and return a tensor and typically have an out and inplace These take and return a tensor and typically have an out and inplace
variant (they may not but we generate patterns to match anyway). variant (they may not but we generate patterns to match anyway).
""" """
kernel_name = kernel_sig.partition("(")[0]
opdef = self.define_op( opdef = self.define_op(
kernel_sig=kernel_sig, kernel_sig=kernel_sig,
ods_name=ods_name, ods_name=ods_name,
op_name=op_name, op_name=op_name,
promote_trailing_out_tensor=promote_trailing_out_tensor, promote_trailing_out_tensor=promote_trailing_out_tensor,
inplace_variant_kernel_name=kernel_name + "_",
traits=list(traits) + ["NoSideEffect"], traits=list(traits) + ["NoSideEffect"],
**kwargs) **kwargs)
opdef.arg_transforms( opdef.arg_transforms(

View File

@ -28,29 +28,6 @@ class aten_Op<string mnemonic, list<OpTrait> traits = [StatisticsOpInterface]> :
include "npcomp/Dialect/ATen/IR/GeneratedATenOps.td" include "npcomp/Dialect/ATen/IR/GeneratedATenOps.td"
include "npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td" include "npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td"
def aten_BatchNormOp: aten_Op<"batch_norm", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor:$output, AnyTensor:$save_mean, AnyTensor:$save_invstd)> {
let arguments = (
ins AnyType:$arg0,
AnyType:$arg1,
AnyType:$arg2,
AnyType:$arg3,
AnyType:$arg4,
AnyType:$arg5,
AnyType:$arg6,
AnyType:$arg7,
AnyType:$arg8
);
let summary = "BatchNorm operator";
let description = [{
BatchNorm operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
// We have list constants, which come out of pytorch. Represent them using // We have list constants, which come out of pytorch. Represent them using
// our own constant-like type, which gets lowered to std_ConstantOp later. // our own constant-like type, which gets lowered to std_ConstantOp later.
def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>, def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>,
@ -79,26 +56,6 @@ def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>,
}]; }];
} }
def aten_MaxPool2dOp: aten_Op<"max_pool2d", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyType:$arg0,
AnyType:$arg1,
AnyType:$arg2,
AnyType:$arg3,
AnyType:$arg4,
AnyType:$arg5
);
let summary = "MaxPool2d operator";
let description = [{
MaxPool2d operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_TypeCastOp : aten_Op<"type_cast", [NoSideEffect]>, def aten_TypeCastOp : aten_Op<"type_cast", [NoSideEffect]>,
Results<(outs AnyType)> { Results<(outs AnyType)> {
let summary = "TypeCast operator"; let summary = "TypeCast operator";

View File

@ -213,6 +213,7 @@ const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::abs"; m.kernelName = "aten::abs";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::abs_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -232,6 +233,7 @@ const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::acos"; m.kernelName = "aten::acos";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::acos_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -251,6 +253,7 @@ const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::angle"; m.kernelName = "aten::angle";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::angle_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -270,6 +273,7 @@ const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::asin"; m.kernelName = "aten::asin";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::asin_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -289,6 +293,7 @@ const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::atan"; m.kernelName = "aten::atan";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::atan_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -308,6 +313,7 @@ const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::ceil"; m.kernelName = "aten::ceil";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::ceil_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -327,6 +333,7 @@ const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::conj"; m.kernelName = "aten::conj";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::conj_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -346,6 +353,7 @@ const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::cos"; m.kernelName = "aten::cos";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::cos_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -365,6 +373,7 @@ const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::cosh"; m.kernelName = "aten::cosh";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::cosh_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -384,6 +393,7 @@ const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::digamma"; m.kernelName = "aten::digamma";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::digamma_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -403,6 +413,7 @@ const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::erf"; m.kernelName = "aten::erf";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::erf_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -422,6 +433,7 @@ const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::erfc"; m.kernelName = "aten::erfc";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::erfc_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -441,6 +453,7 @@ const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::erfinv"; m.kernelName = "aten::erfinv";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::erfinv_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -460,6 +473,7 @@ const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::exp"; m.kernelName = "aten::exp";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::exp_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -479,6 +493,7 @@ const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::expm1"; m.kernelName = "aten::expm1";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::expm1_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -498,6 +513,7 @@ const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::floor"; m.kernelName = "aten::floor";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::floor_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -517,6 +533,7 @@ const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::frac"; m.kernelName = "aten::frac";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::frac_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -536,6 +553,7 @@ const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::lgamma"; m.kernelName = "aten::lgamma";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::lgamma_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -555,6 +573,7 @@ const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::log"; m.kernelName = "aten::log";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::log_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -574,6 +593,7 @@ const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::log10"; m.kernelName = "aten::log10";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::log10_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -593,6 +613,7 @@ const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::log1p"; m.kernelName = "aten::log1p";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::log1p_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -612,6 +633,7 @@ const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::log2"; m.kernelName = "aten::log2";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::log2_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -631,6 +653,7 @@ const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::neg"; m.kernelName = "aten::neg";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::neg_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -650,6 +673,7 @@ const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::relu"; m.kernelName = "aten::relu";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::relu_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -669,6 +693,7 @@ const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::reciprocal"; m.kernelName = "aten::reciprocal";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::reciprocal_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -688,6 +713,7 @@ const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::round"; m.kernelName = "aten::round";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::round_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -707,6 +733,7 @@ const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::rsqrt"; m.kernelName = "aten::rsqrt";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::rsqrt_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -726,6 +753,7 @@ const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::sigmoid"; m.kernelName = "aten::sigmoid";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::sigmoid_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -745,6 +773,7 @@ const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::sign"; m.kernelName = "aten::sign";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::sign_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -764,6 +793,7 @@ const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::sin"; m.kernelName = "aten::sin";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::sin_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -783,6 +813,7 @@ const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::sinh"; m.kernelName = "aten::sinh";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::sinh_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -802,6 +833,7 @@ const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::sqrt"; m.kernelName = "aten::sqrt";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::sqrt_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -821,6 +853,7 @@ const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::tan"; m.kernelName = "aten::tan";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::tan_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -840,6 +873,7 @@ const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::tanh"; m.kernelName = "aten::tanh";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::tanh_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -859,6 +893,7 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m; Torch::BuildKernelMetadata m;
m.kernelName = "aten::trunc"; m.kernelName = "aten::trunc";
m.promoteTrailingOutTensor = true; m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::trunc_";
m.addArgTypes({"Tensor"}); m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor}); m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"}); m.addReturnTypes({"Tensor"});
@ -949,6 +984,63 @@ const Torch::BuildKernelMetadata &LinearOp::getTorchBuildKernelMetadata() {
return metadata; return metadata;
} }
Torch::KernelMetadata BatchNormOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &BatchNormOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::batch_norm";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor?", "Tensor?", "Tensor?", "Tensor?", "bool", "float", "float", "bool"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata MaxPool2dOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &MaxPool2dOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::max_pool2d";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "int[]", "int[]", "int[]", "int[]", "bool"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata AdaptiveAvgPool2dOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &AdaptiveAvgPool2dOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::adaptive_avg_pool2d";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "int[]"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ConvolutionOp::getTorchKernelMetadata() { Torch::KernelMetadata ConvolutionOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata(); return getTorchBuildKernelMetadata();
} }

View File

@ -525,6 +525,50 @@ def aten_LinearOp: aten_Op<"linear", [NoSideEffect, DeclareOpInterfaceMethods<To
); );
} }
def aten_BatchNormOp: aten_Op<"batch_norm", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
let summary = "Recognized op for kernel aten::batch_norm";
let arguments = (ins
AnyTorchImmutableTensor:$input,
AnyTorchOptionalImmutableTensor:$weight,
AnyTorchOptionalImmutableTensor:$bias,
AnyTorchOptionalImmutableTensor:$running_mean,
AnyTorchOptionalImmutableTensor:$running_var,
AnyTorchBoolType:$training,
AnyFloat:$momentum,
AnyFloat:$eps,
AnyTorchBoolType:$cudnn_enabled
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_MaxPool2dOp: aten_Op<"max_pool2d", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
let summary = "Recognized op for kernel aten::max_pool2d";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchIntListType:$kernel_size,
AnyTorchIntListType:$stride,
AnyTorchIntListType:$padding,
AnyTorchIntListType:$dilation,
AnyTorchBoolType:$ceil_mode
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_AdaptiveAvgPool2dOp: aten_Op<"adaptive_avg_pool2d", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
let summary = "Recognized op for kernel aten::adaptive_avg_pool2d";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchIntListType:$output_size
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> { def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
let summary = "Recognized op for kernel aten::convolution_overrideable"; let summary = "Recognized op for kernel aten::convolution_overrideable";
let arguments = (ins let arguments = (ins
@ -691,3 +735,4 @@ def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<Torch
let results = (outs let results = (outs
); );
} }

View File

@ -406,22 +406,6 @@ def aten_HardtanhUnderOp: aten_Op<"hardtanh_", [NoSideEffect, StatisticsOpInterf
}]; }];
} }
def aten_AdaptiveAvgPool2dOp: aten_Op<"_adaptive_avg_pool2d", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyType:$output_size
);
let summary = "aten _adaptive_avg_pool2d operator";
let description = [{
AdaptiveAvgPool2dOp
aten _adaptive_avg_pool2d operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_AdaptiveAvgPool2dBackwardOp: aten_Op<"_adaptive_avg_pool2d_backward", [NoSideEffect, StatisticsOpInterface]>, def aten_AdaptiveAvgPool2dBackwardOp: aten_Op<"_adaptive_avg_pool2d_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> { Results<(outs AnyTensor)> {
let arguments = ( let arguments = (

View File

@ -40,13 +40,6 @@ namespace mlir {
namespace NPCOMP { namespace NPCOMP {
namespace aten { namespace aten {
std::map<std::string, uint64_t> AdaptiveAvgPool2dOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented
toReturn["reads"] = -1;
toReturn["writes"] = -1;
return toReturn;
}
std::map<std::string, uint64_t> AdaptiveAvgPool2dBackwardOp::getStatistics() { std::map<std::string, uint64_t> AdaptiveAvgPool2dBackwardOp::getStatistics() {
std::map<std::string, uint64_t> toReturn; std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented // FIXME: unimplemented
@ -130,46 +123,6 @@ std::map<std::string, uint64_t> AsStridedOp::getStatistics() {
return toReturn; return toReturn;
} }
// batch_norm
std::map<std::string, uint64_t> BatchNormOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
TensorType resultTy = getResult(0).getType().cast<TensorType>();
uint64_t op_volume = getTensorVolume(resultTy);
uint64_t weight_volume = getTensorVolume(getOperand(1).getType());
uint64_t bias_volume = getTensorVolume(getOperand(2).getType());
toReturn["operand:0:activation_in"] = op_volume;
toReturn["result:0:activation_out"] = op_volume;
toReturn["operand:1:parameters_in:weight"] = weight_volume;
toReturn["operand:2:parameters_in:bias"] = bias_volume;
// Now for the arithmetic. Assume variance is calculated as sum of squares
uint64_t ifm_depth = resultTy.getShape()[1];
toReturn["ops:+"] = op_volume; // Add up for mean
toReturn["ops:*"] = op_volume; // Square for variance
toReturn["ops:+"] += op_volume; // Add up squares for variance
toReturn["ops:*"] += ifm_depth; // Calc channel means
toReturn["ops:-"] += ifm_depth; // Calc channel vars
toReturn["ops:*"] += ifm_depth; // Calc channel vars
toReturn["ops:sqrt"] = ifm_depth; // Convert to SD
toReturn["ops:/"] = ifm_depth; // Get the reciprocal
toReturn["ops:+"] += op_volume; // Subtract mean off each pixel
toReturn["ops:*"] += op_volume; // Multiply by 1/SD for each pixel
toReturn["ops:+"] += op_volume; // Bias
toReturn["ops:*"] += op_volume; // Scale
toReturn["reads"] = op_volume + weight_volume + bias_volume;
toReturn["writes"] = op_volume;
return toReturn;
}
// div_ // div_
std::map<std::string, uint64_t> DivUnderOp::getStatistics() { std::map<std::string, uint64_t> DivUnderOp::getStatistics() {
@ -266,33 +219,6 @@ std::map<std::string, uint64_t> HardtanhBackwardOp::getStatistics() {
return toReturn; return toReturn;
} }
// max_pool2d
std::map<std::string, uint64_t> MaxPool2dOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
TensorType resultTy = getResult().getType().cast<TensorType>();
TensorType inputType = getOperand(0).getType().cast<TensorType>();
uint64_t ofm_volume = getTensorVolume(resultTy);
toReturn["result:0:activation_out"] = ofm_volume;
uint64_t ifm_volume = getTensorVolume(inputType);
toReturn["input:0:activation_in"] = ifm_volume;
// To find the number of compares, we need the filter extent
std::vector<uint64_t> kernel_size = unpackListConstant(getOperand(1));
uint64_t aperture = kernel_size[0] * kernel_size[1];
toReturn["ops:>"] = ofm_volume * (aperture - 1);
toReturn["reads"] = ifm_volume;
toReturn["writes"] = ofm_volume;
return toReturn;
}
// max_pool2d_with_indices // max_pool2d_with_indices
std::map<std::string, uint64_t> MaxPool2dWithIndicesOp::getStatistics() { std::map<std::string, uint64_t> MaxPool2dWithIndicesOp::getStatistics() {

View File

@ -162,7 +162,9 @@ public:
ChangeResult ChangeResult
visitOperation(Operation *op, visitOperation(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final { ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
if (isa<Numpy::TensorStaticInfoCastOp, aten::TanhOp>(op)) { if (isa<Numpy::TensorStaticInfoCastOp, Numpy::CopyToTensorOp,
Numpy::CreateArrayFromTensorOp, aten::TanhOp, aten::BatchNormOp,
aten::ReluOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }
if (isa<aten::MmOp>(op)) { if (isa<aten::MmOp>(op)) {
@ -214,6 +216,52 @@ public:
joinElementTypes(operands[1]->getValue().elementType, joinElementTypes(operands[1]->getValue().elementType,
operands[2]->getValue().elementType)); operands[2]->getValue().elementType));
return getLatticeElement(op->getResult(0)).join(knowledge); return getLatticeElement(op->getResult(0)).join(knowledge);
} else if (isa<aten::Conv2dOp>(op)) {
auto knowledge =
ValueKnowledge::getPessimisticValueState(op->getContext());
knowledge.hasRank = true;
knowledge.sizes.resize(4, kUnknownSize);
// Running some experiments in PyTorch, the bias doesn't seem to
// contribute to the final element type.
knowledge.elementType =
joinElementTypes(operands[0]->getValue().elementType,
operands[1]->getValue().elementType);
return getLatticeElement(op->getResult(0)).join(knowledge);
} else if (isa<aten::MaxPool2dOp>(op)) {
auto knowledge =
ValueKnowledge::getPessimisticValueState(op->getContext());
knowledge.hasRank = true;
knowledge.sizes.resize(4, kUnknownSize);
knowledge.elementType = operands[0]->getValue().elementType;
return getLatticeElement(op->getResult(0)).join(knowledge);
} else if (isa<aten::AdaptiveAvgPool2dOp>(op)) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getPessimisticValueState(op->getContext());
if (input.hasRank) {
knowledge.hasRank = true;
knowledge.sizes.resize(input.sizes.size(), kUnknownSize);
}
knowledge.elementType = input.elementType;
return getLatticeElement(op->getResult(0)).join(knowledge);
} else if (isa<aten::AddOp>(op)) {
// This is a general binary broadcasting shape transfer function.
// We currently don't track "size 1" in our lattice, but we might want to.
// We could make this more precise as well. But again, as with the other
// shape transfer functions, handling the statically-invalid case is
// tricky, so we defer that until we need it.
auto lhs = operands[0]->getValue();
auto rhs = operands[1]->getValue();
auto knowledge =
ValueKnowledge::getPessimisticValueState(op->getContext());
if (lhs.hasRank && rhs.hasRank) {
knowledge.hasRank = true;
knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()),
kUnknownSize);
}
knowledge.elementType =
joinElementTypes(lhs.elementType, rhs.elementType);
return getLatticeElement(op->getResult(0)).join(knowledge);
} }
// Otherwise, this is an unknown operation. Just mark all results as having // Otherwise, this is an unknown operation. Just mark all results as having
// reached a pessimistic fixpoint. // reached a pessimistic fixpoint.

View File

@ -1,24 +0,0 @@
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
// CHECK-LABEL: "L0-batch_norm-0": {
// CHECK-NEXT: "activation_in": 103320,
// CHECK-NEXT: "activation_out": 103320,
// CHECK-NEXT: "ops:*": 310206,
// CHECK-NEXT: "ops:+": 413280,
// CHECK-NEXT: "ops:-": 123,
// CHECK-NEXT: "ops:/": 123,
// CHECK-NEXT: "ops:sqrt": 123,
// CHECK-NEXT: "parameters_in": 246,
// CHECK-NEXT: "reads": 103566,
// CHECK-NEXT: "writes": 103320
module {
func @graph(%arg0: tensor<42x123x4x5xf32>, %arg1: tensor<123xf32>, %arg2: tensor<123xf32>, %arg3: tensor<123xf32>, %arg4: tensor<123xf32>, %arg5: tensor<?xi64>) -> tensor<42x123x4x5xf32> {
%0 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%1 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32
%2 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32
%3 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
%4:3 = "aten.batch_norm"(%arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2, %3) : (tensor<42x123x4x5xf32>, tensor<123xf32>, tensor
<123xf32>, tensor<123xf32>, tensor<123xf32>, i1, f32, f32, i1) -> (tensor<42x123x4x5xf32>, tensor<123xf32>, tensor<123xf32>)
return %4#0 : tensor<42x123x4x5xf32>
}
}

View File

@ -1,19 +0,0 @@
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
// CHECK-LABEL: "L0-max_pool2d-0": {
// CHECK-NEXT: "activation_in": 8192,
// CHECK-NEXT: "activation_out": 2048,
// CHECK-NEXT: "ops:>": 16384,
// CHECK-NEXT: "reads": 8192,
// CHECK-NEXT: "writes": 2048
module {
func @graph(%arg0: tensor<1x32x16x16xf32>) -> tensor<1x32x8x8xf32> {
%0 = "aten.constant"() {type = "List[i32]", value = dense<3> : vector<2xi64>} : () -> !aten.list<i32>
%1 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi64>} : () -> !aten.list<i32>
%2 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list<i32>
%3 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list<i32>
%4 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%5 = "aten.max_pool2d"(%arg0, %0, %1, %2, %3, %4) : (tensor<1x32x16x16xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, i1) -> tensor<1x32x8x8xf32>
"std.return"(%5) : (tensor<1x32x8x8xf32>) -> ()
}
}

View File

@ -51,6 +51,76 @@ func @f(%arg0: tensor<?x3xf32>, %arg1: tensor<5x3xf32>, %arg2: tensor<5xf32>) ->
// ----- // -----
// CHECK-LABEL: func @f
// CHECK: %[[CONV2D:.*]] = "aten.conv2d"{{.*}} -> tensor<?x?x?x?x!numpy.any_dtype>
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[CONV2D]] : tensor<?x?x?x?x!numpy.any_dtype> to tensor<*x!numpy.any_dtype>
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
func @f(%arg0:tensor<*x!numpy.any_dtype>, %arg1:tensor<*x!numpy.any_dtype>, %arg2:tensor<*x!numpy.any_dtype>) ->tensor<*x!numpy.any_dtype> {
%c0_i64 = constant 0 : i64
%c1_i64 = constant 1 : i64
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) ->tensor<*x!numpy.any_dtype>
return %3 :tensor<*x!numpy.any_dtype>
}
// CHECK-LABEL: func @g
// CHECK: %[[CONV2D:.*]] = "aten.conv2d"{{.*}} -> tensor<?x?x?x?xf32>
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[CONV2D]] : tensor<?x?x?x?xf32> to tensor<*x!numpy.any_dtype>
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
func @g(%arg0:tensor<*xf32>, %arg1:tensor<*xf32>, %arg2:tensor<*xf32>) ->tensor<*x!numpy.any_dtype> {
%c0_i64 = constant 0 : i64
%c1_i64 = constant 1 : i64
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) ->tensor<*x!numpy.any_dtype>
return %3 :tensor<*x!numpy.any_dtype>
}
// -----
// CHECK-LABEL: func @f
func @f(%arg0: tensor<?x?x?x?xf32>) -> tensor<*x!numpy.any_dtype> {
%c1_i64 = constant 1 : i64
%c3_i64 = constant 3 : i64
%c2_i64 = constant 2 : i64
%bool_false = basicpy.bool_constant false
%21 = basicpy.build_list %c3_i64, %c3_i64 : (i64, i64) -> !basicpy.ListType
%22 = basicpy.build_list %c2_i64, %c2_i64 : (i64, i64) -> !basicpy.ListType
%23 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%24 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
// CHECK: "aten.max_pool2d"{{.*}} -> tensor<?x?x?x?xf32>
%27 = "aten.max_pool2d"(%arg0, %21, %22, %23, %24, %bool_false) : (tensor<?x?x?x?xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType) -> tensor<*x!numpy.any_dtype>
return %27 : tensor<*x!numpy.any_dtype>
}
// -----
// CHECK-LABEL: func @f
func @f(%arg0: tensor<?x?x?x?xf32>) -> tensor<*x!numpy.any_dtype> {
%c1_i64 = constant 1 : i64
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
// CHECK: "aten.adaptive_avg_pool2d"{{.*}} -> tensor<?x?x?x?xf32>
%1 = "aten.adaptive_avg_pool2d"(%arg0, %0) : (tensor<?x?x?x?xf32>, !basicpy.ListType) -> tensor<*x!numpy.any_dtype>
return %1 : tensor<*x!numpy.any_dtype>
}
// -----
// CHECK-LABEL: func @f
func @f(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: tensor<?x3xf32>) {
%c1_i64 = constant 1 : i64
// CHECK: "aten.add"{{.*}} -> tensor<?x?x?xf32>
%0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<*x!numpy.any_dtype>
// CHECK: "aten.add"{{.*}} -> tensor<?x?x?xf32>
%1 = "aten.add"(%arg0, %arg2, %c1_i64) : (tensor<4x6x3xf32>, tensor<?x3xf32>, i64) -> tensor<*x!numpy.any_dtype>
return
}
// -----
// CHECK-LABEL: func @f // CHECK-LABEL: func @f
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> { func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
// Check propagation through multiple ops. // Check propagation through multiple ops.