mirror of https://github.com/llvm/torch-mlir
Add some more ResNet ops.
- aten::relu_, aten::max_pool2d, aten::adaptive_avg_pool2d, aten::batch_norm, aten::conv2d No aten-to-linalg conversion for the latter ones, as they are fairly substantial. At this point, I'm trying to get shape inference and stuff working for them and the IR cleaned up.pull/213/head
parent
9257457d8a
commit
ec6d06aa86
|
@ -106,6 +106,18 @@ def generate_ops(g: "OpGenerator"):
|
|||
g.ordinary_immutable_op("aten::linear(Tensor,Tensor,Tensor?)",
|
||||
"LinearOp",
|
||||
"linear")
|
||||
g.ordinary_immutable_op(
|
||||
"aten::batch_norm(Tensor,Tensor?,Tensor?,Tensor?,Tensor?,bool,float,float,bool)",
|
||||
"BatchNormOp",
|
||||
"batch_norm")
|
||||
g.ordinary_immutable_op(
|
||||
"aten::max_pool2d(Tensor,int[],int[],int[],int[],bool)",
|
||||
"MaxPool2dOp",
|
||||
"max_pool2d")
|
||||
g.ordinary_immutable_op(
|
||||
"aten::adaptive_avg_pool2d(Tensor,int[])",
|
||||
"AdaptiveAvgPool2dOp",
|
||||
"adaptive_avg_pool2d")
|
||||
g.ordinary_immutable_op(
|
||||
"aten::convolution_overrideable(Tensor,Tensor,Tensor?,int[],int[],int[],bool,int[],int)",
|
||||
"ConvolutionOp",
|
||||
|
@ -272,6 +284,7 @@ class OpGenerator:
|
|||
"int[]": "AnyTorchIntListType",
|
||||
"bool": "AnyTorchBoolType",
|
||||
"bool[]": "AnyTorchBoolListType",
|
||||
"float": "AnyFloat",
|
||||
},
|
||||
flag_transforms={
|
||||
"Tensor": ["kImmutableTensor"],
|
||||
|
@ -363,11 +376,13 @@ class OpGenerator:
|
|||
These take and return a tensor and typically have an out and inplace
|
||||
variant (they may not but we generate patterns to match anyway).
|
||||
"""
|
||||
kernel_name = kernel_sig.partition("(")[0]
|
||||
opdef = self.define_op(
|
||||
kernel_sig=kernel_sig,
|
||||
ods_name=ods_name,
|
||||
op_name=op_name,
|
||||
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||
inplace_variant_kernel_name=kernel_name + "_",
|
||||
traits=list(traits) + ["NoSideEffect"],
|
||||
**kwargs)
|
||||
opdef.arg_transforms(
|
||||
|
|
|
@ -28,29 +28,6 @@ class aten_Op<string mnemonic, list<OpTrait> traits = [StatisticsOpInterface]> :
|
|||
include "npcomp/Dialect/ATen/IR/GeneratedATenOps.td"
|
||||
include "npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td"
|
||||
|
||||
def aten_BatchNormOp: aten_Op<"batch_norm", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor:$output, AnyTensor:$save_mean, AnyTensor:$save_invstd)> {
|
||||
let arguments = (
|
||||
ins AnyType:$arg0,
|
||||
AnyType:$arg1,
|
||||
AnyType:$arg2,
|
||||
AnyType:$arg3,
|
||||
AnyType:$arg4,
|
||||
AnyType:$arg5,
|
||||
AnyType:$arg6,
|
||||
AnyType:$arg7,
|
||||
AnyType:$arg8
|
||||
);
|
||||
|
||||
let summary = "BatchNorm operator";
|
||||
let description = [{
|
||||
BatchNorm operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
// We have list constants, which come out of pytorch. Represent them using
|
||||
// our own constant-like type, which gets lowered to std_ConstantOp later.
|
||||
def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>,
|
||||
|
@ -79,26 +56,6 @@ def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>,
|
|||
}];
|
||||
}
|
||||
|
||||
def aten_MaxPool2dOp: aten_Op<"max_pool2d", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyType:$arg0,
|
||||
AnyType:$arg1,
|
||||
AnyType:$arg2,
|
||||
AnyType:$arg3,
|
||||
AnyType:$arg4,
|
||||
AnyType:$arg5
|
||||
);
|
||||
|
||||
let summary = "MaxPool2d operator";
|
||||
let description = [{
|
||||
MaxPool2d operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_TypeCastOp : aten_Op<"type_cast", [NoSideEffect]>,
|
||||
Results<(outs AnyType)> {
|
||||
let summary = "TypeCast operator";
|
||||
|
|
|
@ -213,6 +213,7 @@ const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::abs";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::abs_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -232,6 +233,7 @@ const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::acos";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::acos_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -251,6 +253,7 @@ const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::angle";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::angle_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -270,6 +273,7 @@ const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::asin";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::asin_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -289,6 +293,7 @@ const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::atan";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::atan_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -308,6 +313,7 @@ const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::ceil";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::ceil_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -327,6 +333,7 @@ const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::conj";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::conj_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -346,6 +353,7 @@ const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::cos";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::cos_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -365,6 +373,7 @@ const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::cosh";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::cosh_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -384,6 +393,7 @@ const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::digamma";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::digamma_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -403,6 +413,7 @@ const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::erf";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::erf_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -422,6 +433,7 @@ const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::erfc";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::erfc_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -441,6 +453,7 @@ const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::erfinv";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::erfinv_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -460,6 +473,7 @@ const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::exp";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::exp_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -479,6 +493,7 @@ const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::expm1";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::expm1_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -498,6 +513,7 @@ const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::floor";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::floor_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -517,6 +533,7 @@ const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::frac";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::frac_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -536,6 +553,7 @@ const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::lgamma";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::lgamma_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -555,6 +573,7 @@ const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::log";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::log_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -574,6 +593,7 @@ const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::log10";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::log10_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -593,6 +613,7 @@ const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::log1p";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::log1p_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -612,6 +633,7 @@ const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::log2";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::log2_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -631,6 +653,7 @@ const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::neg";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::neg_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -650,6 +673,7 @@ const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::relu";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::relu_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -669,6 +693,7 @@ const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::reciprocal";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::reciprocal_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -688,6 +713,7 @@ const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::round";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::round_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -707,6 +733,7 @@ const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::rsqrt";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::rsqrt_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -726,6 +753,7 @@ const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::sigmoid";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::sigmoid_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -745,6 +773,7 @@ const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::sign";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::sign_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -764,6 +793,7 @@ const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::sin";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::sin_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -783,6 +813,7 @@ const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::sinh";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::sinh_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -802,6 +833,7 @@ const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::sqrt";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::sqrt_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -821,6 +853,7 @@ const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::tan";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::tan_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -840,6 +873,7 @@ const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::tanh";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::tanh_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -859,6 +893,7 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::trunc";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::trunc_";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -949,6 +984,63 @@ const Torch::BuildKernelMetadata &LinearOp::getTorchBuildKernelMetadata() {
|
|||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata BatchNormOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &BatchNormOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::batch_norm";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor?", "Tensor?", "Tensor?", "Tensor?", "bool", "float", "float", "bool"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata MaxPool2dOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &MaxPool2dOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::max_pool2d";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "int[]", "int[]", "int[]", "int[]", "bool"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata AdaptiveAvgPool2dOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &AdaptiveAvgPool2dOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::adaptive_avg_pool2d";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "int[]"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata ConvolutionOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
|
|
@ -525,6 +525,50 @@ def aten_LinearOp: aten_Op<"linear", [NoSideEffect, DeclareOpInterfaceMethods<To
|
|||
);
|
||||
}
|
||||
|
||||
def aten_BatchNormOp: aten_Op<"batch_norm", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||
let summary = "Recognized op for kernel aten::batch_norm";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$input,
|
||||
AnyTorchOptionalImmutableTensor:$weight,
|
||||
AnyTorchOptionalImmutableTensor:$bias,
|
||||
AnyTorchOptionalImmutableTensor:$running_mean,
|
||||
AnyTorchOptionalImmutableTensor:$running_var,
|
||||
AnyTorchBoolType:$training,
|
||||
AnyFloat:$momentum,
|
||||
AnyFloat:$eps,
|
||||
AnyTorchBoolType:$cudnn_enabled
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
def aten_MaxPool2dOp: aten_Op<"max_pool2d", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||
let summary = "Recognized op for kernel aten::max_pool2d";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$self,
|
||||
AnyTorchIntListType:$kernel_size,
|
||||
AnyTorchIntListType:$stride,
|
||||
AnyTorchIntListType:$padding,
|
||||
AnyTorchIntListType:$dilation,
|
||||
AnyTorchBoolType:$ceil_mode
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
def aten_AdaptiveAvgPool2dOp: aten_Op<"adaptive_avg_pool2d", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||
let summary = "Recognized op for kernel aten::adaptive_avg_pool2d";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$self,
|
||||
AnyTorchIntListType:$output_size
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||
let summary = "Recognized op for kernel aten::convolution_overrideable";
|
||||
let arguments = (ins
|
||||
|
@ -691,3 +735,4 @@ def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<Torch
|
|||
let results = (outs
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -406,22 +406,6 @@ def aten_HardtanhUnderOp: aten_Op<"hardtanh_", [NoSideEffect, StatisticsOpInterf
|
|||
}];
|
||||
}
|
||||
|
||||
def aten_AdaptiveAvgPool2dOp: aten_Op<"_adaptive_avg_pool2d", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyType:$output_size
|
||||
);
|
||||
let summary = "aten _adaptive_avg_pool2d operator";
|
||||
let description = [{
|
||||
AdaptiveAvgPool2dOp
|
||||
aten _adaptive_avg_pool2d operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_AdaptiveAvgPool2dBackwardOp: aten_Op<"_adaptive_avg_pool2d_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
|
|
|
@ -40,13 +40,6 @@ namespace mlir {
|
|||
namespace NPCOMP {
|
||||
namespace aten {
|
||||
|
||||
std::map<std::string, uint64_t> AdaptiveAvgPool2dOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
// FIXME: unimplemented
|
||||
toReturn["reads"] = -1;
|
||||
toReturn["writes"] = -1;
|
||||
return toReturn;
|
||||
}
|
||||
std::map<std::string, uint64_t> AdaptiveAvgPool2dBackwardOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
// FIXME: unimplemented
|
||||
|
@ -130,46 +123,6 @@ std::map<std::string, uint64_t> AsStridedOp::getStatistics() {
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
// batch_norm
|
||||
std::map<std::string, uint64_t> BatchNormOp::getStatistics() {
|
||||
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
|
||||
TensorType resultTy = getResult(0).getType().cast<TensorType>();
|
||||
uint64_t op_volume = getTensorVolume(resultTy);
|
||||
uint64_t weight_volume = getTensorVolume(getOperand(1).getType());
|
||||
uint64_t bias_volume = getTensorVolume(getOperand(2).getType());
|
||||
toReturn["operand:0:activation_in"] = op_volume;
|
||||
toReturn["result:0:activation_out"] = op_volume;
|
||||
toReturn["operand:1:parameters_in:weight"] = weight_volume;
|
||||
toReturn["operand:2:parameters_in:bias"] = bias_volume;
|
||||
|
||||
// Now for the arithmetic. Assume variance is calculated as sum of squares
|
||||
uint64_t ifm_depth = resultTy.getShape()[1];
|
||||
|
||||
toReturn["ops:+"] = op_volume; // Add up for mean
|
||||
toReturn["ops:*"] = op_volume; // Square for variance
|
||||
toReturn["ops:+"] += op_volume; // Add up squares for variance
|
||||
|
||||
toReturn["ops:*"] += ifm_depth; // Calc channel means
|
||||
toReturn["ops:-"] += ifm_depth; // Calc channel vars
|
||||
toReturn["ops:*"] += ifm_depth; // Calc channel vars
|
||||
|
||||
toReturn["ops:sqrt"] = ifm_depth; // Convert to SD
|
||||
toReturn["ops:/"] = ifm_depth; // Get the reciprocal
|
||||
|
||||
toReturn["ops:+"] += op_volume; // Subtract mean off each pixel
|
||||
toReturn["ops:*"] += op_volume; // Multiply by 1/SD for each pixel
|
||||
|
||||
toReturn["ops:+"] += op_volume; // Bias
|
||||
toReturn["ops:*"] += op_volume; // Scale
|
||||
|
||||
toReturn["reads"] = op_volume + weight_volume + bias_volume;
|
||||
toReturn["writes"] = op_volume;
|
||||
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
// div_
|
||||
std::map<std::string, uint64_t> DivUnderOp::getStatistics() {
|
||||
|
||||
|
@ -266,33 +219,6 @@ std::map<std::string, uint64_t> HardtanhBackwardOp::getStatistics() {
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
// max_pool2d
|
||||
std::map<std::string, uint64_t> MaxPool2dOp::getStatistics() {
|
||||
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
|
||||
TensorType resultTy = getResult().getType().cast<TensorType>();
|
||||
TensorType inputType = getOperand(0).getType().cast<TensorType>();
|
||||
|
||||
uint64_t ofm_volume = getTensorVolume(resultTy);
|
||||
toReturn["result:0:activation_out"] = ofm_volume;
|
||||
|
||||
uint64_t ifm_volume = getTensorVolume(inputType);
|
||||
toReturn["input:0:activation_in"] = ifm_volume;
|
||||
|
||||
// To find the number of compares, we need the filter extent
|
||||
|
||||
std::vector<uint64_t> kernel_size = unpackListConstant(getOperand(1));
|
||||
|
||||
uint64_t aperture = kernel_size[0] * kernel_size[1];
|
||||
toReturn["ops:>"] = ofm_volume * (aperture - 1);
|
||||
|
||||
toReturn["reads"] = ifm_volume;
|
||||
toReturn["writes"] = ofm_volume;
|
||||
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
// max_pool2d_with_indices
|
||||
std::map<std::string, uint64_t> MaxPool2dWithIndicesOp::getStatistics() {
|
||||
|
||||
|
|
|
@ -162,7 +162,9 @@ public:
|
|||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
|
||||
if (isa<Numpy::TensorStaticInfoCastOp, aten::TanhOp>(op)) {
|
||||
if (isa<Numpy::TensorStaticInfoCastOp, Numpy::CopyToTensorOp,
|
||||
Numpy::CreateArrayFromTensorOp, aten::TanhOp, aten::BatchNormOp,
|
||||
aten::ReluOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
if (isa<aten::MmOp>(op)) {
|
||||
|
@ -214,6 +216,52 @@ public:
|
|||
joinElementTypes(operands[1]->getValue().elementType,
|
||||
operands[2]->getValue().elementType));
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (isa<aten::Conv2dOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
knowledge.hasRank = true;
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
// Running some experiments in PyTorch, the bias doesn't seem to
|
||||
// contribute to the final element type.
|
||||
knowledge.elementType =
|
||||
joinElementTypes(operands[0]->getValue().elementType,
|
||||
operands[1]->getValue().elementType);
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (isa<aten::MaxPool2dOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
knowledge.hasRank = true;
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
knowledge.elementType = operands[0]->getValue().elementType;
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (isa<aten::AdaptiveAvgPool2dOp>(op)) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
if (input.hasRank) {
|
||||
knowledge.hasRank = true;
|
||||
knowledge.sizes.resize(input.sizes.size(), kUnknownSize);
|
||||
}
|
||||
knowledge.elementType = input.elementType;
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (isa<aten::AddOp>(op)) {
|
||||
// This is a general binary broadcasting shape transfer function.
|
||||
// We currently don't track "size 1" in our lattice, but we might want to.
|
||||
// We could make this more precise as well. But again, as with the other
|
||||
// shape transfer functions, handling the statically-invalid case is
|
||||
// tricky, so we defer that until we need it.
|
||||
auto lhs = operands[0]->getValue();
|
||||
auto rhs = operands[1]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
if (lhs.hasRank && rhs.hasRank) {
|
||||
knowledge.hasRank = true;
|
||||
knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()),
|
||||
kUnknownSize);
|
||||
}
|
||||
knowledge.elementType =
|
||||
joinElementTypes(lhs.elementType, rhs.elementType);
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
// Otherwise, this is an unknown operation. Just mark all results as having
|
||||
// reached a pessimistic fixpoint.
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
|
||||
// CHECK-LABEL: "L0-batch_norm-0": {
|
||||
// CHECK-NEXT: "activation_in": 103320,
|
||||
// CHECK-NEXT: "activation_out": 103320,
|
||||
// CHECK-NEXT: "ops:*": 310206,
|
||||
// CHECK-NEXT: "ops:+": 413280,
|
||||
// CHECK-NEXT: "ops:-": 123,
|
||||
// CHECK-NEXT: "ops:/": 123,
|
||||
// CHECK-NEXT: "ops:sqrt": 123,
|
||||
// CHECK-NEXT: "parameters_in": 246,
|
||||
// CHECK-NEXT: "reads": 103566,
|
||||
// CHECK-NEXT: "writes": 103320
|
||||
|
||||
module {
|
||||
func @graph(%arg0: tensor<42x123x4x5xf32>, %arg1: tensor<123xf32>, %arg2: tensor<123xf32>, %arg3: tensor<123xf32>, %arg4: tensor<123xf32>, %arg5: tensor<?xi64>) -> tensor<42x123x4x5xf32> {
|
||||
%0 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%1 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32
|
||||
%2 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32
|
||||
%3 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
|
||||
%4:3 = "aten.batch_norm"(%arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2, %3) : (tensor<42x123x4x5xf32>, tensor<123xf32>, tensor
|
||||
<123xf32>, tensor<123xf32>, tensor<123xf32>, i1, f32, f32, i1) -> (tensor<42x123x4x5xf32>, tensor<123xf32>, tensor<123xf32>)
|
||||
return %4#0 : tensor<42x123x4x5xf32>
|
||||
}
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
|
||||
// CHECK-LABEL: "L0-max_pool2d-0": {
|
||||
// CHECK-NEXT: "activation_in": 8192,
|
||||
// CHECK-NEXT: "activation_out": 2048,
|
||||
// CHECK-NEXT: "ops:>": 16384,
|
||||
// CHECK-NEXT: "reads": 8192,
|
||||
// CHECK-NEXT: "writes": 2048
|
||||
|
||||
module {
|
||||
func @graph(%arg0: tensor<1x32x16x16xf32>) -> tensor<1x32x8x8xf32> {
|
||||
%0 = "aten.constant"() {type = "List[i32]", value = dense<3> : vector<2xi64>} : () -> !aten.list<i32>
|
||||
%1 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi64>} : () -> !aten.list<i32>
|
||||
%2 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list<i32>
|
||||
%3 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list<i32>
|
||||
%4 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%5 = "aten.max_pool2d"(%arg0, %0, %1, %2, %3, %4) : (tensor<1x32x16x16xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, i1) -> tensor<1x32x8x8xf32>
|
||||
"std.return"(%5) : (tensor<1x32x8x8xf32>) -> ()
|
||||
}
|
||||
}
|
|
@ -51,6 +51,76 @@ func @f(%arg0: tensor<?x3xf32>, %arg1: tensor<5x3xf32>, %arg2: tensor<5xf32>) ->
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
// CHECK: %[[CONV2D:.*]] = "aten.conv2d"{{.*}} -> tensor<?x?x?x?x!numpy.any_dtype>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[CONV2D]] : tensor<?x?x?x?x!numpy.any_dtype> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
|
||||
func @f(%arg0:tensor<*x!numpy.any_dtype>, %arg1:tensor<*x!numpy.any_dtype>, %arg2:tensor<*x!numpy.any_dtype>) ->tensor<*x!numpy.any_dtype> {
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) ->tensor<*x!numpy.any_dtype>
|
||||
return %3 :tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @g
|
||||
// CHECK: %[[CONV2D:.*]] = "aten.conv2d"{{.*}} -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[CONV2D]] : tensor<?x?x?x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
|
||||
func @g(%arg0:tensor<*xf32>, %arg1:tensor<*xf32>, %arg2:tensor<*xf32>) ->tensor<*x!numpy.any_dtype> {
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) ->tensor<*x!numpy.any_dtype>
|
||||
return %3 :tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<?x?x?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
%c1_i64 = constant 1 : i64
|
||||
%c3_i64 = constant 3 : i64
|
||||
%c2_i64 = constant 2 : i64
|
||||
%bool_false = basicpy.bool_constant false
|
||||
%21 = basicpy.build_list %c3_i64, %c3_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%22 = basicpy.build_list %c2_i64, %c2_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%23 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%24 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
// CHECK: "aten.max_pool2d"{{.*}} -> tensor<?x?x?x?xf32>
|
||||
%27 = "aten.max_pool2d"(%arg0, %21, %22, %23, %24, %bool_false) : (tensor<?x?x?x?xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType) -> tensor<*x!numpy.any_dtype>
|
||||
return %27 : tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<?x?x?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
// CHECK: "aten.adaptive_avg_pool2d"{{.*}} -> tensor<?x?x?x?xf32>
|
||||
%1 = "aten.adaptive_avg_pool2d"(%arg0, %0) : (tensor<?x?x?x?xf32>, !basicpy.ListType) -> tensor<*x!numpy.any_dtype>
|
||||
return %1 : tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: tensor<?x3xf32>) {
|
||||
%c1_i64 = constant 1 : i64
|
||||
// CHECK: "aten.add"{{.*}} -> tensor<?x?x?xf32>
|
||||
%0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<*x!numpy.any_dtype>
|
||||
// CHECK: "aten.add"{{.*}} -> tensor<?x?x?xf32>
|
||||
%1 = "aten.add"(%arg0, %arg2, %c1_i64) : (tensor<4x6x3xf32>, tensor<?x3xf32>, i64) -> tensor<*x!numpy.any_dtype>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
// Check propagation through multiple ops.
|
||||
|
|
Loading…
Reference in New Issue