mirror of https://github.com/llvm/torch-mlir
1080 lines
36 KiB
C++
1080 lines
36 KiB
C++
//===-------------------------------------------------------------*- cc -*-===//
|
|
//
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
// Operation summaries and descriptions were systematically derived from public
|
|
// API docstrings and are licensed accordingly:
|
|
// https://github.com/pytorch/pytorch/blob/master/LICENSE
|
|
//===----------------------------------------------------------------------===//
|
|
// This file is automatically generated. Please do not edit.
|
|
// Generated via:
|
|
// python -m torch_mlir_utils.codegen.torch_signature_ods_gen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
// clang-format off
|
|
// -----------------------------------------------------------------------------
|
|
// Binary arithmetic ops
|
|
// -----------------------------------------------------------------------------
|
|
|
|
Torch::KernelMetadata AddOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::add";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor", "Scalar"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar, KVC::kNone});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::atan2";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata DivOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::div";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::floor_divide";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata MulOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::mul";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::remainder";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::true_divide";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata MaximumOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &MaximumOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::maximum";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata MinimumOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &MinimumOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::minimum";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// Unary arithmetic ops
|
|
// -----------------------------------------------------------------------------
|
|
|
|
Torch::KernelMetadata AbsOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::abs";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata AcosOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::acos";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata AngleOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::angle";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata AsinOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::asin";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata AtanOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::atan";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata CeilOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::ceil";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata ConjOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::conj";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata CosOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::cos";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata CoshOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::cosh";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::digamma";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata ErfOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::erf";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::erfc";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::erfinv";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata ExpOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::exp";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::expm1";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata FloorOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::floor";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata FracOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::frac";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::lgamma";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata LogOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::log";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata Log10Op::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::log10";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::log1p";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata Log2Op::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::log2";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata NegOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::neg";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata ReluOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::relu";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::reciprocal";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata RoundOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::round";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::rsqrt";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::sigmoid";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata SignOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::sign";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata SinOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::sin";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata SinhOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::sinh";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::sqrt";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata TanOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::tan";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata TanhOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::tanh";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata TruncOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::trunc";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// NN ops
|
|
// -----------------------------------------------------------------------------
|
|
|
|
Torch::KernelMetadata ConvolutionOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &ConvolutionOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::convolution_overrideable";
|
|
m.aliasKernelNames.push_back("aten::convolution");
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata Conv2dOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &Conv2dOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::conv2d";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "int"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata ConvolutionBackwardOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &ConvolutionBackwardOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::convolution_backward_overrideable";
|
|
m.aliasKernelNames.push_back("aten::convolution_backward");
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
|
|
m.addReturnTypes({"Tensor?", "Tensor?", "Tensor?"});
|
|
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata LogSoftmaxOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &LogSoftmaxOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::_log_softmax";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "int", "bool"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata LogSoftmaxBackwardDataOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &LogSoftmaxBackwardDataOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::_log_softmax_backward_data";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor", "int", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata MmOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &MmOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::mm";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// Loss function ops
|
|
// -----------------------------------------------------------------------------
|
|
|
|
Torch::KernelMetadata NllLossForwardOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &NllLossForwardOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::nll_loss_forward";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
|
|
m.addReturnTypes({"Tensor", "Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata NllLossBackwardOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &NllLossBackwardOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::nll_loss_backward";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata NllLoss2dForwardOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &NllLoss2dForwardOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::nll_loss2d_forward";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
|
|
m.addReturnTypes({"Tensor", "Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata NllLoss2dBackwardOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &NllLoss2dBackwardOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::nll_loss2d_backward";
|
|
m.promoteTrailingOutTensor = true;
|
|
m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"});
|
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kImmutableTensor});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|
|
Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() {
|
|
return getTorchBuildKernelMetadata();
|
|
}
|
|
|
|
const Torch::BuildKernelMetadata &CopyInplaceOp::getTorchBuildKernelMetadata() {
|
|
using KVC = Torch::KernelValueConversion::BitMask;
|
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
|
Torch::BuildKernelMetadata m;
|
|
m.kernelName = "aten::copy_";
|
|
m.addArgTypes({"Tensor", "Tensor", "bool"});
|
|
m.addArgConversions({KVC::kNone, KVC::kImmutableTensor, KVC::kDrop});
|
|
m.addReturnTypes({"Tensor"});
|
|
m.addReturnConversions({KVC::kDropReturnAndAliasArg0});
|
|
return m;
|
|
})();
|
|
return metadata;
|
|
}
|
|
|