diff --git a/CMakeLists.txt b/CMakeLists.txt index 64c283410..788acc8c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -129,6 +129,12 @@ message(STATUS "Found ppython libraries: ${PYTHON_LIBRARIES}") find_package(pybind11 CONFIG REQUIRED) message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIRS}") +#------------------------------------------------------------------------------- +# Pytorch Configuration +#------------------------------------------------------------------------------- + +find_package(Torch) + #------------------------------------------------------------------------------- # Directory setup #------------------------------------------------------------------------------- @@ -137,9 +143,12 @@ set(MLIR_NPCOMP_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(MLIR_NPCOMP_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) add_custom_target(check-npcomp) +add_custom_target(check-all) +add_dependencies(check-all check-npcomp) add_subdirectory(include/npcomp) add_subdirectory(lib) add_subdirectory(tools) add_subdirectory(python) add_subdirectory(test) +add_subdirectory(frontends) diff --git a/frontends/CMakeLists.txt b/frontends/CMakeLists.txt new file mode 100644 index 000000000..1c7105443 --- /dev/null +++ b/frontends/CMakeLists.txt @@ -0,0 +1,5 @@ +if(${TORCH_FOUND}) + add_subdirectory(pytorch) +else() + message("Skipping pytorch frontend, because PyTorch not found!") +endif() diff --git a/frontends/__init__.py b/frontends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/frontends/pytorch/CMakeLists.txt b/frontends/pytorch/CMakeLists.txt new file mode 100644 index 000000000..8c59690f1 --- /dev/null +++ b/frontends/pytorch/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(lib) +add_subdirectory(csrc) +add_subdirectory(test) diff --git a/frontends/pytorch/LICENSE b/frontends/pytorch/LICENSE new file mode 100644 index 000000000..b6c3eaad6 --- /dev/null +++ b/frontends/pytorch/LICENSE @@ -0,0 +1,65 @@ +In order to facilitate future incorporation in pytorch, the code in this +directory (frontends/pytorch) is provided under the below license. + +Copyright (c) 2020 LLVM Foundation. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +The design of this code is highly inspired by the design of the xla device for +pytorch (git@github.com:pytorch/xla.git). The license for pytorch/xla is: + +Copyright (c) 2018 Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/frontends/pytorch/README.md b/frontends/pytorch/README.md index 1d9376433..7a7e78f82 100644 --- a/frontends/pytorch/README.md +++ b/frontends/pytorch/README.md @@ -13,5 +13,46 @@ along with their lowerings to common intermediate dialects and backends. This directory should be purely about interfacing with the PyTorch/LibTorch components for extracting and executing programs. -See the [overall documentation for frontends](../README.md) for further details -about code layout and integration philosophy. +The code in this directory is intended to integrate tightly with pytorch, and +follows the code style for pytorch. See the [overall documentation for +frontends](../README.md) for further details about code layout and integration +philosophy. In particular, this directory exists to provide a working +frontend to an MLIR based pytorch compilation flow and is not intended to be +contributed to the LLVM monorepo. If the project is successful, it makes more +sense to either break it out as an independent project that depends on +LLVM/MLIR/npcomp or contribute it upstream to PyTorch. However, as it will be +quite some time before the components are in a state to support such a +dependency, it is being carried in-tree in the interim. + +### Program capture with a ATen pseudo-device. + +Integration with a pseudo-device is typified by code like the following: + +``` +import npcomp.frontends.pytorch as torch_mlir + +dev = torch_mlir.mlir_device() +t0 = torch.randn((4,4), device=dev) +t1 = torch.randn((4,4)).to(dev) +t2 = t0 + t1 +t2_mlir = torch_mlir.get_mlir( t2 ) +t2_cpu = t2.to('cpu') +``` + +In this case t2_cpu contains the result of the computation, and t2_mlir +contains the mlir description of the computation. Tensors are allocated +directly on the virtual device using the `device=` argument, or computed on +the host and then moved to the virtual device using the `to(dev)` +call. Subsequent calls on those tensors construct a graph of computation, but +do not perform compute in most cases. This computation graph is returned in +MLIR format by the `get_mlir` call, or lazily evaluated to return a regular +pytorch tensor by the `to(`cpu`)` call. + +This technique has several advantages and disadvantages. For training use +cases, this technique generates a backward path automatically using the same +method that pytorch natively uses. The resulting graph also tends to be +simpler, since it will not reflect conditionals in the original python +code. Lastly, it is natural if MLIR is being used as a frontend target for an +actual device of some sort. In this case, the MLIR could go through a +device-specific lowering path and the resulting code run on a device. +The implementation of this technique is largely modeled after pytorch_xla. diff --git a/frontends/pytorch/csrc/CMakeLists.txt b/frontends/pytorch/csrc/CMakeLists.txt new file mode 100644 index 000000000..22f8b841b --- /dev/null +++ b/frontends/pytorch/csrc/CMakeLists.txt @@ -0,0 +1,32 @@ +include_directories( + ${TORCH_INCLUDE_DIRS} + ${TORCH_INSTALL_PREFIX}/include/TH + ${TORCH_INSTALL_PREFIX}/include/THC/opt/pytorch/pytorch + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + ${PYTHON_INCLUDE_DIRS} + ) +link_directories("${TORCH_INSTALL_PREFIX}/lib") +add_library(_torch_mlir SHARED + aten_mlir_bridge.cpp + aten_mlir_type.cpp + aten_mlir_type_default.cpp + device.cpp + init_python_bindings.cpp + ir.cpp + jit.cpp + mlir_gen.cpp + tensor.cpp + tensor_impl.cpp + torch_util.cpp + ) +set_target_properties(_torch_mlir PROPERTIES PREFIX "") + +get_property(mlir_libs GLOBAL PROPERTY MLIR_ALL_LIBS) +target_link_libraries(_torch_mlir + NPCOMPATenDialect + ${TORCH_LIBRARIES} + ${mlir_libs} + ${PYTHON_LIBRARIES} + torch_python + ) diff --git a/frontends/pytorch/csrc/aten_mlir_bridge.cpp b/frontends/pytorch/csrc/aten_mlir_bridge.cpp new file mode 100644 index 000000000..dd4fd65c0 --- /dev/null +++ b/frontends/pytorch/csrc/aten_mlir_bridge.cpp @@ -0,0 +1,192 @@ +//===- aten_mlir_bridge.cpp -------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +// Structured similarly to code from git@github.com:pytorch/xla.git + +#include "aten_mlir_bridge.h" + +#include +#include + +#include "device.h" +#include "tensor_impl.h" + +namespace torch_mlir { +namespace bridge { +namespace { + +class AtenMLIRDeviceMapper { +public: + static AtenMLIRDeviceMapper *Get(); + + size_t GetDeviceOrdinal(const Device &device) const { + auto it = devices_ordinals_.find(device); + assert(it != devices_ordinals_.end()); + return it->second; + } + + const Device &GetDeviceFromOrdinal(size_t ordinal) const { + return devices_.at(ordinal); + } + +private: + AtenMLIRDeviceMapper() { + std::vector local_devices{"mlir:0", "mlir:1", "mlir:2"}; + for (auto &device_str : local_devices) { + devices_.emplace_back(device_str); + devices_ordinals_[devices_.back()] = devices_.size() - 1; + } + } + + std::vector devices_; + std::map devices_ordinals_; +}; + +AtenMLIRDeviceMapper *AtenMLIRDeviceMapper::Get() { + static AtenMLIRDeviceMapper *device_mapper = new AtenMLIRDeviceMapper(); + return device_mapper; +} + +} // namespace + +c10::optional TryGetMLIRTensor(const at::Tensor &tensor) { + MLIRTensorImpl *impl = + dynamic_cast(tensor.unsafeGetTensorImpl()); + if (impl == nullptr) { + return c10::nullopt; + } + return impl->tensor(); +} + +MLIRTensor GetMLIRTensor(const at::Tensor &tensor) { + auto xtensor = TryGetMLIRTensor(tensor); + assert(xtensor && "Input tensor is not an MLIR tensor"); + return *xtensor; +} + +MLIRTensor GetOrCreateMLIRTensor(const at::Tensor &tensor, + const Device &device) { + if (!tensor.defined()) { + return MLIRTensor(); + } + auto xtensor = TryGetMLIRTensor(tensor); + return xtensor ? *xtensor : MLIRTensor::Create(tensor, device); +} + +std::vector MLIRCreateTensorList(const at::TensorList &tensors) { + + std::vector aten_device_tensors(tensors.size()); + std::vector device_tensors; + + std::vector to_translate(tensors.size()); + + for (size_t i = 0; i < tensors.size(); ++i) { + const at::Tensor &tensor = tensors[i]; + if (tensor.defined()) { + auto xtensor = TryGetMLIRTensor(tensor); + if (xtensor) { + to_translate[i] = true; + device_tensors.push_back(*xtensor); + } else { + aten_device_tensors[i] = tensor; + } + } + } + + for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) { + if (to_translate[i]) { + aten_device_tensors[i] = + std::move(device_tensors[defined_pos++].ToTensor()); + } + } + return aten_device_tensors; +} + +c10::optional GetMLIRDevice(const at::TensorList &tensors) { + for (const auto &tensor : tensors) { + auto device = GetMLIRDevice(tensor); + if (device) { + return device; + } + } + return c10::nullopt; +} + +c10::optional GetMLIRDevice(const at::TensorOptions &tensor_options) { + if (!tensor_options.has_device()) { + return c10::nullopt; + } + return GetMLIRDevice(tensor_options.device()); +} + +c10::optional GetMLIRDevice(const c10::Device &device) { + if (device.type() != at::kXLA) { + return c10::nullopt; + } + return AtenDeviceToMLIRDevice(device); +} + +c10::optional GetMLIRDevice(const at::Tensor &tensor) { + auto xtensor = TryGetMLIRTensor(tensor); + if (!xtensor) { + return c10::nullopt; + } + return xtensor->GetDevice(); +} + +Device AtenDeviceToMLIRDevice(const c10::Device &device) { + assert(device.type() == at::kXLA); + int ordinal = device.has_index() ? device.index() : -1; + if (ordinal < 0) { + c10::Device current_device = MLIRTensorImpl::GetCurrentAtenDevice(); + if (current_device.has_index()) { + ordinal = current_device.index(); + } + } + if (ordinal < 0) { + return *GetDefaultDevice(); + } + return AtenMLIRDeviceMapper::Get()->GetDeviceFromOrdinal(ordinal); +} + +c10::Device MLIRDeviceToAtenDevice(const Device &device) { + // TODO: define our own device and stop hijacking the xla device. + return c10::Device(at::kXLA, + AtenMLIRDeviceMapper::Get()->GetDeviceOrdinal(device)); +} + +at::Tensor MLIRToAtenTensor(MLIRTensor device_tensor, + const at::TensorOptions &tensor_options) { + if (tensor_options.has_device()) { + assert(tensor_options.device().type() != at::kXLA); + } + + at::Tensor tensor = device_tensor.ToTensor(); + + // We need to copy the tensor since it is cached within the MLIRTensor, and + // returning it directly might expose it to in place changes. + return tensor.to(tensor_options, /*non_blocking=*/false, /*copy=*/true); +} + +at::Tensor AtenFromMLIRTensor(MLIRTensor device_tensor) { + assert(!device_tensor.is_null()); + at::Tensor ret = + at::Tensor(c10::make_intrusive(std::move(device_tensor))); + return ret; +} + +at::Tensor CreateMLIRTensor(at::Tensor tensor, + const c10::optional &device) { + if (tensor.defined() && device) { + MLIRTensor device_tensor = MLIRTensor::Create(std::move(tensor), *device); + tensor = AtenFromMLIRTensor(device_tensor); + } + return tensor; +} + +} // namespace bridge +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/aten_mlir_bridge.h b/frontends/pytorch/csrc/aten_mlir_bridge.h new file mode 100644 index 000000000..835f515a5 --- /dev/null +++ b/frontends/pytorch/csrc/aten_mlir_bridge.h @@ -0,0 +1,61 @@ +//===- aten_mlir_bridge.h ---------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +// Structured similarly to code from git@github.com:pytorch/xla.git + +// This file implements a bridge which moves data back and forth from torch +// tensors (at::Tensor) to MLIRTensor, which represents a tensor associated +// with our virtual 'MLIR' device. + +#include "device.h" +#include "tensor.h" + +#include +#include +#include + +namespace torch_mlir { +namespace bridge { + +c10::optional TryGetMLIRTensor(const at::Tensor &tensor); + +// Return an MLIR tensor that is computed the same way as the given at::Tensor +MLIRTensor GetMLIRTensor(const at::Tensor &tensor); + +MLIRTensor GetOrCreateMLIRTensor(const at::Tensor &tensor, + const Device &device); + +// Creates a vector of at::Tensor objects extracted from a list of MLIR tensors. +std::vector MLIRCreateTensorList(const at::TensorList &tensors); + +c10::optional GetMLIRDevice(const at::TensorList &tensors); + +c10::optional GetMLIRDevice(const at::TensorOptions &tensor_options); + +c10::optional GetMLIRDevice(const c10::Device &device); + +c10::optional GetMLIRDevice(const at::Tensor &tensor); + +Device AtenDeviceToMLIRDevice(const c10::Device &device); + +c10::Device MLIRDeviceToAtenDevice(const Device &device); + +at::Tensor MLIRToAtenTensor(MLIRTensor device_tensor, + const at::TensorOptions &tensor_options); + +// Create an Aten tensor with MLIR type id from MLIRTensor +at::Tensor AtenFromMLIRTensor(MLIRTensor device_tensor); + +// Creates an MLIR tensor holding the data in tensor, on the given device. +at::Tensor CreateMLIRTensor(at::Tensor tensor, + const c10::optional &device); + +} // namespace bridge + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/aten_mlir_type.cpp b/frontends/pytorch/csrc/aten_mlir_type.cpp new file mode 100644 index 000000000..3f5bd5ccd --- /dev/null +++ b/frontends/pytorch/csrc/aten_mlir_type.cpp @@ -0,0 +1,669 @@ +//===- aten_mlir_type.cpp ---------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +// Structured similarly to code from git@github.com:pytorch/xla.git + +#include "llvm/Support/Debug.h" + +#include "aten_mlir_bridge.h" +#include "aten_mlir_type.h" +#include "aten_mlir_type_default.h" +#include "ir.h" +#include "tensor_impl.h" +#include "torch_util.h" + +#include + +#define DEBUG_TYPE "torch_mlir" + +namespace torch_mlir { +namespace { + +struct MLIROptions { + MLIROptions(const at::TensorOptions &options, + c10::optional device_opt = c10::nullopt, + c10::optional scalar_type_opt = c10::nullopt) + : device(std::move(device_opt)), scalar_type(std::move(scalar_type_opt)) { + if (options.has_device()) { + device = bridge::AtenDeviceToMLIRDevice(options.device()); + } + if (options.has_dtype()) { + scalar_type = c10::typeMetaToScalarType(options.dtype()); + } + } + + Device get_device() const { return device ? *device : *GetDefaultDevice(); } + + at::ScalarType + get_scalar_type(at::ScalarType defval = at::ScalarType::Float) const { + return scalar_type ? *scalar_type : defval; + } + + c10::optional device; + c10::optional scalar_type; +}; + +std::tuple +GetPromotedMLIRTensorsForBinaryOp(const at::Tensor &self, + const at::Tensor &other) { + // this requires slightly newer than pytorch 1.3.0, disable for now. + // at::ScalarType dtype = at::result_type(self, other); + MLIRTensor tensor1 = bridge::GetMLIRTensor(self); + MLIRTensor tensor2 = + bridge::GetOrCreateMLIRTensor(other, tensor1.GetDevice()); + // tensor1.SetScalarType(dtype); + // tensor2.SetScalarType(dtype); + return std::make_tuple(tensor1, tensor2); +} + +void AtenInitialize() { + RegisterAtenTypeFunctions(); + ir::RegisterAtenIR(); +} + +} // namespace + +void ATenMLIRType::InitializeAtenBindings() { + static std::once_flag once; + std::call_once(once, []() { AtenInitialize(); }); +} + +at::Tensor ATenMLIRType::_adaptive_avg_pool2d(const at::Tensor &self, + at::IntArrayRef output_size) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor(MLIRTensor::_adaptive_avg_pool2d( + bridge::GetMLIRTensor(self), output_size)); +} + +at::Tensor +ATenMLIRType::_adaptive_avg_pool2d_backward(const at::Tensor &grad_output, + const at::Tensor &self) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto grad_output_tensor = + bridge::GetOrCreateMLIRTensor(grad_output, input_tensor.GetDevice()); + + return bridge::AtenFromMLIRTensor(MLIRTensor::_adaptive_avg_pool2d_backward( + grad_output_tensor, input_tensor)); +} + +at::Tensor ATenMLIRType::add(const at::Tensor &self, const at::Tensor &other, + at::Scalar alpha) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto tensors = GetPromotedMLIRTensorsForBinaryOp(self, other); + return bridge::AtenFromMLIRTensor( + MLIRTensor::add(std::get<0>(tensors), std::get<1>(tensors), alpha)); +} + +at::Tensor &ATenMLIRType::add_(at::Tensor &self, const at::Tensor &other, + at::Scalar alpha) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto tensors = GetPromotedMLIRTensorsForBinaryOp(self, other); + auto result = bridge::AtenFromMLIRTensor( + MLIRTensor::add_(std::get<0>(tensors), std::get<1>(tensors), alpha)); + MLIRTensorImpl *self_impl = + dynamic_cast(self.unsafeGetTensorImpl()); + self_impl->shallow_copy_from(result.getIntrusivePtr()); + return self; +} + +at::Tensor ATenMLIRType::addmm(const at::Tensor &self, const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto tensor = bridge::GetMLIRTensor(self); + return bridge::AtenFromMLIRTensor(MLIRTensor::addmm( + tensor, bridge::GetOrCreateMLIRTensor(mat1, tensor.GetDevice()), + bridge::GetOrCreateMLIRTensor(mat2, tensor.GetDevice()), beta, alpha)); +} + +at::Tensor ATenMLIRType::as_strided(const at::Tensor &self, + at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor(MLIRTensor::as_strided( + bridge::GetMLIRTensor(self), size, stride, storage_offset)); +} + +at::Tensor ATenMLIRType::clone(const at::Tensor &self) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + + return bridge::AtenFromMLIRTensor( + MLIRTensor::clone(bridge::GetMLIRTensor(self))); +} + +at::Tensor &ATenMLIRType::copy_(at::Tensor &self, const at::Tensor &src, + bool non_blocking) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + + auto self_tensor = bridge::TryGetMLIRTensor(self); + auto src_tensor = bridge::TryGetMLIRTensor(src); + + if (!src_tensor) { + assert(self_tensor); + self_tensor->SetTensor(util::CopyTensor(src, self.scalar_type())); + } else if (!self_tensor) { + at::Tensor t = src_tensor->ToTensor(); + const_cast(self).unsafeGetTensorImpl()->shallow_copy_from( + t.getIntrusivePtr()); + } else { + MLIRTensor::copy_(*self_tensor, *src_tensor); + } + return self; +} + +at::Tensor ATenMLIRType::_copy_from(const at::Tensor &self, + const at::Tensor &dst, bool non_blocking) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + + std::vector tensors = {self}; + auto device_tensors = bridge::MLIRCreateTensorList(tensors); + // Hack in an overwrite of a const tensor. + at::Tensor t = util::CopyTensor(device_tensors.front(), dst.scalar_type()); + const_cast(dst).unsafeGetTensorImpl()->shallow_copy_from( + t.getIntrusivePtr()); + return dst; +} + +std::tuple +ATenMLIRType::convolution_backward_overrideable( + const at::Tensor &grad_output, const at::Tensor &input, + const at::Tensor &weight, at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, + int64_t groups, std::array output_mask) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(input); + auto weight_tensor = + bridge::GetOrCreateMLIRTensor(weight, input_tensor.GetDevice()); + auto grad_output_tensor = + bridge::GetOrCreateMLIRTensor(grad_output, input_tensor.GetDevice()); + + auto ret = MLIRTensor::convolution_backward( + grad_output_tensor, input_tensor, weight_tensor, stride, padding, + dilation, transposed, output_padding, groups, output_mask); + return std::make_tuple(bridge::AtenFromMLIRTensor(std::get<0>(ret)), + bridge::AtenFromMLIRTensor(std::get<1>(ret)), + bridge::AtenFromMLIRTensor(std::get<2>(ret))); +} + +at::Tensor ATenMLIRType::convolution_overrideable( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(input); + auto weight_tensor = + bridge::GetOrCreateMLIRTensor(weight, input_tensor.GetDevice()); + + auto bias_tensor = + bias.defined() + ? bridge::GetOrCreateMLIRTensor(bias, input_tensor.GetDevice()) + : bridge::GetOrCreateMLIRTensor( + at::zeros(at::IntArrayRef{weight.sizes()[0]}), + input_tensor.GetDevice()); + + return bridge::AtenFromMLIRTensor(MLIRTensor::convolution( + input_tensor, weight_tensor, bias_tensor, stride, padding, dilation, + transposed, output_padding, groups)); +} + +at::Tensor ATenMLIRType::div(const at::Tensor &self, at::Scalar other) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + return bridge::AtenFromMLIRTensor(MLIRTensor::div(input_tensor, other)); +} + +at::Tensor ATenMLIRType::div(const at::Tensor &self, const at::Tensor &other) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto tensors = GetPromotedMLIRTensorsForBinaryOp(self, other); + return bridge::AtenFromMLIRTensor( + MLIRTensor::div(std::get<0>(tensors), std::get<1>(tensors))); +} + +at::Tensor &ATenMLIRType::div_(at::Tensor &self, const at::Tensor &other) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto tensors = GetPromotedMLIRTensorsForBinaryOp(self, other); + auto result = bridge::AtenFromMLIRTensor( + MLIRTensor::div_(std::get<0>(tensors), std::get<1>(tensors))); + MLIRTensorImpl *self_impl = + dynamic_cast(self.unsafeGetTensorImpl()); + self_impl->shallow_copy_from(result.getIntrusivePtr()); + return self; +} + +at::Tensor ATenMLIRType::expand(const at::Tensor &self, at::IntArrayRef size, + bool implicit) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + return bridge::AtenFromMLIRTensor( + MLIRTensor::expand(input_tensor, size, implicit)); +} + +at::Tensor ATenMLIRType::gather(const at::Tensor &self, int64_t dim, + const at::Tensor &index, bool sparse_grad) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto index_tensor = + bridge::GetOrCreateMLIRTensor(index, input_tensor.GetDevice()); + return bridge::AtenFromMLIRTensor( + MLIRTensor::gather(input_tensor, dim, index_tensor, sparse_grad)); +} + +at::Tensor ATenMLIRType::hardtanh(const at::Tensor &self, at::Scalar min_val, + at::Scalar max_val) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto result = bridge::AtenFromMLIRTensor( + MLIRTensor::hardtanh(input_tensor, min_val, max_val)); + MLIRTensorImpl *self_impl = + dynamic_cast(self.unsafeGetTensorImpl()); + self_impl->shallow_copy_from(result.getIntrusivePtr()); + return self; +} + +at::Tensor &ATenMLIRType::hardtanh_(at::Tensor &self, at::Scalar min_val, + at::Scalar max_val) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto result = bridge::AtenFromMLIRTensor( + MLIRTensor::hardtanh_(input_tensor, min_val, max_val)); + MLIRTensorImpl *self_impl = + dynamic_cast(self.unsafeGetTensorImpl()); + self_impl->shallow_copy_from(result.getIntrusivePtr()); + return self; +} + +at::Tensor ATenMLIRType::hardtanh_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar min_val, + at::Scalar max_val) { + auto input_tensor = bridge::GetMLIRTensor(self); + auto grad_output_tensor = + bridge::GetOrCreateMLIRTensor(grad_output, input_tensor.GetDevice()); + return bridge::AtenFromMLIRTensor(MLIRTensor::hardtanh_backward( + grad_output_tensor, input_tensor, min_val, max_val)); +} + +at::Tensor ATenMLIRType::_log_softmax(const at::Tensor &self, int64_t dim, + bool half_to_float) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + return bridge::AtenFromMLIRTensor( + MLIRTensor::_log_softmax(input_tensor, dim, half_to_float)); +} + +at::Tensor +ATenMLIRType::_log_softmax_backward_data(const at::Tensor &grad_output, + const at::Tensor &output, int64_t dim, + const at::Tensor &self) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto output_tensor = + bridge::GetOrCreateMLIRTensor(output, input_tensor.GetDevice()); + auto grad_output_tensor = + bridge::GetOrCreateMLIRTensor(grad_output, input_tensor.GetDevice()); + return bridge::AtenFromMLIRTensor(MLIRTensor::_log_softmax_backward_data( + grad_output_tensor, output_tensor, dim, input_tensor)); +} + +std::tuple ATenMLIRType::max_pool2d_with_indices( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto ret = MLIRTensor::max_pool2d_with_indices( + input_tensor, kernel_size, stride, padding, dilation, ceil_mode); + return std::make_tuple(bridge::AtenFromMLIRTensor(std::get<0>(ret)), + bridge::AtenFromMLIRTensor(std::get<1>(ret))); +} + +at::Tensor ATenMLIRType::max_pool2d_with_indices_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const at::Tensor &indices) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto grad_output_tensor = + bridge::GetOrCreateMLIRTensor(grad_output, input_tensor.GetDevice()); + auto indices_tensor = + bridge::GetOrCreateMLIRTensor(indices, input_tensor.GetDevice()); + + return bridge::AtenFromMLIRTensor( + MLIRTensor::max_pool2d_with_indices_backward( + grad_output_tensor, input_tensor, kernel_size, stride, padding, + dilation, ceil_mode, indices_tensor)); +} + +at::Tensor ATenMLIRType::mean(const at::Tensor &self, + c10::optional dtype) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor( + MLIRTensor::mean(bridge::GetMLIRTensor(self), dtype)); +} + +at::Tensor ATenMLIRType::mean(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim, + c10::optional dtype) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor( + MLIRTensor::mean(bridge::GetMLIRTensor(self), dim, keepdim, dtype)); +} + +at::Tensor ATenMLIRType::mm(const at::Tensor &input, const at::Tensor &mat2) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(input); + auto mat2_tensor = + bridge::GetOrCreateMLIRTensor(mat2, input_tensor.GetDevice()); + return bridge::AtenFromMLIRTensor(MLIRTensor::mm(input_tensor, mat2_tensor)); +} + +at::Tensor ATenMLIRType::mul(const at::Tensor &self, const at::Tensor &other) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto tensors = GetPromotedMLIRTensorsForBinaryOp(self, other); + return bridge::AtenFromMLIRTensor( + MLIRTensor::mul(std::get<0>(tensors), std::get<1>(tensors))); +} + +at::Tensor &ATenMLIRType::mul_(at::Tensor &self, const at::Tensor &other) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto tensors = GetPromotedMLIRTensorsForBinaryOp(self, other); + auto result = bridge::AtenFromMLIRTensor( + MLIRTensor::mul_(std::get<0>(tensors), std::get<1>(tensors))); + MLIRTensorImpl *self_impl = + dynamic_cast(self.unsafeGetTensorImpl()); + self_impl->shallow_copy_from(result.getIntrusivePtr()); + return self; +} + +std::tuple ATenMLIRType::native_batch_norm( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + const at::Tensor &running_mean, const at::Tensor &running_var, + bool training, double momentum, double eps) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(input); + auto weight_tensor = + bridge::GetOrCreateMLIRTensor(weight, input_tensor.GetDevice()); + auto bias_tensor = + bridge::GetOrCreateMLIRTensor(bias, input_tensor.GetDevice()); + auto running_mean_tensor = + bridge::GetOrCreateMLIRTensor(running_mean, input_tensor.GetDevice()); + auto running_var_tensor = + bridge::GetOrCreateMLIRTensor(running_var, input_tensor.GetDevice()); + + auto ret = MLIRTensor::native_batch_norm( + input_tensor, weight_tensor, bias_tensor, running_mean_tensor, + running_var_tensor, training, momentum, eps); + + return std::make_tuple(bridge::AtenFromMLIRTensor(std::get<0>(ret)), + bridge::AtenFromMLIRTensor(std::get<1>(ret)), + bridge::AtenFromMLIRTensor(std::get<2>(ret))); +} + +std::tuple +ATenMLIRType::native_batch_norm_backward( + const at::Tensor &grad_out, const at::Tensor &input, + const at::Tensor &weight, const at::Tensor &running_mean, + const at::Tensor &running_var, const at::Tensor &save_mean, + const at::Tensor &save_invstd, bool train, double eps, + std::array output_mask) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(input); + auto grad_out_tensor = + bridge::GetOrCreateMLIRTensor(grad_out, input_tensor.GetDevice()); + auto weight_tensor = + bridge::GetOrCreateMLIRTensor(weight, input_tensor.GetDevice()); + auto running_mean_tensor = + bridge::GetOrCreateMLIRTensor(running_mean, input_tensor.GetDevice()); + auto running_var_tensor = + bridge::GetOrCreateMLIRTensor(running_var, input_tensor.GetDevice()); + auto save_mean_tensor = + bridge::GetOrCreateMLIRTensor(save_mean, input_tensor.GetDevice()); + auto save_invstd_tensor = + bridge::GetOrCreateMLIRTensor(save_invstd, input_tensor.GetDevice()); + + auto ret = MLIRTensor::native_batch_norm_backward( + grad_out_tensor, input_tensor, weight_tensor, running_mean_tensor, + running_var_tensor, save_mean_tensor, save_invstd_tensor, train, eps, + output_mask); + + return std::make_tuple(bridge::AtenFromMLIRTensor(std::get<0>(ret)), + bridge::AtenFromMLIRTensor(std::get<1>(ret)), + bridge::AtenFromMLIRTensor(std::get<2>(ret))); +} + +at::Tensor ATenMLIRType::neg(const at::Tensor &self) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + return bridge::AtenFromMLIRTensor(MLIRTensor::neg(input_tensor)); +} + +std::tuple ATenMLIRType::nll_loss2d_forward( + const at::Tensor &self, const at::Tensor &target, const at::Tensor &weight, + int64_t reduction, int64_t ignore_index) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto target_tensor = + bridge::GetOrCreateMLIRTensor(target, input_tensor.GetDevice()); + + auto weight_tensor = + weight.defined() + ? bridge::GetOrCreateMLIRTensor(weight, input_tensor.GetDevice()) + : bridge::GetOrCreateMLIRTensor(at::ones(self.sizes()[1]), + input_tensor.GetDevice()); + + auto ret = MLIRTensor::nll_loss2d_forward( + input_tensor, target_tensor, weight_tensor, reduction, ignore_index); + + return std::make_tuple(bridge::AtenFromMLIRTensor(std::get<0>(ret)), + bridge::AtenFromMLIRTensor(std::get<1>(ret))); +} + +at::Tensor ATenMLIRType::nll_loss2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, const at::Tensor &weight, int64_t reduction, + int64_t ignore_index, const at::Tensor &total_weight) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto grad_output_tensor = + bridge::GetOrCreateMLIRTensor(grad_output, input_tensor.GetDevice()); + auto target_tensor = + bridge::GetOrCreateMLIRTensor(target, input_tensor.GetDevice()); + + auto weight_tensor = + weight.defined() + ? bridge::GetOrCreateMLIRTensor(weight, input_tensor.GetDevice()) + : bridge::GetOrCreateMLIRTensor(at::ones(self.sizes()[1]), + input_tensor.GetDevice()); + auto total_weight_tensor = + bridge::GetOrCreateMLIRTensor(total_weight, input_tensor.GetDevice()); + + return bridge::AtenFromMLIRTensor(MLIRTensor::nll_loss2d_backward( + grad_output_tensor, input_tensor, target_tensor, weight_tensor, reduction, + ignore_index, total_weight_tensor)); +} + +std::tuple +ATenMLIRType::nll_loss_forward(const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto target_tensor = + bridge::GetOrCreateMLIRTensor(target, input_tensor.GetDevice()); + + auto weight_tensor = + weight.defined() + ? bridge::GetOrCreateMLIRTensor(weight, input_tensor.GetDevice()) + : bridge::GetOrCreateMLIRTensor(at::ones(self.sizes()[1]), + input_tensor.GetDevice()); + + auto ret = MLIRTensor::nll_loss_forward( + input_tensor, target_tensor, weight_tensor, reduction, ignore_index); + + return std::make_tuple(bridge::AtenFromMLIRTensor(std::get<0>(ret)), + bridge::AtenFromMLIRTensor(std::get<1>(ret))); +} + +at::Tensor ATenMLIRType::nll_loss_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, const at::Tensor &weight, int64_t reduction, + int64_t ignore_index, const at::Tensor &total_weight) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto grad_output_tensor = + bridge::GetOrCreateMLIRTensor(grad_output, input_tensor.GetDevice()); + auto target_tensor = + bridge::GetOrCreateMLIRTensor(target, input_tensor.GetDevice()); + + auto weight_tensor = + weight.defined() + ? bridge::GetOrCreateMLIRTensor(weight, input_tensor.GetDevice()) + : bridge::GetOrCreateMLIRTensor(at::ones(self.sizes()[1]), + input_tensor.GetDevice()); + auto total_weight_tensor = + bridge::GetOrCreateMLIRTensor(total_weight, input_tensor.GetDevice()); + + return bridge::AtenFromMLIRTensor(MLIRTensor::nll_loss_backward( + grad_output_tensor, input_tensor, target_tensor, weight_tensor, reduction, + ignore_index, total_weight_tensor)); +} + +at::Tensor ATenMLIRType::relu(const at::Tensor &self) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor( + MLIRTensor::relu(bridge::GetMLIRTensor(self))); +} + +at::Tensor &ATenMLIRType::relu_(at::Tensor &self) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto result = bridge::AtenFromMLIRTensor(MLIRTensor::relu_(input_tensor)); + MLIRTensorImpl *self_impl = + dynamic_cast(self.unsafeGetTensorImpl()); + self_impl->shallow_copy_from(result.getIntrusivePtr()); + return self; +} + +int64_t ATenMLIRType::size(const at::Tensor &self, int64_t dim) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::GetMLIRTensor(self).sizes()[dim]; +} + +at::Tensor ATenMLIRType::squeeze(const at::Tensor &self, int64_t dim) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor( + MLIRTensor::squeeze(bridge::GetMLIRTensor(self), dim)); +} + +at::Tensor ATenMLIRType::sub(const at::Tensor &self, const at::Tensor &other, + at::Scalar alpha) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto tensors = GetPromotedMLIRTensorsForBinaryOp(self, other); + return bridge::AtenFromMLIRTensor( + MLIRTensor::sub(std::get<0>(tensors), std::get<1>(tensors), alpha)); +} + +at::Tensor &ATenMLIRType::sub_(at::Tensor &self, const at::Tensor &other, + at::Scalar alpha) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto tensors = GetPromotedMLIRTensorsForBinaryOp(self, other); + auto result = bridge::AtenFromMLIRTensor( + MLIRTensor::sub_(std::get<0>(tensors), std::get<1>(tensors), alpha)); + MLIRTensorImpl *self_impl = + dynamic_cast(self.unsafeGetTensorImpl()); + self_impl->shallow_copy_from(result.getIntrusivePtr()); + return self; +} + +at::Tensor ATenMLIRType::sum(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim, + c10::optional dtype) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor( + MLIRTensor::sum(bridge::GetMLIRTensor(self), dim, keepdim, dtype)); +} + +at::Tensor ATenMLIRType::t(const at::Tensor &self) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor(MLIRTensor::t(bridge::GetMLIRTensor(self))); +} + +at::Tensor ATenMLIRType::threshold_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar threshold) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + auto input_tensor = bridge::GetMLIRTensor(self); + auto grad_output_tensor = + bridge::GetOrCreateMLIRTensor(grad_output, input_tensor.GetDevice()); + return bridge::AtenFromMLIRTensor(MLIRTensor::threshold_backward( + grad_output_tensor, input_tensor, threshold)); +} + +at::Tensor ATenMLIRType::to(const at::Tensor &self, + const at::TensorOptions &options, + bool /* non_blocking */, bool /* copy */) { + + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + + auto self_tensor = bridge::TryGetMLIRTensor(self); + if (!self_tensor) { + assert(options.has_device()); + at::ScalarType dtype = options.has_dtype() + ? c10::typeMetaToScalarType(options.dtype()) + : self.scalar_type(); + MLIRTensor xtensor = + MLIRTensor::Create(util::CopyTensor(self, dtype), + bridge::AtenDeviceToMLIRDevice(options.device())); + return bridge::AtenFromMLIRTensor(xtensor); + } + if (options.has_device() && options.device().type() != at::kXLA) { + return bridge::MLIRToAtenTensor(*self_tensor, options); + } + MLIROptions mlir_options(options, self_tensor->GetDevice(), + self_tensor->dtype()); + return bridge::AtenFromMLIRTensor(MLIRTensor::to( + *self_tensor, mlir_options.device, mlir_options.scalar_type)); +} + +at::Tensor ATenMLIRType::to(const at::Tensor &self, c10::Device device, + at::ScalarType dtype, bool non_blocking, + bool copy) { + return to(self, self.options().device(device).dtype(dtype), non_blocking, + copy); +} + +at::Tensor ATenMLIRType::to(const at::Tensor &self, at::ScalarType dtype, + bool non_blocking, bool copy) { + return to(self, self.options().dtype(dtype), non_blocking, copy); +} + +at::Tensor ATenMLIRType::to(const at::Tensor &self, const at::Tensor &other, + bool non_blocking, bool copy) { + return to(self, other.options(), non_blocking, copy); +} + +at::Tensor ATenMLIRType::_unsafe_view(const at::Tensor &self, + at::IntArrayRef size) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor( + MLIRTensor::view(bridge::GetMLIRTensor(self), size)); +} + +at::Tensor ATenMLIRType::unsqueeze(const at::Tensor &self, int64_t dim) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor( + MLIRTensor::unsqueeze(bridge::GetMLIRTensor(self), dim)); +} + +at::Tensor ATenMLIRType::view(const at::Tensor &self, at::IntArrayRef size) { + LLVM_DEBUG(llvm::dbgs() << "ATenMLIRType::" << __func__ << "\n"); + return bridge::AtenFromMLIRTensor( + MLIRTensor::view(bridge::GetMLIRTensor(self), size)); +} +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/aten_mlir_type.h b/frontends/pytorch/csrc/aten_mlir_type.h new file mode 100644 index 000000000..57fe6b08f --- /dev/null +++ b/frontends/pytorch/csrc/aten_mlir_type.h @@ -0,0 +1,212 @@ +//===- aten_mlir_type.h -----------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +// Structured similarly to code from git@github.com:pytorch/xla.git + +#pragma once + +#include + +namespace torch_mlir { + +// Base ATEN Type class where the MLIR specific overrides should be defined. +class ATenMLIRType { +public: + static void InitializeAtenBindings(); + + ////////////////////////////////////////////////////////////////////////////// + // ATEN API overrides in alphabetical order. + // Note: The C++ signatures must match the ones listed within the following + // pytorch folder file: + // build/aten/src/ATen/RegistrationDeclarations.h + ///////////////////////////////////////////////////////////////////////////// + // The static method definitions here have multiple uses. Each function + // signature here will override the default implementation provided by + // aten_mlir_type_defaults.h. Most of these overrides are used to construct + // a small internal IR that can be used for different purposes. Primarily, + // in this code, the IR will be converted to MLIR. As such there is a often + // a 1:1 correspondance between code here and operations in the ATen MLIR + // dialect. + + // This file is parsed by gen_aten_dialect.py to generate + // aten_mlir_type_defaults.*, including the appropriate bindings in that + // file for all pytorch methods. + + static at::Tensor _adaptive_avg_pool2d(const at::Tensor &self, + at::IntArrayRef output_size); + + static at::Tensor _adaptive_avg_pool2d_backward(const at::Tensor &grad_output, + const at::Tensor &self); + + static at::Tensor add(const at::Tensor &self, const at::Tensor &other, + at::Scalar alpha); + + static at::Tensor &add_(at::Tensor &self, const at::Tensor &other, + at::Scalar alpha); + + static at::Tensor addmm(const at::Tensor &self, const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha); + + static at::Tensor as_strided(const at::Tensor &self, at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset); + + static at::Tensor clone(const at::Tensor &self); + + static std::tuple + convolution_backward_overrideable( + const at::Tensor &grad_output, const at::Tensor &input, + const at::Tensor &weight, at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, + int64_t groups, std::array output_mask); + + static at::Tensor convolution_overrideable( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups); + + static at::Tensor ©_(at::Tensor &self, const at::Tensor &src, + bool non_blocking); + + static at::Tensor _copy_from(const at::Tensor &self, const at::Tensor &dst, + bool non_blocking); + + static at::Tensor div(const at::Tensor &self, const at::Tensor &other); + + static at::Tensor &div_(at::Tensor &self, const at::Tensor &other); + + static at::Tensor div(const at::Tensor &self, at::Scalar other); + + static at::Tensor expand(const at::Tensor &self, at::IntArrayRef size, + bool implicit); + + static at::Tensor gather(const at::Tensor &self, int64_t dim, + const at::Tensor &index, bool sparse_grad); + + static at::Tensor hardtanh(const at::Tensor &self, at::Scalar min_val, + at::Scalar max_val); + + static at::Tensor &hardtanh_(at::Tensor &self, at::Scalar min_val, + at::Scalar max_val); + + static at::Tensor hardtanh_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar min_val, at::Scalar max_val); + + static at::Tensor _log_softmax(const at::Tensor &self, int64_t dim, + bool half_to_float); + + static at::Tensor _log_softmax_backward_data(const at::Tensor &grad_output, + const at::Tensor &output, + int64_t dim, + const at::Tensor &self); + + static std::tuple + max_pool2d_with_indices(const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode); + + static at::Tensor max_pool2d_with_indices_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const at::Tensor &indices); + + static at::Tensor mean(const at::Tensor &self, + c10::optional dtype); + + static at::Tensor mean(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim, c10::optional dtype); + + static at::Tensor mm(const at::Tensor &self, const at::Tensor &mat2); + + static at::Tensor mul(const at::Tensor &self, const at::Tensor &other); + + static at::Tensor &mul_(at::Tensor &self, const at::Tensor &other); + + static std::tuple + native_batch_norm(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, const at::Tensor &running_mean, + const at::Tensor &running_var, bool training, + double momentum, double eps); + + static std::tuple + native_batch_norm_backward(const at::Tensor &grad_out, + const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &running_mean, + const at::Tensor &running_var, + const at::Tensor &save_mean, + const at::Tensor &save_invstd, bool train, + double eps, std::array output_mask); + + static at::Tensor neg(const at::Tensor &self); + + static std::tuple + nll_loss2d_forward(const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index); + + static at::Tensor nll_loss2d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, + int64_t reduction, int64_t ignore_index, + const at::Tensor &total_weight); + + static std::tuple + nll_loss_forward(const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index); + + static at::Tensor nll_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, + int64_t reduction, int64_t ignore_index, + const at::Tensor &total_weight); + + static at::Tensor relu(const at::Tensor &self); + + static at::Tensor &relu_(at::Tensor &self); + + static int64_t size(const at::Tensor &self, int64_t dim); + + static at::Tensor squeeze(const at::Tensor &self, int64_t dim); + + static at::Tensor sub(const at::Tensor &self, const at::Tensor &other, + at::Scalar alpha); + + static at::Tensor &sub_(at::Tensor &self, const at::Tensor &other, + at::Scalar alpha); + + static at::Tensor sum(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim, c10::optional dtype); + + static at::Tensor t(const at::Tensor &self); + + static at::Tensor threshold_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar threshold); + + static at::Tensor to(const at::Tensor &self, const at::TensorOptions &options, + bool non_blocking, bool copy); + static at::Tensor to(const at::Tensor &self, c10::Device device, + at::ScalarType dtype, bool non_blocking, bool copy); + static at::Tensor to(const at::Tensor &self, at::ScalarType dtype, + bool non_blocking, bool copy); + static at::Tensor to(const at::Tensor &self, const at::Tensor &other, + bool non_blocking, bool copy); + + static at::Tensor _unsafe_view(const at::Tensor &self, at::IntArrayRef size); + + static at::Tensor unsqueeze(const at::Tensor &self, int64_t dim); + + static at::Tensor view(const at::Tensor &self, at::IntArrayRef size); +}; + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/aten_mlir_type_default.cpp b/frontends/pytorch/csrc/aten_mlir_type_default.cpp new file mode 100644 index 000000000..a0c1c306e --- /dev/null +++ b/frontends/pytorch/csrc/aten_mlir_type_default.cpp @@ -0,0 +1,24500 @@ +//===- aten_mlir_type_default.cpp -------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#include "aten_mlir_type_default.h" + +#include +#include +#include +#include + +#include "aten_mlir_bridge.h" +#include "aten_mlir_type.h" + +namespace torch_mlir { + +at::Tensor ATenMLIRTypeDefault::_cast_Byte(const at::Tensor &self, + bool non_blocking) { + std::cout << "aten::_cast_Byte" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cast_Byte(mlirtens[0], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_cast_Char(const at::Tensor &self, + bool non_blocking) { + std::cout << "aten::_cast_Char" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cast_Char(mlirtens[0], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_cast_Double(const at::Tensor &self, + bool non_blocking) { + std::cout << "aten::_cast_Double" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cast_Double(mlirtens[0], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_cast_Float(const at::Tensor &self, + bool non_blocking) { + std::cout << "aten::_cast_Float" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cast_Float(mlirtens[0], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_cast_Int(const at::Tensor &self, + bool non_blocking) { + std::cout << "aten::_cast_Int" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cast_Int(mlirtens[0], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_cast_Long(const at::Tensor &self, + bool non_blocking) { + std::cout << "aten::_cast_Long" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cast_Long(mlirtens[0], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_cast_Short(const at::Tensor &self, + bool non_blocking) { + std::cout << "aten::_cast_Short" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cast_Short(mlirtens[0], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_cast_Half(const at::Tensor &self, + bool non_blocking) { + std::cout << "aten::_cast_Half" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cast_Half(mlirtens[0], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +void ATenMLIRTypeDefault::backward(const at::Tensor &self, + const at::Tensor &gradient, bool keep_graph, + bool create_graph) { + std::cout << "aten::backward" << std::endl; + std::vector mlirtens_tensors = {self, gradient}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + mlirtens[0].backward(mlirtens[1], keep_graph, create_graph); +} + +void ATenMLIRTypeDefault::set_data(const at::Tensor &self, + const at::Tensor &new_data) { + std::cout << "aten::set_data" << std::endl; + std::vector mlirtens_tensors = {self, new_data}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + mlirtens[0].set_data(mlirtens[1]); +} + +at::Tensor ATenMLIRTypeDefault::data(const at::Tensor &self) { + std::cout << "aten::data" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].data(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +int64_t +ATenMLIRTypeDefault::_debug_has_internal_overlap(const at::Tensor &self) { + std::cout << "aten::_debug_has_internal_overlap" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_debug_has_internal_overlap(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +std::tuple +ATenMLIRTypeDefault::_fused_dropout(const at::Tensor &self, double p, + at::Generator *generator) { + std::cout << "aten::_fused_dropout" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_fused_dropout(mlirtens[0], p, generator); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::_masked_scale(const at::Tensor &self, + const at::Tensor &mask, + double scale) { + std::cout << "aten::_masked_scale" << std::endl; + std::vector mlirtens_tensors = {self, mask}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_masked_scale(mlirtens[0], mlirtens[1], scale); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple ATenMLIRTypeDefault::_sobol_engine_draw( + const at::Tensor &quasi, int64_t n, const at::Tensor &sobolstate, + int64_t dimension, int64_t num_generated, + c10::optional dtype) { + std::cout << "aten::_sobol_engine_draw" << std::endl; + std::vector mlirtens_tensors = {quasi, sobolstate}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sobol_engine_draw(mlirtens[0], n, mlirtens[1], + dimension, num_generated, dtype); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(quasi)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(quasi))); +} + +at::Tensor &ATenMLIRTypeDefault::_sobol_engine_ff_(at::Tensor &self, int64_t n, + const at::Tensor &sobolstate, + int64_t dimension, + int64_t num_generated) { + std::cout << "aten::_sobol_engine_ff_" << std::endl; + std::vector mlirtens_tensors = {self, sobolstate}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sobol_engine_ff_(mlirtens[0], n, mlirtens[1], + dimension, num_generated); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::_sobol_engine_scramble_(at::Tensor &self, + const at::Tensor <m, + int64_t dimension) { + std::cout << "aten::_sobol_engine_scramble_" << std::endl; + std::vector mlirtens_tensors = {self, ltm}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_sobol_engine_scramble_(mlirtens[0], mlirtens[1], dimension); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor & +ATenMLIRTypeDefault::_sobol_engine_initialize_state_(at::Tensor &self, + int64_t dimension) { + std::cout << "aten::_sobol_engine_initialize_state_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sobol_engine_initialize_state_(mlirtens[0], dimension); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::_reshape_from_tensor(const at::Tensor &self, + const at::Tensor &shape) { + std::cout << "aten::_reshape_from_tensor" << std::endl; + std::vector mlirtens_tensors = {self, shape}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_reshape_from_tensor(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_shape_as_tensor(const at::Tensor &self) { + std::cout << "aten::_shape_as_tensor" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_shape_as_tensor(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::dropout(const at::Tensor &input, double p, + bool train) { + std::cout << "aten::dropout" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::dropout(mlirtens[0], p, train); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor &ATenMLIRTypeDefault::dropout_(at::Tensor &self, double p, + bool train) { + std::cout << "aten::dropout_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::dropout_(mlirtens[0], p, train); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::feature_dropout(const at::Tensor &input, + double p, bool train) { + std::cout << "aten::feature_dropout" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::feature_dropout(mlirtens[0], p, train); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor &ATenMLIRTypeDefault::feature_dropout_(at::Tensor &self, double p, + bool train) { + std::cout << "aten::feature_dropout_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::feature_dropout_(mlirtens[0], p, train); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::alpha_dropout(const at::Tensor &input, double p, + bool train) { + std::cout << "aten::alpha_dropout" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::alpha_dropout(mlirtens[0], p, train); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor &ATenMLIRTypeDefault::alpha_dropout_(at::Tensor &self, double p, + bool train) { + std::cout << "aten::alpha_dropout_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::alpha_dropout_(mlirtens[0], p, train); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::feature_alpha_dropout(const at::Tensor &input, + double p, bool train) { + std::cout << "aten::feature_alpha_dropout" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::feature_alpha_dropout(mlirtens[0], p, train); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor &ATenMLIRTypeDefault::feature_alpha_dropout_(at::Tensor &self, + double p, bool train) { + std::cout << "aten::feature_alpha_dropout_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::feature_alpha_dropout_(mlirtens[0], p, train); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::abs(const at::Tensor &self) { + std::cout << "aten::abs" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::abs(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::abs_(at::Tensor &self) { + std::cout << "aten::abs_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::abs_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::abs_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::abs_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::abs_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::acos(const at::Tensor &self) { + std::cout << "aten::acos" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::acos(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::acos_(at::Tensor &self) { + std::cout << "aten::acos_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::acos_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::acos_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::acos_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::acos_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::avg_pool1d( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, bool count_include_pad) { + std::cout << "aten::avg_pool1d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::avg_pool1d(mlirtens[0], kernel_size, stride, padding, + ceil_mode, count_include_pad); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor +ATenMLIRTypeDefault::adaptive_avg_pool1d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::adaptive_avg_pool1d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_avg_pool1d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::adaptive_max_pool1d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::adaptive_max_pool1d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_max_pool1d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::add(const at::Tensor &self, + const at::Tensor &other, at::Scalar alpha) { + std::cout << "aten::add" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::add(mlirtens[0], mlirtens[1], alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::add_(at::Tensor &self, const at::Tensor &other, + at::Scalar alpha) { + std::cout << "aten::add_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].add_(mlirtens[1], alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::add_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other, + at::Scalar alpha) { + std::cout << "aten::add_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::add_out(mlirtens[0], mlirtens[1], mlirtens[2], alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::add(const at::Tensor &self, at::Scalar other, + at::Scalar alpha) { + std::cout << "aten::add" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::add(mlirtens[0], other, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::add_(at::Tensor &self, at::Scalar other, + at::Scalar alpha) { + std::cout << "aten::add_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].add_(other, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::addmv(const at::Tensor &self, + const at::Tensor &mat, + const at::Tensor &vec, at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::addmv" << std::endl; + std::vector mlirtens_tensors = {self, mat, vec}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::addmv(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::addmv_(at::Tensor &self, const at::Tensor &mat, + const at::Tensor &vec, at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::addmv_" << std::endl; + std::vector mlirtens_tensors = {self, mat, vec}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::addmv_(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::addmv_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &mat, + const at::Tensor &vec, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::addmv_out" << std::endl; + std::vector mlirtens_tensors = {out, self, mat, vec}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::addmv_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::addr(const at::Tensor &self, + const at::Tensor &vec1, + const at::Tensor &vec2, at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::addr" << std::endl; + std::vector mlirtens_tensors = {self, vec1, vec2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::addr(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::addr_(at::Tensor &self, const at::Tensor &vec1, + const at::Tensor &vec2, at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::addr_" << std::endl; + std::vector mlirtens_tensors = {self, vec1, vec2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].addr_(mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::addr_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &vec1, + const at::Tensor &vec2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::addr_out" << std::endl; + std::vector mlirtens_tensors = {out, self, vec1, vec2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::addr_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::affine_grid_generator(const at::Tensor &theta, + at::IntArrayRef size, + bool align_corners) { + std::cout << "aten::affine_grid_generator" << std::endl; + std::vector mlirtens_tensors = {theta}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::affine_grid_generator(mlirtens[0], size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(theta)); +} + +at::Tensor ATenMLIRTypeDefault::affine_grid_generator_backward( + const at::Tensor &grad, at::IntArrayRef size, bool align_corners) { + std::cout << "aten::affine_grid_generator_backward" << std::endl; + std::vector mlirtens_tensors = {grad}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::affine_grid_generator_backward(mlirtens[0], size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::all(const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::all" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::all(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::all_out(at::Tensor &out, + const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::all_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::all_out(mlirtens[0], mlirtens[1], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +bool ATenMLIRTypeDefault::allclose(const at::Tensor &self, + const at::Tensor &other, double rtol, + double atol, bool equal_nan) { + std::cout << "aten::allclose" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::allclose(mlirtens[0], mlirtens[1], rtol, atol, equal_nan); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::any(const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::any" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::any(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::any_out(at::Tensor &out, + const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::any_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::any_out(mlirtens[0], mlirtens[1], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::arange(at::Scalar end, + const at::TensorOptions &options) { + std::cout << "aten::arange" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::arange(end, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::arange(at::Scalar start, at::Scalar end, + const at::TensorOptions &options) { + std::cout << "aten::arange" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::arange(start, end, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::arange(at::Scalar start, at::Scalar end, + at::Scalar step, + const at::TensorOptions &options) { + std::cout << "aten::arange" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::arange(start, end, step, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::arange_out(at::Tensor &out, at::Scalar end) { + std::cout << "aten::arange_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::arange_out(mlirtens[0], end); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::arange_out(at::Tensor &out, at::Scalar start, + at::Scalar end, at::Scalar step) { + std::cout << "aten::arange_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::arange_out(mlirtens[0], start, end, step); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::_dim_arange(const at::Tensor &like, + int64_t dim) { + std::cout << "aten::_dim_arange" << std::endl; + std::vector mlirtens_tensors = {like}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_dim_arange(mlirtens[0], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(like)); +} + +at::Tensor ATenMLIRTypeDefault::argmax(const at::Tensor &self, + c10::optional dim, + bool keepdim) { + std::cout << "aten::argmax" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::argmax(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::argmin(const at::Tensor &self, + c10::optional dim, + bool keepdim) { + std::cout << "aten::argmin" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::argmin(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor +ATenMLIRTypeDefault::as_strided(const at::Tensor &self, at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset) { + std::cout << "aten::as_strided" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::as_strided(mlirtens[0], size, stride, storage_offset); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor & +ATenMLIRTypeDefault::as_strided_(at::Tensor &self, at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset) { + std::cout << "aten::as_strided_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::as_strided_(mlirtens[0], size, stride, storage_offset); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::asin(const at::Tensor &self) { + std::cout << "aten::asin" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::asin(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::asin_(at::Tensor &self) { + std::cout << "aten::asin_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::asin_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::asin_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::asin_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::asin_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::atan(const at::Tensor &self) { + std::cout << "aten::atan" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::atan(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::atan_(at::Tensor &self) { + std::cout << "aten::atan_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::atan_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::atan_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::atan_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::atan_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::baddbmm(const at::Tensor &self, + const at::Tensor &batch1, + const at::Tensor &batch2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::baddbmm" << std::endl; + std::vector mlirtens_tensors = {self, batch1, batch2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::baddbmm(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::baddbmm_(at::Tensor &self, + const at::Tensor &batch1, + const at::Tensor &batch2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::baddbmm_" << std::endl; + std::vector mlirtens_tensors = {self, batch1, batch2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].baddbmm_(mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::_baddbmm_mkl_(at::Tensor &self, + const at::Tensor &batch1, + const at::Tensor &batch2, + at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::_baddbmm_mkl_" << std::endl; + std::vector mlirtens_tensors = {self, batch1, batch2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_baddbmm_mkl_(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::baddbmm_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &batch1, + const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::baddbmm_out" << std::endl; + std::vector mlirtens_tensors = {out, self, batch1, batch2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::baddbmm_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor +ATenMLIRTypeDefault::bartlett_window(int64_t window_length, + const at::TensorOptions &options) { + std::cout << "aten::bartlett_window" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::bartlett_window(window_length, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor +ATenMLIRTypeDefault::bartlett_window(int64_t window_length, bool periodic, + const at::TensorOptions &options) { + std::cout << "aten::bartlett_window" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::bartlett_window(window_length, periodic, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::batch_norm( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + const at::Tensor &running_mean, const at::Tensor &running_var, + bool training, double momentum, double eps, bool cudnn_enabled) { + std::cout << "aten::batch_norm" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias, running_mean, + running_var}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::batch_norm(mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], + mlirtens[4], training, momentum, eps, cudnn_enabled); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple +ATenMLIRTypeDefault::_batch_norm_impl_index( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + const at::Tensor &running_mean, const at::Tensor &running_var, + bool training, double momentum, double eps, bool cudnn_enabled) { + std::cout << "aten::_batch_norm_impl_index" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias, running_mean, + running_var}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_batch_norm_impl_index( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], training, + momentum, eps, cudnn_enabled); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input)), + std::get<3>(x_result)); +} + +std::tuple +ATenMLIRTypeDefault::_batch_norm_impl_index_backward( + int64_t impl_index, const at::Tensor &input, const at::Tensor &grad_output, + const at::Tensor &weight, const at::Tensor &running_mean, + const at::Tensor &running_var, const at::Tensor &save_mean, + const at::Tensor &save_var_transform, bool train, double eps, + std::array output_mask) { + std::cout << "aten::_batch_norm_impl_index_backward" << std::endl; + std::vector mlirtens_tensors = { + input, grad_output, weight, running_mean, running_var, + save_mean, save_var_transform}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_batch_norm_impl_index_backward( + impl_index, mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], + mlirtens[4], mlirtens[5], mlirtens[6], train, eps, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input))); +} + +at::Tensor ATenMLIRTypeDefault::bernoulli(const at::Tensor &self, + at::Generator *generator) { + std::cout << "aten::bernoulli" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::bernoulli(mlirtens[0], generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::bernoulli_out(at::Tensor &out, + const at::Tensor &self, + at::Generator *generator) { + std::cout << "aten::bernoulli_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::bernoulli_out(mlirtens[0], mlirtens[1], generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::bernoulli_(at::Tensor &self, + const at::Tensor &p, + at::Generator *generator) { + std::cout << "aten::bernoulli_" << std::endl; + std::vector mlirtens_tensors = {self, p}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].bernoulli_(mlirtens[1], generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::bernoulli_(at::Tensor &self, double p, + at::Generator *generator) { + std::cout << "aten::bernoulli_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].bernoulli_(p, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::bernoulli(const at::Tensor &self, double p, + at::Generator *generator) { + std::cout << "aten::bernoulli" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::bernoulli(mlirtens[0], p, generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::bilinear(const at::Tensor &input1, + const at::Tensor &input2, + const at::Tensor &weight, + const at::Tensor &bias) { + std::cout << "aten::bilinear" << std::endl; + std::vector mlirtens_tensors = {input1, input2, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::bilinear(mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input1)); +} + +at::Tensor ATenMLIRTypeDefault::binary_cross_entropy_with_logits( + const at::Tensor &self, const at::Tensor &target, const at::Tensor &weight, + const at::Tensor &pos_weight, int64_t reduction) { + std::cout << "aten::binary_cross_entropy_with_logits" << std::endl; + std::vector mlirtens_tensors = {self, target, weight, pos_weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::binary_cross_entropy_with_logits( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::binary_cross_entropy_with_logits_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, const at::Tensor &weight, + const at::Tensor &pos_weight, int64_t reduction) { + std::cout << "aten::binary_cross_entropy_with_logits_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target, weight, + pos_weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::binary_cross_entropy_with_logits_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor ATenMLIRTypeDefault::bincount(const at::Tensor &self, + const at::Tensor &weights, + int64_t minlength) { + std::cout << "aten::bincount" << std::endl; + std::vector mlirtens_tensors = {self, weights}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::bincount(mlirtens[0], mlirtens[1], minlength); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::bitwise_not(const at::Tensor &self) { + std::cout << "aten::bitwise_not" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::bitwise_not(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::bitwise_not_(at::Tensor &self) { + std::cout << "aten::bitwise_not_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].bitwise_not_(); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::bitwise_not_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::bitwise_not_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::bitwise_not_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::logical_not(const at::Tensor &self) { + std::cout << "aten::logical_not" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::logical_not(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::logical_not_(at::Tensor &self) { + std::cout << "aten::logical_not_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].logical_not_(); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::logical_not_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::logical_not_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::logical_not_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::logical_xor(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::logical_xor" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::logical_xor(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::logical_xor_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::logical_xor_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].logical_xor_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::logical_xor_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::logical_xor_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::logical_xor_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor +ATenMLIRTypeDefault::blackman_window(int64_t window_length, + const at::TensorOptions &options) { + std::cout << "aten::blackman_window" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::blackman_window(window_length, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor +ATenMLIRTypeDefault::blackman_window(int64_t window_length, bool periodic, + const at::TensorOptions &options) { + std::cout << "aten::blackman_window" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::blackman_window(window_length, periodic, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::bmm(const at::Tensor &self, + const at::Tensor &mat2) { + std::cout << "aten::bmm" << std::endl; + std::vector mlirtens_tensors = {self, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::bmm(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::bmm_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &mat2) { + std::cout << "aten::bmm_out" << std::endl; + std::vector mlirtens_tensors = {out, self, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::bmm_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +std::vector +ATenMLIRTypeDefault::broadcast_tensors(at::TensorList tensors) { + std::cout << "aten::broadcast_tensors" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::broadcast_tensors(tensors); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::cat(at::TensorList tensors, int64_t dim) { + std::cout << "aten::cat" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cat(tensors, dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(tensors)); +} + +at::Tensor &ATenMLIRTypeDefault::cat_out(at::Tensor &out, + at::TensorList tensors, int64_t dim) { + std::cout << "aten::cat_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cat_out(mlirtens[0], tensors, dim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::ceil(const at::Tensor &self) { + std::cout << "aten::ceil" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ceil(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::ceil_(at::Tensor &self) { + std::cout << "aten::ceil_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ceil_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::ceil_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::ceil_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ceil_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::chain_matmul(at::TensorList matrices) { + std::cout << "aten::chain_matmul" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::chain_matmul(matrices); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(matrices)); +} + +std::vector ATenMLIRTypeDefault::chunk(const at::Tensor &self, + int64_t chunks, + int64_t dim) { + std::cout << "aten::chunk" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::chunk(mlirtens[0], chunks, dim); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::clamp(const at::Tensor &self, + c10::optional min, + c10::optional max) { + std::cout << "aten::clamp" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::clamp(mlirtens[0], min, max); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::clamp_(at::Tensor &self, + c10::optional min, + c10::optional max) { + std::cout << "aten::clamp_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::clamp_(mlirtens[0], min, max); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::clamp_out(at::Tensor &out, + const at::Tensor &self, + c10::optional min, + c10::optional max) { + std::cout << "aten::clamp_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::clamp_out(mlirtens[0], mlirtens[1], min, max); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::clamp_max(const at::Tensor &self, + at::Scalar max) { + std::cout << "aten::clamp_max" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::clamp_max(mlirtens[0], max); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::clamp_max_(at::Tensor &self, at::Scalar max) { + std::cout << "aten::clamp_max_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::clamp_max_(mlirtens[0], max); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::clamp_max_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar max) { + std::cout << "aten::clamp_max_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::clamp_max_out(mlirtens[0], mlirtens[1], max); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::clamp_min(const at::Tensor &self, + at::Scalar min) { + std::cout << "aten::clamp_min" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::clamp_min(mlirtens[0], min); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::clamp_min_(at::Tensor &self, at::Scalar min) { + std::cout << "aten::clamp_min_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::clamp_min_(mlirtens[0], min); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::clamp_min_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar min) { + std::cout << "aten::clamp_min_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::clamp_min_out(mlirtens[0], mlirtens[1], min); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::constant_pad_nd(const at::Tensor &self, + at::IntArrayRef pad, + at::Scalar value) { + std::cout << "aten::constant_pad_nd" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::constant_pad_nd(mlirtens[0], pad, value); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::contiguous(const at::Tensor &self, + at::MemoryFormat memory_format) { + std::cout << "aten::contiguous" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].contiguous(memory_format); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::convolution( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups) { + std::cout << "aten::convolution" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::convolution(mlirtens[0], mlirtens[1], mlirtens[2], stride, padding, + dilation, transposed, output_padding, groups); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::convolution_overrideable( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups) { + std::cout << "aten::convolution_overrideable" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::convolution_overrideable( + mlirtens[0], mlirtens[1], mlirtens[2], stride, padding, dilation, + transposed, output_padding, groups); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple +ATenMLIRTypeDefault::convolution_backward_overrideable( + const at::Tensor &grad_output, const at::Tensor &input, + const at::Tensor &weight, at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, + int64_t groups, std::array output_mask) { + std::cout << "aten::convolution_backward_overrideable" << std::endl; + std::vector mlirtens_tensors = {grad_output, input, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::convolution_backward_overrideable( + mlirtens[0], mlirtens[1], mlirtens[2], stride, padding, dilation, + transposed, output_padding, groups, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor ATenMLIRTypeDefault::_convolution( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups, + bool benchmark, bool deterministic, bool cudnn_enabled) { + std::cout << "aten::_convolution" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_convolution(mlirtens[0], mlirtens[1], mlirtens[2], stride, padding, + dilation, transposed, output_padding, groups, benchmark, + deterministic, cudnn_enabled); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::_convolution_nogroup( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding) { + std::cout << "aten::_convolution_nogroup" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_convolution_nogroup(mlirtens[0], mlirtens[1], mlirtens[2], stride, + padding, dilation, transposed, output_padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple +ATenMLIRTypeDefault::_convolution_double_backward( + const at::Tensor &ggI, const at::Tensor &ggW, const at::Tensor &ggb, + const at::Tensor &gO, const at::Tensor &weight, const at::Tensor &self, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups, + bool benchmark, bool deterministic, bool cudnn_enabled, + std::array output_mask) { + std::cout << "aten::_convolution_double_backward" << std::endl; + std::vector mlirtens_tensors = {ggI, ggW, ggb, gO, weight, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_convolution_double_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], stride, padding, dilation, transposed, output_padding, + groups, benchmark, deterministic, cudnn_enabled, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(ggI)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(ggI)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(ggI))); +} + +at::Tensor +ATenMLIRTypeDefault::conv1d(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, + int64_t groups) { + std::cout << "aten::conv1d" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::conv1d(mlirtens[0], mlirtens[1], mlirtens[2], stride, + padding, dilation, groups); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor +ATenMLIRTypeDefault::conv2d(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, + int64_t groups) { + std::cout << "aten::conv2d" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::conv2d(mlirtens[0], mlirtens[1], mlirtens[2], stride, + padding, dilation, groups); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor +ATenMLIRTypeDefault::conv3d(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, + int64_t groups) { + std::cout << "aten::conv3d" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::conv3d(mlirtens[0], mlirtens[1], mlirtens[2], stride, + padding, dilation, groups); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::conv_tbc(const at::Tensor &self, + const at::Tensor &weight, + const at::Tensor &bias, int64_t pad) { + std::cout << "aten::conv_tbc" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::conv_tbc(mlirtens[0], mlirtens[1], mlirtens[2], pad); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::conv_tbc_backward(const at::Tensor &self, + const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, int64_t pad) { + std::cout << "aten::conv_tbc_backward" << std::endl; + std::vector mlirtens_tensors = {self, input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::conv_tbc_backward(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], pad); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::conv_transpose1d( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, int64_t groups, at::IntArrayRef dilation) { + std::cout << "aten::conv_transpose1d" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::conv_transpose1d(mlirtens[0], mlirtens[1], mlirtens[2], stride, + padding, output_padding, groups, dilation); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::conv_transpose2d( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, int64_t groups, at::IntArrayRef dilation) { + std::cout << "aten::conv_transpose2d" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::conv_transpose2d(mlirtens[0], mlirtens[1], mlirtens[2], stride, + padding, output_padding, groups, dilation); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::conv_transpose3d( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, int64_t groups, at::IntArrayRef dilation) { + std::cout << "aten::conv_transpose3d" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::conv_transpose3d(mlirtens[0], mlirtens[1], mlirtens[2], stride, + padding, output_padding, groups, dilation); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor &ATenMLIRTypeDefault::copy_(at::Tensor &self, const at::Tensor &src, + bool non_blocking) { + std::cout << "aten::copy_" << std::endl; + std::vector mlirtens_tensors = {self, src}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].copy_(mlirtens[1], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::_copy_from(const at::Tensor &self, + const at::Tensor &dst, + bool non_blocking) { + std::cout << "aten::_copy_from" << std::endl; + std::vector mlirtens_tensors = {self, dst}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_copy_from(mlirtens[0], mlirtens[1], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::cos(const at::Tensor &self) { + std::cout << "aten::cos" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cos(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::cos_(at::Tensor &self) { + std::cout << "aten::cos_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cos_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::cos_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::cos_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cos_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::cosh(const at::Tensor &self) { + std::cout << "aten::cosh" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cosh(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::cosh_(at::Tensor &self) { + std::cout << "aten::cosh_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cosh_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::cosh_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::cosh_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cosh_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::cosine_embedding_loss(const at::Tensor &input1, + const at::Tensor &input2, + const at::Tensor &target, + double margin, + int64_t reduction) { + std::cout << "aten::cosine_embedding_loss" << std::endl; + std::vector mlirtens_tensors = {input1, input2, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cosine_embedding_loss(mlirtens[0], mlirtens[1], + mlirtens[2], margin, reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input1)); +} + +at::Tensor ATenMLIRTypeDefault::cumsum(const at::Tensor &self, int64_t dim, + c10::optional dtype) { + std::cout << "aten::cumsum" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cumsum(mlirtens[0], dim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor & +ATenMLIRTypeDefault::cumsum_out(at::Tensor &out, const at::Tensor &self, + int64_t dim, + c10::optional dtype) { + std::cout << "aten::cumsum_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cumsum_out(mlirtens[0], mlirtens[1], dim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::cumprod(const at::Tensor &self, int64_t dim, + c10::optional dtype) { + std::cout << "aten::cumprod" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cumprod(mlirtens[0], dim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor & +ATenMLIRTypeDefault::cumprod_out(at::Tensor &out, const at::Tensor &self, + int64_t dim, + c10::optional dtype) { + std::cout << "aten::cumprod_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cumprod_out(mlirtens[0], mlirtens[1], dim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::ctc_loss(const at::Tensor &log_probs, + const at::Tensor &targets, + at::IntArrayRef input_lengths, + at::IntArrayRef target_lengths, + int64_t blank, int64_t reduction, + bool zero_infinity) { + std::cout << "aten::ctc_loss" << std::endl; + std::vector mlirtens_tensors = {log_probs, targets}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::ctc_loss(mlirtens[0], mlirtens[1], input_lengths, target_lengths, + blank, reduction, zero_infinity); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(log_probs)); +} + +at::Tensor ATenMLIRTypeDefault::ctc_loss(const at::Tensor &log_probs, + const at::Tensor &targets, + const at::Tensor &input_lengths, + const at::Tensor &target_lengths, + int64_t blank, int64_t reduction, + bool zero_infinity) { + std::cout << "aten::ctc_loss" << std::endl; + std::vector mlirtens_tensors = {log_probs, targets, input_lengths, + target_lengths}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ctc_loss(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], blank, reduction, zero_infinity); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(log_probs)); +} + +std::tuple ATenMLIRTypeDefault::_ctc_loss( + const at::Tensor &log_probs, const at::Tensor &targets, + at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, + int64_t blank, bool zero_infinity) { + std::cout << "aten::_ctc_loss" << std::endl; + std::vector mlirtens_tensors = {log_probs, targets}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_ctc_loss(mlirtens[0], mlirtens[1], input_lengths, + target_lengths, blank, zero_infinity); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(log_probs)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(log_probs))); +} + +at::Tensor ATenMLIRTypeDefault::_ctc_loss_backward( + const at::Tensor &grad, const at::Tensor &log_probs, + const at::Tensor &targets, at::IntArrayRef input_lengths, + at::IntArrayRef target_lengths, const at::Tensor &neg_log_likelihood, + const at::Tensor &log_alpha, int64_t blank, bool zero_infinity) { + std::cout << "aten::_ctc_loss_backward" << std::endl; + std::vector mlirtens_tensors = {grad, log_probs, targets, + neg_log_likelihood, log_alpha}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_ctc_loss_backward( + mlirtens[0], mlirtens[1], mlirtens[2], input_lengths, target_lengths, + mlirtens[3], mlirtens[4], blank, zero_infinity); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::det(const at::Tensor &self) { + std::cout << "aten::det" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::det(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::diag_embed(const at::Tensor &self, + int64_t offset, int64_t dim1, + int64_t dim2) { + std::cout << "aten::diag_embed" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::diag_embed(mlirtens[0], offset, dim1, dim2); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::diagflat(const at::Tensor &self, + int64_t offset) { + std::cout << "aten::diagflat" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::diagflat(mlirtens[0], offset); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::diagonal(const at::Tensor &self, int64_t offset, + int64_t dim1, int64_t dim2) { + std::cout << "aten::diagonal" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::diagonal(mlirtens[0], offset, dim1, dim2); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::fill_diagonal_(at::Tensor &self, + at::Scalar fill_value, + bool wrap) { + std::cout << "aten::fill_diagonal_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].fill_diagonal_(fill_value, wrap); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::div(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::div" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::div(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::div_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::div_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].div_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::div_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::div_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::div_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::div(const at::Tensor &self, at::Scalar other) { + std::cout << "aten::div" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::div(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::div_(at::Tensor &self, at::Scalar other) { + std::cout << "aten::div_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].div_(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::dot(const at::Tensor &self, + const at::Tensor &tensor) { + std::cout << "aten::dot" << std::endl; + std::vector mlirtens_tensors = {self, tensor}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::dot(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::dot_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &tensor) { + std::cout << "aten::dot_out" << std::endl; + std::vector mlirtens_tensors = {out, self, tensor}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::dot_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::einsum(std::string equation, + at::TensorList tensors) { + std::cout << "aten::einsum" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::einsum(equation, tensors); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(tensors)); +} + +at::Tensor ATenMLIRTypeDefault::embedding(const at::Tensor &weight, + const at::Tensor &indices, + int64_t padding_idx, + bool scale_grad_by_freq, + bool sparse) { + std::cout << "aten::embedding" << std::endl; + std::vector mlirtens_tensors = {weight, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::embedding(mlirtens[0], mlirtens[1], padding_idx, + scale_grad_by_freq, sparse); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(weight)); +} + +at::Tensor ATenMLIRTypeDefault::embedding_backward( + const at::Tensor &grad, const at::Tensor &indices, int64_t num_weights, + int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { + std::cout << "aten::embedding_backward" << std::endl; + std::vector mlirtens_tensors = {grad, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::embedding_backward(mlirtens[0], mlirtens[1], num_weights, padding_idx, + scale_grad_by_freq, sparse); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::embedding_dense_backward( + const at::Tensor &grad_output, const at::Tensor &indices, + int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + std::cout << "aten::embedding_dense_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::embedding_dense_backward( + mlirtens[0], mlirtens[1], num_weights, padding_idx, scale_grad_by_freq); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::embedding_renorm_(at::Tensor &self, + const at::Tensor &indices, + double max_norm, + double norm_type) { + std::cout << "aten::embedding_renorm_" << std::endl; + std::vector mlirtens_tensors = {self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::embedding_renorm_(mlirtens[0], mlirtens[1], max_norm, norm_type); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::embedding_sparse_backward( + const at::Tensor &grad, const at::Tensor &indices, int64_t num_weights, + int64_t padding_idx, bool scale_grad_by_freq) { + std::cout << "aten::embedding_sparse_backward" << std::endl; + std::vector mlirtens_tensors = {grad, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::embedding_sparse_backward( + mlirtens[0], mlirtens[1], num_weights, padding_idx, scale_grad_by_freq); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +std::tuple +ATenMLIRTypeDefault::embedding_bag(const at::Tensor &weight, + const at::Tensor &indices, + const at::Tensor &offsets, + bool scale_grad_by_freq, int64_t mode, + bool sparse, + const at::Tensor &per_sample_weights) { + std::cout << "aten::embedding_bag" << std::endl; + std::vector mlirtens_tensors = {weight, indices, offsets, + per_sample_weights}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::embedding_bag(mlirtens[0], mlirtens[1], mlirtens[2], + scale_grad_by_freq, mode, sparse, mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(weight)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(weight)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(weight)), + bridge::CreateMLIRTensor(std::get<3>(x_result), + bridge::GetMLIRDevice(weight))); +} + +std::tuple +ATenMLIRTypeDefault::_embedding_bag(const at::Tensor &weight, + const at::Tensor &indices, + const at::Tensor &offsets, + bool scale_grad_by_freq, int64_t mode, + bool sparse, + const at::Tensor &per_sample_weights) { + std::cout << "aten::_embedding_bag" << std::endl; + std::vector mlirtens_tensors = {weight, indices, offsets, + per_sample_weights}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_embedding_bag(mlirtens[0], mlirtens[1], mlirtens[2], + scale_grad_by_freq, mode, sparse, mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(weight)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(weight)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(weight)), + bridge::CreateMLIRTensor(std::get<3>(x_result), + bridge::GetMLIRDevice(weight))); +} + +at::Tensor ATenMLIRTypeDefault::_embedding_bag_backward( + const at::Tensor &grad, const at::Tensor &indices, + const at::Tensor &offsets, const at::Tensor &offset2bag, + const at::Tensor &bag_size, const at::Tensor &maximum_indices, + int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, + const at::Tensor &per_sample_weights) { + std::cout << "aten::_embedding_bag_backward" << std::endl; + std::vector mlirtens_tensors = { + grad, indices, offsets, offset2bag, + bag_size, maximum_indices, per_sample_weights}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_embedding_bag_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], num_weights, scale_grad_by_freq, mode, sparse, mlirtens[6]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::_embedding_bag_sparse_backward( + const at::Tensor &grad, const at::Tensor &indices, + const at::Tensor &offsets, const at::Tensor &offset2bag, + const at::Tensor &bag_size, int64_t num_weights, bool scale_grad_by_freq, + int64_t mode, const at::Tensor &per_sample_weights) { + std::cout << "aten::_embedding_bag_sparse_backward" << std::endl; + std::vector mlirtens_tensors = { + grad, indices, offsets, offset2bag, bag_size, per_sample_weights}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_embedding_bag_sparse_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + num_weights, scale_grad_by_freq, mode, mlirtens[5]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::_embedding_bag_dense_backward( + const at::Tensor &grad, const at::Tensor &indices, + const at::Tensor &offsets, const at::Tensor &offset2bag, + const at::Tensor &bag_size, const at::Tensor &maximum_indices, + int64_t num_weights, bool scale_grad_by_freq, int64_t mode, + const at::Tensor &per_sample_weights) { + std::cout << "aten::_embedding_bag_dense_backward" << std::endl; + std::vector mlirtens_tensors = { + grad, indices, offsets, offset2bag, + bag_size, maximum_indices, per_sample_weights}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_embedding_bag_dense_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], num_weights, scale_grad_by_freq, mode, mlirtens[6]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::_embedding_bag_per_sample_weights_backward( + const at::Tensor &grad, const at::Tensor &weight, const at::Tensor &indices, + const at::Tensor &offsets, const at::Tensor &offset2bag, int64_t mode) { + std::cout << "aten::_embedding_bag_per_sample_weights_backward" << std::endl; + std::vector mlirtens_tensors = {grad, weight, indices, offsets, + offset2bag}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_embedding_bag_per_sample_weights_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], mode); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor +ATenMLIRTypeDefault::empty(at::IntArrayRef size, + const at::TensorOptions &options, + c10::optional memory_format) { + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::empty(size, o_options, memory_format); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::new_empty(const at::Tensor &self, + at::IntArrayRef size, + const at::TensorOptions &options) { + std::cout << "aten::new_empty" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].new_empty(size, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::new_full(const at::Tensor &self, + at::IntArrayRef size, + at::Scalar fill_value, + const at::TensorOptions &options) { + std::cout << "aten::new_full" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].new_full(size, fill_value, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_empty_affine_quantized( + at::IntArrayRef size, const at::TensorOptions &options, double scale, + int64_t zero_point, c10::optional memory_format) { + std::cout << "aten::_empty_affine_quantized" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_empty_affine_quantized(size, options, scale, + zero_point, memory_format); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::_empty_per_channel_affine_quantized_like( + const at::Tensor &self, const at::Tensor &zero_points, at::IntArrayRef size, + at::IntArrayRef axis, const at::TensorOptions &options, + c10::optional memory_format) { + std::cout << "aten::_empty_per_channel_affine_quantized_like" << std::endl; + std::vector mlirtens_tensors = {self, zero_points}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_empty_per_channel_affine_quantized_like( + mlirtens[0], mlirtens[1], size, axis, options, memory_format); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::resize_(at::Tensor &self, + at::IntArrayRef size) { + std::cout << "aten::resize_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].resize_(size); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor & +ATenMLIRTypeDefault::empty_out(at::Tensor &out, at::IntArrayRef size, + c10::optional memory_format) { + std::cout << "aten::empty_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::empty_out(mlirtens[0], size, memory_format); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::empty_like(const at::Tensor &self) { + std::cout << "aten::empty_like" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::empty_like(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor +ATenMLIRTypeDefault::empty_like(const at::Tensor &self, + const at::TensorOptions &options, + c10::optional memory_format) { + std::cout << "aten::empty_like" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::empty_like(mlirtens[0], options, memory_format); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor +ATenMLIRTypeDefault::empty_strided(at::IntArrayRef size, at::IntArrayRef stride, + const at::TensorOptions &options) { + std::cout << "aten::empty_strided" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::empty_strided(size, stride, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::erf(const at::Tensor &self) { + std::cout << "aten::erf" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::erf(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::erf_(at::Tensor &self) { + std::cout << "aten::erf_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::erf_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::erf_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::erf_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::erf_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::erfc(const at::Tensor &self) { + std::cout << "aten::erfc" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::erfc(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::erfc_(at::Tensor &self) { + std::cout << "aten::erfc_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::erfc_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::erfc_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::erfc_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::erfc_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::exp(const at::Tensor &self) { + std::cout << "aten::exp" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::exp(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::exp_(at::Tensor &self) { + std::cout << "aten::exp_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::exp_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::exp_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::exp_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::exp_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::expm1(const at::Tensor &self) { + std::cout << "aten::expm1" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::expm1(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::expm1_(at::Tensor &self) { + std::cout << "aten::expm1_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::expm1_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::expm1_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::expm1_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::expm1_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::expand(const at::Tensor &self, + at::IntArrayRef size, bool implicit) { + std::cout << "aten::expand" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].expand(size, implicit); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::expand_as(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::expand_as" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].expand_as(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::eye(int64_t n, + const at::TensorOptions &options) { + std::cout << "aten::eye" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::eye(n, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::eye(int64_t n, int64_t m, + const at::TensorOptions &options) { + std::cout << "aten::eye" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::eye(n, m, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::eye_out(at::Tensor &out, int64_t n) { + std::cout << "aten::eye_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::eye_out(mlirtens[0], n); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::eye_out(at::Tensor &out, int64_t n, + int64_t m) { + std::cout << "aten::eye_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::eye_out(mlirtens[0], n, m); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::flatten(const at::Tensor &self, + int64_t start_dim, int64_t end_dim) { + std::cout << "aten::flatten" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::flatten(mlirtens[0], start_dim, end_dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::fill_(at::Tensor &self, at::Scalar value) { + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fill_(mlirtens[0], value); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::fill_(at::Tensor &self, + const at::Tensor &value) { + std::vector mlirtens_tensors = {self, value}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fill_(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::floor(const at::Tensor &self) { + std::cout << "aten::floor" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::floor(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::floor_(at::Tensor &self) { + std::cout << "aten::floor_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::floor_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::floor_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::floor_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::floor_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::frac(const at::Tensor &self) { + std::cout << "aten::frac" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::frac(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::frac_(at::Tensor &self) { + std::cout << "aten::frac_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::frac_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::frac_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::frac_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::frac_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::full(at::IntArrayRef size, + at::Scalar fill_value, + const at::TensorOptions &options) { + std::cout << "aten::full" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::full(size, fill_value, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::full_out(at::Tensor &out, at::IntArrayRef size, + at::Scalar fill_value) { + std::cout << "aten::full_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::full_out(mlirtens[0], size, fill_value); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::full_like(const at::Tensor &self, + at::Scalar fill_value) { + std::cout << "aten::full_like" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::full_like(mlirtens[0], fill_value); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::full_like(const at::Tensor &self, + at::Scalar fill_value, + const at::TensorOptions &options) { + std::cout << "aten::full_like" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::full_like(mlirtens[0], fill_value, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::from_file(std::string filename, + c10::optional shared, + c10::optional size, + const at::TensorOptions &options) { + std::cout << "aten::from_file" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::from_file(filename, shared, size, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::grid_sampler(const at::Tensor &input, + const at::Tensor &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + std::cout << "aten::grid_sampler" << std::endl; + std::vector mlirtens_tensors = {input, grid}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::grid_sampler(mlirtens[0], mlirtens[1], interpolation_mode, + padding_mode, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::grid_sampler_2d(const at::Tensor &input, + const at::Tensor &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + std::cout << "aten::grid_sampler_2d" << std::endl; + std::vector mlirtens_tensors = {input, grid}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::grid_sampler_2d(mlirtens[0], mlirtens[1], interpolation_mode, + padding_mode, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple +ATenMLIRTypeDefault::grid_sampler_2d_backward(const at::Tensor &grad_output, + const at::Tensor &input, + const at::Tensor &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + std::cout << "aten::grid_sampler_2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, input, grid}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::grid_sampler_2d_backward( + mlirtens[0], mlirtens[1], mlirtens[2], interpolation_mode, padding_mode, + align_corners); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor ATenMLIRTypeDefault::grid_sampler_3d(const at::Tensor &input, + const at::Tensor &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + std::cout << "aten::grid_sampler_3d" << std::endl; + std::vector mlirtens_tensors = {input, grid}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::grid_sampler_3d(mlirtens[0], mlirtens[1], interpolation_mode, + padding_mode, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple +ATenMLIRTypeDefault::grid_sampler_3d_backward(const at::Tensor &grad_output, + const at::Tensor &input, + const at::Tensor &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + std::cout << "aten::grid_sampler_3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, input, grid}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::grid_sampler_3d_backward( + mlirtens[0], mlirtens[1], mlirtens[2], interpolation_mode, padding_mode, + align_corners); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor ATenMLIRTypeDefault::hann_window(int64_t window_length, + const at::TensorOptions &options) { + std::cout << "aten::hann_window" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hann_window(window_length, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::hann_window(int64_t window_length, + bool periodic, + const at::TensorOptions &options) { + std::cout << "aten::hann_window" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hann_window(window_length, periodic, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor +ATenMLIRTypeDefault::hamming_window(int64_t window_length, + const at::TensorOptions &options) { + std::cout << "aten::hamming_window" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hamming_window(window_length, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor +ATenMLIRTypeDefault::hamming_window(int64_t window_length, bool periodic, + const at::TensorOptions &options) { + std::cout << "aten::hamming_window" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hamming_window(window_length, periodic, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor +ATenMLIRTypeDefault::hamming_window(int64_t window_length, bool periodic, + double alpha, + const at::TensorOptions &options) { + std::cout << "aten::hamming_window" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hamming_window(window_length, periodic, alpha, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor +ATenMLIRTypeDefault::hamming_window(int64_t window_length, bool periodic, + double alpha, double beta, + const at::TensorOptions &options) { + std::cout << "aten::hamming_window" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::hamming_window(window_length, periodic, alpha, beta, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::hinge_embedding_loss(const at::Tensor &self, + const at::Tensor &target, + double margin, + int64_t reduction) { + std::cout << "aten::hinge_embedding_loss" << std::endl; + std::vector mlirtens_tensors = {self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::hinge_embedding_loss(mlirtens[0], mlirtens[1], margin, reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::ger(const at::Tensor &self, + const at::Tensor &vec2) { + std::cout << "aten::ger" << std::endl; + std::vector mlirtens_tensors = {self, vec2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ger(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::ger_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &vec2) { + std::cout << "aten::ger_out" << std::endl; + std::vector mlirtens_tensors = {out, self, vec2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ger_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::group_norm(const at::Tensor &input, + int64_t num_groups, + const at::Tensor &weight, + const at::Tensor &bias, double eps, + bool cudnn_enabled) { + std::cout << "aten::group_norm" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::group_norm(mlirtens[0], num_groups, mlirtens[1], + mlirtens[2], eps, cudnn_enabled); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::fft(const at::Tensor &self, int64_t signal_ndim, + bool normalized) { + std::cout << "aten::fft" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fft(mlirtens[0], signal_ndim, normalized); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::ifft(const at::Tensor &self, + int64_t signal_ndim, bool normalized) { + std::cout << "aten::ifft" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ifft(mlirtens[0], signal_ndim, normalized); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::rfft(const at::Tensor &self, + int64_t signal_ndim, bool normalized, + bool onesided) { + std::cout << "aten::rfft" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rfft(mlirtens[0], signal_ndim, normalized, onesided); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::irfft(const at::Tensor &self, + int64_t signal_ndim, bool normalized, + bool onesided, + at::IntArrayRef signal_sizes) { + std::cout << "aten::irfft" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::irfft(mlirtens[0], signal_ndim, normalized, onesided, signal_sizes); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_fft_with_size( + const at::Tensor &self, int64_t signal_ndim, bool complex_input, + bool complex_output, bool inverse, at::IntArrayRef checked_signal_sizes, + bool normalized, bool onesided, at::IntArrayRef output_sizes) { + std::cout << "aten::_fft_with_size" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_fft_with_size( + mlirtens[0], signal_ndim, complex_input, complex_output, inverse, + checked_signal_sizes, normalized, onesided, output_sizes); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +int64_t ATenMLIRTypeDefault::_cufft_get_plan_cache_size(int64_t device_index) { + std::cout << "aten::_cufft_get_plan_cache_size" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cufft_get_plan_cache_size(device_index); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +int64_t +ATenMLIRTypeDefault::_cufft_get_plan_cache_max_size(int64_t device_index) { + std::cout << "aten::_cufft_get_plan_cache_max_size" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cufft_get_plan_cache_max_size(device_index); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +void ATenMLIRTypeDefault::_cufft_set_plan_cache_max_size(int64_t device_index, + int64_t max_size) { + std::cout << "aten::_cufft_set_plan_cache_max_size" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + at::_cufft_set_plan_cache_max_size(device_index, max_size); +} + +void ATenMLIRTypeDefault::_cufft_clear_plan_cache(int64_t device_index) { + std::cout << "aten::_cufft_clear_plan_cache" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + at::_cufft_clear_plan_cache(device_index); +} + +at::Tensor ATenMLIRTypeDefault::index(const at::Tensor &self, + at::TensorList indices) { + std::cout << "aten::index" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::index(mlirtens[0], indices); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::index_copy_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &source) { + std::cout << "aten::index_copy_" << std::endl; + std::vector mlirtens_tensors = {self, index, source}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].index_copy_(dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::index_copy(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &source) { + std::cout << "aten::index_copy" << std::endl; + std::vector mlirtens_tensors = {self, index, source}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::index_copy(mlirtens[0], dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::index_put_(at::Tensor &self, + at::TensorList indices, + const at::Tensor &values, + bool accumulate) { + std::cout << "aten::index_put_" << std::endl; + std::vector mlirtens_tensors = {self, values}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::index_put_(mlirtens[0], indices, mlirtens[1], accumulate); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::index_put(const at::Tensor &self, + at::TensorList indices, + const at::Tensor &values, + bool accumulate) { + std::cout << "aten::index_put" << std::endl; + std::vector mlirtens_tensors = {self, values}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::index_put(mlirtens[0], indices, mlirtens[1], accumulate); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::_index_put_impl_(at::Tensor &self, + at::TensorList indices, + const at::Tensor &values, + bool accumulate, + bool unsafe) { + std::cout << "aten::_index_put_impl_" << std::endl; + std::vector mlirtens_tensors = {self, values}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_index_put_impl_(mlirtens[0], indices, mlirtens[1], + accumulate, unsafe); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::instance_norm( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + const at::Tensor &running_mean, const at::Tensor &running_var, + bool use_input_stats, double momentum, double eps, bool cudnn_enabled) { + std::cout << "aten::instance_norm" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias, running_mean, + running_var}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::instance_norm(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], mlirtens[4], use_input_stats, + momentum, eps, cudnn_enabled); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::inverse(const at::Tensor &self) { + std::cout << "aten::inverse" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::inverse(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::inverse_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::inverse_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::inverse_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::_inverse_helper(const at::Tensor &self) { + std::cout << "aten::_inverse_helper" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_inverse_helper(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::isclose(const at::Tensor &self, + const at::Tensor &other, double rtol, + double atol, bool equal_nan) { + std::cout << "aten::isclose" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::isclose(mlirtens[0], mlirtens[1], rtol, atol, equal_nan); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::isnan(const at::Tensor &self) { + std::cout << "aten::isnan" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::isnan(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +bool ATenMLIRTypeDefault::is_distributed(const at::Tensor &self) { + std::cout << "aten::is_distributed" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::is_distributed(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +bool ATenMLIRTypeDefault::is_floating_point(const at::Tensor &self) { + std::cout << "aten::is_floating_point" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::is_floating_point(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +bool ATenMLIRTypeDefault::is_complex(const at::Tensor &self) { + std::cout << "aten::is_complex" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::is_complex(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +bool ATenMLIRTypeDefault::is_nonzero(const at::Tensor &self) { + std::cout << "aten::is_nonzero" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::is_nonzero(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +bool ATenMLIRTypeDefault::is_same_size(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::is_same_size" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::is_same_size(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +bool ATenMLIRTypeDefault::is_signed(const at::Tensor &self) { + std::cout << "aten::is_signed" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::is_signed(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::kl_div(const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::kl_div" << std::endl; + std::vector mlirtens_tensors = {self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::kl_div(mlirtens[0], mlirtens[1], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::kl_div_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::kl_div_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::kl_div_backward(mlirtens[0], mlirtens[1], mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::kthvalue(const at::Tensor &self, int64_t k, int64_t dim, + bool keepdim) { + std::cout << "aten::kthvalue" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::kthvalue(mlirtens[0], k, dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::kthvalue_out(at::Tensor &values, at::Tensor &indices, + const at::Tensor &self, int64_t k, + int64_t dim, bool keepdim) { + std::cout << "aten::kthvalue_out" << std::endl; + std::vector mlirtens_tensors = {values, indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::kthvalue_out(mlirtens[0], mlirtens[1], mlirtens[2], k, dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(values, indices); +} + +at::Tensor ATenMLIRTypeDefault::layer_norm(const at::Tensor &input, + at::IntArrayRef normalized_shape, + const at::Tensor &weight, + const at::Tensor &bias, double eps, + bool cudnn_enable) { + std::cout << "aten::layer_norm" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::layer_norm(mlirtens[0], normalized_shape, mlirtens[1], + mlirtens[2], eps, cudnn_enable); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple +ATenMLIRTypeDefault::native_layer_norm(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, int64_t M, + int64_t N, double eps) { + std::cout << "aten::native_layer_norm" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::native_layer_norm(mlirtens[0], mlirtens[1], mlirtens[2], M, N, eps); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple +ATenMLIRTypeDefault::native_layer_norm_backward( + const at::Tensor &grad_out, const at::Tensor &input, const at::Tensor &mean, + const at::Tensor &rstd, const at::Tensor &weight, int64_t M, int64_t N, + std::array output_mask) { + std::cout << "aten::native_layer_norm_backward" << std::endl; + std::vector mlirtens_tensors = {grad_out, input, mean, rstd, + weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::native_layer_norm_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], M, N, + output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_out)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_out)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_out))); +} + +std::tuple +ATenMLIRTypeDefault::native_layer_norm_double_backward( + const at::Tensor &ggI, const at::Tensor &ggW, const at::Tensor &ggb, + const at::Tensor &gO, const at::Tensor &input, const at::Tensor &mean, + const at::Tensor &rstd, const at::Tensor &weight, int64_t M, int64_t N, + std::array output_mask) { + std::cout << "aten::native_layer_norm_double_backward" << std::endl; + std::vector mlirtens_tensors = {ggI, ggW, ggb, gO, + input, mean, rstd, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::native_layer_norm_double_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], mlirtens[6], mlirtens[7], M, N, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(ggI)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(ggI)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(ggI))); +} + +at::Tensor ATenMLIRTypeDefault::linear(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias) { + std::cout << "aten::linear" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::linear(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::mkldnn_linear(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias) { + std::cout << "aten::mkldnn_linear" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mkldnn_linear(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::fbgemm_linear_int8_weight_fp32_activation( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &packed, + const at::Tensor &col_offsets, at::Scalar weight_scale, + at::Scalar weight_zero_point, const at::Tensor &bias) { + std::cout << "aten::fbgemm_linear_int8_weight_fp32_activation" << std::endl; + std::vector mlirtens_tensors = {input, weight, packed, + col_offsets, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fbgemm_linear_int8_weight_fp32_activation( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], weight_scale, + weight_zero_point, mlirtens[4]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::fbgemm_linear_int8_weight( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &packed, + const at::Tensor &col_offsets, at::Scalar weight_scale, + at::Scalar weight_zero_point, const at::Tensor &bias) { + std::cout << "aten::fbgemm_linear_int8_weight" << std::endl; + std::vector mlirtens_tensors = {input, weight, packed, + col_offsets, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fbgemm_linear_int8_weight( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], weight_scale, + weight_zero_point, mlirtens[4]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple +ATenMLIRTypeDefault::fbgemm_linear_quantize_weight(const at::Tensor &input) { + std::cout << "aten::fbgemm_linear_quantize_weight" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fbgemm_linear_quantize_weight(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + std::get<2>(x_result), std::get<3>(x_result)); +} + +at::Tensor +ATenMLIRTypeDefault::fbgemm_pack_gemm_matrix_fp16(const at::Tensor &input) { + std::cout << "aten::fbgemm_pack_gemm_matrix_fp16" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fbgemm_pack_gemm_matrix_fp16(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::fbgemm_linear_fp16_weight_fp32_activation( + const at::Tensor &input, const at::Tensor &packed_weight, + const at::Tensor &bias) { + std::cout << "aten::fbgemm_linear_fp16_weight_fp32_activation" << std::endl; + std::vector mlirtens_tensors = {input, packed_weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fbgemm_linear_fp16_weight_fp32_activation( + mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor +ATenMLIRTypeDefault::fbgemm_linear_fp16_weight(const at::Tensor &input, + const at::Tensor &packed_weight, + const at::Tensor &bias) { + std::cout << "aten::fbgemm_linear_fp16_weight" << std::endl; + std::vector mlirtens_tensors = {input, packed_weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::fbgemm_linear_fp16_weight(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor +ATenMLIRTypeDefault::fbgemm_pack_quantized_matrix(const at::Tensor &input) { + std::cout << "aten::fbgemm_pack_quantized_matrix" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fbgemm_pack_quantized_matrix(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor +ATenMLIRTypeDefault::fbgemm_pack_quantized_matrix(const at::Tensor &input, + int64_t K, int64_t N) { + std::cout << "aten::fbgemm_pack_quantized_matrix" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fbgemm_pack_quantized_matrix(mlirtens[0], K, N); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::linspace(at::Scalar start, at::Scalar end, + int64_t steps, + const at::TensorOptions &options) { + std::cout << "aten::linspace" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::linspace(start, end, steps, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::linspace_out(at::Tensor &out, at::Scalar start, + at::Scalar end, int64_t steps) { + std::cout << "aten::linspace_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::linspace_out(mlirtens[0], start, end, steps); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::log(const at::Tensor &self) { + std::cout << "aten::log" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::log_(at::Tensor &self) { + std::cout << "aten::log_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::log_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::log_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::log10(const at::Tensor &self) { + std::cout << "aten::log10" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log10(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::log10_(at::Tensor &self) { + std::cout << "aten::log10_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log10_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::log10_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::log10_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log10_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::log1p(const at::Tensor &self) { + std::cout << "aten::log1p" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log1p(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::log1p_(at::Tensor &self) { + std::cout << "aten::log1p_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log1p_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::log1p_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::log1p_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log1p_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::log2(const at::Tensor &self) { + std::cout << "aten::log2" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log2(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::log2_(at::Tensor &self) { + std::cout << "aten::log2_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log2_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::log2_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::log2_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log2_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::logdet(const at::Tensor &self) { + std::cout << "aten::logdet" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::logdet(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::logspace(at::Scalar start, at::Scalar end, + int64_t steps, double base, + const at::TensorOptions &options) { + std::cout << "aten::logspace" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::logspace(start, end, steps, base, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::logspace_out(at::Tensor &out, at::Scalar start, + at::Scalar end, int64_t steps, + double base) { + std::cout << "aten::logspace_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::logspace_out(mlirtens[0], start, end, steps, base); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor +ATenMLIRTypeDefault::log_softmax(const at::Tensor &self, int64_t dim, + c10::optional dtype) { + std::cout << "aten::log_softmax" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log_softmax(mlirtens[0], dim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_log_softmax(const at::Tensor &self, + int64_t dim, bool half_to_float) { + std::cout << "aten::_log_softmax" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_log_softmax(mlirtens[0], dim, half_to_float); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_log_softmax_backward_data( + const at::Tensor &grad_output, const at::Tensor &output, int64_t dim, + const at::Tensor &self) { + std::cout << "aten::_log_softmax_backward_data" << std::endl; + std::vector mlirtens_tensors = {grad_output, output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_log_softmax_backward_data(mlirtens[0], mlirtens[1], + dim, mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor ATenMLIRTypeDefault::logsumexp(const at::Tensor &self, + at::IntArrayRef dim, bool keepdim) { + std::cout << "aten::logsumexp" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::logsumexp(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::logsumexp_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef dim, + bool keepdim) { + std::cout << "aten::logsumexp_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::logsumexp_out(mlirtens[0], mlirtens[1], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::margin_ranking_loss(const at::Tensor &input1, + const at::Tensor &input2, + const at::Tensor &target, + double margin, + int64_t reduction) { + std::cout << "aten::margin_ranking_loss" << std::endl; + std::vector mlirtens_tensors = {input1, input2, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::margin_ranking_loss(mlirtens[0], mlirtens[1], + mlirtens[2], margin, reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input1)); +} + +at::Tensor ATenMLIRTypeDefault::matmul(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::matmul" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::matmul(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::matmul_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::matmul_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::matmul_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::matrix_rank(const at::Tensor &self, double tol, + bool symmetric) { + std::cout << "aten::matrix_rank" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::matrix_rank(mlirtens[0], tol, symmetric); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::matrix_rank(const at::Tensor &self, + bool symmetric) { + std::cout << "aten::matrix_rank" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::matrix_rank(mlirtens[0], symmetric); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::matrix_power(const at::Tensor &self, + int64_t n) { + std::cout << "aten::matrix_power" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::matrix_power(mlirtens[0], n); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::max(const at::Tensor &self, int64_t dim, bool keepdim) { + std::cout << "aten::max" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::max_out(at::Tensor &max, at::Tensor &max_values, + const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::max_out" << std::endl; + std::vector mlirtens_tensors = {max, max_values, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::max_out(mlirtens[0], mlirtens[1], mlirtens[2], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(max, max_values); +} + +at::Tensor ATenMLIRTypeDefault::max_values(const at::Tensor &self, + at::IntArrayRef dim, bool keepdim) { + std::cout << "aten::max_values" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_values(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple ATenMLIRTypeDefault::max_pool1d_with_indices( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + std::cout << "aten::max_pool1d_with_indices" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool1d_with_indices( + mlirtens[0], kernel_size, stride, padding, dilation, ceil_mode); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::max_pool1d( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + std::cout << "aten::max_pool1d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool1d(mlirtens[0], kernel_size, stride, padding, + dilation, ceil_mode); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::max_pool2d( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + std::cout << "aten::max_pool2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool2d(mlirtens[0], kernel_size, stride, padding, + dilation, ceil_mode); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::mkldnn_max_pool2d( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + std::cout << "aten::mkldnn_max_pool2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mkldnn_max_pool2d(mlirtens[0], kernel_size, stride, + padding, dilation, ceil_mode); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::quantized_max_pool2d( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation) { + std::cout << "aten::quantized_max_pool2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::quantized_max_pool2d(mlirtens[0], kernel_size, stride, + padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::max_pool3d( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + std::cout << "aten::max_pool3d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool3d(mlirtens[0], kernel_size, stride, padding, + dilation, ceil_mode); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::mean(const at::Tensor &self, + c10::optional dtype) { + std::cout << "aten::mean" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mean(mlirtens[0], dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::mean(const at::Tensor &self, + at::IntArrayRef dim, bool keepdim, + c10::optional dtype) { + std::cout << "aten::mean" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mean(mlirtens[0], dim, keepdim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::mean_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef dim, bool keepdim, + c10::optional dtype) { + std::cout << "aten::mean_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mean_out(mlirtens[0], mlirtens[1], dim, keepdim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +std::tuple +ATenMLIRTypeDefault::median(const at::Tensor &self, int64_t dim, bool keepdim) { + std::cout << "aten::median" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::median(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::median_out(at::Tensor &values, at::Tensor &indices, + const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::median_out" << std::endl; + std::vector mlirtens_tensors = {values, indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::median_out(mlirtens[0], mlirtens[1], mlirtens[2], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(values, indices); +} + +std::tuple +ATenMLIRTypeDefault::min(const at::Tensor &self, int64_t dim, bool keepdim) { + std::cout << "aten::min" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::min(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::min_out(at::Tensor &min, at::Tensor &min_indices, + const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::min_out" << std::endl; + std::vector mlirtens_tensors = {min, min_indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::min_out(mlirtens[0], mlirtens[1], mlirtens[2], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(min, min_indices); +} + +at::Tensor ATenMLIRTypeDefault::min_values(const at::Tensor &self, + at::IntArrayRef dim, bool keepdim) { + std::cout << "aten::min_values" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::min_values(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::mkldnn_convolution( + const at::Tensor &self, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, + int64_t groups) { + std::cout << "aten::mkldnn_convolution" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mkldnn_convolution( + mlirtens[0], mlirtens[1], mlirtens[2], padding, stride, dilation, groups); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::mkldnn_convolution_backward_input( + at::IntArrayRef self_size, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool bias_defined) { + std::cout << "aten::mkldnn_convolution_backward_input" << std::endl; + std::vector mlirtens_tensors = {grad_output, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mkldnn_convolution_backward_input( + self_size, mlirtens[0], mlirtens[1], padding, stride, dilation, groups, + bias_defined); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::mkldnn_convolution_backward_weights( + at::IntArrayRef weight_size, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool bias_defined) { + std::cout << "aten::mkldnn_convolution_backward_weights" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mkldnn_convolution_backward_weights( + weight_size, mlirtens[0], mlirtens[1], padding, stride, dilation, groups, + bias_defined); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +std::tuple +ATenMLIRTypeDefault::mkldnn_convolution_backward( + const at::Tensor &self, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, std::array output_mask) { + std::cout << "aten::mkldnn_convolution_backward" << std::endl; + std::vector mlirtens_tensors = {self, grad_output, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mkldnn_convolution_backward( + mlirtens[0], mlirtens[1], mlirtens[2], padding, stride, dilation, groups, + output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::miopen_batch_norm( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + const at::Tensor &running_mean, const at::Tensor &running_var, + bool training, double exponential_average_factor, double epsilon) { + std::cout << "aten::miopen_batch_norm" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias, running_mean, + running_var}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_batch_norm(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], mlirtens[4], training, + exponential_average_factor, epsilon); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple +ATenMLIRTypeDefault::miopen_batch_norm_backward( + const at::Tensor &input, const at::Tensor &grad_output, + const at::Tensor &weight, const at::Tensor &running_mean, + const at::Tensor &running_var, const at::Tensor &save_mean, + const at::Tensor &save_var, double epsilon) { + std::cout << "aten::miopen_batch_norm_backward" << std::endl; + std::vector mlirtens_tensors = { + input, grad_output, weight, running_mean, + running_var, save_mean, save_var}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_batch_norm_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], mlirtens[6], epsilon); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input))); +} + +at::Tensor ATenMLIRTypeDefault::miopen_convolution( + const at::Tensor &self, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) { + std::cout << "aten::miopen_convolution" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_convolution( + mlirtens[0], mlirtens[1], mlirtens[2], padding, stride, dilation, groups, + benchmark, deterministic); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::miopen_convolution_backward_input( + at::IntArrayRef self_size, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic) { + std::cout << "aten::miopen_convolution_backward_input" << std::endl; + std::vector mlirtens_tensors = {grad_output, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_convolution_backward_input( + self_size, mlirtens[0], mlirtens[1], padding, stride, dilation, groups, + benchmark, deterministic); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::miopen_convolution_backward( + const at::Tensor &self, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic, std::array output_mask) { + std::cout << "aten::miopen_convolution_backward" << std::endl; + std::vector mlirtens_tensors = {self, grad_output, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_convolution_backward( + mlirtens[0], mlirtens[1], mlirtens[2], padding, stride, dilation, groups, + benchmark, deterministic, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::miopen_convolution_backward_bias( + const at::Tensor &grad_output) { + std::cout << "aten::miopen_convolution_backward_bias" << std::endl; + std::vector mlirtens_tensors = {grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_convolution_backward_bias(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor ATenMLIRTypeDefault::miopen_convolution_backward_weight( + at::IntArrayRef weight_size, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic) { + std::cout << "aten::miopen_convolution_backward_weight" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_convolution_backward_weight( + weight_size, mlirtens[0], mlirtens[1], padding, stride, dilation, groups, + benchmark, deterministic); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor ATenMLIRTypeDefault::miopen_convolution_transpose( + const at::Tensor &self, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef padding, at::IntArrayRef output_padding, + at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic) { + std::cout << "aten::miopen_convolution_transpose" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_convolution_transpose( + mlirtens[0], mlirtens[1], mlirtens[2], padding, output_padding, stride, + dilation, groups, benchmark, deterministic); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::miopen_convolution_transpose_backward( + const at::Tensor &self, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic, std::array output_mask) { + std::cout << "aten::miopen_convolution_transpose_backward" << std::endl; + std::vector mlirtens_tensors = {self, grad_output, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_convolution_transpose_backward( + mlirtens[0], mlirtens[1], mlirtens[2], padding, output_padding, stride, + dilation, groups, benchmark, deterministic, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::miopen_convolution_transpose_backward_input( + const at::Tensor &grad_output, const at::Tensor &weight, + at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) { + std::cout << "aten::miopen_convolution_transpose_backward_input" << std::endl; + std::vector mlirtens_tensors = {grad_output, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_convolution_transpose_backward_input( + mlirtens[0], mlirtens[1], padding, stride, dilation, groups, benchmark, + deterministic); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor ATenMLIRTypeDefault::miopen_convolution_transpose_backward_weight( + at::IntArrayRef weight_size, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic) { + std::cout << "aten::miopen_convolution_transpose_backward_weight" + << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_convolution_transpose_backward_weight( + weight_size, mlirtens[0], mlirtens[1], padding, stride, dilation, groups, + benchmark, deterministic); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor ATenMLIRTypeDefault::miopen_depthwise_convolution( + const at::Tensor &self, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) { + std::cout << "aten::miopen_depthwise_convolution" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_depthwise_convolution( + mlirtens[0], mlirtens[1], mlirtens[2], padding, stride, dilation, groups, + benchmark, deterministic); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::miopen_depthwise_convolution_backward_input( + at::IntArrayRef self_size, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic) { + std::cout << "aten::miopen_depthwise_convolution_backward_input" << std::endl; + std::vector mlirtens_tensors = {grad_output, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_depthwise_convolution_backward_input( + self_size, mlirtens[0], mlirtens[1], padding, stride, dilation, groups, + benchmark, deterministic); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::miopen_depthwise_convolution_backward( + const at::Tensor &self, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic, std::array output_mask) { + std::cout << "aten::miopen_depthwise_convolution_backward" << std::endl; + std::vector mlirtens_tensors = {self, grad_output, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_depthwise_convolution_backward( + mlirtens[0], mlirtens[1], mlirtens[2], padding, stride, dilation, groups, + benchmark, deterministic, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::miopen_depthwise_convolution_backward_weight( + at::IntArrayRef weight_size, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic) { + std::cout << "aten::miopen_depthwise_convolution_backward_weight" + << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_depthwise_convolution_backward_weight( + weight_size, mlirtens[0], mlirtens[1], padding, stride, dilation, groups, + benchmark, deterministic); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::miopen_rnn(const at::Tensor &input, at::TensorList weight, + int64_t weight_stride0, const at::Tensor &hx, + const at::Tensor &cx, int64_t mode, + int64_t hidden_size, int64_t num_layers, + bool batch_first, double dropout, bool train, + bool bidirectional, at::IntArrayRef batch_sizes, + const at::Tensor &dropout_state) { + std::cout << "aten::miopen_rnn" << std::endl; + std::vector mlirtens_tensors = {input, hx, cx, dropout_state}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::miopen_rnn(mlirtens[0], weight, weight_stride0, mlirtens[1], + mlirtens[2], mode, hidden_size, num_layers, batch_first, + dropout, train, bidirectional, batch_sizes, mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<3>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<4>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple> +ATenMLIRTypeDefault::miopen_rnn_backward( + const at::Tensor &input, at::TensorList weight, int64_t weight_stride0, + const at::Tensor &weight_buf, const at::Tensor &hx, const at::Tensor &cx, + const at::Tensor &output, const at::Tensor &grad_output, + const at::Tensor &grad_hy, const at::Tensor &grad_cy, int64_t mode, + int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, + bool train, bool bidirectional, at::IntArrayRef batch_sizes, + const at::Tensor &dropout_state, const at::Tensor &reserve, + std::array output_mask) { + std::cout << "aten::miopen_rnn_backward" << std::endl; + std::vector mlirtens_tensors = { + input, weight_buf, hx, cx, output, grad_output, grad_hy, + grad_cy, dropout_state, reserve}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::miopen_rnn_backward( + mlirtens[0], weight, weight_stride0, mlirtens[1], mlirtens[2], + mlirtens[3], mlirtens[4], mlirtens[5], mlirtens[6], mlirtens[7], mode, + hidden_size, num_layers, batch_first, dropout, train, bidirectional, + batch_sizes, mlirtens[8], mlirtens[9], output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple>( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input)), + std::get<3>(x_result)); +} + +at::Tensor ATenMLIRTypeDefault::mm(const at::Tensor &self, + const at::Tensor &mat2) { + std::cout << "aten::mm" << std::endl; + std::vector mlirtens_tensors = {self, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mm(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::mm_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &mat2) { + std::cout << "aten::mm_out" << std::endl; + std::vector mlirtens_tensors = {out, self, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mm_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::_sparse_mm(const at::Tensor &sparse, + const at::Tensor &dense) { + std::cout << "aten::_sparse_mm" << std::endl; + std::vector mlirtens_tensors = {sparse, dense}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sparse_mm(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(sparse)); +} + +std::tuple +ATenMLIRTypeDefault::mode(const at::Tensor &self, int64_t dim, bool keepdim) { + std::cout << "aten::mode" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mode(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::mode_out(at::Tensor &values, at::Tensor &indices, + const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::mode_out" << std::endl; + std::vector mlirtens_tensors = {values, indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::mode_out(mlirtens[0], mlirtens[1], mlirtens[2], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(values, indices); +} + +at::Tensor ATenMLIRTypeDefault::mul(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::mul" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mul(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::mul_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::mul_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].mul_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::mul_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::mul_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mul_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::mul(const at::Tensor &self, at::Scalar other) { + std::cout << "aten::mul" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mul(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::mul_(at::Tensor &self, at::Scalar other) { + std::cout << "aten::mul_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].mul_(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::mv(const at::Tensor &self, + const at::Tensor &vec) { + std::cout << "aten::mv" << std::endl; + std::vector mlirtens_tensors = {self, vec}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mv(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::mv_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &vec) { + std::cout << "aten::mv_out" << std::endl; + std::vector mlirtens_tensors = {out, self, vec}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mv_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::mvlgamma(const at::Tensor &self, int64_t p) { + std::cout << "aten::mvlgamma" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mvlgamma(mlirtens[0], p); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::mvlgamma_(at::Tensor &self, int64_t p) { + std::cout << "aten::mvlgamma_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].mvlgamma_(p); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::narrow_copy(const at::Tensor &self, int64_t dim, + int64_t start, int64_t length) { + std::cout << "aten::narrow_copy" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].narrow_copy(dim, start, length); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::narrow(const at::Tensor &self, int64_t dim, + int64_t start, int64_t length) { + std::cout << "aten::narrow" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::narrow(mlirtens[0], dim, start, length); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::native_batch_norm( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + const at::Tensor &running_mean, const at::Tensor &running_var, + bool training, double momentum, double eps) { + std::cout << "aten::native_batch_norm" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias, running_mean, + running_var}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::native_batch_norm(mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], + mlirtens[4], training, momentum, eps); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple +ATenMLIRTypeDefault::batch_norm_stats(const at::Tensor &input, double eps) { + std::cout << "aten::batch_norm_stats" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::batch_norm_stats(mlirtens[0], eps); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +at::Tensor ATenMLIRTypeDefault::batch_norm_elemt( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + const at::Tensor &mean, const at::Tensor &invstd, double eps) { + std::cout << "aten::batch_norm_elemt" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias, mean, + invstd}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::batch_norm_elemt(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], mlirtens[4], eps); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple ATenMLIRTypeDefault::batch_norm_gather_stats( + const at::Tensor &input, const at::Tensor &mean, const at::Tensor &invstd, + const at::Tensor &running_mean, const at::Tensor &running_var, + double momentum, double eps, int64_t count) { + std::cout << "aten::batch_norm_gather_stats" << std::endl; + std::vector mlirtens_tensors = {input, mean, invstd, running_mean, + running_var}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::batch_norm_gather_stats( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], momentum, + eps, count); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple +ATenMLIRTypeDefault::batch_norm_gather_stats_with_counts( + const at::Tensor &input, const at::Tensor &mean, const at::Tensor &invstd, + const at::Tensor &running_mean, const at::Tensor &running_var, + double momentum, double eps, at::IntArrayRef counts) { + std::cout << "aten::batch_norm_gather_stats_with_counts" << std::endl; + std::vector mlirtens_tensors = {input, mean, invstd, running_mean, + running_var}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::batch_norm_gather_stats_with_counts( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], momentum, + eps, counts); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple +ATenMLIRTypeDefault::native_batch_norm_backward( + const at::Tensor &grad_out, const at::Tensor &input, + const at::Tensor &weight, const at::Tensor &running_mean, + const at::Tensor &running_var, const at::Tensor &save_mean, + const at::Tensor &save_invstd, bool train, double eps, + std::array output_mask) { + std::cout << "aten::native_batch_norm_backward" << std::endl; + std::vector mlirtens_tensors = { + grad_out, input, weight, running_mean, + running_var, save_mean, save_invstd}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::native_batch_norm_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], mlirtens[6], train, eps, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_out)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_out)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_out))); +} + +std::tuple +ATenMLIRTypeDefault::batch_norm_backward_reduce( + const at::Tensor &grad_out, const at::Tensor &input, const at::Tensor &mean, + const at::Tensor &invstd, const at::Tensor &weight, bool input_g, + bool weight_g, bool bias_g) { + std::cout << "aten::batch_norm_backward_reduce" << std::endl; + std::vector mlirtens_tensors = {grad_out, input, mean, invstd, + weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::batch_norm_backward_reduce( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], input_g, + weight_g, bias_g); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_out)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_out)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_out)), + bridge::CreateMLIRTensor(std::get<3>(x_result), + bridge::GetMLIRDevice(grad_out))); +} + +at::Tensor ATenMLIRTypeDefault::batch_norm_backward_elemt( + const at::Tensor &grad_out, const at::Tensor &input, const at::Tensor &mean, + const at::Tensor &invstd, const at::Tensor &weight, + const at::Tensor &mean_dy, const at::Tensor &mean_dy_xmu) { + std::cout << "aten::batch_norm_backward_elemt" << std::endl; + std::vector mlirtens_tensors = { + grad_out, input, mean, invstd, weight, mean_dy, mean_dy_xmu}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::batch_norm_backward_elemt( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], mlirtens[6]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_out)); +} + +std::tuple ATenMLIRTypeDefault::batch_norm_update_stats( + const at::Tensor &input, const at::Tensor &running_mean, + const at::Tensor &running_var, double momentum) { + std::cout << "aten::batch_norm_update_stats" << std::endl; + std::vector mlirtens_tensors = {input, running_mean, running_var}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::batch_norm_update_stats(mlirtens[0], mlirtens[1], + mlirtens[2], momentum); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +at::Tensor ATenMLIRTypeDefault::_nnpack_spatial_convolution( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef padding) { + std::cout << "aten::_nnpack_spatial_convolution" << std::endl; + std::vector mlirtens_tensors = {input, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_nnpack_spatial_convolution(mlirtens[0], mlirtens[1], + mlirtens[2], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple +ATenMLIRTypeDefault::_nnpack_spatial_convolution_backward( + const at::Tensor &input, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, + std::array output_mask) { + std::cout << "aten::_nnpack_spatial_convolution_backward" << std::endl; + std::vector mlirtens_tensors = {input, grad_output, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_nnpack_spatial_convolution_backward( + mlirtens[0], mlirtens[1], mlirtens[2], padding, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input))); +} + +at::Tensor ATenMLIRTypeDefault::_nnpack_spatial_convolution_backward_input( + const at::Tensor &input, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding) { + std::cout << "aten::_nnpack_spatial_convolution_backward_input" << std::endl; + std::vector mlirtens_tensors = {input, grad_output, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_nnpack_spatial_convolution_backward_input( + mlirtens[0], mlirtens[1], mlirtens[2], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::_nnpack_spatial_convolution_backward_weight( + const at::Tensor &input, at::IntArrayRef weightsize, + const at::Tensor &grad_output, at::IntArrayRef padding) { + std::cout << "aten::_nnpack_spatial_convolution_backward_weight" << std::endl; + std::vector mlirtens_tensors = {input, grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_nnpack_spatial_convolution_backward_weight( + mlirtens[0], weightsize, mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor &ATenMLIRTypeDefault::ones_out(at::Tensor &out, + at::IntArrayRef size) { + std::cout << "aten::ones_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ones_out(mlirtens[0], size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::pairwise_distance(const at::Tensor &x1, + const at::Tensor &x2, + double p, double eps, + bool keepdim) { + std::cout << "aten::pairwise_distance" << std::endl; + std::vector mlirtens_tensors = {x1, x2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::pairwise_distance(mlirtens[0], mlirtens[1], p, eps, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(x1)); +} + +at::Tensor ATenMLIRTypeDefault::cdist(const at::Tensor &x1, + const at::Tensor &x2, double p) { + std::cout << "aten::cdist" << std::endl; + std::vector mlirtens_tensors = {x1, x2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cdist(mlirtens[0], mlirtens[1], p); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(x1)); +} + +at::Tensor ATenMLIRTypeDefault::_cdist_backward(const at::Tensor &grad, + const at::Tensor &x1, + const at::Tensor &x2, double p, + const at::Tensor &cdist) { + std::cout << "aten::_cdist_backward" << std::endl; + std::vector mlirtens_tensors = {grad, x1, x2, cdist}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cdist_backward(mlirtens[0], mlirtens[1], mlirtens[2], + p, mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::pdist(const at::Tensor &self, double p) { + std::cout << "aten::pdist" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::pdist(mlirtens[0], p); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_pdist_forward(const at::Tensor &self, + double p) { + std::cout << "aten::_pdist_forward" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_pdist_forward(mlirtens[0], p); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_pdist_backward(const at::Tensor &grad, + const at::Tensor &self, + double p, + const at::Tensor &pdist) { + std::cout << "aten::_pdist_backward" << std::endl; + std::vector mlirtens_tensors = {grad, self, pdist}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_pdist_backward(mlirtens[0], mlirtens[1], p, mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::cosine_similarity(const at::Tensor &x1, + const at::Tensor &x2, + int64_t dim, double eps) { + std::cout << "aten::cosine_similarity" << std::endl; + std::vector mlirtens_tensors = {x1, x2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cosine_similarity(mlirtens[0], mlirtens[1], dim, eps); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(x1)); +} + +at::Tensor ATenMLIRTypeDefault::permute(const at::Tensor &self, + at::IntArrayRef dims) { + std::cout << "aten::permute" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].permute(dims); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::numpy_T(const at::Tensor &self) { + std::cout << "aten::numpy_T" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].numpy_T(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::pixel_shuffle(const at::Tensor &self, + int64_t upscale_factor) { + std::cout << "aten::pixel_shuffle" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::pixel_shuffle(mlirtens[0], upscale_factor); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +bool ATenMLIRTypeDefault::is_pinned(const at::Tensor &self) { + std::cout << "aten::is_pinned" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].is_pinned(); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::pin_memory(const at::Tensor &self) { + std::cout << "aten::pin_memory" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].pin_memory(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::pinverse(const at::Tensor &self, double rcond) { + std::cout << "aten::pinverse" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::pinverse(mlirtens[0], rcond); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::poisson_nll_loss(const at::Tensor &input, + const at::Tensor &target, + bool log_input, bool full, + double eps, + int64_t reduction) { + std::cout << "aten::poisson_nll_loss" << std::endl; + std::vector mlirtens_tensors = {input, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::poisson_nll_loss(mlirtens[0], mlirtens[1], log_input, + full, eps, reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor +ATenMLIRTypeDefault::scalar_tensor(at::Scalar s, + const at::TensorOptions &options) { + std::cout << "aten::scalar_tensor" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::scalar_tensor(s, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::rand(at::IntArrayRef size, + const at::TensorOptions &options) { + std::cout << "aten::rand" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rand(size, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::rand(at::IntArrayRef size, + at::Generator *generator, + const at::TensorOptions &options) { + std::cout << "aten::rand" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rand(size, generator, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::rand_out(at::Tensor &out, + at::IntArrayRef size) { + std::cout << "aten::rand_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rand_out(mlirtens[0], size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::rand_out(at::Tensor &out, at::IntArrayRef size, + at::Generator *generator) { + std::cout << "aten::rand_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rand_out(mlirtens[0], size, generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::rand_like(const at::Tensor &self) { + std::cout << "aten::rand_like" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rand_like(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::rand_like(const at::Tensor &self, + const at::TensorOptions &options) { + std::cout << "aten::rand_like" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rand_like(mlirtens[0], o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::randint(int64_t high, at::IntArrayRef size, + const at::TensorOptions &options) { + std::cout << "aten::randint" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint(high, size, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::randint(int64_t high, at::IntArrayRef size, + at::Generator *generator, + const at::TensorOptions &options) { + std::cout << "aten::randint" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint(high, size, generator, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::randint(int64_t low, int64_t high, + at::IntArrayRef size, + const at::TensorOptions &options) { + std::cout << "aten::randint" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint(low, high, size, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::randint(int64_t low, int64_t high, + at::IntArrayRef size, + at::Generator *generator, + const at::TensorOptions &options) { + std::cout << "aten::randint" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint(low, high, size, generator, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::randint_out(at::Tensor &out, int64_t high, + at::IntArrayRef size) { + std::cout << "aten::randint_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint_out(mlirtens[0], high, size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::randint_out(at::Tensor &out, int64_t high, + at::IntArrayRef size, + at::Generator *generator) { + std::cout << "aten::randint_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint_out(mlirtens[0], high, size, generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::randint_out(at::Tensor &out, int64_t low, + int64_t high, + at::IntArrayRef size) { + std::cout << "aten::randint_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint_out(mlirtens[0], low, high, size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::randint_out(at::Tensor &out, int64_t low, + int64_t high, at::IntArrayRef size, + at::Generator *generator) { + std::cout << "aten::randint_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint_out(mlirtens[0], low, high, size, generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::randint_like(const at::Tensor &self, + int64_t high) { + std::cout << "aten::randint_like" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint_like(mlirtens[0], high); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::randint_like(const at::Tensor &self, + int64_t low, int64_t high) { + std::cout << "aten::randint_like" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint_like(mlirtens[0], low, high); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::randint_like(const at::Tensor &self, + int64_t high, + const at::TensorOptions &options) { + std::cout << "aten::randint_like" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint_like(mlirtens[0], high, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::randint_like(const at::Tensor &self, + int64_t low, int64_t high, + const at::TensorOptions &options) { + std::cout << "aten::randint_like" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randint_like(mlirtens[0], low, high, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::randn(at::IntArrayRef size, + const at::TensorOptions &options) { + std::cout << "aten::randn" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randn(size, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::randn(at::IntArrayRef size, + at::Generator *generator, + const at::TensorOptions &options) { + std::cout << "aten::randn" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randn(size, generator, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::randn_out(at::Tensor &out, + at::IntArrayRef size) { + std::cout << "aten::randn_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randn_out(mlirtens[0], size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::randn_out(at::Tensor &out, + at::IntArrayRef size, + at::Generator *generator) { + std::cout << "aten::randn_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randn_out(mlirtens[0], size, generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::randn_like(const at::Tensor &self) { + std::cout << "aten::randn_like" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randn_like(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::randn_like(const at::Tensor &self, + const at::TensorOptions &options) { + std::cout << "aten::randn_like" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randn_like(mlirtens[0], o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::randperm(int64_t n, + const at::TensorOptions &options) { + std::cout << "aten::randperm" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randperm(n, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::randperm(int64_t n, at::Generator *generator, + const at::TensorOptions &options) { + std::cout << "aten::randperm" << std::endl; + at::TensorOptions o_options = options.device(at::DeviceType::CPU); + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randperm(n, generator, o_options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::randperm_out(at::Tensor &out, int64_t n) { + std::cout << "aten::randperm_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randperm_out(mlirtens[0], n); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::randperm_out(at::Tensor &out, int64_t n, + at::Generator *generator) { + std::cout << "aten::randperm_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::randperm_out(mlirtens[0], n, generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::range(at::Scalar start, at::Scalar end, + at::Scalar step, + const at::TensorOptions &options) { + std::cout << "aten::range" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::range(start, end, step, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::range(at::Scalar start, at::Scalar end, + const at::TensorOptions &options) { + std::cout << "aten::range" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::range(start, end, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::range_out(at::Tensor &out, at::Scalar start, + at::Scalar end, at::Scalar step) { + std::cout << "aten::range_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::range_out(mlirtens[0], start, end, step); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::reciprocal(const at::Tensor &self) { + std::cout << "aten::reciprocal" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::reciprocal(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::reciprocal_(at::Tensor &self) { + std::cout << "aten::reciprocal_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::reciprocal_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::reciprocal_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::reciprocal_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::reciprocal_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::neg(const at::Tensor &self) { + std::cout << "aten::neg" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::neg(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::neg_(at::Tensor &self) { + std::cout << "aten::neg_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::neg_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::neg_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::neg_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::neg_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::repeat(const at::Tensor &self, + at::IntArrayRef repeats) { + std::cout << "aten::repeat" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].repeat(repeats); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::repeat_interleave(const at::Tensor &repeats) { + std::cout << "aten::repeat_interleave" << std::endl; + std::vector mlirtens_tensors = {repeats}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::repeat_interleave(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(repeats)); +} + +at::Tensor ATenMLIRTypeDefault::repeat_interleave(const at::Tensor &self, + const at::Tensor &repeats, + c10::optional dim) { + std::cout << "aten::repeat_interleave" << std::endl; + std::vector mlirtens_tensors = {self, repeats}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::repeat_interleave(mlirtens[0], mlirtens[1], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::repeat_interleave(const at::Tensor &self, + int64_t repeats, + c10::optional dim) { + std::cout << "aten::repeat_interleave" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::repeat_interleave(mlirtens[0], repeats, dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::reshape(const at::Tensor &self, + at::IntArrayRef shape) { + std::cout << "aten::reshape" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::reshape(mlirtens[0], shape); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_mkldnn_reshape(const at::Tensor &self, + at::IntArrayRef shape) { + std::cout << "aten::_mkldnn_reshape" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_mkldnn_reshape(mlirtens[0], shape); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::reshape_as(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::reshape_as" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].reshape_as(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::round(const at::Tensor &self) { + std::cout << "aten::round" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::round(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::round_(at::Tensor &self) { + std::cout << "aten::round_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::round_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::round_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::round_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::round_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::rrelu(const at::Tensor &self, at::Scalar lower, + at::Scalar upper, bool training, + at::Generator *generator) { + std::cout << "aten::rrelu" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rrelu(mlirtens[0], lower, upper, training, generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::rrelu_(at::Tensor &self, at::Scalar lower, + at::Scalar upper, bool training, + at::Generator *generator) { + std::cout << "aten::rrelu_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rrelu_(mlirtens[0], lower, upper, training, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::relu(const at::Tensor &self) { + std::cout << "aten::relu" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::relu(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::relu_(at::Tensor &self) { + std::cout << "aten::relu_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::relu_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::prelu(const at::Tensor &self, + const at::Tensor &weight) { + std::cout << "aten::prelu" << std::endl; + std::vector mlirtens_tensors = {self, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::prelu(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::prelu_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &weight) { + std::cout << "aten::prelu_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::prelu_backward(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor ATenMLIRTypeDefault::gelu(const at::Tensor &self) { + std::cout << "aten::gelu" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::gelu(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::gelu_backward(const at::Tensor &grad, + const at::Tensor &self) { + std::cout << "aten::gelu_backward" << std::endl; + std::vector mlirtens_tensors = {grad, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::gelu_backward(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::hardshrink(const at::Tensor &self, + at::Scalar lambd) { + std::cout << "aten::hardshrink" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hardshrink(mlirtens[0], lambd); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::hardshrink_backward(const at::Tensor &grad_out, + const at::Tensor &self, + at::Scalar lambd) { + std::cout << "aten::hardshrink_backward" << std::endl; + std::vector mlirtens_tensors = {grad_out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hardshrink_backward(mlirtens[0], mlirtens[1], lambd); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_out)); +} + +at::Tensor ATenMLIRTypeDefault::rsqrt(const at::Tensor &self) { + std::cout << "aten::rsqrt" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rsqrt(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::rsqrt_(at::Tensor &self) { + std::cout << "aten::rsqrt_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rsqrt_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::rsqrt_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::rsqrt_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rsqrt_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::select(const at::Tensor &self, int64_t dim, + int64_t index) { + std::cout << "aten::select" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::select(mlirtens[0], dim, index); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::selu(const at::Tensor &self) { + std::cout << "aten::selu" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::selu(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::selu_(at::Tensor &self) { + std::cout << "aten::selu_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::selu_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::celu(const at::Tensor &self, at::Scalar alpha) { + std::cout << "aten::celu" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::celu(mlirtens[0], alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::celu_(at::Tensor &self, at::Scalar alpha) { + std::cout << "aten::celu_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::celu_(mlirtens[0], alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::sigmoid(const at::Tensor &self) { + std::cout << "aten::sigmoid" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sigmoid(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::sigmoid_(at::Tensor &self) { + std::cout << "aten::sigmoid_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sigmoid_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::sigmoid_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::sigmoid_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sigmoid_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::sin(const at::Tensor &self) { + std::cout << "aten::sin" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sin(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::sin_(at::Tensor &self) { + std::cout << "aten::sin_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sin_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::sin_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::sin_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sin_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::sinh(const at::Tensor &self) { + std::cout << "aten::sinh" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sinh(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::sinh_(at::Tensor &self) { + std::cout << "aten::sinh_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sinh_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::sinh_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::sinh_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sinh_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::detach(const at::Tensor &self) { + std::cout << "aten::detach" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::detach(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::detach_(at::Tensor &self) { + std::cout << "aten::detach_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::detach_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +int64_t ATenMLIRTypeDefault::size(const at::Tensor &self, int64_t dim) { + std::cout << "aten::size" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::size(mlirtens[0], dim); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::slice(const at::Tensor &self, int64_t dim, + int64_t start, int64_t end, + int64_t step) { + std::cout << "aten::slice" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slice(mlirtens[0], dim, start, end, step); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::slogdet(const at::Tensor &self) { + std::cout << "aten::slogdet" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slogdet(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::smm(const at::Tensor &self, + const at::Tensor &mat2) { + std::cout << "aten::smm" << std::endl; + std::vector mlirtens_tensors = {self, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::smm(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::softmax(const at::Tensor &self, int64_t dim, + c10::optional dtype) { + std::cout << "aten::softmax" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::softmax(mlirtens[0], dim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_softmax(const at::Tensor &self, int64_t dim, + bool half_to_float) { + std::cout << "aten::_softmax" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_softmax(mlirtens[0], dim, half_to_float); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_softmax_backward_data( + const at::Tensor &grad_output, const at::Tensor &output, int64_t dim, + const at::Tensor &self) { + std::cout << "aten::_softmax_backward_data" << std::endl; + std::vector mlirtens_tensors = {grad_output, output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_softmax_backward_data(mlirtens[0], mlirtens[1], dim, mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::_sparse_add_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other, + at::Scalar alpha) { + std::cout << "aten::_sparse_add_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_sparse_add_out(mlirtens[0], mlirtens[1], mlirtens[2], alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::_sparse_dense_add_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other, + at::Scalar alpha) { + std::cout << "aten::_sparse_dense_add_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_sparse_dense_add_out(mlirtens[0], mlirtens[1], mlirtens[2], alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::_sparse_div_zerodim_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &other) { + std::cout << "aten::_sparse_div_zerodim_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_sparse_div_zerodim_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::_sparse_div_scalar_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::_sparse_div_scalar_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sparse_div_scalar_out(mlirtens[0], mlirtens[1], other); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::_sparse_mul_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::_sparse_mul_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sparse_mul_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::_sparse_mul_zerodim_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &other) { + std::cout << "aten::_sparse_mul_zerodim_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_sparse_mul_zerodim_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::_sparse_mul_scalar_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::_sparse_mul_scalar_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sparse_mul_scalar_out(mlirtens[0], mlirtens[1], other); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +std::vector ATenMLIRTypeDefault::split(const at::Tensor &self, + int64_t split_size, + int64_t dim) { + std::cout << "aten::split" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::split(mlirtens[0], split_size, dim); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +std::vector ATenMLIRTypeDefault::split_with_sizes( + const at::Tensor &self, at::IntArrayRef split_sizes, int64_t dim) { + std::cout << "aten::split_with_sizes" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::split_with_sizes(mlirtens[0], split_sizes, dim); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::squeeze(const at::Tensor &self) { + std::cout << "aten::squeeze" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::squeeze(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::squeeze(const at::Tensor &self, int64_t dim) { + std::cout << "aten::squeeze" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::squeeze(mlirtens[0], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::squeeze_(at::Tensor &self) { + std::cout << "aten::squeeze_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].squeeze_(); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::squeeze_(at::Tensor &self, int64_t dim) { + std::cout << "aten::squeeze_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].squeeze_(dim); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::sspaddmm(const at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::sspaddmm" << std::endl; + std::vector mlirtens_tensors = {self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::sspaddmm(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::sspaddmm_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::sspaddmm_out" << std::endl; + std::vector mlirtens_tensors = {out, self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sspaddmm_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::stack(at::TensorList tensors, int64_t dim) { + std::cout << "aten::stack" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::stack(tensors, dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(tensors)); +} + +at::Tensor &ATenMLIRTypeDefault::stack_out(at::Tensor &out, + at::TensorList tensors, + int64_t dim) { + std::cout << "aten::stack_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::stack_out(mlirtens[0], tensors, dim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::stft(const at::Tensor &self, int64_t n_fft, + c10::optional hop_length, + c10::optional win_length, + const at::Tensor &window, bool normalized, + bool onesided) { + std::cout << "aten::stft" << std::endl; + std::vector mlirtens_tensors = {self, window}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::stft(mlirtens[0], n_fft, hop_length, win_length, + mlirtens[1], normalized, onesided); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +int64_t ATenMLIRTypeDefault::stride(const at::Tensor &self, int64_t dim) { + std::cout << "aten::stride" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::stride(mlirtens[0], dim); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::sum(const at::Tensor &self, + c10::optional dtype) { + std::cout << "aten::sum" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sum(mlirtens[0], dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::sum(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim, + c10::optional dtype) { + std::cout << "aten::sum" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sum(mlirtens[0], dim, keepdim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::sum_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef dim, bool keepdim, + c10::optional dtype) { + std::cout << "aten::sum_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sum_out(mlirtens[0], mlirtens[1], dim, keepdim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::sum_to_size(const at::Tensor &self, + at::IntArrayRef size) { + std::cout << "aten::sum_to_size" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].sum_to_size(size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::sqrt(const at::Tensor &self) { + std::cout << "aten::sqrt" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sqrt(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::sqrt_(at::Tensor &self) { + std::cout << "aten::sqrt_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sqrt_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::sqrt_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::sqrt_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sqrt_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::std(const at::Tensor &self, bool unbiased) { + std::cout << "aten::std" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::std(mlirtens[0], unbiased); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::std(const at::Tensor &self, at::IntArrayRef dim, + bool unbiased, bool keepdim) { + std::cout << "aten::std" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::std(mlirtens[0], dim, unbiased, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::std_mean(const at::Tensor &self, bool unbiased) { + std::cout << "aten::std_mean" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::std_mean(mlirtens[0], unbiased); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::std_mean(const at::Tensor &self, at::IntArrayRef dim, + bool unbiased, bool keepdim) { + std::cout << "aten::std_mean" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::std_mean(mlirtens[0], dim, unbiased, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::std_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef dim, bool unbiased, + bool keepdim) { + std::cout << "aten::std_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::std_out(mlirtens[0], mlirtens[1], dim, unbiased, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::prod(const at::Tensor &self, + c10::optional dtype) { + std::cout << "aten::prod" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::prod(mlirtens[0], dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::prod(const at::Tensor &self, int64_t dim, + bool keepdim, + c10::optional dtype) { + std::cout << "aten::prod" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::prod(mlirtens[0], dim, keepdim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::prod_out(at::Tensor &out, + const at::Tensor &self, int64_t dim, + bool keepdim, + c10::optional dtype) { + std::cout << "aten::prod_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::prod_out(mlirtens[0], mlirtens[1], dim, keepdim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::t(const at::Tensor &self) { + std::cout << "aten::t" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::t(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::t_(at::Tensor &self) { + std::cout << "aten::t_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].t_(); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::tan(const at::Tensor &self) { + std::cout << "aten::tan" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::tan(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::tan_(at::Tensor &self) { + std::cout << "aten::tan_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::tan_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::tan_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::tan_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::tan_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::tanh(const at::Tensor &self) { + std::cout << "aten::tanh" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::tanh(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::tanh_(at::Tensor &self) { + std::cout << "aten::tanh_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::tanh_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::tanh_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::tanh_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::tanh_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::tensordot(const at::Tensor &self, + const at::Tensor &other, + at::IntArrayRef dims_self, + at::IntArrayRef dims_other) { + std::cout << "aten::tensordot" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::tensordot(mlirtens[0], mlirtens[1], dims_self, dims_other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::threshold(const at::Tensor &self, + at::Scalar threshold, + at::Scalar value) { + std::cout << "aten::threshold" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::threshold(mlirtens[0], threshold, value); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::threshold_(at::Tensor &self, + at::Scalar threshold, + at::Scalar value) { + std::cout << "aten::threshold_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::threshold_(mlirtens[0], threshold, value); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::threshold_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar threshold, + at::Scalar value) { + std::cout << "aten::threshold_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::threshold_out(mlirtens[0], mlirtens[1], threshold, value); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor +ATenMLIRTypeDefault::threshold_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar threshold) { + std::cout << "aten::threshold_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::threshold_backward(mlirtens[0], mlirtens[1], threshold); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor ATenMLIRTypeDefault::transpose(const at::Tensor &self, int64_t dim0, + int64_t dim1) { + std::cout << "aten::transpose" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::transpose(mlirtens[0], dim0, dim1); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_mkldnn_transpose(const at::Tensor &self, + int64_t dim0, int64_t dim1) { + std::cout << "aten::_mkldnn_transpose" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_mkldnn_transpose(mlirtens[0], dim0, dim1); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::transpose_(at::Tensor &self, int64_t dim0, + int64_t dim1) { + std::cout << "aten::transpose_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].transpose_(dim0, dim1); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::_mkldnn_transpose_(at::Tensor &self, + int64_t dim0, + int64_t dim1) { + std::cout << "aten::_mkldnn_transpose_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_mkldnn_transpose_(mlirtens[0], dim0, dim1); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::one_hot(const at::Tensor &self, + int64_t num_classes) { + std::cout << "aten::one_hot" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::one_hot(mlirtens[0], num_classes); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::flip(const at::Tensor &self, + at::IntArrayRef dims) { + std::cout << "aten::flip" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::flip(mlirtens[0], dims); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::roll(const at::Tensor &self, + at::IntArrayRef shifts, + at::IntArrayRef dims) { + std::cout << "aten::roll" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::roll(mlirtens[0], shifts, dims); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::rot90(const at::Tensor &self, int64_t k, + at::IntArrayRef dims) { + std::cout << "aten::rot90" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rot90(mlirtens[0], k, dims); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::trapz(const at::Tensor &y, const at::Tensor &x, + int64_t dim) { + std::cout << "aten::trapz" << std::endl; + std::vector mlirtens_tensors = {y, x}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::trapz(mlirtens[0], mlirtens[1], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(y)); +} + +at::Tensor ATenMLIRTypeDefault::trapz(const at::Tensor &y, double dx, + int64_t dim) { + std::cout << "aten::trapz" << std::endl; + std::vector mlirtens_tensors = {y}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::trapz(mlirtens[0], dx, dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(y)); +} + +at::Tensor ATenMLIRTypeDefault::_trilinear( + const at::Tensor &i1, const at::Tensor &i2, const at::Tensor &i3, + at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, + at::IntArrayRef sumdim, int64_t unroll_dim) { + std::cout << "aten::_trilinear" << std::endl; + std::vector mlirtens_tensors = {i1, i2, i3}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_trilinear(mlirtens[0], mlirtens[1], mlirtens[2], expand1, expand2, + expand3, sumdim, unroll_dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(i1)); +} + +at::Tensor ATenMLIRTypeDefault::triplet_margin_loss(const at::Tensor &anchor, + const at::Tensor &positive, + const at::Tensor &negative, + double margin, double p, + double eps, bool swap, + int64_t reduction) { + std::cout << "aten::triplet_margin_loss" << std::endl; + std::vector mlirtens_tensors = {anchor, positive, negative}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::triplet_margin_loss( + mlirtens[0], mlirtens[1], mlirtens[2], margin, p, eps, swap, reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(anchor)); +} + +at::Tensor ATenMLIRTypeDefault::trunc(const at::Tensor &self) { + std::cout << "aten::trunc" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::trunc(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::trunc_(at::Tensor &self) { + std::cout << "aten::trunc_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::trunc_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::trunc_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::trunc_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::trunc_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::type_as(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::type_as" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].type_as(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +bool ATenMLIRTypeDefault::_has_compatible_shallow_copy_type( + const at::Tensor &self, const at::Tensor &from) { + std::cout << "aten::_has_compatible_shallow_copy_type" << std::endl; + std::vector mlirtens_tensors = {self, from}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_has_compatible_shallow_copy_type(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +std::tuple +ATenMLIRTypeDefault::_unique(const at::Tensor &self, bool sorted, + bool return_inverse) { + std::cout << "aten::_unique" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_unique(mlirtens[0], sorted, return_inverse); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::unique_dim(const at::Tensor &self, int64_t dim, + bool sorted, bool return_inverse, + bool return_counts) { + std::cout << "aten::unique_dim" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::unique_dim(mlirtens[0], dim, sorted, return_inverse, return_counts); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::unique_consecutive(const at::Tensor &self, + bool return_inverse, bool return_counts, + c10::optional dim) { + std::cout << "aten::unique_consecutive" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::unique_consecutive(mlirtens[0], return_inverse, return_counts, dim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::unique_dim_consecutive(const at::Tensor &self, int64_t dim, + bool return_inverse, + bool return_counts) { + std::cout << "aten::unique_dim_consecutive" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::unique_dim_consecutive(mlirtens[0], dim, return_inverse, + return_counts); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::_unique2(const at::Tensor &self, bool sorted, + bool return_inverse, bool return_counts) { + std::cout << "aten::_unique2" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_unique2(mlirtens[0], sorted, return_inverse, return_counts); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::_unsafe_view(const at::Tensor &self, + at::IntArrayRef size) { + std::cout << "aten::_unsafe_view" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_unsafe_view(mlirtens[0], size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::unsqueeze(const at::Tensor &self, int64_t dim) { + std::cout << "aten::unsqueeze" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::unsqueeze(mlirtens[0], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::unsqueeze_(at::Tensor &self, int64_t dim) { + std::cout << "aten::unsqueeze_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].unsqueeze_(dim); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::var(const at::Tensor &self, bool unbiased) { + std::cout << "aten::var" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::var(mlirtens[0], unbiased); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::var(const at::Tensor &self, at::IntArrayRef dim, + bool unbiased, bool keepdim) { + std::cout << "aten::var" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::var(mlirtens[0], dim, unbiased, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::var_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef dim, bool unbiased, + bool keepdim) { + std::cout << "aten::var_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::var_out(mlirtens[0], mlirtens[1], dim, unbiased, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +std::tuple +ATenMLIRTypeDefault::var_mean(const at::Tensor &self, bool unbiased) { + std::cout << "aten::var_mean" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::var_mean(mlirtens[0], unbiased); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::var_mean(const at::Tensor &self, at::IntArrayRef dim, + bool unbiased, bool keepdim) { + std::cout << "aten::var_mean" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::var_mean(mlirtens[0], dim, unbiased, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::view_as(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::view_as" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].view_as(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::where(const at::Tensor &condition, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::where" << std::endl; + std::vector mlirtens_tensors = {condition, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::where(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(condition)); +} + +std::vector +ATenMLIRTypeDefault::where(const at::Tensor &condition) { + std::cout << "aten::where" << std::endl; + std::vector mlirtens_tensors = {condition}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::where(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::_s_where(const at::Tensor &condition, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::_s_where" << std::endl; + std::vector mlirtens_tensors = {condition, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_s_where(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(condition)); +} + +at::Tensor ATenMLIRTypeDefault::norm_except_dim(const at::Tensor &v, + int64_t pow, int64_t dim) { + std::cout << "aten::norm_except_dim" << std::endl; + std::vector mlirtens_tensors = {v}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::norm_except_dim(mlirtens[0], pow, dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(v)); +} + +at::Tensor ATenMLIRTypeDefault::_weight_norm(const at::Tensor &v, + const at::Tensor &g, int64_t dim) { + std::cout << "aten::_weight_norm" << std::endl; + std::vector mlirtens_tensors = {v, g}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_weight_norm(mlirtens[0], mlirtens[1], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(v)); +} + +std::tuple +ATenMLIRTypeDefault::_weight_norm_cuda_interface(const at::Tensor &v, + const at::Tensor &g, + int64_t dim) { + std::cout << "aten::_weight_norm_cuda_interface" << std::endl; + std::vector mlirtens_tensors = {v, g}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_weight_norm_cuda_interface(mlirtens[0], mlirtens[1], dim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), bridge::GetMLIRDevice(v)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(v))); +} + +std::tuple +ATenMLIRTypeDefault::_weight_norm_cuda_interface_backward( + const at::Tensor &grad_w, const at::Tensor &saved_v, + const at::Tensor &saved_g, const at::Tensor &saved_norms, int64_t dim) { + std::cout << "aten::_weight_norm_cuda_interface_backward" << std::endl; + std::vector mlirtens_tensors = {grad_w, saved_v, saved_g, + saved_norms}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_weight_norm_cuda_interface_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], dim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_w)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_w))); +} + +std::tuple +ATenMLIRTypeDefault::_weight_norm_differentiable_backward( + const at::Tensor &grad_w, const at::Tensor &saved_v, + const at::Tensor &saved_g, const at::Tensor &saved_norms, int64_t dim) { + std::cout << "aten::_weight_norm_differentiable_backward" << std::endl; + std::vector mlirtens_tensors = {grad_w, saved_v, saved_g, + saved_norms}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_weight_norm_differentiable_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], dim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_w)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_w))); +} + +at::Tensor &ATenMLIRTypeDefault::zeros_out(at::Tensor &out, + at::IntArrayRef size) { + std::cout << "aten::zeros_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::zeros_out(mlirtens[0], size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::_standard_gamma_grad(const at::Tensor &self, + const at::Tensor &output) { + std::cout << "aten::_standard_gamma_grad" << std::endl; + std::vector mlirtens_tensors = {self, output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_standard_gamma_grad(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_standard_gamma(const at::Tensor &self, + at::Generator *generator) { + std::cout << "aten::_standard_gamma" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_standard_gamma(mlirtens[0], generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_dirichlet_grad(const at::Tensor &x, + const at::Tensor &alpha, + const at::Tensor &total) { + std::cout << "aten::_dirichlet_grad" << std::endl; + std::vector mlirtens_tensors = {x, alpha, total}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_dirichlet_grad(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(x)); +} + +at::Tensor ATenMLIRTypeDefault::_sample_dirichlet(const at::Tensor &self, + at::Generator *generator) { + std::cout << "aten::_sample_dirichlet" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sample_dirichlet(mlirtens[0], generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::poisson(const at::Tensor &self, + at::Generator *generator) { + std::cout << "aten::poisson" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::poisson(mlirtens[0], generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::native_norm(const at::Tensor &self, + at::Scalar p) { + std::cout << "aten::native_norm" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::native_norm(mlirtens[0], p); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_sparse_sum(const at::Tensor &self) { + std::cout << "aten::_sparse_sum" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sparse_sum(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_sparse_sum(const at::Tensor &self, + at::ScalarType dtype) { + std::cout << "aten::_sparse_sum" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sparse_sum(mlirtens[0], dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_sparse_sum(const at::Tensor &self, + at::IntArrayRef dim) { + std::cout << "aten::_sparse_sum" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sparse_sum(mlirtens[0], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_sparse_sum(const at::Tensor &self, + at::IntArrayRef dim, + at::ScalarType dtype) { + std::cout << "aten::_sparse_sum" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sparse_sum(mlirtens[0], dim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_sparse_sum_backward(const at::Tensor &grad, + const at::Tensor &self, + at::IntArrayRef dim) { + std::cout << "aten::_sparse_sum_backward" << std::endl; + std::vector mlirtens_tensors = {grad, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sparse_sum_backward(mlirtens[0], mlirtens[1], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::norm(const at::Tensor &self, + c10::optional p, + at::ScalarType dtype) { + std::cout << "aten::norm" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::norm(mlirtens[0], p, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::norm(const at::Tensor &self, at::Scalar p) { + std::cout << "aten::norm" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::norm(mlirtens[0], p); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::norm(const at::Tensor &self, + c10::optional p, + at::IntArrayRef dim, bool keepdim, + at::ScalarType dtype) { + std::cout << "aten::norm" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::norm(mlirtens[0], p, dim, keepdim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::norm(const at::Tensor &self, + c10::optional p, + at::IntArrayRef dim, bool keepdim) { + std::cout << "aten::norm" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::norm(mlirtens[0], p, dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::norm_out(at::Tensor &out, + const at::Tensor &self, + c10::optional p, + at::IntArrayRef dim, bool keepdim, + at::ScalarType dtype) { + std::cout << "aten::norm_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::norm_out(mlirtens[0], mlirtens[1], p, dim, keepdim, dtype); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::norm_out(at::Tensor &out, + const at::Tensor &self, + c10::optional p, + at::IntArrayRef dim, bool keepdim) { + std::cout << "aten::norm_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::norm_out(mlirtens[0], mlirtens[1], p, dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::frobenius_norm(const at::Tensor &self) { + std::cout << "aten::frobenius_norm" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::frobenius_norm(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::frobenius_norm(const at::Tensor &self, + at::IntArrayRef dim, + bool keepdim) { + std::cout << "aten::frobenius_norm" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::frobenius_norm(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::frobenius_norm_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef dim, + bool keepdim) { + std::cout << "aten::frobenius_norm_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::frobenius_norm_out(mlirtens[0], mlirtens[1], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::nuclear_norm(const at::Tensor &self, + bool keepdim) { + std::cout << "aten::nuclear_norm" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nuclear_norm(mlirtens[0], keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::nuclear_norm_out(at::Tensor &out, + const at::Tensor &self, + bool keepdim) { + std::cout << "aten::nuclear_norm_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nuclear_norm_out(mlirtens[0], mlirtens[1], keepdim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::nuclear_norm(const at::Tensor &self, + at::IntArrayRef dim, + bool keepdim) { + std::cout << "aten::nuclear_norm" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nuclear_norm(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::nuclear_norm_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef dim, + bool keepdim) { + std::cout << "aten::nuclear_norm_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::nuclear_norm_out(mlirtens[0], mlirtens[1], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::clone(const at::Tensor &self) { + std::cout << "aten::clone" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::clone(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::resize_as_(at::Tensor &self, + const at::Tensor &the_template) { + std::cout << "aten::resize_as_" << std::endl; + std::vector mlirtens_tensors = {self, the_template}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::resize_as_(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::pow_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar exponent) { + std::cout << "aten::pow_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::pow_out(mlirtens[0], mlirtens[1], exponent); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::pow(const at::Tensor &self, + at::Scalar exponent) { + std::cout << "aten::pow" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::pow(mlirtens[0], exponent); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::zero_(at::Tensor &self) { + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::zero_(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::sub_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other, + at::Scalar alpha) { + std::cout << "aten::sub_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sub_out(mlirtens[0], mlirtens[1], mlirtens[2], alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::sub(const at::Tensor &self, + const at::Tensor &other, at::Scalar alpha) { + std::cout << "aten::sub" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sub(mlirtens[0], mlirtens[1], alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::sub_(at::Tensor &self, const at::Tensor &other, + at::Scalar alpha) { + std::cout << "aten::sub_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].sub_(mlirtens[1], alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::sub(const at::Tensor &self, at::Scalar other, + at::Scalar alpha) { + std::cout << "aten::sub" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sub(mlirtens[0], other, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::sub_(at::Tensor &self, at::Scalar other, + at::Scalar alpha) { + std::cout << "aten::sub_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].sub_(other, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::rsub(const at::Tensor &self, + const at::Tensor &other, + at::Scalar alpha) { + std::cout << "aten::rsub" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rsub(mlirtens[0], mlirtens[1], alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::rsub(const at::Tensor &self, at::Scalar other, + at::Scalar alpha) { + std::cout << "aten::rsub" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rsub(mlirtens[0], other, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::s_native_addmm_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::s_native_addmm_out" << std::endl; + std::vector mlirtens_tensors = {out, self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::s_native_addmm_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::s_native_addmm(const at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, + at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::s_native_addmm" << std::endl; + std::vector mlirtens_tensors = {self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::s_native_addmm(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::s_native_addmm_(at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, + at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::s_native_addmm_" << std::endl; + std::vector mlirtens_tensors = {self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::s_native_addmm_(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::_sparse_addmm(const at::Tensor &self, + const at::Tensor &sparse, + const at::Tensor &dense, + at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::_sparse_addmm" << std::endl; + std::vector mlirtens_tensors = {self, sparse, dense}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_sparse_addmm(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::addmm_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::addmm_out" << std::endl; + std::vector mlirtens_tensors = {out, self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::addmm_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::addmm(const at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::addmm" << std::endl; + std::vector mlirtens_tensors = {self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::addmm(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::addmm_(at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::addmm_" << std::endl; + std::vector mlirtens_tensors = {self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].addmm_(mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor +ATenMLIRTypeDefault::sparse_coo_tensor(at::IntArrayRef size, + const at::TensorOptions &options) { + std::cout << "aten::sparse_coo_tensor" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sparse_coo_tensor(size, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor +ATenMLIRTypeDefault::sparse_coo_tensor(const at::Tensor &indices, + const at::Tensor &values, + const at::TensorOptions &options) { + std::cout << "aten::sparse_coo_tensor" << std::endl; + std::vector mlirtens_tensors = {indices, values}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sparse_coo_tensor(mlirtens[0], mlirtens[1], options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(indices)); +} + +at::Tensor ATenMLIRTypeDefault::sparse_coo_tensor( + const at::Tensor &indices, const at::Tensor &values, at::IntArrayRef size, + const at::TensorOptions &options) { + std::cout << "aten::sparse_coo_tensor" << std::endl; + std::vector mlirtens_tensors = {indices, values}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::sparse_coo_tensor(mlirtens[0], mlirtens[1], size, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(indices)); +} + +at::Tensor ATenMLIRTypeDefault::_sparse_coo_tensor_unsafe( + const at::Tensor &indices, const at::Tensor &values, at::IntArrayRef size, + const at::TensorOptions &options) { + std::cout << "aten::_sparse_coo_tensor_unsafe" << std::endl; + std::vector mlirtens_tensors = {indices, values}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_sparse_coo_tensor_unsafe(mlirtens[0], mlirtens[1], size, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(indices)); +} + +at::Tensor ATenMLIRTypeDefault::_sparse_coo_tensor_with_dims( + int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, + const at::TensorOptions &options) { + std::cout << "aten::_sparse_coo_tensor_with_dims" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_sparse_coo_tensor_with_dims(sparse_dim, dense_dim, size, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::_sparse_coo_tensor_with_dims_and_tensors( + int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, + const at::Tensor &indices, const at::Tensor &values, + const at::TensorOptions &options) { + std::cout << "aten::_sparse_coo_tensor_with_dims_and_tensors" << std::endl; + std::vector mlirtens_tensors = {indices, values}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_sparse_coo_tensor_with_dims_and_tensors( + sparse_dim, dense_dim, size, mlirtens[0], mlirtens[1], options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(indices)); +} + +at::Tensor &ATenMLIRTypeDefault::sparse_resize_(at::Tensor &self, + at::IntArrayRef size, + int64_t sparse_dim, + int64_t dense_dim) { + std::cout << "aten::sparse_resize_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].sparse_resize_(size, sparse_dim, dense_dim); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::sparse_resize_and_clear_(at::Tensor &self, + at::IntArrayRef size, + int64_t sparse_dim, + int64_t dense_dim) { + std::cout << "aten::sparse_resize_and_clear_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + mlirtens[0].sparse_resize_and_clear_(size, sparse_dim, dense_dim); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::sparse_mask(const at::Tensor &self, + const at::Tensor &mask) { + std::cout << "aten::sparse_mask" << std::endl; + std::vector mlirtens_tensors = {self, mask}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].sparse_mask(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::to_dense(const at::Tensor &self) { + std::cout << "aten::to_dense" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].to_dense(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::to_dense_backward(const at::Tensor &grad, + const at::Tensor &input) { + std::cout << "aten::to_dense_backward" << std::endl; + std::vector mlirtens_tensors = {grad, input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::to_dense_backward(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +int64_t ATenMLIRTypeDefault::sparse_dim(const at::Tensor &self) { + std::cout << "aten::sparse_dim" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].sparse_dim(); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +int64_t ATenMLIRTypeDefault::_dimI(const at::Tensor &self) { + std::cout << "aten::_dimI" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0]._dimI(); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +int64_t ATenMLIRTypeDefault::dense_dim(const at::Tensor &self) { + std::cout << "aten::dense_dim" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].dense_dim(); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +int64_t ATenMLIRTypeDefault::_dimV(const at::Tensor &self) { + std::cout << "aten::_dimV" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0]._dimV(); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +int64_t ATenMLIRTypeDefault::_nnz(const at::Tensor &self) { + std::cout << "aten::_nnz" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0]._nnz(); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::coalesce(const at::Tensor &self) { + std::cout << "aten::coalesce" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].coalesce(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +bool ATenMLIRTypeDefault::is_coalesced(const at::Tensor &self) { + std::cout << "aten::is_coalesced" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].is_coalesced(); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::_indices(const at::Tensor &self) { + std::cout << "aten::_indices" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0]._indices(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_values(const at::Tensor &self) { + std::cout << "aten::_values" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0]._values(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::_coalesced_(at::Tensor &self, bool coalesced) { + std::cout << "aten::_coalesced_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0]._coalesced_(coalesced); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::indices(const at::Tensor &self) { + std::cout << "aten::indices" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].indices(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::values(const at::Tensor &self) { + std::cout << "aten::values" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].values(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::hspmm_out(at::Tensor &out, + const at::Tensor &mat1, + const at::Tensor &mat2) { + std::cout << "aten::hspmm_out" << std::endl; + std::vector mlirtens_tensors = {out, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hspmm_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::hspmm(const at::Tensor &mat1, + const at::Tensor &mat2) { + std::cout << "aten::hspmm" << std::endl; + std::vector mlirtens_tensors = {mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hspmm(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(mat1)); +} + +at::Tensor &ATenMLIRTypeDefault::copy_sparse_to_sparse_(at::Tensor &self, + const at::Tensor &src, + bool non_blocking) { + std::cout << "aten::copy_sparse_to_sparse_" << std::endl; + std::vector mlirtens_tensors = {self, src}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::copy_sparse_to_sparse_(mlirtens[0], mlirtens[1], non_blocking); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +std::vector ATenMLIRTypeDefault::unbind(const at::Tensor &self, + int64_t dim) { + std::cout << "aten::unbind" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::unbind(mlirtens[0], dim); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::to_sparse(const at::Tensor &self, + int64_t sparse_dim) { + std::cout << "aten::to_sparse" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].to_sparse(sparse_dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::to_sparse(const at::Tensor &self) { + std::cout << "aten::to_sparse" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].to_sparse(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::to_mkldnn(const at::Tensor &self) { + std::cout << "aten::to_mkldnn" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].to_mkldnn(); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::mkldnn_reorder_conv2d_weight( + const at::Tensor &self, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups) { + std::cout << "aten::mkldnn_reorder_conv2d_weight" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mkldnn_reorder_conv2d_weight(mlirtens[0], padding, + stride, dilation, groups); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::to_mkldnn_backward(const at::Tensor &grad, + const at::Tensor &input) { + std::cout << "aten::to_mkldnn_backward" << std::endl; + std::vector mlirtens_tensors = {grad, input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::to_mkldnn_backward(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::quantize_linear(const at::Tensor &self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + std::cout << "aten::quantize_linear" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::quantize_linear(mlirtens[0], scale, zero_point, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::quantize_linear_per_channel( + const at::Tensor &self, const at::Tensor &scales, + const at::Tensor &zero_points, at::IntArrayRef axis, at::ScalarType dtype) { + std::cout << "aten::quantize_linear_per_channel" << std::endl; + std::vector mlirtens_tensors = {self, scales, zero_points}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::quantize_linear_per_channel(mlirtens[0], mlirtens[1], + mlirtens[2], axis, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::dequantize(const at::Tensor &self) { + std::cout << "aten::dequantize" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::dequantize(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_dequantize_linear(const at::Tensor &self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + std::cout << "aten::_dequantize_linear" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_dequantize_linear(mlirtens[0], scale, zero_point, dtype); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +double ATenMLIRTypeDefault::q_scale(const at::Tensor &self) { + std::cout << "aten::q_scale" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::q_scale(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +int64_t ATenMLIRTypeDefault::q_zero_point(const at::Tensor &self) { + std::cout << "aten::q_zero_point" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::q_zero_point(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::q_per_channel_scales(const at::Tensor &self) { + std::cout << "aten::q_per_channel_scales" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::q_per_channel_scales(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor +ATenMLIRTypeDefault::q_per_channel_zero_points(const at::Tensor &self) { + std::cout << "aten::q_per_channel_zero_points" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::q_per_channel_zero_points(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::int_repr(const at::Tensor &self) { + std::cout << "aten::int_repr" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::int_repr(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_per_tensor_affine_qtensor( + const at::Tensor &self, double scale, int64_t zero_point) { + std::cout << "aten::_per_tensor_affine_qtensor" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_per_tensor_affine_qtensor(mlirtens[0], scale, zero_point); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_per_channel_affine_qtensor( + const at::Tensor &self, const at::Tensor &scale, + const at::Tensor &zero_point, at::IntArrayRef axis) { + std::cout << "aten::_per_channel_affine_qtensor" << std::endl; + std::vector mlirtens_tensors = {self, scale, zero_point}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_per_channel_affine_qtensor(mlirtens[0], mlirtens[1], + mlirtens[2], axis); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::QScheme ATenMLIRTypeDefault::qscheme(const at::Tensor &self) { + std::cout << "aten::qscheme" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].qscheme(); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::fake_quantize_per_tensor_affine( + const at::Tensor &self, double scale, int64_t zero_point, int64_t quant_min, + int64_t quant_max) { + std::cout << "aten::fake_quantize_per_tensor_affine" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fake_quantize_per_tensor_affine( + mlirtens[0], scale, zero_point, quant_min, quant_max); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::fake_quantize_per_tensor_affine_backward( + const at::Tensor &grad, const at::Tensor &self, double scale, + int64_t zero_point, int64_t quant_min, int64_t quant_max) { + std::cout << "aten::fake_quantize_per_tensor_affine_backward" << std::endl; + std::vector mlirtens_tensors = {grad, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fake_quantize_per_tensor_affine_backward( + mlirtens[0], mlirtens[1], scale, zero_point, quant_min, quant_max); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +at::Tensor ATenMLIRTypeDefault::to(const at::Tensor &self, + const at::TensorOptions &options, + bool non_blocking, bool copy) { + std::cout << "aten::to" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].to(options, non_blocking, copy); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::to(const at::Tensor &self, c10::Device device, + at::ScalarType dtype, bool non_blocking, + bool copy) { + std::cout << "aten::to" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].to(device, dtype, non_blocking, copy); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::to(const at::Tensor &self, at::ScalarType dtype, + bool non_blocking, bool copy) { + std::cout << "aten::to" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].to(dtype, non_blocking, copy); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::to(const at::Tensor &self, + const at::Tensor &other, bool non_blocking, + bool copy) { + std::cout << "aten::to" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].to(mlirtens[1], non_blocking, copy); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::vector ATenMLIRTypeDefault::meshgrid(at::TensorList tensors) { + std::cout << "aten::meshgrid" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::meshgrid(tensors); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor ATenMLIRTypeDefault::cartesian_prod(at::TensorList tensors) { + std::cout << "aten::cartesian_prod" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cartesian_prod(tensors); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(tensors)); +} + +at::Tensor ATenMLIRTypeDefault::combinations(const at::Tensor &self, int64_t r, + bool with_replacement) { + std::cout << "aten::combinations" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::combinations(mlirtens[0], r, with_replacement); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Scalar ATenMLIRTypeDefault::item(const at::Tensor &self) { + std::cout << "aten::item" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].item(); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Scalar ATenMLIRTypeDefault::_local_scalar_dense(const at::Tensor &self) { + std::cout << "aten::_local_scalar_dense" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_local_scalar_dense(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +std::tuple +ATenMLIRTypeDefault::_thnn_fused_lstm_cell(const at::Tensor &input_gates, + const at::Tensor &hidden_gates, + const at::Tensor &cx, + const at::Tensor &input_bias, + const at::Tensor &hidden_bias) { + std::cout << "aten::_thnn_fused_lstm_cell" << std::endl; + std::vector mlirtens_tensors = {input_gates, hidden_gates, cx, + input_bias, hidden_bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_thnn_fused_lstm_cell( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input_gates)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input_gates)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input_gates))); +} + +std::tuple +ATenMLIRTypeDefault::_thnn_fused_lstm_cell_backward( + const at::Tensor &grad_hy, const at::Tensor &grad_cy, const at::Tensor &cx, + const at::Tensor &cy, const at::Tensor &workspace, bool has_bias) { + std::cout << "aten::_thnn_fused_lstm_cell_backward" << std::endl; + std::vector mlirtens_tensors = {grad_hy, grad_cy, cx, cy, + workspace}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_thnn_fused_lstm_cell_backward(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], mlirtens[4], has_bias); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_hy)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_hy)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_hy)), + bridge::CreateMLIRTensor(std::get<3>(x_result), + bridge::GetMLIRDevice(grad_hy)), + bridge::CreateMLIRTensor(std::get<4>(x_result), + bridge::GetMLIRDevice(grad_hy))); +} + +std::tuple ATenMLIRTypeDefault::_thnn_fused_gru_cell( + const at::Tensor &input_gates, const at::Tensor &hidden_gates, + const at::Tensor &hx, const at::Tensor &input_bias, + const at::Tensor &hidden_bias) { + std::cout << "aten::_thnn_fused_gru_cell" << std::endl; + std::vector mlirtens_tensors = {input_gates, hidden_gates, hx, + input_bias, hidden_bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_thnn_fused_gru_cell( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input_gates)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input_gates))); +} + +std::tuple +ATenMLIRTypeDefault::_thnn_fused_gru_cell_backward(const at::Tensor &grad_hy, + const at::Tensor &workspace, + bool has_bias) { + std::cout << "aten::_thnn_fused_gru_cell_backward" << std::endl; + std::vector mlirtens_tensors = {grad_hy, workspace}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_thnn_fused_gru_cell_backward(mlirtens[0], mlirtens[1], has_bias); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_hy)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_hy)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_hy)), + bridge::CreateMLIRTensor(std::get<3>(x_result), + bridge::GetMLIRDevice(grad_hy)), + bridge::CreateMLIRTensor(std::get<4>(x_result), + bridge::GetMLIRDevice(grad_hy))); +} + +std::tuple +ATenMLIRTypeDefault::lstm(const at::Tensor &input, at::TensorList hx, + at::TensorList params, bool has_biases, + int64_t num_layers, double dropout, bool train, + bool bidirectional, bool batch_first) { + std::cout << "aten::lstm" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lstm(mlirtens[0], hx, params, has_biases, num_layers, + dropout, train, bidirectional, batch_first); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple +ATenMLIRTypeDefault::lstm(const at::Tensor &data, const at::Tensor &batch_sizes, + at::TensorList hx, at::TensorList params, + bool has_biases, int64_t num_layers, double dropout, + bool train, bool bidirectional) { + std::cout << "aten::lstm" << std::endl; + std::vector mlirtens_tensors = {data, batch_sizes}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lstm(mlirtens[0], mlirtens[1], hx, params, has_biases, + num_layers, dropout, train, bidirectional); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(data)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(data)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(data))); +} + +std::tuple +ATenMLIRTypeDefault::gru(const at::Tensor &input, const at::Tensor &hx, + at::TensorList params, bool has_biases, + int64_t num_layers, double dropout, bool train, + bool bidirectional, bool batch_first) { + std::cout << "aten::gru" << std::endl; + std::vector mlirtens_tensors = {input, hx}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::gru(mlirtens[0], mlirtens[1], params, has_biases, num_layers, dropout, + train, bidirectional, batch_first); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple +ATenMLIRTypeDefault::gru(const at::Tensor &data, const at::Tensor &batch_sizes, + const at::Tensor &hx, at::TensorList params, + bool has_biases, int64_t num_layers, double dropout, + bool train, bool bidirectional) { + std::cout << "aten::gru" << std::endl; + std::vector mlirtens_tensors = {data, batch_sizes, hx}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::gru(mlirtens[0], mlirtens[1], mlirtens[2], params, has_biases, + num_layers, dropout, train, bidirectional); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(data)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(data))); +} + +std::tuple +ATenMLIRTypeDefault::rnn_tanh(const at::Tensor &input, const at::Tensor &hx, + at::TensorList params, bool has_biases, + int64_t num_layers, double dropout, bool train, + bool bidirectional, bool batch_first) { + std::cout << "aten::rnn_tanh" << std::endl; + std::vector mlirtens_tensors = {input, hx}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::rnn_tanh(mlirtens[0], mlirtens[1], params, has_biases, num_layers, + dropout, train, bidirectional, batch_first); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple ATenMLIRTypeDefault::rnn_tanh( + const at::Tensor &data, const at::Tensor &batch_sizes, const at::Tensor &hx, + at::TensorList params, bool has_biases, int64_t num_layers, double dropout, + bool train, bool bidirectional) { + std::cout << "aten::rnn_tanh" << std::endl; + std::vector mlirtens_tensors = {data, batch_sizes, hx}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::rnn_tanh(mlirtens[0], mlirtens[1], mlirtens[2], params, has_biases, + num_layers, dropout, train, bidirectional); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(data)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(data))); +} + +std::tuple +ATenMLIRTypeDefault::rnn_relu(const at::Tensor &input, const at::Tensor &hx, + at::TensorList params, bool has_biases, + int64_t num_layers, double dropout, bool train, + bool bidirectional, bool batch_first) { + std::cout << "aten::rnn_relu" << std::endl; + std::vector mlirtens_tensors = {input, hx}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::rnn_relu(mlirtens[0], mlirtens[1], params, has_biases, num_layers, + dropout, train, bidirectional, batch_first); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple ATenMLIRTypeDefault::rnn_relu( + const at::Tensor &data, const at::Tensor &batch_sizes, const at::Tensor &hx, + at::TensorList params, bool has_biases, int64_t num_layers, double dropout, + bool train, bool bidirectional) { + std::cout << "aten::rnn_relu" << std::endl; + std::vector mlirtens_tensors = {data, batch_sizes, hx}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::rnn_relu(mlirtens[0], mlirtens[1], mlirtens[2], params, has_biases, + num_layers, dropout, train, bidirectional); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(data)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(data))); +} + +std::tuple +ATenMLIRTypeDefault::lstm_cell(const at::Tensor &input, at::TensorList hx, + const at::Tensor &w_ih, const at::Tensor &w_hh, + const at::Tensor &b_ih, const at::Tensor &b_hh) { + std::cout << "aten::lstm_cell" << std::endl; + std::vector mlirtens_tensors = {input, w_ih, w_hh, b_ih, b_hh}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lstm_cell(mlirtens[0], hx, mlirtens[1], mlirtens[2], + mlirtens[3], mlirtens[4]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +at::Tensor +ATenMLIRTypeDefault::gru_cell(const at::Tensor &input, const at::Tensor &hx, + const at::Tensor &w_ih, const at::Tensor &w_hh, + const at::Tensor &b_ih, const at::Tensor &b_hh) { + std::cout << "aten::gru_cell" << std::endl; + std::vector mlirtens_tensors = {input, hx, w_ih, + w_hh, b_ih, b_hh}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::gru_cell(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], mlirtens[4], mlirtens[5]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::rnn_tanh_cell( + const at::Tensor &input, const at::Tensor &hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, const at::Tensor &b_hh) { + std::cout << "aten::rnn_tanh_cell" << std::endl; + std::vector mlirtens_tensors = {input, hx, w_ih, + w_hh, b_ih, b_hh}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rnn_tanh_cell(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], mlirtens[4], mlirtens[5]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::rnn_relu_cell( + const at::Tensor &input, const at::Tensor &hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, const at::Tensor &b_hh) { + std::cout << "aten::rnn_relu_cell" << std::endl; + std::vector mlirtens_tensors = {input, hx, w_ih, + w_hh, b_ih, b_hh}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rnn_relu_cell(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], mlirtens[4], mlirtens[5]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple +ATenMLIRTypeDefault::quantized_lstm(const at::Tensor &input, at::TensorList hx, + at::TensorList params, bool has_biases, + int64_t num_layers, double dropout, + bool train, bool bidirectional, + bool batch_first, + c10::optional dtype, + bool use_dynamic) { + std::cout << "aten::quantized_lstm" << std::endl; + std::vector mlirtens_tensors = {input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::quantized_lstm( + mlirtens[0], hx, params, has_biases, num_layers, dropout, train, + bidirectional, batch_first, dtype, use_dynamic); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple ATenMLIRTypeDefault::quantized_gru( + const at::Tensor &input, const at::Tensor &hx, at::TensorList params, + bool has_biases, int64_t num_layers, double dropout, bool train, + bool bidirectional, bool batch_first) { + std::cout << "aten::quantized_gru" << std::endl; + std::vector mlirtens_tensors = {input, hx}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::quantized_gru(mlirtens[0], mlirtens[1], params, has_biases, + num_layers, dropout, train, bidirectional, batch_first); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +std::tuple ATenMLIRTypeDefault::quantized_gru( + const at::Tensor &data, const at::Tensor &batch_sizes, const at::Tensor &hx, + at::TensorList params, bool has_biases, int64_t num_layers, double dropout, + bool train, bool bidirectional) { + std::cout << "aten::quantized_gru" << std::endl; + std::vector mlirtens_tensors = {data, batch_sizes, hx}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::quantized_gru(mlirtens[0], mlirtens[1], mlirtens[2], params, + has_biases, num_layers, dropout, train, bidirectional); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(data)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(data))); +} + +std::tuple ATenMLIRTypeDefault::quantized_lstm_cell( + const at::Tensor &input, at::TensorList hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, const at::Tensor &b_hh, + const at::Tensor &packed_ih, const at::Tensor &packed_hh, + const at::Tensor &col_offsets_ih, const at::Tensor &col_offsets_hh, + at::Scalar scale_ih, at::Scalar scale_hh, at::Scalar zero_point_ih, + at::Scalar zero_point_hh) { + std::cout << "aten::quantized_lstm_cell" << std::endl; + std::vector mlirtens_tensors = { + input, w_ih, w_hh, b_ih, b_hh, + packed_ih, packed_hh, col_offsets_ih, col_offsets_hh}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::quantized_lstm_cell( + mlirtens[0], hx, mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], mlirtens[6], mlirtens[7], mlirtens[8], scale_ih, scale_hh, + zero_point_ih, zero_point_hh); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +at::Tensor ATenMLIRTypeDefault::quantized_gru_cell( + const at::Tensor &input, const at::Tensor &hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, const at::Tensor &b_hh, + const at::Tensor &packed_ih, const at::Tensor &packed_hh, + const at::Tensor &col_offsets_ih, const at::Tensor &col_offsets_hh, + at::Scalar scale_ih, at::Scalar scale_hh, at::Scalar zero_point_ih, + at::Scalar zero_point_hh) { + std::cout << "aten::quantized_gru_cell" << std::endl; + std::vector mlirtens_tensors = { + input, hx, w_ih, w_hh, b_ih, + b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::quantized_gru_cell( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], mlirtens[6], mlirtens[7], mlirtens[8], mlirtens[9], scale_ih, + scale_hh, zero_point_ih, zero_point_hh); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::quantized_rnn_relu_cell( + const at::Tensor &input, const at::Tensor &hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, const at::Tensor &b_hh, + const at::Tensor &packed_ih, const at::Tensor &packed_hh, + const at::Tensor &col_offsets_ih, const at::Tensor &col_offsets_hh, + at::Scalar scale_ih, at::Scalar scale_hh, at::Scalar zero_point_ih, + at::Scalar zero_point_hh) { + std::cout << "aten::quantized_rnn_relu_cell" << std::endl; + std::vector mlirtens_tensors = { + input, hx, w_ih, w_hh, b_ih, + b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::quantized_rnn_relu_cell( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], mlirtens[6], mlirtens[7], mlirtens[8], mlirtens[9], scale_ih, + scale_hh, zero_point_ih, zero_point_hh); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +at::Tensor ATenMLIRTypeDefault::quantized_rnn_tanh_cell( + const at::Tensor &input, const at::Tensor &hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, const at::Tensor &b_hh, + const at::Tensor &packed_ih, const at::Tensor &packed_hh, + const at::Tensor &col_offsets_ih, const at::Tensor &col_offsets_hh, + at::Scalar scale_ih, at::Scalar scale_hh, at::Scalar zero_point_ih, + at::Scalar zero_point_hh) { + std::cout << "aten::quantized_rnn_tanh_cell" << std::endl; + std::vector mlirtens_tensors = { + input, hx, w_ih, w_hh, b_ih, + b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::quantized_rnn_tanh_cell( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], mlirtens[6], mlirtens[7], mlirtens[8], mlirtens[9], scale_ih, + scale_hh, zero_point_ih, zero_point_hh); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(input)); +} + +std::tuple ATenMLIRTypeDefault::_pack_padded_sequence( + const at::Tensor &input, const at::Tensor &lengths, bool batch_first) { + std::cout << "aten::_pack_padded_sequence" << std::endl; + std::vector mlirtens_tensors = {input, lengths}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_pack_padded_sequence(mlirtens[0], mlirtens[1], batch_first); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(input)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(input))); +} + +at::Tensor ATenMLIRTypeDefault::_pack_padded_sequence_backward( + const at::Tensor &grad, at::IntArrayRef input_size, + const at::Tensor &batch_sizes, bool batch_first) { + std::cout << "aten::_pack_padded_sequence_backward" << std::endl; + std::vector mlirtens_tensors = {grad, batch_sizes}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_pack_padded_sequence_backward( + mlirtens[0], input_size, mlirtens[1], batch_first); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad)); +} + +std::tuple ATenMLIRTypeDefault::_pad_packed_sequence( + const at::Tensor &data, const at::Tensor &batch_sizes, bool batch_first, + at::Scalar padding_value, int64_t total_length) { + std::cout << "aten::_pad_packed_sequence" << std::endl; + std::vector mlirtens_tensors = {data, batch_sizes}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_pad_packed_sequence( + mlirtens[0], mlirtens[1], batch_first, padding_value, total_length); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(data)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(data))); +} + +at::Tensor &ATenMLIRTypeDefault::set_(at::Tensor &self, at::Storage source) { + std::cout << "aten::set_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].set_(source); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::set_(at::Tensor &self, at::Storage source, + int64_t storage_offset, + at::IntArrayRef size, + at::IntArrayRef stride) { + std::cout << "aten::set_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].set_(source, storage_offset, size, stride); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::set_(at::Tensor &self, + const at::Tensor &source) { + std::cout << "aten::set_" << std::endl; + std::vector mlirtens_tensors = {self, source}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].set_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::set_(at::Tensor &self) { + std::cout << "aten::set_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].set_(); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor & +ATenMLIRTypeDefault::set_quantizer_(at::Tensor &self, + at::ConstQuantizerPtr quantizer) { + std::cout << "aten::set_quantizer_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].set_quantizer_(quantizer); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +bool ATenMLIRTypeDefault::is_set_to(const at::Tensor &self, + const at::Tensor &tensor) { + std::cout << "aten::is_set_to" << std::endl; + std::vector mlirtens_tensors = {self, tensor}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].is_set_to(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor &ATenMLIRTypeDefault::masked_fill_(at::Tensor &self, + const at::Tensor &mask, + at::Scalar value) { + std::cout << "aten::masked_fill_" << std::endl; + std::vector mlirtens_tensors = {self, mask}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].masked_fill_(mlirtens[1], value); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::masked_fill(const at::Tensor &self, + const at::Tensor &mask, + at::Scalar value) { + std::cout << "aten::masked_fill" << std::endl; + std::vector mlirtens_tensors = {self, mask}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::masked_fill(mlirtens[0], mlirtens[1], value); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::masked_fill_(at::Tensor &self, + const at::Tensor &mask, + const at::Tensor &value) { + std::cout << "aten::masked_fill_" << std::endl; + std::vector mlirtens_tensors = {self, mask, value}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].masked_fill_(mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::masked_fill(const at::Tensor &self, + const at::Tensor &mask, + const at::Tensor &value) { + std::cout << "aten::masked_fill" << std::endl; + std::vector mlirtens_tensors = {self, mask, value}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::masked_fill(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::masked_scatter_(at::Tensor &self, + const at::Tensor &mask, + const at::Tensor &source) { + std::cout << "aten::masked_scatter_" << std::endl; + std::vector mlirtens_tensors = {self, mask, source}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].masked_scatter_(mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::masked_scatter(const at::Tensor &self, + const at::Tensor &mask, + const at::Tensor &source) { + std::cout << "aten::masked_scatter" << std::endl; + std::vector mlirtens_tensors = {self, mask, source}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::masked_scatter(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::view(const at::Tensor &self, + at::IntArrayRef size) { + std::cout << "aten::view" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].view(size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::put_(at::Tensor &self, const at::Tensor &index, + const at::Tensor &source, + bool accumulate) { + std::cout << "aten::put_" << std::endl; + std::vector mlirtens_tensors = {self, index, source}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].put_(mlirtens[1], mlirtens[2], accumulate); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::index_add_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &source) { + std::cout << "aten::index_add_" << std::endl; + std::vector mlirtens_tensors = {self, index, source}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].index_add_(dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::index_add(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &source) { + std::cout << "aten::index_add" << std::endl; + std::vector mlirtens_tensors = {self, index, source}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::index_add(mlirtens[0], dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::index_fill_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + at::Scalar value) { + std::cout << "aten::index_fill_" << std::endl; + std::vector mlirtens_tensors = {self, index}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].index_fill_(dim, mlirtens[1], value); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::index_fill(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + at::Scalar value) { + std::cout << "aten::index_fill" << std::endl; + std::vector mlirtens_tensors = {self, index}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::index_fill(mlirtens[0], dim, mlirtens[1], value); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::index_fill_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &value) { + std::cout << "aten::index_fill_" << std::endl; + std::vector mlirtens_tensors = {self, index, value}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].index_fill_(dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::index_fill(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &value) { + std::cout << "aten::index_fill" << std::endl; + std::vector mlirtens_tensors = {self, index, value}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::index_fill(mlirtens[0], dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::scatter_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &src) { + std::cout << "aten::scatter_" << std::endl; + std::vector mlirtens_tensors = {self, index, src}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].scatter_(dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::scatter(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &src) { + std::cout << "aten::scatter" << std::endl; + std::vector mlirtens_tensors = {self, index, src}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::scatter(mlirtens[0], dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::scatter_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + at::Scalar value) { + std::cout << "aten::scatter_" << std::endl; + std::vector mlirtens_tensors = {self, index}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].scatter_(dim, mlirtens[1], value); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::scatter(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + at::Scalar value) { + std::cout << "aten::scatter" << std::endl; + std::vector mlirtens_tensors = {self, index}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::scatter(mlirtens[0], dim, mlirtens[1], value); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::scatter_add_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &src) { + std::cout << "aten::scatter_add_" << std::endl; + std::vector mlirtens_tensors = {self, index, src}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].scatter_add_(dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::scatter_add(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &src) { + std::cout << "aten::scatter_add" << std::endl; + std::vector mlirtens_tensors = {self, index, src}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::scatter_add(mlirtens[0], dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::lt_(at::Tensor &self, at::Scalar other) { + std::cout << "aten::lt_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].lt_(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::lt_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::lt_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].lt_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::gt_(at::Tensor &self, at::Scalar other) { + std::cout << "aten::gt_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].gt_(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::gt_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::gt_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].gt_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::le_(at::Tensor &self, at::Scalar other) { + std::cout << "aten::le_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].le_(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::le_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::le_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].le_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::ge_(at::Tensor &self, at::Scalar other) { + std::cout << "aten::ge_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].ge_(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::ge_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::ge_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].ge_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::eq_(at::Tensor &self, at::Scalar other) { + std::cout << "aten::eq_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].eq_(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::eq_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::eq_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].eq_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::ne_(at::Tensor &self, at::Scalar other) { + std::cout << "aten::ne_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].ne_(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::ne_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::ne_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].ne_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::__and__(const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::__and__" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::__and__(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::__and__(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::__and__" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::__and__(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::__iand__(at::Tensor &self, at::Scalar other) { + std::cout << "aten::__iand__" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].__iand__(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::__iand__(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::__iand__" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].__iand__(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::__or__(const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::__or__" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::__or__(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::__or__(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::__or__" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::__or__(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::__ior__(at::Tensor &self, at::Scalar other) { + std::cout << "aten::__ior__" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].__ior__(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::__ior__(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::__ior__" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].__ior__(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::__xor__(const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::__xor__" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::__xor__(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::__xor__(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::__xor__" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::__xor__(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::__ixor__(at::Tensor &self, at::Scalar other) { + std::cout << "aten::__ixor__" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].__ixor__(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::__ixor__(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::__ixor__" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].__ixor__(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::__lshift__(const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::__lshift__" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::__lshift__(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::__lshift__(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::__lshift__" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::__lshift__(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::__ilshift__(at::Tensor &self, + at::Scalar other) { + std::cout << "aten::__ilshift__" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].__ilshift__(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::__ilshift__(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::__ilshift__" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].__ilshift__(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::__rshift__(const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::__rshift__" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::__rshift__(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::__rshift__(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::__rshift__" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::__rshift__(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::__irshift__(at::Tensor &self, + at::Scalar other) { + std::cout << "aten::__irshift__" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].__irshift__(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::__irshift__(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::__irshift__" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].__irshift__(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::lgamma_(at::Tensor &self) { + std::cout << "aten::lgamma_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].lgamma_(); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::atan2_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::atan2_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].atan2_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::tril_(at::Tensor &self, int64_t diagonal) { + std::cout << "aten::tril_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].tril_(diagonal); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::triu_(at::Tensor &self, int64_t diagonal) { + std::cout << "aten::triu_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].triu_(diagonal); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::digamma_(at::Tensor &self) { + std::cout << "aten::digamma_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].digamma_(); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::polygamma_(at::Tensor &self, int64_t n) { + std::cout << "aten::polygamma_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].polygamma_(n); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::renorm_(at::Tensor &self, at::Scalar p, + int64_t dim, at::Scalar maxnorm) { + std::cout << "aten::renorm_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].renorm_(p, dim, maxnorm); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::pow_(at::Tensor &self, at::Scalar exponent) { + std::cout << "aten::pow_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].pow_(exponent); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::pow_(at::Tensor &self, + const at::Tensor &exponent) { + std::cout << "aten::pow_" << std::endl; + std::vector mlirtens_tensors = {self, exponent}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].pow_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::lerp_(at::Tensor &self, const at::Tensor &end, + at::Scalar weight) { + std::cout << "aten::lerp_" << std::endl; + std::vector mlirtens_tensors = {self, end}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].lerp_(mlirtens[1], weight); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::lerp_(at::Tensor &self, const at::Tensor &end, + const at::Tensor &weight) { + std::cout << "aten::lerp_" << std::endl; + std::vector mlirtens_tensors = {self, end, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].lerp_(mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::fmod_(at::Tensor &self, at::Scalar other) { + std::cout << "aten::fmod_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].fmod_(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::fmod_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::fmod_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].fmod_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::remainder_(at::Tensor &self, + at::Scalar other) { + std::cout << "aten::remainder_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].remainder_(other); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::remainder_(at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::remainder_" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].remainder_(mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::addbmm_(at::Tensor &self, + const at::Tensor &batch1, + const at::Tensor &batch2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::addbmm_" << std::endl; + std::vector mlirtens_tensors = {self, batch1, batch2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].addbmm_(mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::addbmm_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &batch1, + const at::Tensor &batch2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::addbmm_out" << std::endl; + std::vector mlirtens_tensors = {out, self, batch1, batch2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::addbmm_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::addbmm(const at::Tensor &self, + const at::Tensor &batch1, + const at::Tensor &batch2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::addbmm" << std::endl; + std::vector mlirtens_tensors = {self, batch1, batch2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::addbmm(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::addcdiv_(at::Tensor &self, + const at::Tensor &tensor1, + const at::Tensor &tensor2, + at::Scalar value) { + std::cout << "aten::addcdiv_" << std::endl; + std::vector mlirtens_tensors = {self, tensor1, tensor2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].addcdiv_(mlirtens[1], mlirtens[2], value); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::random_(at::Tensor &self, int64_t from, + int64_t to, at::Generator *generator) { + std::cout << "aten::random_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].random_(from, to, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::random_(at::Tensor &self, int64_t to, + at::Generator *generator) { + std::cout << "aten::random_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].random_(to, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::random_(at::Tensor &self, + at::Generator *generator) { + std::cout << "aten::random_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].random_(generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::uniform_(at::Tensor &self, double from, + double to, at::Generator *generator) { + std::cout << "aten::uniform_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].uniform_(from, to, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::normal_(at::Tensor &self, double mean, + double std, at::Generator *generator) { + std::cout << "aten::normal_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].normal_(mean, std, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::cauchy_(at::Tensor &self, double median, + double sigma, + at::Generator *generator) { + std::cout << "aten::cauchy_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].cauchy_(median, sigma, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::log_normal_(at::Tensor &self, double mean, + double std, + at::Generator *generator) { + std::cout << "aten::log_normal_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].log_normal_(mean, std, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::exponential_(at::Tensor &self, double lambd, + at::Generator *generator) { + std::cout << "aten::exponential_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].exponential_(lambd, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::geometric_(at::Tensor &self, double p, + at::Generator *generator) { + std::cout << "aten::geometric_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].geometric_(p, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::diag_out(at::Tensor &out, + const at::Tensor &self, + int64_t diagonal) { + std::cout << "aten::diag_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::diag_out(mlirtens[0], mlirtens[1], diagonal); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::diag(const at::Tensor &self, int64_t diagonal) { + std::cout << "aten::diag" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::diag(mlirtens[0], diagonal); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::cross_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other, + c10::optional dim) { + std::cout << "aten::cross_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cross_out(mlirtens[0], mlirtens[1], mlirtens[2], dim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::cross(const at::Tensor &self, + const at::Tensor &other, + c10::optional dim) { + std::cout << "aten::cross" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cross(mlirtens[0], mlirtens[1], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::triu_out(at::Tensor &out, + const at::Tensor &self, + int64_t diagonal) { + std::cout << "aten::triu_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::triu_out(mlirtens[0], mlirtens[1], diagonal); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::triu(const at::Tensor &self, int64_t diagonal) { + std::cout << "aten::triu" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::triu(mlirtens[0], diagonal); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::tril_out(at::Tensor &out, + const at::Tensor &self, + int64_t diagonal) { + std::cout << "aten::tril_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::tril_out(mlirtens[0], mlirtens[1], diagonal); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::tril(const at::Tensor &self, int64_t diagonal) { + std::cout << "aten::tril" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::tril(mlirtens[0], diagonal); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::tril_indices(int64_t row, int64_t col, + int64_t offset, + const at::TensorOptions &options) { + std::cout << "aten::tril_indices" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::tril_indices(row, col, offset, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::triu_indices(int64_t row, int64_t col, + int64_t offset, + const at::TensorOptions &options) { + std::cout << "aten::triu_indices" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::triu_indices(row, col, offset, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor ATenMLIRTypeDefault::trace(const at::Tensor &self) { + std::cout << "aten::trace" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::trace(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::ne_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::ne_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ne_out(mlirtens[0], mlirtens[1], other); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::ne(const at::Tensor &self, at::Scalar other) { + std::cout << "aten::ne" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ne(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::ne_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::ne_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ne_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::ne(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::ne" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ne(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::eq_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::eq_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::eq_out(mlirtens[0], mlirtens[1], other); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::eq(const at::Tensor &self, at::Scalar other) { + std::cout << "aten::eq" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::eq(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::eq_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::eq_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::eq_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::eq(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::eq" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::eq(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::ge_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::ge_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ge_out(mlirtens[0], mlirtens[1], other); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::ge(const at::Tensor &self, at::Scalar other) { + std::cout << "aten::ge" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ge(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::ge_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::ge_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ge_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::ge(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::ge" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ge(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::le_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::le_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::le_out(mlirtens[0], mlirtens[1], other); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::le(const at::Tensor &self, at::Scalar other) { + std::cout << "aten::le" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::le(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::le_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::le_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::le_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::le(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::le" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::le(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::gt_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::gt_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::gt_out(mlirtens[0], mlirtens[1], other); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::gt(const at::Tensor &self, at::Scalar other) { + std::cout << "aten::gt" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::gt(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::gt_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::gt_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::gt_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::gt(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::gt" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::gt(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::lt_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::lt_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lt_out(mlirtens[0], mlirtens[1], other); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::lt(const at::Tensor &self, at::Scalar other) { + std::cout << "aten::lt" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lt(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::lt_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::lt_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lt_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::lt(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::lt" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lt(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::take_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &index) { + std::cout << "aten::take_out" << std::endl; + std::vector mlirtens_tensors = {out, self, index}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::take_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::take(const at::Tensor &self, + const at::Tensor &index) { + std::cout << "aten::take" << std::endl; + std::vector mlirtens_tensors = {self, index}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::take(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::index_select_out(at::Tensor &out, + const at::Tensor &self, + int64_t dim, + const at::Tensor &index) { + std::cout << "aten::index_select_out" << std::endl; + std::vector mlirtens_tensors = {out, self, index}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::index_select_out(mlirtens[0], mlirtens[1], dim, mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::index_select(const at::Tensor &self, + int64_t dim, + const at::Tensor &index) { + std::cout << "aten::index_select" << std::endl; + std::vector mlirtens_tensors = {self, index}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::index_select(mlirtens[0], dim, mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::masked_select_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &mask) { + std::cout << "aten::masked_select_out" << std::endl; + std::vector mlirtens_tensors = {out, self, mask}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::masked_select_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::masked_select(const at::Tensor &self, + const at::Tensor &mask) { + std::cout << "aten::masked_select" << std::endl; + std::vector mlirtens_tensors = {self, mask}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::masked_select(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::nonzero_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::nonzero_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nonzero_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::nonzero(const at::Tensor &self) { + std::cout << "aten::nonzero" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nonzero(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::vector +ATenMLIRTypeDefault::nonzero_numpy(const at::Tensor &self) { + std::cout << "aten::nonzero_numpy" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nonzero_numpy(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor &ATenMLIRTypeDefault::gather_out(at::Tensor &out, + const at::Tensor &self, int64_t dim, + const at::Tensor &index, + bool sparse_grad) { + std::cout << "aten::gather_out" << std::endl; + std::vector mlirtens_tensors = {out, self, index}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::gather_out(mlirtens[0], mlirtens[1], dim, mlirtens[2], sparse_grad); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::gather(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + bool sparse_grad) { + std::cout << "aten::gather" << std::endl; + std::vector mlirtens_tensors = {self, index}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::gather(mlirtens[0], dim, mlirtens[1], sparse_grad); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_gather_sparse_backward( + const at::Tensor &self, int64_t dim, const at::Tensor &index, + const at::Tensor &grad) { + std::cout << "aten::_gather_sparse_backward" << std::endl; + std::vector mlirtens_tensors = {self, index, grad}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_gather_sparse_backward(mlirtens[0], dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::addcmul_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &tensor1, + const at::Tensor &tensor2, + at::Scalar value) { + std::cout << "aten::addcmul_out" << std::endl; + std::vector mlirtens_tensors = {out, self, tensor1, tensor2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::addcmul_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], value); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::addcmul(const at::Tensor &self, + const at::Tensor &tensor1, + const at::Tensor &tensor2, + at::Scalar value) { + std::cout << "aten::addcmul" << std::endl; + std::vector mlirtens_tensors = {self, tensor1, tensor2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::addcmul(mlirtens[0], mlirtens[1], mlirtens[2], value); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::addcmul_(at::Tensor &self, + const at::Tensor &tensor1, + const at::Tensor &tensor2, + at::Scalar value) { + std::cout << "aten::addcmul_" << std::endl; + std::vector mlirtens_tensors = {self, tensor1, tensor2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].addcmul_(mlirtens[1], mlirtens[2], value); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::addcdiv_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &tensor1, + const at::Tensor &tensor2, + at::Scalar value) { + std::cout << "aten::addcdiv_out" << std::endl; + std::vector mlirtens_tensors = {out, self, tensor1, tensor2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::addcdiv_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], value); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::addcdiv(const at::Tensor &self, + const at::Tensor &tensor1, + const at::Tensor &tensor2, + at::Scalar value) { + std::cout << "aten::addcdiv" << std::endl; + std::vector mlirtens_tensors = {self, tensor1, tensor2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::addcdiv(mlirtens[0], mlirtens[1], mlirtens[2], value); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::lstsq_out(at::Tensor &X, at::Tensor &qr, + const at::Tensor &self, const at::Tensor &A) { + std::cout << "aten::lstsq_out" << std::endl; + std::vector mlirtens_tensors = {X, qr, self, A}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::lstsq_out(mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(X, qr); +} + +std::tuple +ATenMLIRTypeDefault::lstsq(const at::Tensor &self, const at::Tensor &A) { + std::cout << "aten::lstsq" << std::endl; + std::vector mlirtens_tensors = {self, A}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lstsq(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::triangular_solve_out(at::Tensor &X, at::Tensor &M, + const at::Tensor &self, + const at::Tensor &A, bool upper, + bool transpose, bool unitriangular) { + std::cout << "aten::triangular_solve_out" << std::endl; + std::vector mlirtens_tensors = {X, M, self, A}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::triangular_solve_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], upper, transpose, unitriangular); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(X, M); +} + +std::tuple +ATenMLIRTypeDefault::triangular_solve(const at::Tensor &self, + const at::Tensor &A, bool upper, + bool transpose, bool unitriangular) { + std::cout << "aten::triangular_solve" << std::endl; + std::vector mlirtens_tensors = {self, A}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::triangular_solve(mlirtens[0], mlirtens[1], upper, + transpose, unitriangular); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::_triangular_solve_helper(const at::Tensor &self, + const at::Tensor &A, bool upper, + bool transpose, + bool unitriangular) { + std::cout << "aten::_triangular_solve_helper" << std::endl; + std::vector mlirtens_tensors = {self, A}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_triangular_solve_helper( + mlirtens[0], mlirtens[1], upper, transpose, unitriangular); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::symeig_out(at::Tensor &e, at::Tensor &V, + const at::Tensor &self, bool eigenvectors, + bool upper) { + std::cout << "aten::symeig_out" << std::endl; + std::vector mlirtens_tensors = {e, V, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::symeig_out(mlirtens[0], mlirtens[1], mlirtens[2], + eigenvectors, upper); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(e, V); +} + +std::tuple +ATenMLIRTypeDefault::symeig(const at::Tensor &self, bool eigenvectors, + bool upper) { + std::cout << "aten::symeig" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::symeig(mlirtens[0], eigenvectors, upper); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::_symeig_helper(const at::Tensor &self, bool eigenvectors, + bool upper) { + std::cout << "aten::_symeig_helper" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_symeig_helper(mlirtens[0], eigenvectors, upper); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::eig_out(at::Tensor &e, at::Tensor &v, + const at::Tensor &self, bool eigenvectors) { + std::cout << "aten::eig_out" << std::endl; + std::vector mlirtens_tensors = {e, v, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::eig_out(mlirtens[0], mlirtens[1], mlirtens[2], eigenvectors); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(e, v); +} + +std::tuple +ATenMLIRTypeDefault::eig(const at::Tensor &self, bool eigenvectors) { + std::cout << "aten::eig" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::eig(mlirtens[0], eigenvectors); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::svd_out(at::Tensor &U, at::Tensor &S, at::Tensor &V, + const at::Tensor &self, bool some, + bool compute_uv) { + std::cout << "aten::svd_out" << std::endl; + std::vector mlirtens_tensors = {U, S, V, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::svd_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], some, compute_uv); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(U, S, V); +} + +std::tuple +ATenMLIRTypeDefault::svd(const at::Tensor &self, bool some, bool compute_uv) { + std::cout << "aten::svd" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::svd(mlirtens[0], some, compute_uv); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::_svd_helper(const at::Tensor &self, bool some, + bool compute_uv) { + std::cout << "aten::_svd_helper" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_svd_helper(mlirtens[0], some, compute_uv); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::cholesky_out(at::Tensor &out, + const at::Tensor &self, + bool upper) { + std::cout << "aten::cholesky_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cholesky_out(mlirtens[0], mlirtens[1], upper); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::cholesky(const at::Tensor &self, bool upper) { + std::cout << "aten::cholesky" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cholesky(mlirtens[0], upper); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_cholesky_helper(const at::Tensor &self, + bool upper) { + std::cout << "aten::_cholesky_helper" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cholesky_helper(mlirtens[0], upper); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::cholesky_solve_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &input2, + bool upper) { + std::cout << "aten::cholesky_solve_out" << std::endl; + std::vector mlirtens_tensors = {out, self, input2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::cholesky_solve_out(mlirtens[0], mlirtens[1], mlirtens[2], upper); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::cholesky_solve(const at::Tensor &self, + const at::Tensor &input2, + bool upper) { + std::cout << "aten::cholesky_solve" << std::endl; + std::vector mlirtens_tensors = {self, input2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cholesky_solve(mlirtens[0], mlirtens[1], upper); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_cholesky_solve_helper(const at::Tensor &self, + const at::Tensor &A, + bool upper) { + std::cout << "aten::_cholesky_solve_helper" << std::endl; + std::vector mlirtens_tensors = {self, A}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cholesky_solve_helper(mlirtens[0], mlirtens[1], upper); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::solve(const at::Tensor &self, const at::Tensor &A) { + std::cout << "aten::solve" << std::endl; + std::vector mlirtens_tensors = {self, A}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::solve(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::solve_out(at::Tensor &solution, at::Tensor &lu, + const at::Tensor &self, const at::Tensor &A) { + std::cout << "aten::solve_out" << std::endl; + std::vector mlirtens_tensors = {solution, lu, self, A}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::solve_out(mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(solution, lu); +} + +std::tuple +ATenMLIRTypeDefault::_solve_helper(const at::Tensor &self, + const at::Tensor &A) { + std::cout << "aten::_solve_helper" << std::endl; + std::vector mlirtens_tensors = {self, A}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_solve_helper(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::cholesky_inverse_out(at::Tensor &out, + const at::Tensor &self, + bool upper) { + std::cout << "aten::cholesky_inverse_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cholesky_inverse_out(mlirtens[0], mlirtens[1], upper); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::cholesky_inverse(const at::Tensor &self, + bool upper) { + std::cout << "aten::cholesky_inverse" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::cholesky_inverse(mlirtens[0], upper); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::qr_out(at::Tensor &Q, at::Tensor &R, + const at::Tensor &self, bool some) { + std::cout << "aten::qr_out" << std::endl; + std::vector mlirtens_tensors = {Q, R, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::qr_out(mlirtens[0], mlirtens[1], mlirtens[2], some); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(Q, R); +} + +std::tuple +ATenMLIRTypeDefault::qr(const at::Tensor &self, bool some) { + std::cout << "aten::qr" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::qr(mlirtens[0], some); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::_qr_helper(const at::Tensor &self, bool some) { + std::cout << "aten::_qr_helper" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_qr_helper(mlirtens[0], some); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::geqrf_out(at::Tensor &a, at::Tensor &tau, + const at::Tensor &self) { + std::cout << "aten::geqrf_out" << std::endl; + std::vector mlirtens_tensors = {a, tau, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::geqrf_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(a, tau); +} + +std::tuple +ATenMLIRTypeDefault::geqrf(const at::Tensor &self) { + std::cout << "aten::geqrf" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::geqrf(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::orgqr_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &input2) { + std::cout << "aten::orgqr_out" << std::endl; + std::vector mlirtens_tensors = {out, self, input2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::orgqr_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::orgqr(const at::Tensor &self, + const at::Tensor &input2) { + std::cout << "aten::orgqr" << std::endl; + std::vector mlirtens_tensors = {self, input2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::orgqr(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::ormqr_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &input2, + const at::Tensor &input3, bool left, + bool transpose) { + std::cout << "aten::ormqr_out" << std::endl; + std::vector mlirtens_tensors = {out, self, input2, input3}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::ormqr_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], left, transpose); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::ormqr(const at::Tensor &self, + const at::Tensor &input2, + const at::Tensor &input3, bool left, + bool transpose) { + std::cout << "aten::ormqr" << std::endl; + std::vector mlirtens_tensors = {self, input2, input3}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::ormqr(mlirtens[0], mlirtens[1], mlirtens[2], left, transpose); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::_lu_with_info(const at::Tensor &self, bool pivot, + bool check_errors) { + std::cout << "aten::_lu_with_info" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_lu_with_info(mlirtens[0], pivot, check_errors); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::lu_solve_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &LU_data, + const at::Tensor &LU_pivots) { + std::cout << "aten::lu_solve_out" << std::endl; + std::vector mlirtens_tensors = {out, self, LU_data, LU_pivots}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::lu_solve_out(mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::lu_solve(const at::Tensor &self, + const at::Tensor &LU_data, + const at::Tensor &LU_pivots) { + std::cout << "aten::lu_solve" << std::endl; + std::vector mlirtens_tensors = {self, LU_data, LU_pivots}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lu_solve(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_lu_solve_helper(const at::Tensor &self, + const at::Tensor &LU_data, + const at::Tensor &LU_pivots) { + std::cout << "aten::_lu_solve_helper" << std::endl; + std::vector mlirtens_tensors = {self, LU_data, LU_pivots}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_lu_solve_helper(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::multinomial_out(at::Tensor &out, + const at::Tensor &self, + int64_t num_samples, + bool replacement, + at::Generator *generator) { + std::cout << "aten::multinomial_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::multinomial_out(mlirtens[0], mlirtens[1], num_samples, + replacement, generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::multinomial(const at::Tensor &self, + int64_t num_samples, + bool replacement, + at::Generator *generator) { + std::cout << "aten::multinomial" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::multinomial(mlirtens[0], num_samples, replacement, generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::_multinomial_alias_setup(const at::Tensor &probs) { + std::cout << "aten::_multinomial_alias_setup" << std::endl; + std::vector mlirtens_tensors = {probs}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_multinomial_alias_setup(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(probs)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(probs))); +} + +at::Tensor ATenMLIRTypeDefault::_multinomial_alias_draw( + const at::Tensor &J, const at::Tensor &q, int64_t num_samples, + at::Generator *generator) { + std::cout << "aten::_multinomial_alias_draw" << std::endl; + std::vector mlirtens_tensors = {J, q}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_multinomial_alias_draw(mlirtens[0], mlirtens[1], + num_samples, generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(J)); +} + +at::Tensor &ATenMLIRTypeDefault::lgamma_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::lgamma_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lgamma_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::lgamma(const at::Tensor &self) { + std::cout << "aten::lgamma" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lgamma(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::digamma_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::digamma_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::digamma_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::digamma(const at::Tensor &self) { + std::cout << "aten::digamma" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::digamma(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::polygamma_out(at::Tensor &out, int64_t n, + const at::Tensor &self) { + std::cout << "aten::polygamma_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::polygamma_out(mlirtens[0], n, mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::polygamma(int64_t n, const at::Tensor &self) { + std::cout << "aten::polygamma" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::polygamma(n, mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::erfinv(const at::Tensor &self) { + std::cout << "aten::erfinv" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::erfinv(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::erfinv_(at::Tensor &self) { + std::cout << "aten::erfinv_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].erfinv_(); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::erfinv_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::erfinv_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::erfinv_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::sign(const at::Tensor &self) { + std::cout << "aten::sign" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sign(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::sign_(at::Tensor &self) { + std::cout << "aten::sign_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].sign_(); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::sign_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::sign_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sign_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::dist(const at::Tensor &self, + const at::Tensor &other, at::Scalar p) { + std::cout << "aten::dist" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::dist(mlirtens[0], mlirtens[1], p); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::atan2_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::atan2_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::atan2_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::atan2(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::atan2" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::atan2(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::lerp_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &end, + at::Scalar weight) { + std::cout << "aten::lerp_out" << std::endl; + std::vector mlirtens_tensors = {out, self, end}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lerp_out(mlirtens[0], mlirtens[1], mlirtens[2], weight); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::lerp_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &end, + const at::Tensor &weight) { + std::cout << "aten::lerp_out" << std::endl; + std::vector mlirtens_tensors = {out, self, end, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::lerp_out(mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::lerp(const at::Tensor &self, + const at::Tensor &end, at::Scalar weight) { + std::cout << "aten::lerp" << std::endl; + std::vector mlirtens_tensors = {self, end}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lerp(mlirtens[0], mlirtens[1], weight); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::lerp(const at::Tensor &self, + const at::Tensor &end, + const at::Tensor &weight) { + std::cout << "aten::lerp" << std::endl; + std::vector mlirtens_tensors = {self, end, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::lerp(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::histc_out(at::Tensor &out, + const at::Tensor &self, int64_t bins, + at::Scalar min, at::Scalar max) { + std::cout << "aten::histc_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::histc_out(mlirtens[0], mlirtens[1], bins, min, max); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::histc(const at::Tensor &self, int64_t bins, + at::Scalar min, at::Scalar max) { + std::cout << "aten::histc" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::histc(mlirtens[0], bins, min, max); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::fmod_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::fmod_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fmod_out(mlirtens[0], mlirtens[1], other); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::fmod(const at::Tensor &self, at::Scalar other) { + std::cout << "aten::fmod" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fmod(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::fmod_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::fmod_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fmod_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::fmod(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::fmod" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fmod(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::remainder_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::remainder_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::remainder_out(mlirtens[0], mlirtens[1], other); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::remainder(const at::Tensor &self, + at::Scalar other) { + std::cout << "aten::remainder" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::remainder(mlirtens[0], other); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::remainder_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::remainder_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::remainder_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::remainder(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::remainder" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::remainder(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::min_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::min_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::min_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::min(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::min" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::min(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::min(const at::Tensor &self) { + std::cout << "aten::min" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::min(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::max_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::max_out" << std::endl; + std::vector mlirtens_tensors = {out, self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::max(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::max" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::max(const at::Tensor &self) { + std::cout << "aten::max" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::median(const at::Tensor &self) { + std::cout << "aten::median" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::median(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::sort_out(at::Tensor &values, at::Tensor &indices, + const at::Tensor &self, int64_t dim, + bool descending) { + std::cout << "aten::sort_out" << std::endl; + std::vector mlirtens_tensors = {values, indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::sort_out(mlirtens[0], mlirtens[1], mlirtens[2], dim, descending); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(values, indices); +} + +std::tuple +ATenMLIRTypeDefault::sort(const at::Tensor &self, int64_t dim, + bool descending) { + std::cout << "aten::sort" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sort(mlirtens[0], dim, descending); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::argsort(const at::Tensor &self, int64_t dim, + bool descending) { + std::cout << "aten::argsort" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::argsort(mlirtens[0], dim, descending); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::topk_out(at::Tensor &values, at::Tensor &indices, + const at::Tensor &self, int64_t k, int64_t dim, + bool largest, bool sorted) { + std::cout << "aten::topk_out" << std::endl; + std::vector mlirtens_tensors = {values, indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::topk_out(mlirtens[0], mlirtens[1], mlirtens[2], k, dim, + largest, sorted); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(values, indices); +} + +std::tuple +ATenMLIRTypeDefault::topk(const at::Tensor &self, int64_t k, int64_t dim, + bool largest, bool sorted) { + std::cout << "aten::topk" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::topk(mlirtens[0], k, dim, largest, sorted); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor ATenMLIRTypeDefault::all(const at::Tensor &self) { + std::cout << "aten::all" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::all(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::any(const at::Tensor &self) { + std::cout << "aten::any" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::any(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::renorm_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar p, int64_t dim, + at::Scalar maxnorm) { + std::cout << "aten::renorm_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::renorm_out(mlirtens[0], mlirtens[1], p, dim, maxnorm); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::renorm(const at::Tensor &self, at::Scalar p, + int64_t dim, at::Scalar maxnorm) { + std::cout << "aten::renorm" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::renorm(mlirtens[0], p, dim, maxnorm); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::unfold(const at::Tensor &self, + int64_t dimension, int64_t size, + int64_t step) { + std::cout << "aten::unfold" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = mlirtens[0].unfold(dimension, size, step); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +bool ATenMLIRTypeDefault::equal(const at::Tensor &self, + const at::Tensor &other) { + std::cout << "aten::equal" << std::endl; + std::vector mlirtens_tensors = {self, other}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::equal(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return x_result; +} + +at::Tensor &ATenMLIRTypeDefault::pow_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &exponent) { + std::cout << "aten::pow_out" << std::endl; + std::vector mlirtens_tensors = {out, self, exponent}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::pow_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::pow(const at::Tensor &self, + const at::Tensor &exponent) { + std::cout << "aten::pow" << std::endl; + std::vector mlirtens_tensors = {self, exponent}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::pow(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::pow_out(at::Tensor &out, at::Scalar self, + const at::Tensor &exponent) { + std::cout << "aten::pow_out" << std::endl; + std::vector mlirtens_tensors = {out, exponent}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::pow_out(mlirtens[0], self, mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::pow(at::Scalar self, + const at::Tensor &exponent) { + std::cout << "aten::pow" << std::endl; + std::vector mlirtens_tensors = {exponent}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::pow(self, mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(exponent)); +} + +at::Tensor &ATenMLIRTypeDefault::normal_out(at::Tensor &out, + const at::Tensor &mean, double std, + at::Generator *generator) { + std::cout << "aten::normal_out" << std::endl; + std::vector mlirtens_tensors = {out, mean}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::normal_out(mlirtens[0], mlirtens[1], std, generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::normal(const at::Tensor &mean, double std, + at::Generator *generator) { + std::cout << "aten::normal" << std::endl; + std::vector mlirtens_tensors = {mean}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::normal(mlirtens[0], std, generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(mean)); +} + +at::Tensor &ATenMLIRTypeDefault::normal_out(at::Tensor &out, double mean, + const at::Tensor &std, + at::Generator *generator) { + std::cout << "aten::normal_out" << std::endl; + std::vector mlirtens_tensors = {out, std}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::normal_out(mlirtens[0], mean, mlirtens[1], generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::normal(double mean, const at::Tensor &std, + at::Generator *generator) { + std::cout << "aten::normal" << std::endl; + std::vector mlirtens_tensors = {std}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::normal(mean, mlirtens[0], generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(std)); +} + +at::Tensor &ATenMLIRTypeDefault::normal_out(at::Tensor &out, + const at::Tensor &mean, + const at::Tensor &std, + at::Generator *generator) { + std::cout << "aten::normal_out" << std::endl; + std::vector mlirtens_tensors = {out, mean, std}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::normal_out(mlirtens[0], mlirtens[1], mlirtens[2], generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::normal(const at::Tensor &mean, + const at::Tensor &std, + at::Generator *generator) { + std::cout << "aten::normal" << std::endl; + std::vector mlirtens_tensors = {mean, std}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::normal(mlirtens[0], mlirtens[1], generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(mean)); +} + +at::Tensor ATenMLIRTypeDefault::normal(double mean, double std, + at::IntArrayRef size, + at::Generator *generator, + const at::TensorOptions &options) { + std::cout << "aten::normal" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::normal(mean, std, size, generator, options); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(options)); +} + +at::Tensor &ATenMLIRTypeDefault::normal_out(at::Tensor &out, double mean, + double std, at::IntArrayRef size, + at::Generator *generator) { + std::cout << "aten::normal_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::normal_out(mlirtens[0], mean, std, size, generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::alias(const at::Tensor &self) { + std::cout << "aten::alias" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::alias(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_addr(const at::Tensor &self, + const at::Tensor &vec1, + const at::Tensor &vec2, at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::_addr" << std::endl; + std::vector mlirtens_tensors = {self, vec1, vec2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_addr(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::_addr_(at::Tensor &self, + const at::Tensor &vec1, + const at::Tensor &vec2, at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::_addr_" << std::endl; + std::vector mlirtens_tensors = {self, vec1, vec2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_addr_(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::_addr_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &vec1, + const at::Tensor &vec2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::_addr_out" << std::endl; + std::vector mlirtens_tensors = {out, self, vec1, vec2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_addr_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor &ATenMLIRTypeDefault::_index_copy_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &source) { + std::cout << "aten::_index_copy_" << std::endl; + std::vector mlirtens_tensors = {self, index, source}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_index_copy_(mlirtens[0], dim, mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::_cumsum(const at::Tensor &self, int64_t dim) { + std::cout << "aten::_cumsum" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cumsum(mlirtens[0], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::_cumsum_out(at::Tensor &out, + const at::Tensor &self, + int64_t dim) { + std::cout << "aten::_cumsum_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cumsum_out(mlirtens[0], mlirtens[1], dim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::_cumprod(const at::Tensor &self, int64_t dim) { + std::cout << "aten::_cumprod" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cumprod(mlirtens[0], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::_cumprod_out(at::Tensor &out, + const at::Tensor &self, + int64_t dim) { + std::cout << "aten::_cumprod_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cumprod_out(mlirtens[0], mlirtens[1], dim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::_var(const at::Tensor &self, bool unbiased) { + std::cout << "aten::_var" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_var(mlirtens[0], unbiased); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_std(const at::Tensor &self, bool unbiased) { + std::cout << "aten::_std" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_std(mlirtens[0], unbiased); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::_addmm_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::_addmm_out" << std::endl; + std::vector mlirtens_tensors = {out, self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_addmm_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::_addmm(const at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha) { + std::cout << "aten::_addmm" << std::endl; + std::vector mlirtens_tensors = {self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_addmm(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::_addmm_(at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, + at::Scalar beta, at::Scalar alpha) { + std::cout << "aten::_addmm_" << std::endl; + std::vector mlirtens_tensors = {self, mat1, mat2}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_addmm_(mlirtens[0], mlirtens[1], mlirtens[2], beta, alpha); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor ATenMLIRTypeDefault::_cat(at::TensorList tensors, int64_t dim) { + std::cout << "aten::_cat" << std::endl; + std::vector mlirtens_tensors = {}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cat(tensors, dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(tensors)); +} + +at::Tensor &ATenMLIRTypeDefault::_cat_out(at::Tensor &out, + at::TensorList tensors, int64_t dim) { + std::cout << "aten::_cat_out" << std::endl; + std::vector mlirtens_tensors = {out}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_cat_out(mlirtens[0], tensors, dim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +std::tuple +ATenMLIRTypeDefault::_mode(const at::Tensor &self, int64_t dim, bool keepdim) { + std::cout << "aten::_mode" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_mode(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::_mode_out(at::Tensor &values, at::Tensor &indices, + const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::_mode_out" << std::endl; + std::vector mlirtens_tensors = {values, indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_mode_out(mlirtens[0], mlirtens[1], mlirtens[2], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(values, indices); +} + +std::tuple +ATenMLIRTypeDefault::_max(const at::Tensor &self, int64_t dim, bool keepdim) { + std::cout << "aten::_max" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_max(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::_max_out(at::Tensor &max, at::Tensor &max_indices, + const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::_max_out" << std::endl; + std::vector mlirtens_tensors = {max, max_indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_max_out(mlirtens[0], mlirtens[1], mlirtens[2], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(max, max_indices); +} + +std::tuple +ATenMLIRTypeDefault::_min(const at::Tensor &self, int64_t dim, bool keepdim) { + std::cout << "aten::_min" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_min(mlirtens[0], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::_min_out(at::Tensor &min, at::Tensor &min_indices, + const at::Tensor &self, int64_t dim, + bool keepdim) { + std::cout << "aten::_min_out" << std::endl; + std::vector mlirtens_tensors = {min, min_indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::_min_out(mlirtens[0], mlirtens[1], mlirtens[2], dim, keepdim); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(min, min_indices); +} + +at::Tensor &ATenMLIRTypeDefault::binary_cross_entropy_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction) { + std::cout << "aten::binary_cross_entropy_out" << std::endl; + std::vector mlirtens_tensors = {out, self, target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::binary_cross_entropy_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], reduction); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::binary_cross_entropy(const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, + int64_t reduction) { + std::cout << "aten::binary_cross_entropy" << std::endl; + std::vector mlirtens_tensors = {self, target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::binary_cross_entropy(mlirtens[0], mlirtens[1], + mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::binary_cross_entropy_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, const at::Tensor &weight, + int64_t reduction) { + std::cout << "aten::binary_cross_entropy_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::binary_cross_entropy_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + reduction); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::binary_cross_entropy_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, const at::Tensor &weight, int64_t reduction) { + std::cout << "aten::binary_cross_entropy_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target, + weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::binary_cross_entropy_backward( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::mse_loss_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::mse_loss_out" << std::endl; + std::vector mlirtens_tensors = {out, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::mse_loss_out(mlirtens[0], mlirtens[1], mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::mse_loss(const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::mse_loss" << std::endl; + std::vector mlirtens_tensors = {self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mse_loss(mlirtens[0], mlirtens[1], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::mse_loss_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, int64_t reduction) { + std::cout << "aten::mse_loss_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mse_loss_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], reduction); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::mse_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::mse_loss_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::mse_loss_backward(mlirtens[0], mlirtens[1], mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::l1_loss_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::l1_loss_out" << std::endl; + std::vector mlirtens_tensors = {out, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::l1_loss_out(mlirtens[0], mlirtens[1], mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::l1_loss(const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::l1_loss" << std::endl; + std::vector mlirtens_tensors = {self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::l1_loss(mlirtens[0], mlirtens[1], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::l1_loss_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, int64_t reduction) { + std::cout << "aten::l1_loss_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::l1_loss_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], reduction); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::l1_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::l1_loss_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::l1_loss_backward(mlirtens[0], mlirtens[1], mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::multi_margin_loss_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &target, + at::Scalar p, at::Scalar margin, const at::Tensor &weight, + int64_t reduction) { + std::cout << "aten::multi_margin_loss_out" << std::endl; + std::vector mlirtens_tensors = {out, self, target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::multi_margin_loss_out( + mlirtens[0], mlirtens[1], mlirtens[2], p, margin, mlirtens[3], reduction); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::multi_margin_loss( + const at::Tensor &self, const at::Tensor &target, at::Scalar p, + at::Scalar margin, const at::Tensor &weight, int64_t reduction) { + std::cout << "aten::multi_margin_loss" << std::endl; + std::vector mlirtens_tensors = {self, target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::multi_margin_loss(mlirtens[0], mlirtens[1], p, margin, + mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::multi_margin_loss_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, at::Scalar p, + at::Scalar margin, const at::Tensor &weight, int64_t reduction) { + std::cout << "aten::multi_margin_loss_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::multi_margin_loss_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], p, margin, + mlirtens[4], reduction); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::multi_margin_loss_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, at::Scalar p, at::Scalar margin, + const at::Tensor &weight, int64_t reduction) { + std::cout << "aten::multi_margin_loss_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target, + weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::multi_margin_loss_backward( + mlirtens[0], mlirtens[1], mlirtens[2], p, margin, mlirtens[3], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::multilabel_margin_loss_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::multilabel_margin_loss_out" << std::endl; + std::vector mlirtens_tensors = {out, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::multilabel_margin_loss_out(mlirtens[0], mlirtens[1], + mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::multilabel_margin_loss(const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::multilabel_margin_loss" << std::endl; + std::vector mlirtens_tensors = {self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::multilabel_margin_loss(mlirtens[0], mlirtens[1], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::multilabel_margin_loss_forward_out( + at::Tensor &output, at::Tensor &is_target, const at::Tensor &self, + const at::Tensor &target, int64_t reduction) { + std::cout << "aten::multilabel_margin_loss_forward_out" << std::endl; + std::vector mlirtens_tensors = {output, is_target, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::multilabel_margin_loss_forward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], reduction); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(output, is_target); +} + +std::tuple +ATenMLIRTypeDefault::multilabel_margin_loss_forward(const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::multilabel_margin_loss_forward" << std::endl; + std::vector mlirtens_tensors = {self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::multilabel_margin_loss_forward(mlirtens[0], mlirtens[1], reduction); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::multilabel_margin_loss_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, int64_t reduction, + const at::Tensor &is_target) { + std::cout << "aten::multilabel_margin_loss_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + target, is_target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::multilabel_margin_loss_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], reduction, + mlirtens[4]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::multilabel_margin_loss_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, int64_t reduction, const at::Tensor &is_target) { + std::cout << "aten::multilabel_margin_loss_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target, + is_target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::multilabel_margin_loss_backward( + mlirtens[0], mlirtens[1], mlirtens[2], reduction, mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::nll_loss_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, int64_t ignore_index) { + std::cout << "aten::nll_loss_out" << std::endl; + std::vector mlirtens_tensors = {out, self, target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], reduction, ignore_index); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::nll_loss(const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, + int64_t reduction, + int64_t ignore_index) { + std::cout << "aten::nll_loss" << std::endl; + std::vector mlirtens_tensors = {self, target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss(mlirtens[0], mlirtens[1], mlirtens[2], + reduction, ignore_index); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::nll_loss_forward_out( + at::Tensor &output, at::Tensor &total_weight, const at::Tensor &self, + const at::Tensor &target, const at::Tensor &weight, int64_t reduction, + int64_t ignore_index) { + std::cout << "aten::nll_loss_forward_out" << std::endl; + std::vector mlirtens_tensors = {output, total_weight, self, + target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss_forward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + reduction, ignore_index); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(output, total_weight); +} + +std::tuple ATenMLIRTypeDefault::nll_loss_forward( + const at::Tensor &self, const at::Tensor &target, const at::Tensor &weight, + int64_t reduction, int64_t ignore_index) { + std::cout << "aten::nll_loss_forward" << std::endl; + std::vector mlirtens_tensors = {self, target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss_forward(mlirtens[0], mlirtens[1], mlirtens[2], + reduction, ignore_index); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::nll_loss_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, const at::Tensor &weight, + int64_t reduction, int64_t ignore_index, const at::Tensor &total_weight) { + std::cout << "aten::nll_loss_backward_out" << std::endl; + std::vector mlirtens_tensors = { + grad_input, grad_output, self, target, weight, total_weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + reduction, ignore_index, mlirtens[5]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::nll_loss_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, const at::Tensor &weight, int64_t reduction, + int64_t ignore_index, const at::Tensor &total_weight) { + std::cout << "aten::nll_loss_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target, weight, + total_weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::nll_loss_backward(mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], + reduction, ignore_index, mlirtens[4]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::nll_loss2d_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, int64_t ignore_index) { + std::cout << "aten::nll_loss2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self, target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss2d_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], reduction, ignore_index); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::nll_loss2d(const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, + int64_t reduction, + int64_t ignore_index) { + std::cout << "aten::nll_loss2d" << std::endl; + std::vector mlirtens_tensors = {self, target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss2d(mlirtens[0], mlirtens[1], mlirtens[2], + reduction, ignore_index); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::nll_loss2d_forward_out( + at::Tensor &output, at::Tensor &total_weight, const at::Tensor &self, + const at::Tensor &target, const at::Tensor &weight, int64_t reduction, + int64_t ignore_index) { + std::cout << "aten::nll_loss2d_forward_out" << std::endl; + std::vector mlirtens_tensors = {output, total_weight, self, + target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss2d_forward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + reduction, ignore_index); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(output, total_weight); +} + +std::tuple ATenMLIRTypeDefault::nll_loss2d_forward( + const at::Tensor &self, const at::Tensor &target, const at::Tensor &weight, + int64_t reduction, int64_t ignore_index) { + std::cout << "aten::nll_loss2d_forward" << std::endl; + std::vector mlirtens_tensors = {self, target, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss2d_forward( + mlirtens[0], mlirtens[1], mlirtens[2], reduction, ignore_index); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::nll_loss2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, const at::Tensor &weight, + int64_t reduction, int64_t ignore_index, const at::Tensor &total_weight) { + std::cout << "aten::nll_loss2d_backward_out" << std::endl; + std::vector mlirtens_tensors = { + grad_input, grad_output, self, target, weight, total_weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss2d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + reduction, ignore_index, mlirtens[5]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::nll_loss2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, const at::Tensor &weight, int64_t reduction, + int64_t ignore_index, const at::Tensor &total_weight) { + std::cout << "aten::nll_loss2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target, weight, + total_weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::nll_loss2d_backward(mlirtens[0], mlirtens[1], + mlirtens[2], mlirtens[3], reduction, + ignore_index, mlirtens[4]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::smooth_l1_loss_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::smooth_l1_loss_out" << std::endl; + std::vector mlirtens_tensors = {out, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::smooth_l1_loss_out(mlirtens[0], mlirtens[1], mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::smooth_l1_loss(const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::smooth_l1_loss" << std::endl; + std::vector mlirtens_tensors = {self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::smooth_l1_loss(mlirtens[0], mlirtens[1], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::smooth_l1_loss_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, int64_t reduction) { + std::cout << "aten::smooth_l1_loss_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::smooth_l1_loss_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], reduction); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::smooth_l1_loss_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, int64_t reduction) { + std::cout << "aten::smooth_l1_loss_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::smooth_l1_loss_backward(mlirtens[0], mlirtens[1], + mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::soft_margin_loss_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::soft_margin_loss_out" << std::endl; + std::vector mlirtens_tensors = {out, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::soft_margin_loss_out(mlirtens[0], mlirtens[1], + mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::soft_margin_loss(const at::Tensor &self, + const at::Tensor &target, + int64_t reduction) { + std::cout << "aten::soft_margin_loss" << std::endl; + std::vector mlirtens_tensors = {self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::soft_margin_loss(mlirtens[0], mlirtens[1], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::soft_margin_loss_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, int64_t reduction) { + std::cout << "aten::soft_margin_loss_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::soft_margin_loss_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], reduction); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::soft_margin_loss_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, int64_t reduction) { + std::cout << "aten::soft_margin_loss_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, target}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::soft_margin_loss_backward(mlirtens[0], mlirtens[1], + mlirtens[2], reduction); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::elu_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar alpha, at::Scalar scale, + at::Scalar input_scale) { + std::cout << "aten::elu_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::elu_out(mlirtens[0], mlirtens[1], alpha, scale, input_scale); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::elu(const at::Tensor &self, at::Scalar alpha, + at::Scalar scale, at::Scalar input_scale) { + std::cout << "aten::elu" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::elu(mlirtens[0], alpha, scale, input_scale); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::elu_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, at::Scalar alpha, + at::Scalar scale, at::Scalar input_scale, const at::Tensor &output) { + std::cout << "aten::elu_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::elu_backward_out(mlirtens[0], mlirtens[1], alpha, scale, + input_scale, mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::elu_backward(const at::Tensor &grad_output, + at::Scalar alpha, at::Scalar scale, + at::Scalar input_scale, + const at::Tensor &output) { + std::cout << "aten::elu_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::elu_backward(mlirtens[0], alpha, scale, input_scale, mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::elu_(at::Tensor &self, at::Scalar alpha, + at::Scalar scale, + at::Scalar input_scale) { + std::cout << "aten::elu_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::elu_(mlirtens[0], alpha, scale, input_scale); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::glu_out(at::Tensor &out, + const at::Tensor &self, int64_t dim) { + std::cout << "aten::glu_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::glu_out(mlirtens[0], mlirtens[1], dim); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::glu(const at::Tensor &self, int64_t dim) { + std::cout << "aten::glu" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::glu(mlirtens[0], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::glu_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, + int64_t dim) { + std::cout << "aten::glu_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::glu_backward_out(mlirtens[0], mlirtens[1], mlirtens[2], dim); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::glu_backward(const at::Tensor &grad_output, + const at::Tensor &self, + int64_t dim) { + std::cout << "aten::glu_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::glu_backward(mlirtens[0], mlirtens[1], dim); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::hardtanh_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar min_val, + at::Scalar max_val) { + std::cout << "aten::hardtanh_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::hardtanh_out(mlirtens[0], mlirtens[1], min_val, max_val); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::hardtanh(const at::Tensor &self, + at::Scalar min_val, + at::Scalar max_val) { + std::cout << "aten::hardtanh" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hardtanh(mlirtens[0], min_val, max_val); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::hardtanh_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::Scalar min_val, at::Scalar max_val) { + std::cout << "aten::hardtanh_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hardtanh_backward_out(mlirtens[0], mlirtens[1], + mlirtens[2], min_val, max_val); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::hardtanh_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar min_val, + at::Scalar max_val) { + std::cout << "aten::hardtanh_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::hardtanh_backward(mlirtens[0], mlirtens[1], min_val, max_val); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::hardtanh_(at::Tensor &self, at::Scalar min_val, + at::Scalar max_val) { + std::cout << "aten::hardtanh_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::hardtanh_(mlirtens[0], min_val, max_val); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::leaky_relu_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar negative_slope) { + std::cout << "aten::leaky_relu_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::leaky_relu_out(mlirtens[0], mlirtens[1], negative_slope); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::leaky_relu(const at::Tensor &self, + at::Scalar negative_slope) { + std::cout << "aten::leaky_relu" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::leaky_relu(mlirtens[0], negative_slope); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::leaky_relu_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::Scalar negative_slope) { + std::cout << "aten::leaky_relu_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::leaky_relu_backward_out(mlirtens[0], mlirtens[1], + mlirtens[2], negative_slope); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::leaky_relu_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar negative_slope) { + std::cout << "aten::leaky_relu_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::leaky_relu_backward(mlirtens[0], mlirtens[1], negative_slope); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::leaky_relu_(at::Tensor &self, + at::Scalar negative_slope) { + std::cout << "aten::leaky_relu_" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::leaky_relu_(mlirtens[0], negative_slope); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::log_sigmoid_out(at::Tensor &out, + const at::Tensor &self) { + std::cout << "aten::log_sigmoid_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log_sigmoid_out(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::log_sigmoid(const at::Tensor &self) { + std::cout << "aten::log_sigmoid" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log_sigmoid(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::log_sigmoid_forward_out(at::Tensor &output, + at::Tensor &buffer, + const at::Tensor &self) { + std::cout << "aten::log_sigmoid_forward_out" << std::endl; + std::vector mlirtens_tensors = {output, buffer, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::log_sigmoid_forward_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(output, buffer); +} + +std::tuple +ATenMLIRTypeDefault::log_sigmoid_forward(const at::Tensor &self) { + std::cout << "aten::log_sigmoid_forward" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log_sigmoid_forward(mlirtens[0]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::log_sigmoid_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &buffer) { + std::cout << "aten::log_sigmoid_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + buffer}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::log_sigmoid_backward_out(mlirtens[0], mlirtens[1], + mlirtens[2], mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::log_sigmoid_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &buffer) { + std::cout << "aten::log_sigmoid_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, buffer}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::log_sigmoid_backward(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::rrelu_with_noise_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &noise, + at::Scalar lower, at::Scalar upper, bool training, + at::Generator *generator) { + std::cout << "aten::rrelu_with_noise_out" << std::endl; + std::vector mlirtens_tensors = {out, self, noise}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rrelu_with_noise_out( + mlirtens[0], mlirtens[1], mlirtens[2], lower, upper, training, generator); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::rrelu_with_noise( + const at::Tensor &self, const at::Tensor &noise, at::Scalar lower, + at::Scalar upper, bool training, at::Generator *generator) { + std::cout << "aten::rrelu_with_noise" << std::endl; + std::vector mlirtens_tensors = {self, noise}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rrelu_with_noise(mlirtens[0], mlirtens[1], lower, upper, + training, generator); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::rrelu_with_noise_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &noise, at::Scalar lower, + at::Scalar upper, bool training) { + std::cout << "aten::rrelu_with_noise_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + noise}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::rrelu_with_noise_backward_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], lower, upper, training); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::rrelu_with_noise_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &noise, at::Scalar lower, at::Scalar upper, + bool training) { + std::cout << "aten::rrelu_with_noise_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, noise}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rrelu_with_noise_backward( + mlirtens[0], mlirtens[1], mlirtens[2], lower, upper, training); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::rrelu_with_noise_( + at::Tensor &self, const at::Tensor &noise, at::Scalar lower, + at::Scalar upper, bool training, at::Generator *generator) { + std::cout << "aten::rrelu_with_noise_" << std::endl; + std::vector mlirtens_tensors = {self, noise}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::rrelu_with_noise_(mlirtens[0], mlirtens[1], lower, + upper, training, generator); + static_cast(x_result); // Avoid warnings in case not used + return self; +} + +at::Tensor &ATenMLIRTypeDefault::softplus_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar beta, + at::Scalar threshold) { + std::cout << "aten::softplus_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::softplus_out(mlirtens[0], mlirtens[1], beta, threshold); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::softplus(const at::Tensor &self, + at::Scalar beta, + at::Scalar threshold) { + std::cout << "aten::softplus" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::softplus(mlirtens[0], beta, threshold); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::softplus_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::Scalar beta, at::Scalar threshold, + const at::Tensor &output) { + std::cout << "aten::softplus_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::softplus_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], beta, threshold, mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::softplus_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar beta, + at::Scalar threshold, + const at::Tensor &output) { + std::cout << "aten::softplus_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::softplus_backward(mlirtens[0], mlirtens[1], beta, + threshold, mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::softshrink_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar lambd) { + std::cout << "aten::softshrink_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::softshrink_out(mlirtens[0], mlirtens[1], lambd); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::softshrink(const at::Tensor &self, + at::Scalar lambd) { + std::cout << "aten::softshrink" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::softshrink(mlirtens[0], lambd); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::softshrink_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::Scalar lambd) { + std::cout << "aten::softshrink_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::softshrink_backward_out(mlirtens[0], mlirtens[1], mlirtens[2], lambd); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::softshrink_backward( + const at::Tensor &grad_output, const at::Tensor &self, at::Scalar lambd) { + std::cout << "aten::softshrink_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::softshrink_backward(mlirtens[0], mlirtens[1], lambd); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::adaptive_avg_pool2d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef output_size) { + std::cout << "aten::adaptive_avg_pool2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::adaptive_avg_pool2d_out(mlirtens[0], mlirtens[1], output_size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor +ATenMLIRTypeDefault::adaptive_avg_pool2d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::adaptive_avg_pool2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_avg_pool2d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor +ATenMLIRTypeDefault::mkldnn_adaptive_avg_pool2d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::mkldnn_adaptive_avg_pool2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::mkldnn_adaptive_avg_pool2d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor +ATenMLIRTypeDefault::_adaptive_avg_pool2d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::_adaptive_avg_pool2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_adaptive_avg_pool2d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor ATenMLIRTypeDefault::_adaptive_avg_pool2d_backward( + const at::Tensor &grad_output, const at::Tensor &self) { + std::cout << "aten::_adaptive_avg_pool2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::_adaptive_avg_pool2d_backward(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::adaptive_avg_pool3d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef output_size) { + std::cout << "aten::adaptive_avg_pool3d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::adaptive_avg_pool3d_out(mlirtens[0], mlirtens[1], output_size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor +ATenMLIRTypeDefault::adaptive_avg_pool3d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::adaptive_avg_pool3d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_avg_pool3d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::adaptive_avg_pool3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self) { + std::cout << "aten::adaptive_avg_pool3d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_avg_pool3d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::adaptive_avg_pool3d_backward(const at::Tensor &grad_output, + const at::Tensor &self) { + std::cout << "aten::adaptive_avg_pool3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_avg_pool3d_backward(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::adaptive_max_pool2d_out(at::Tensor &out, + at::Tensor &indices, + const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::adaptive_max_pool2d_out" << std::endl; + std::vector mlirtens_tensors = {out, indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_max_pool2d_out(mlirtens[0], mlirtens[1], + mlirtens[2], output_size); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(out, indices); +} + +std::tuple +ATenMLIRTypeDefault::adaptive_max_pool2d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::adaptive_max_pool2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_max_pool2d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::adaptive_max_pool2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &indices) { + std::cout << "aten::adaptive_max_pool2d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_max_pool2d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::adaptive_max_pool2d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &indices) { + std::cout << "aten::adaptive_max_pool2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::adaptive_max_pool2d_backward(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::adaptive_max_pool3d_out(at::Tensor &out, + at::Tensor &indices, + const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::adaptive_max_pool3d_out" << std::endl; + std::vector mlirtens_tensors = {out, indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_max_pool3d_out(mlirtens[0], mlirtens[1], + mlirtens[2], output_size); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(out, indices); +} + +std::tuple +ATenMLIRTypeDefault::adaptive_max_pool3d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::adaptive_max_pool3d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_max_pool3d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::adaptive_max_pool3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &indices) { + std::cout << "aten::adaptive_max_pool3d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::adaptive_max_pool3d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::adaptive_max_pool3d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &indices) { + std::cout << "aten::adaptive_max_pool3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::adaptive_max_pool3d_backward(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::avg_pool2d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, + bool count_include_pad, c10::optional divisor_override) { + std::cout << "aten::avg_pool2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::avg_pool2d_out(mlirtens[0], mlirtens[1], kernel_size, stride, padding, + ceil_mode, count_include_pad, divisor_override); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::avg_pool2d( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, + c10::optional divisor_override) { + std::cout << "aten::avg_pool2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::avg_pool2d(mlirtens[0], kernel_size, stride, padding, ceil_mode, + count_include_pad, divisor_override); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::avg_pool2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, + c10::optional divisor_override) { + std::cout << "aten::avg_pool2d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::avg_pool2d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + ceil_mode, count_include_pad, divisor_override); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::avg_pool2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, + c10::optional divisor_override) { + std::cout << "aten::avg_pool2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::avg_pool2d_backward( + mlirtens[0], mlirtens[1], kernel_size, stride, padding, ceil_mode, + count_include_pad, divisor_override); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::avg_pool3d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, + bool count_include_pad, c10::optional divisor_override) { + std::cout << "aten::avg_pool3d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::avg_pool3d_out(mlirtens[0], mlirtens[1], kernel_size, stride, padding, + ceil_mode, count_include_pad, divisor_override); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::avg_pool3d( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, + c10::optional divisor_override) { + std::cout << "aten::avg_pool3d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::avg_pool3d(mlirtens[0], kernel_size, stride, padding, ceil_mode, + count_include_pad, divisor_override); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::avg_pool3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, + c10::optional divisor_override) { + std::cout << "aten::avg_pool3d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::avg_pool3d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + ceil_mode, count_include_pad, divisor_override); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::avg_pool3d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, + c10::optional divisor_override) { + std::cout << "aten::avg_pool3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::avg_pool3d_backward( + mlirtens[0], mlirtens[1], kernel_size, stride, padding, ceil_mode, + count_include_pad, divisor_override); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::fractional_max_pool2d_out( + at::Tensor &output, at::Tensor &indices, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef output_size, + const at::Tensor &random_samples) { + std::cout << "aten::fractional_max_pool2d_out" << std::endl; + std::vector mlirtens_tensors = {output, indices, self, + random_samples}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::fractional_max_pool2d_out(mlirtens[0], mlirtens[1], mlirtens[2], + kernel_size, output_size, mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(output, indices); +} + +std::tuple ATenMLIRTypeDefault::fractional_max_pool2d( + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef output_size, const at::Tensor &random_samples) { + std::cout << "aten::fractional_max_pool2d" << std::endl; + std::vector mlirtens_tensors = {self, random_samples}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fractional_max_pool2d(mlirtens[0], kernel_size, + output_size, mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::fractional_max_pool2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef output_size, const at::Tensor &indices) { + std::cout << "aten::fractional_max_pool2d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fractional_max_pool2d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, output_size, + mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::fractional_max_pool2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef output_size, + const at::Tensor &indices) { + std::cout << "aten::fractional_max_pool2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fractional_max_pool2d_backward( + mlirtens[0], mlirtens[1], kernel_size, output_size, mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::fractional_max_pool3d_out( + at::Tensor &output, at::Tensor &indices, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef output_size, + const at::Tensor &random_samples) { + std::cout << "aten::fractional_max_pool3d_out" << std::endl; + std::vector mlirtens_tensors = {output, indices, self, + random_samples}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::fractional_max_pool3d_out(mlirtens[0], mlirtens[1], mlirtens[2], + kernel_size, output_size, mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(output, indices); +} + +std::tuple ATenMLIRTypeDefault::fractional_max_pool3d( + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef output_size, const at::Tensor &random_samples) { + std::cout << "aten::fractional_max_pool3d" << std::endl; + std::vector mlirtens_tensors = {self, random_samples}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fractional_max_pool3d(mlirtens[0], kernel_size, + output_size, mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::fractional_max_pool3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef output_size, const at::Tensor &indices) { + std::cout << "aten::fractional_max_pool3d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fractional_max_pool3d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, output_size, + mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::fractional_max_pool3d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef output_size, + const at::Tensor &indices) { + std::cout << "aten::fractional_max_pool3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::fractional_max_pool3d_backward( + mlirtens[0], mlirtens[1], kernel_size, output_size, mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::max_pool2d_with_indices_out( + at::Tensor &out, at::Tensor &indices, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + std::cout << "aten::max_pool2d_with_indices_out" << std::endl; + std::vector mlirtens_tensors = {out, indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool2d_with_indices_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + dilation, ceil_mode); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(out, indices); +} + +std::tuple ATenMLIRTypeDefault::max_pool2d_with_indices( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + std::cout << "aten::max_pool2d_with_indices" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool2d_with_indices( + mlirtens[0], kernel_size, stride, padding, dilation, ceil_mode); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::max_pool2d_with_indices_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const at::Tensor &indices) { + std::cout << "aten::max_pool2d_with_indices_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool2d_with_indices_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + dilation, ceil_mode, mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::max_pool2d_with_indices_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const at::Tensor &indices) { + std::cout << "aten::max_pool2d_with_indices_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool2d_with_indices_backward( + mlirtens[0], mlirtens[1], kernel_size, stride, padding, dilation, + ceil_mode, mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +std::tuple +ATenMLIRTypeDefault::max_pool3d_with_indices_out( + at::Tensor &out, at::Tensor &indices, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + std::cout << "aten::max_pool3d_with_indices_out" << std::endl; + std::vector mlirtens_tensors = {out, indices, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool3d_with_indices_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + dilation, ceil_mode); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(out, indices); +} + +std::tuple ATenMLIRTypeDefault::max_pool3d_with_indices( + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + std::cout << "aten::max_pool3d_with_indices" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool3d_with_indices( + mlirtens[0], kernel_size, stride, padding, dilation, ceil_mode); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self))); +} + +at::Tensor &ATenMLIRTypeDefault::max_pool3d_with_indices_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const at::Tensor &indices) { + std::cout << "aten::max_pool3d_with_indices_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool3d_with_indices_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + dilation, ceil_mode, mlirtens[3]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::max_pool3d_with_indices_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const at::Tensor &indices) { + std::cout << "aten::max_pool3d_with_indices_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_pool3d_with_indices_backward( + mlirtens[0], mlirtens[1], kernel_size, stride, padding, dilation, + ceil_mode, mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::max_unpool2d_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size) { + std::cout << "aten::max_unpool2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::max_unpool2d_out(mlirtens[0], mlirtens[1], mlirtens[2], output_size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::max_unpool2d(const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size) { + std::cout << "aten::max_unpool2d" << std::endl; + std::vector mlirtens_tensors = {self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_unpool2d(mlirtens[0], mlirtens[1], output_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::max_unpool2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &indices, + at::IntArrayRef output_size) { + std::cout << "aten::max_unpool2d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_unpool2d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], output_size); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::max_unpool2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &indices, at::IntArrayRef output_size) { + std::cout << "aten::max_unpool2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_unpool2d_backward(mlirtens[0], mlirtens[1], + mlirtens[2], output_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::max_unpool3d_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size, + at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::max_unpool3d_out" << std::endl; + std::vector mlirtens_tensors = {out, self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_unpool3d_out(mlirtens[0], mlirtens[1], mlirtens[2], + output_size, stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::max_unpool3d(const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size, + at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::max_unpool3d" << std::endl; + std::vector mlirtens_tensors = {self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::max_unpool3d(mlirtens[0], mlirtens[1], output_size, stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::max_unpool3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &indices, + at::IntArrayRef output_size, at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::max_unpool3d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self, + indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::max_unpool3d_backward_out(mlirtens[0], mlirtens[1], mlirtens[2], + mlirtens[3], output_size, stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::max_unpool3d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &indices, at::IntArrayRef output_size, + at::IntArrayRef stride, at::IntArrayRef padding) { + std::cout << "aten::max_unpool3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, indices}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::max_unpool3d_backward( + mlirtens[0], mlirtens[1], mlirtens[2], output_size, stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::reflection_pad1d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::reflection_pad1d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::reflection_pad1d_out(mlirtens[0], mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::reflection_pad1d(const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::reflection_pad1d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::reflection_pad1d(mlirtens[0], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::reflection_pad1d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding) { + std::cout << "aten::reflection_pad1d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::reflection_pad1d_backward_out(mlirtens[0], mlirtens[1], + mlirtens[2], padding); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::reflection_pad1d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::reflection_pad1d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::reflection_pad1d_backward(mlirtens[0], mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::reflection_pad2d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::reflection_pad2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::reflection_pad2d_out(mlirtens[0], mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::reflection_pad2d(const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::reflection_pad2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::reflection_pad2d(mlirtens[0], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::reflection_pad2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding) { + std::cout << "aten::reflection_pad2d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::reflection_pad2d_backward_out(mlirtens[0], mlirtens[1], + mlirtens[2], padding); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::reflection_pad2d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::reflection_pad2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::reflection_pad2d_backward(mlirtens[0], mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::replication_pad1d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef padding) { + std::cout << "aten::replication_pad1d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::replication_pad1d_out(mlirtens[0], mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::replication_pad1d(const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::replication_pad1d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::replication_pad1d(mlirtens[0], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::replication_pad1d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding) { + std::cout << "aten::replication_pad1d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::replication_pad1d_backward_out(mlirtens[0], mlirtens[1], + mlirtens[2], padding); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::replication_pad1d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::replication_pad1d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::replication_pad1d_backward(mlirtens[0], mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::replication_pad2d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef padding) { + std::cout << "aten::replication_pad2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::replication_pad2d_out(mlirtens[0], mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::replication_pad2d(const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::replication_pad2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::replication_pad2d(mlirtens[0], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::replication_pad2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding) { + std::cout << "aten::replication_pad2d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::replication_pad2d_backward_out(mlirtens[0], mlirtens[1], + mlirtens[2], padding); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::replication_pad2d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::replication_pad2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::replication_pad2d_backward(mlirtens[0], mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::replication_pad3d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef padding) { + std::cout << "aten::replication_pad3d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::replication_pad3d_out(mlirtens[0], mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::replication_pad3d(const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::replication_pad3d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::replication_pad3d(mlirtens[0], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::replication_pad3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding) { + std::cout << "aten::replication_pad3d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::replication_pad3d_backward_out(mlirtens[0], mlirtens[1], + mlirtens[2], padding); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::replication_pad3d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::IntArrayRef padding) { + std::cout << "aten::replication_pad3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::replication_pad3d_backward(mlirtens[0], mlirtens[1], padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_linear1d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef output_size, + bool align_corners) { + std::cout << "aten::upsample_linear1d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_linear1d_out(mlirtens[0], mlirtens[1], + output_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::upsample_linear1d(const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners) { + std::cout << "aten::upsample_linear1d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::upsample_linear1d(mlirtens[0], output_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_linear1d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size, + bool align_corners) { + std::cout << "aten::upsample_linear1d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_linear1d_backward_out( + mlirtens[0], mlirtens[1], output_size, input_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::upsample_linear1d_backward( + const at::Tensor &grad_output, at::IntArrayRef output_size, + at::IntArrayRef input_size, bool align_corners) { + std::cout << "aten::upsample_linear1d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_linear1d_backward(mlirtens[0], output_size, + input_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_bilinear2d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef output_size, + bool align_corners) { + std::cout << "aten::upsample_bilinear2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_bilinear2d_out(mlirtens[0], mlirtens[1], + output_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::upsample_bilinear2d(const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners) { + std::cout << "aten::upsample_bilinear2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::upsample_bilinear2d(mlirtens[0], output_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_bilinear2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size, + bool align_corners) { + std::cout << "aten::upsample_bilinear2d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_bilinear2d_backward_out( + mlirtens[0], mlirtens[1], output_size, input_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::upsample_bilinear2d_backward( + const at::Tensor &grad_output, at::IntArrayRef output_size, + at::IntArrayRef input_size, bool align_corners) { + std::cout << "aten::upsample_bilinear2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_bilinear2d_backward(mlirtens[0], output_size, + input_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_bicubic2d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef output_size, + bool align_corners) { + std::cout << "aten::upsample_bicubic2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_bicubic2d_out(mlirtens[0], mlirtens[1], + output_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::upsample_bicubic2d(const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners) { + std::cout << "aten::upsample_bicubic2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::upsample_bicubic2d(mlirtens[0], output_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_bicubic2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size, + bool align_corners) { + std::cout << "aten::upsample_bicubic2d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_bicubic2d_backward_out( + mlirtens[0], mlirtens[1], output_size, input_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::upsample_bicubic2d_backward( + const at::Tensor &grad_output, at::IntArrayRef output_size, + at::IntArrayRef input_size, bool align_corners) { + std::cout << "aten::upsample_bicubic2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_bicubic2d_backward(mlirtens[0], output_size, + input_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_trilinear3d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef output_size, + bool align_corners) { + std::cout << "aten::upsample_trilinear3d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_trilinear3d_out(mlirtens[0], mlirtens[1], + output_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::upsample_trilinear3d( + const at::Tensor &self, at::IntArrayRef output_size, bool align_corners) { + std::cout << "aten::upsample_trilinear3d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::upsample_trilinear3d(mlirtens[0], output_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_trilinear3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size, + bool align_corners) { + std::cout << "aten::upsample_trilinear3d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_trilinear3d_backward_out( + mlirtens[0], mlirtens[1], output_size, input_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::upsample_trilinear3d_backward( + const at::Tensor &grad_output, at::IntArrayRef output_size, + at::IntArrayRef input_size, bool align_corners) { + std::cout << "aten::upsample_trilinear3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_trilinear3d_backward( + mlirtens[0], output_size, input_size, align_corners); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_nearest1d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef output_size) { + std::cout << "aten::upsample_nearest1d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::upsample_nearest1d_out(mlirtens[0], mlirtens[1], output_size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor +ATenMLIRTypeDefault::upsample_nearest1d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::upsample_nearest1d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_nearest1d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_nearest1d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size) { + std::cout << "aten::upsample_nearest1d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_nearest1d_backward_out( + mlirtens[0], mlirtens[1], output_size, input_size); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::upsample_nearest1d_backward(const at::Tensor &grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size) { + std::cout << "aten::upsample_nearest1d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::upsample_nearest1d_backward(mlirtens[0], output_size, input_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_nearest2d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef output_size) { + std::cout << "aten::upsample_nearest2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::upsample_nearest2d_out(mlirtens[0], mlirtens[1], output_size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor +ATenMLIRTypeDefault::upsample_nearest2d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::upsample_nearest2d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_nearest2d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_nearest2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size) { + std::cout << "aten::upsample_nearest2d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_nearest2d_backward_out( + mlirtens[0], mlirtens[1], output_size, input_size); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::upsample_nearest2d_backward(const at::Tensor &grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size) { + std::cout << "aten::upsample_nearest2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::upsample_nearest2d_backward(mlirtens[0], output_size, input_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_nearest3d_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef output_size) { + std::cout << "aten::upsample_nearest3d_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::upsample_nearest3d_out(mlirtens[0], mlirtens[1], output_size); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor +ATenMLIRTypeDefault::upsample_nearest3d(const at::Tensor &self, + at::IntArrayRef output_size) { + std::cout << "aten::upsample_nearest3d" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_nearest3d(mlirtens[0], output_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::upsample_nearest3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size) { + std::cout << "aten::upsample_nearest3d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::upsample_nearest3d_backward_out( + mlirtens[0], mlirtens[1], output_size, input_size); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor +ATenMLIRTypeDefault::upsample_nearest3d_backward(const at::Tensor &grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size) { + std::cout << "aten::upsample_nearest3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::upsample_nearest3d_backward(mlirtens[0], output_size, input_size); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor & +ATenMLIRTypeDefault::sigmoid_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &output) { + std::cout << "aten::sigmoid_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::sigmoid_backward_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::sigmoid_backward(const at::Tensor &grad_output, + const at::Tensor &output) { + std::cout << "aten::sigmoid_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::sigmoid_backward(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor & +ATenMLIRTypeDefault::tanh_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &output) { + std::cout << "aten::tanh_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output, output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::tanh_backward_out(mlirtens[0], mlirtens[1], mlirtens[2]); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::tanh_backward(const at::Tensor &grad_output, + const at::Tensor &output) { + std::cout << "aten::tanh_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::tanh_backward(mlirtens[0], mlirtens[1]); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::slow_conv_transpose2d_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef output_padding, + at::IntArrayRef dilation) { + std::cout << "aten::slow_conv_transpose2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slow_conv_transpose2d_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, mlirtens[3], stride, + padding, output_padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::slow_conv_transpose2d( + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef output_padding, + at::IntArrayRef dilation) { + std::cout << "aten::slow_conv_transpose2d" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slow_conv_transpose2d( + mlirtens[0], mlirtens[1], kernel_size, mlirtens[2], stride, padding, + output_padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::slow_conv_transpose2d_backward_out( + at::Tensor &grad_input, at::Tensor &grad_weight, at::Tensor &grad_bias, + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef dilation, + const at::Tensor &columns, const at::Tensor &ones) { + std::cout << "aten::slow_conv_transpose2d_backward_out" << std::endl; + std::vector mlirtens_tensors = { + grad_input, grad_weight, grad_bias, grad_output, + self, weight, columns, ones}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slow_conv_transpose2d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], kernel_size, stride, padding, output_padding, dilation, + mlirtens[6], mlirtens[7]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + grad_input, grad_weight, grad_bias); +} + +std::tuple +ATenMLIRTypeDefault::slow_conv_transpose2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef dilation, + const at::Tensor &columns, const at::Tensor &ones, + std::array output_mask) { + std::cout << "aten::slow_conv_transpose2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, weight, + columns, ones}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slow_conv_transpose2d_backward( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + output_padding, dilation, mlirtens[3], mlirtens[4], output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor &ATenMLIRTypeDefault::slow_conv_transpose3d_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef output_padding, + at::IntArrayRef dilation) { + std::cout << "aten::slow_conv_transpose3d_out" << std::endl; + std::vector mlirtens_tensors = {out, self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slow_conv_transpose3d_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, mlirtens[3], stride, + padding, output_padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::slow_conv_transpose3d( + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef output_padding, + at::IntArrayRef dilation) { + std::cout << "aten::slow_conv_transpose3d" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slow_conv_transpose3d( + mlirtens[0], mlirtens[1], kernel_size, mlirtens[2], stride, padding, + output_padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::slow_conv_transpose3d_backward_out( + at::Tensor &grad_input, at::Tensor &grad_weight, at::Tensor &grad_bias, + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef dilation, + const at::Tensor &finput, const at::Tensor &fgrad_input) { + std::cout << "aten::slow_conv_transpose3d_backward_out" << std::endl; + std::vector mlirtens_tensors = { + grad_input, grad_weight, grad_bias, grad_output, + self, weight, finput, fgrad_input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slow_conv_transpose3d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], kernel_size, stride, padding, output_padding, dilation, + mlirtens[6], mlirtens[7]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + grad_input, grad_weight, grad_bias); +} + +std::tuple +ATenMLIRTypeDefault::slow_conv_transpose3d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef dilation, + const at::Tensor &finput, const at::Tensor &fgrad_input, + std::array output_mask) { + std::cout << "aten::slow_conv_transpose3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, weight, finput, + fgrad_input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slow_conv_transpose3d_backward( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + output_padding, dilation, mlirtens[3], mlirtens[4], output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor &ATenMLIRTypeDefault::thnn_conv2d_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::thnn_conv2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::thnn_conv2d_out(mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, + mlirtens[3], stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::thnn_conv2d(const at::Tensor &self, + const at::Tensor &weight, + at::IntArrayRef kernel_size, + const at::Tensor &bias, + at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::thnn_conv2d" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv2d(mlirtens[0], mlirtens[1], kernel_size, + mlirtens[2], stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::thnn_conv2d_forward_out( + at::Tensor &output, at::Tensor &finput, at::Tensor &fgrad_input, + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::thnn_conv2d_forward_out" << std::endl; + std::vector mlirtens_tensors = {output, finput, fgrad_input, + self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv2d_forward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + kernel_size, mlirtens[5], stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(output, finput, + fgrad_input); +} + +std::tuple +ATenMLIRTypeDefault::thnn_conv2d_forward(const at::Tensor &self, + const at::Tensor &weight, + at::IntArrayRef kernel_size, + const at::Tensor &bias, + at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::thnn_conv2d_forward" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv2d_forward( + mlirtens[0], mlirtens[1], kernel_size, mlirtens[2], stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::thnn_conv2d_backward_out( + at::Tensor &grad_input, at::Tensor &grad_weight, at::Tensor &grad_bias, + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, const at::Tensor &finput, + const at::Tensor &fgrad_input) { + std::cout << "aten::thnn_conv2d_backward_out" << std::endl; + std::vector mlirtens_tensors = { + grad_input, grad_weight, grad_bias, grad_output, + self, weight, finput, fgrad_input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv2d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], kernel_size, stride, padding, mlirtens[6], mlirtens[7]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + grad_input, grad_weight, grad_bias); +} + +std::tuple +ATenMLIRTypeDefault::thnn_conv2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, const at::Tensor &finput, + const at::Tensor &fgrad_input, std::array output_mask) { + std::cout << "aten::thnn_conv2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, weight, finput, + fgrad_input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv2d_backward( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + mlirtens[3], mlirtens[4], output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor &ATenMLIRTypeDefault::thnn_conv_depthwise2d_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation) { + std::cout << "aten::thnn_conv_depthwise2d_out" << std::endl; + std::vector mlirtens_tensors = {out, self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv_depthwise2d_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, mlirtens[3], stride, + padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::thnn_conv_depthwise2d( + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation) { + std::cout << "aten::thnn_conv_depthwise2d" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::thnn_conv_depthwise2d(mlirtens[0], mlirtens[1], kernel_size, + mlirtens[2], stride, padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::thnn_conv_depthwise2d_forward_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation) { + std::cout << "aten::thnn_conv_depthwise2d_forward_out" << std::endl; + std::vector mlirtens_tensors = {out, self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv_depthwise2d_forward_out( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, mlirtens[3], stride, + padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::thnn_conv_depthwise2d_forward( + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation) { + std::cout << "aten::thnn_conv_depthwise2d_forward" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::thnn_conv_depthwise2d_forward(mlirtens[0], mlirtens[1], kernel_size, + mlirtens[2], stride, padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::thnn_conv_depthwise2d_backward_out( + at::Tensor &grad_input, at::Tensor &grad_weight, + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) { + std::cout << "aten::thnn_conv_depthwise2d_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_weight, + grad_output, self, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv_depthwise2d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + kernel_size, stride, padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(grad_input, grad_weight); +} + +std::tuple +ATenMLIRTypeDefault::thnn_conv_depthwise2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + std::array output_mask) { + std::cout << "aten::thnn_conv_depthwise2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv_depthwise2d_backward( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + dilation, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor &ATenMLIRTypeDefault::thnn_conv3d_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::thnn_conv3d_out" << std::endl; + std::vector mlirtens_tensors = {out, self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::thnn_conv3d_out(mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, + mlirtens[3], stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::thnn_conv3d(const at::Tensor &self, + const at::Tensor &weight, + at::IntArrayRef kernel_size, + const at::Tensor &bias, + at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::thnn_conv3d" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv3d(mlirtens[0], mlirtens[1], kernel_size, + mlirtens[2], stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::thnn_conv3d_forward_out( + at::Tensor &output, at::Tensor &finput, at::Tensor &fgrad_input, + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::thnn_conv3d_forward_out" << std::endl; + std::vector mlirtens_tensors = {output, finput, fgrad_input, + self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv3d_forward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + kernel_size, mlirtens[5], stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple(output, finput, + fgrad_input); +} + +std::tuple +ATenMLIRTypeDefault::thnn_conv3d_forward(const at::Tensor &self, + const at::Tensor &weight, + at::IntArrayRef kernel_size, + const at::Tensor &bias, + at::IntArrayRef stride, + at::IntArrayRef padding) { + std::cout << "aten::thnn_conv3d_forward" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv3d_forward( + mlirtens[0], mlirtens[1], kernel_size, mlirtens[2], stride, padding); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(self)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(self))); +} + +std::tuple +ATenMLIRTypeDefault::thnn_conv3d_backward_out( + at::Tensor &grad_input, at::Tensor &grad_weight, at::Tensor &grad_bias, + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, const at::Tensor &finput, + const at::Tensor &fgrad_input) { + std::cout << "aten::thnn_conv3d_backward_out" << std::endl; + std::vector mlirtens_tensors = { + grad_input, grad_weight, grad_bias, grad_output, + self, weight, finput, fgrad_input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv3d_backward_out( + mlirtens[0], mlirtens[1], mlirtens[2], mlirtens[3], mlirtens[4], + mlirtens[5], kernel_size, stride, padding, mlirtens[6], mlirtens[7]); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + grad_input, grad_weight, grad_bias); +} + +std::tuple +ATenMLIRTypeDefault::thnn_conv3d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, const at::Tensor &finput, + const at::Tensor &fgrad_input, std::array output_mask) { + std::cout << "aten::thnn_conv3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, weight, finput, + fgrad_input}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::thnn_conv3d_backward( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + mlirtens[3], mlirtens[4], output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor ATenMLIRTypeDefault::slow_conv_dilated2d( + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation) { + std::cout << "aten::slow_conv_dilated2d" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::slow_conv_dilated2d(mlirtens[0], mlirtens[1], kernel_size, + mlirtens[2], stride, padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::slow_conv_dilated2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + std::array output_mask) { + std::cout << "aten::slow_conv_dilated2d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slow_conv_dilated2d_backward( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + dilation, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor ATenMLIRTypeDefault::slow_conv_dilated3d( + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation) { + std::cout << "aten::slow_conv_dilated3d" << std::endl; + std::vector mlirtens_tensors = {self, weight, bias}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::slow_conv_dilated3d(mlirtens[0], mlirtens[1], kernel_size, + mlirtens[2], stride, padding, dilation); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +std::tuple +ATenMLIRTypeDefault::slow_conv_dilated3d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + std::array output_mask) { + std::cout << "aten::slow_conv_dilated3d_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output, self, weight}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::slow_conv_dilated3d_backward( + mlirtens[0], mlirtens[1], mlirtens[2], kernel_size, stride, padding, + dilation, output_mask); + static_cast(x_result); // Avoid warnings in case not used + return std::tuple( + bridge::CreateMLIRTensor(std::get<0>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<1>(x_result), + bridge::GetMLIRDevice(grad_output)), + bridge::CreateMLIRTensor(std::get<2>(x_result), + bridge::GetMLIRDevice(grad_output))); +} + +at::Tensor &ATenMLIRTypeDefault::col2im_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef output_size, + at::IntArrayRef kernel_size, at::IntArrayRef dilation, + at::IntArrayRef padding, at::IntArrayRef stride) { + std::cout << "aten::col2im_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::col2im_out(mlirtens[0], mlirtens[1], output_size, + kernel_size, dilation, padding, stride); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::col2im(const at::Tensor &self, + at::IntArrayRef output_size, + at::IntArrayRef kernel_size, + at::IntArrayRef dilation, + at::IntArrayRef padding, + at::IntArrayRef stride) { + std::cout << "aten::col2im" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::col2im(mlirtens[0], output_size, kernel_size, dilation, + padding, stride); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::col2im_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef kernel_size, at::IntArrayRef dilation, + at::IntArrayRef padding, at::IntArrayRef stride) { + std::cout << "aten::col2im_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::col2im_backward_out( + mlirtens[0], mlirtens[1], kernel_size, dilation, padding, stride); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::col2im_backward(const at::Tensor &grad_output, + at::IntArrayRef kernel_size, + at::IntArrayRef dilation, + at::IntArrayRef padding, + at::IntArrayRef stride) { + std::cout << "aten::col2im_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::col2im_backward(mlirtens[0], kernel_size, dilation, padding, stride); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +at::Tensor &ATenMLIRTypeDefault::im2col_out( + at::Tensor &out, const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + std::cout << "aten::im2col_out" << std::endl; + std::vector mlirtens_tensors = {out, self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::im2col_out(mlirtens[0], mlirtens[1], kernel_size, + dilation, padding, stride); + static_cast(x_result); // Avoid warnings in case not used + return out; +} + +at::Tensor ATenMLIRTypeDefault::im2col(const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef dilation, + at::IntArrayRef padding, + at::IntArrayRef stride) { + std::cout << "aten::im2col" << std::endl; + std::vector mlirtens_tensors = {self}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::im2col(mlirtens[0], kernel_size, dilation, padding, stride); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(self)); +} + +at::Tensor &ATenMLIRTypeDefault::im2col_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef input_size, at::IntArrayRef kernel_size, + at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + std::cout << "aten::im2col_backward_out" << std::endl; + std::vector mlirtens_tensors = {grad_input, grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = + at::im2col_backward_out(mlirtens[0], mlirtens[1], input_size, kernel_size, + dilation, padding, stride); + static_cast(x_result); // Avoid warnings in case not used + return grad_input; +} + +at::Tensor ATenMLIRTypeDefault::im2col_backward(const at::Tensor &grad_output, + at::IntArrayRef input_size, + at::IntArrayRef kernel_size, + at::IntArrayRef dilation, + at::IntArrayRef padding, + at::IntArrayRef stride) { + std::cout << "aten::im2col_backward" << std::endl; + std::vector mlirtens_tensors = {grad_output}; + auto mlirtens = bridge::MLIRCreateTensorList(mlirtens_tensors); + auto &&x_result = at::im2col_backward(mlirtens[0], input_size, kernel_size, + dilation, padding, stride); + static_cast(x_result); // Avoid warnings in case not used + return bridge::CreateMLIRTensor(x_result, bridge::GetMLIRDevice(grad_output)); +} + +void RegisterAtenTypeFunctions() { + static auto dispatch = + torch::RegisterOperators() + .op(torch::RegisterOperators::options() + .schema("aten::_cast_Byte(Tensor self, bool " + "non_blocking=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cast_Char(Tensor self, bool " + "non_blocking=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cast_Double(Tensor self, bool " + "non_blocking=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cast_Float(Tensor self, bool " + "non_blocking=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cast_Int(Tensor self, bool " + "non_blocking=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cast_Long(Tensor self, bool " + "non_blocking=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cast_Short(Tensor self, bool " + "non_blocking=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cast_Half(Tensor self, bool " + "non_blocking=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::backward(Tensor self, Tensor? gradient=None, bool " + "keep_graph=False, bool create_graph=False) -> void") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::set_data(Tensor(a!) self, Tensor new_data) -> " + "void") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::data(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_debug_has_internal_overlap(Tensor self) -> int") + .impl_unboxedOnlyKernel< + int64_t(const at::Tensor &), + &ATenMLIRTypeDefault::_debug_has_internal_overlap>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_fused_dropout(Tensor self, float p, " + "Generator? generator=None) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, double, + at::Generator *), + &ATenMLIRTypeDefault::_fused_dropout>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_masked_scale(Tensor self, Tensor mask, float " + "scale) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_sobol_engine_draw(Tensor quasi, int n, Tensor " + "sobolstate, int dimension, int num_generated, " + "ScalarType? dtype) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, + int64_t, const at::Tensor &, int64_t, int64_t, + c10::optional), + &ATenMLIRTypeDefault::_sobol_engine_draw>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sobol_engine_ff_(Tensor(a!) self, int n, " + "Tensor sobolstate, int dimension, int " + "num_generated) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, int64_t, const at::Tensor &, + int64_t, int64_t), + &ATenMLIRTypeDefault::_sobol_engine_ff_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sobol_engine_scramble_(Tensor(a!) self, " + "Tensor ltm, int dimension) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::_sobol_engine_scramble_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sobol_engine_initialize_state_(Tensor(a!) " + "self, int dimension) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, int64_t), + &ATenMLIRTypeDefault::_sobol_engine_initialize_state_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_reshape_from_tensor(Tensor self, Tensor " + "shape) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::_reshape_from_tensor>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_shape_as_tensor(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &), + &ATenMLIRTypeDefault::_shape_as_tensor>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::dropout(Tensor input, float p, bool train) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::dropout_(Tensor(a!) self, float p, bool " + "train) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::feature_dropout(Tensor input, float p, bool " + "train) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, double, bool), + &ATenMLIRTypeDefault::feature_dropout>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::feature_dropout_(Tensor(a!) self, float p, " + "bool train) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, double, bool), + &ATenMLIRTypeDefault::feature_dropout_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::alpha_dropout(Tensor input, float p, bool " + "train) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::alpha_dropout_(Tensor(a!) self, float p, bool " + "train) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::feature_alpha_dropout(Tensor input, float p, " + "bool train) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, double, bool), + &ATenMLIRTypeDefault::feature_alpha_dropout>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::feature_alpha_dropout_(Tensor(a!) self, float " + "p, bool train) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, double, bool), + &ATenMLIRTypeDefault::feature_alpha_dropout_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::abs(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::abs_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::abs.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::acos(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::acos_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::acos.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::avg_pool1d(Tensor self, int[1] kernel_size, " + "int[1] stride=[], int[1] padding=0, bool " + "ceil_mode=False, bool count_include_pad=True) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, bool), + &ATenMLIRTypeDefault::avg_pool1d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_avg_pool1d(Tensor self, int[1] " + "output_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::adaptive_avg_pool1d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_max_pool1d(Tensor self, int[1] " + "output_size) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::adaptive_max_pool1d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::add.Tensor(Tensor self, Tensor other, *, " + "Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::add_.Tensor(Tensor(a!) self, Tensor other, *, " + "Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::add.out(Tensor self, Tensor other, *, Scalar " + "alpha=1, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::add.Scalar(Tensor self, Scalar other, Scalar " + "alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::add>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::add_.Scalar(Tensor(a!) self, Scalar other, " + "Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::add_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addmv(Tensor self, Tensor mat, Tensor vec, *, " + "Scalar beta=1, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addmv>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, " + "*, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addmv_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addmv.out(Tensor self, Tensor mat, Tensor " + "vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addmv_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addr(Tensor self, Tensor vec1, Tensor vec2, " + "*, Scalar beta=1, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addr>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, " + "*, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addr_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addr.out(Tensor self, Tensor vec1, Tensor " + "vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addr_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::affine_grid_generator(Tensor theta, int[] " + "size, bool align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, bool), + &ATenMLIRTypeDefault::affine_grid_generator>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::affine_grid_generator_backward(Tensor grad, " + "int[] size, bool align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, bool), + &ATenMLIRTypeDefault::affine_grid_generator_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::all.dim(Tensor self, int dim, bool " + "keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, bool), + &ATenMLIRTypeDefault::all>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::all.out(Tensor self, int dim, bool " + "keepdim=False, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::allclose(Tensor self, Tensor other, float " + "rtol=1e-05, float atol=1e-08, bool equal_nan=False) " + "-> bool") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::any.dim(Tensor self, int dim, bool " + "keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, bool), + &ATenMLIRTypeDefault::any>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::any.out(Tensor self, int dim, bool " + "keepdim=False, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::arange(Scalar end, *, ScalarType? dtype=None, " + "Layout? layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::arange.start(Scalar start, Scalar end, *, " + "ScalarType? dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::arange.start_step(Scalar start, Scalar end, " + "Scalar step, *, ScalarType? dtype=None, Layout? " + "layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::arange.out(Scalar end, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::arange.start_out(Scalar start, Scalar end, " + "Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_dim_arange(Tensor like, int dim) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::argmax(Tensor self, int? dim=None, bool " + "keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel, + bool), + &ATenMLIRTypeDefault::argmax>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::argmin(Tensor self, int? dim=None, bool " + "keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel, + bool), + &ATenMLIRTypeDefault::argmin>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::as_strided(Tensor(a) self, int[] size, int[] " + "stride, int? storage_offset=None) -> Tensor(a)") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, c10::optional), + &ATenMLIRType::as_strided>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::as_strided_(Tensor(a!) self, int[] size, int[] " + "stride, int? storage_offset=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::IntArrayRef, + at::IntArrayRef, c10::optional), + &ATenMLIRTypeDefault::as_strided_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::asin(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::asin_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::asin.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::atan(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::atan_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::atan.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::baddbmm(Tensor self, Tensor batch1, Tensor " + "batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::baddbmm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor " + "batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::baddbmm_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_baddbmm_mkl_(Tensor(a!) self, Tensor batch1, " + "Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::_baddbmm_mkl_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::baddbmm.out(Tensor self, Tensor batch1, " + "Tensor batch2, *, Scalar beta=1, Scalar alpha=1, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::baddbmm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::bartlett_window(int window_length, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? device=None, " + "bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(int64_t, const at::TensorOptions &), + &ATenMLIRTypeDefault::bartlett_window>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bartlett_window.periodic(int window_length, " + "bool periodic, *, ScalarType? dtype=None, Layout? " + "layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(int64_t, bool, const at::TensorOptions &), + &ATenMLIRTypeDefault::bartlett_window>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::batch_norm(Tensor input, Tensor? weight, " + "Tensor? bias, Tensor? running_mean, Tensor? " + "running_var, bool training, float momentum, float " + "eps, bool cudnn_enabled) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, bool, double, double, + bool), + &ATenMLIRTypeDefault::batch_norm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_batch_norm_impl_index(Tensor input, Tensor? " + "weight, Tensor? bias, Tensor? running_mean, Tensor? " + "running_var, bool training, float momentum, float eps, " + "bool cudnn_enabled) -> (Tensor, Tensor, Tensor, int)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, bool, double, double, bool), + &ATenMLIRTypeDefault::_batch_norm_impl_index>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_batch_norm_impl_index_backward(int " + "impl_index, Tensor input, Tensor grad_output, " + "Tensor? weight, Tensor? running_mean, Tensor? " + "running_var, Tensor? save_mean, Tensor? " + "save_var_transform, bool train, float eps, bool[3] " + "output_mask) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + int64_t, const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, bool, double, + std::array), + &ATenMLIRTypeDefault::_batch_norm_impl_index_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bernoulli(Tensor self, *, Generator? " + "generator=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bernoulli.out(Tensor self, *, Generator? " + "generator=None, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, " + "*, Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bernoulli_.float(Tensor(a!) self, float " + "p=0.5, *, Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bernoulli.p(Tensor self, float p, *, " + "Generator? generator=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bilinear(Tensor input1, Tensor input2, Tensor " + "weight, Tensor? bias) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::bilinear>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::binary_cross_entropy_with_logits(Tensor self, " + "Tensor target, Tensor? weight=None, Tensor? " + "pos_weight=None, int reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::binary_cross_entropy_with_logits>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::binary_cross_entropy_with_logits_backward(" + "Tensor grad_output, Tensor self, Tensor target, " + "Tensor? weight=None, Tensor? pos_weight=None, int " + "reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault:: + binary_cross_entropy_with_logits_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bincount(Tensor self, Tensor? weights=None, " + "int minlength=0) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bitwise_not(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bitwise_not.out(Tensor self, *, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::bitwise_not_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::logical_not(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::logical_not_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::logical_not.out(Tensor self, *, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::logical_not_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::logical_xor(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::logical_xor_(Tensor(a!) self, Tensor other) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::logical_xor.out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::logical_xor_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::blackman_window(int window_length, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? device=None, " + "bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(int64_t, const at::TensorOptions &), + &ATenMLIRTypeDefault::blackman_window>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::blackman_window.periodic(int window_length, " + "bool periodic, *, ScalarType? dtype=None, Layout? " + "layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(int64_t, bool, const at::TensorOptions &), + &ATenMLIRTypeDefault::blackman_window>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bmm(Tensor self, Tensor mat2) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::bmm>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::bmm.out(Tensor self, Tensor mat2, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::broadcast_tensors(Tensor[] tensors) -> Tensor[]") + .impl_unboxedOnlyKernel< + std::vector(at::TensorList), + &ATenMLIRTypeDefault::broadcast_tensors>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cat(Tensor[] tensors, int dim=0) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cat.out(Tensor[] tensors, int dim=0, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ceil(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ceil_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ceil.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::chain_matmul(Tensor[] matrices) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::chunk(Tensor(a) self, int chunks, int dim=0) " + "-> Tensor(a)[]") + .impl_unboxedOnlyKernel( + const at::Tensor &, int64_t, + int64_t), + &ATenMLIRTypeDefault::chunk>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::clamp(Tensor self, Scalar? min=None, Scalar? " + "max=None) -> Tensor") + .impl_unboxedOnlyKernel, + c10::optional), + &ATenMLIRTypeDefault::clamp>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::clamp_(Tensor(a!) self, Scalar? min=None, " + "Scalar? max=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, c10::optional, + c10::optional), + &ATenMLIRTypeDefault::clamp_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::clamp.out(Tensor self, Scalar? min=None, " + "Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel, + c10::optional), + &ATenMLIRTypeDefault::clamp_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::clamp_max(Tensor self, Scalar max) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::clamp_max_(Tensor(a!) self, Scalar max) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::clamp_max.out(Tensor self, Scalar max, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::clamp_min(Tensor self, Scalar min) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::clamp_min_(Tensor(a!) self, Scalar min) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::clamp_min.out(Tensor self, Scalar min, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::constant_pad_nd(Tensor self, int[] pad, " + "Scalar value=0) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::Scalar), + &ATenMLIRTypeDefault::constant_pad_nd>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::contiguous(Tensor self, *, MemoryFormat " + "memory_format=contiguous_format) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::convolution(Tensor input, Tensor weight, Tensor? " + "bias, int[] stride, int[] padding, int[] dilation, bool " + "transposed, int[] output_padding, int groups) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, + at::IntArrayRef, int64_t), + &ATenMLIRTypeDefault::convolution>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::convolution_overrideable(Tensor input, Tensor " + "weight, Tensor? bias, int[] stride, int[] padding, " + "int[] dilation, bool transposed, int[] " + "output_padding, int groups) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, + at::IntArrayRef, int64_t), + &ATenMLIRType::convolution_overrideable>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::convolution_backward_overrideable(Tensor " + "grad_output, Tensor input, Tensor weight, int[] " + "stride, int[] padding, int[] dilation, bool " + "transposed, int[] output_padding, int groups, " + "bool[3] output_mask) -> (Tensor grad_input, Tensor " + "grad_weight, Tensor grad_bias)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool, at::IntArrayRef, int64_t, + std::array), + &ATenMLIRType::convolution_backward_overrideable>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_convolution(Tensor input, Tensor weight, " + "Tensor? bias, int[] stride, int[] padding, int[] " + "dilation, bool transposed, int[] output_padding, " + "int groups, bool benchmark, bool deterministic, " + "bool cudnn_enabled) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, + at::IntArrayRef, int64_t, bool, bool, bool), + &ATenMLIRTypeDefault::_convolution>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_convolution_nogroup(Tensor input, Tensor " + "weight, Tensor? bias, int[] stride, int[] padding, " + "int[] dilation, bool transposed, int[] " + "output_padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, + at::IntArrayRef), + &ATenMLIRTypeDefault::_convolution_nogroup>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_convolution_double_backward(Tensor? ggI, Tensor? " + "ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor " + "self, int[] stride, int[] padding, int[] dilation, bool " + "transposed, int[] output_padding, int groups, bool " + "benchmark, bool deterministic, bool cudnn_enabled, " + "bool[3] output_mask) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + bool, at::IntArrayRef, int64_t, bool, bool, bool, + std::array), + &ATenMLIRTypeDefault::_convolution_double_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::conv1d(Tensor input, Tensor weight, Tensor? " + "bias=None, int[1] stride=1, int[1] padding=0, " + "int[1] dilation=1, int groups=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t), + &ATenMLIRTypeDefault::conv1d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::conv2d(Tensor input, Tensor weight, Tensor? " + "bias=None, int[2] stride=1, int[2] padding=0, " + "int[2] dilation=1, int groups=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t), + &ATenMLIRTypeDefault::conv2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::conv3d(Tensor input, Tensor weight, Tensor? " + "bias=None, int[3] stride=1, int[3] padding=0, " + "int[3] dilation=1, int groups=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t), + &ATenMLIRTypeDefault::conv3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::conv_tbc(Tensor self, Tensor weight, Tensor " + "bias, int pad=0) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::conv_tbc>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::conv_tbc_backward(Tensor self, Tensor input, " + "Tensor weight, Tensor bias, int pad) -> (Tensor, " + "Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::conv_tbc_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::conv_transpose1d(Tensor input, Tensor weight, " + "Tensor? bias=None, int[1] stride=1, int[1] " + "padding=0, int[1] output_padding=0, int groups=1, " + "int[1] dilation=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + at::IntArrayRef), + &ATenMLIRTypeDefault::conv_transpose1d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::conv_transpose2d.input(Tensor input, Tensor " + "weight, Tensor? bias=None, int[2] stride=1, int[2] " + "padding=0, int[2] output_padding=0, int groups=1, " + "int[2] dilation=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + at::IntArrayRef), + &ATenMLIRTypeDefault::conv_transpose2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::conv_transpose3d.input(Tensor input, Tensor " + "weight, Tensor? bias=None, int[3] stride=1, int[3] " + "padding=0, int[3] output_padding=0, int groups=1, " + "int[3] dilation=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + at::IntArrayRef), + &ATenMLIRTypeDefault::conv_transpose3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::copy_(Tensor(a!) self, Tensor src, bool " + "non_blocking=False) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, bool), + &ATenMLIRType::copy_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_copy_from(Tensor self, Tensor dst, bool " + "non_blocking=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, bool), + &ATenMLIRType::_copy_from>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cos(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cos_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cos.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cosh(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cosh_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cosh.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cosine_embedding_loss(Tensor input1, Tensor " + "input2, Tensor target, float margin=0.0, int " + "reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, double, int64_t), + &ATenMLIRTypeDefault::cosine_embedding_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cumsum(Tensor self, int dim, *, ScalarType? " + "dtype=None) -> Tensor") + .impl_unboxedOnlyKernel), + &ATenMLIRTypeDefault::cumsum>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::cumsum.out(Tensor self, int dim, *, ScalarType? " + "dtype=None, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t, + c10::optional), + &ATenMLIRTypeDefault::cumsum_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cumprod(Tensor self, int dim, *, ScalarType? " + "dtype=None) -> Tensor") + .impl_unboxedOnlyKernel), + &ATenMLIRTypeDefault::cumprod>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::cumprod.out(Tensor self, int dim, *, ScalarType? " + "dtype=None, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t, + c10::optional), + &ATenMLIRTypeDefault::cumprod_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ctc_loss.IntList(Tensor log_probs, Tensor " + "targets, int[] input_lengths, int[] target_lengths, " + "int blank=0, int reduction=Mean, bool " + "zero_infinity=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, int64_t, + int64_t, bool), + &ATenMLIRTypeDefault::ctc_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ctc_loss.Tensor(Tensor log_probs, Tensor " + "targets, Tensor input_lengths, Tensor " + "target_lengths, int blank=0, int reduction=Mean, " + "bool zero_infinity=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t, int64_t, bool), + &ATenMLIRTypeDefault::ctc_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_ctc_loss(Tensor log_probs, Tensor targets, int[] " + "input_lengths, int[] target_lengths, int blank=0, bool " + "zero_infinity=False) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + int64_t, bool), + &ATenMLIRTypeDefault::_ctc_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_ctc_loss_backward(Tensor grad, Tensor " + "log_probs, Tensor targets, int[] input_lengths, " + "int[] target_lengths, Tensor neg_log_likelihood, " + "Tensor log_alpha, int blank, bool " + "zero_infinity=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, const at::Tensor &, + const at::Tensor &, int64_t, bool), + &ATenMLIRTypeDefault::_ctc_loss_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::det(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::diag_embed(Tensor self, int offset=0, int " + "dim1=-2, int dim2=-1) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::diagflat(Tensor self, int offset=0) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::diagonal(Tensor(a) self, int offset=0, int " + "dim1=0, int dim2=1) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fill_diagonal_(Tensor(a!) self, Scalar " + "fill_value, bool wrap=False) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::div.Tensor(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRType::div>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::div_.Tensor(Tensor(a!) self, Tensor other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRType::div_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::div.out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::div.Scalar(Tensor self, Scalar other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRType::div>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::div_.Scalar(Tensor(a!) self, Scalar other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::div_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::dot(Tensor self, Tensor tensor) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::dot>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::dot.out(Tensor self, Tensor tensor, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::einsum(str equation, Tensor[] tensors) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::embedding(Tensor weight, Tensor indices, int " + "padding_idx=-1, bool scale_grad_by_freq=False, bool " + "sparse=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::embedding_backward(Tensor grad, Tensor " + "indices, int num_weights, int padding_idx, bool " + "scale_grad_by_freq, bool sparse) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t, int64_t, bool, bool), + &ATenMLIRTypeDefault::embedding_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::embedding_dense_backward(Tensor grad_output, " + "Tensor indices, int num_weights, int padding_idx, " + "bool scale_grad_by_freq) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t, int64_t, bool), + &ATenMLIRTypeDefault::embedding_dense_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::embedding_renorm_(Tensor(a!) self, Tensor " + "indices, float max_norm, float norm_type) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, double, + double), + &ATenMLIRTypeDefault::embedding_renorm_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::embedding_sparse_backward(Tensor grad, Tensor " + "indices, int num_weights, int padding_idx, bool " + "scale_grad_by_freq) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t, int64_t, bool), + &ATenMLIRTypeDefault::embedding_sparse_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::embedding_bag(Tensor weight, Tensor indices, " + "Tensor offsets, bool scale_grad_by_freq=False, int " + "mode=0, bool sparse=False, Tensor? " + "per_sample_weights=None) -> (Tensor, Tensor, " + "Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, bool, int64_t, + bool, const at::Tensor &), + &ATenMLIRTypeDefault::embedding_bag>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_embedding_bag(Tensor weight, Tensor indices, " + "Tensor offsets, bool scale_grad_by_freq=False, int " + "mode=0, bool sparse=False, Tensor? " + "per_sample_weights=None) -> (Tensor, Tensor, " + "Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, bool, int64_t, + bool, const at::Tensor &), + &ATenMLIRTypeDefault::_embedding_bag>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_embedding_bag_backward(Tensor grad, Tensor " + "indices, Tensor offsets, Tensor offset2bag, Tensor " + "bag_size, Tensor maximum_indices, int num_weights, " + "bool scale_grad_by_freq, int mode, bool sparse, " + "Tensor? per_sample_weights) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t, bool, int64_t, bool, + const at::Tensor &), + &ATenMLIRTypeDefault::_embedding_bag_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_embedding_bag_sparse_backward(Tensor grad, " + "Tensor indices, Tensor offsets, Tensor offset2bag, " + "Tensor bag_size, int num_weights, bool " + "scale_grad_by_freq, int mode, Tensor? " + "per_sample_weights) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t, bool, int64_t, + const at::Tensor &), + &ATenMLIRTypeDefault::_embedding_bag_sparse_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_embedding_bag_dense_backward(Tensor grad, " + "Tensor indices, Tensor offsets, Tensor offset2bag, " + "Tensor bag_size, Tensor maximum_indices, int " + "num_weights, bool scale_grad_by_freq, int mode, " + "Tensor? per_sample_weights) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t, bool, int64_t, const at::Tensor &), + &ATenMLIRTypeDefault::_embedding_bag_dense_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_embedding_bag_per_sample_weights_backward(" + "Tensor grad, Tensor weight, Tensor indices, Tensor " + "offsets, Tensor offset2bag, int mode) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault:: + _embedding_bag_per_sample_weights_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::empty.memory_format(int[] size, *, " + "ScalarType? dtype=None, Layout? layout=None, " + "Device? device=None, bool? pin_memory=None, " + "MemoryFormat? memory_format=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::IntArrayRef, const at::TensorOptions &, + c10::optional), + &ATenMLIRTypeDefault::empty>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::new_empty(Tensor self, int[] size, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? device=None, " + "bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::new_full(Tensor self, int[] size, Scalar " + "fill_value, *, ScalarType? dtype=None, Layout? " + "layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::Scalar, const at::TensorOptions &), + &ATenMLIRTypeDefault::new_full>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_empty_affine_quantized(int[] size, *, " + "ScalarType? dtype=None, Layout? layout=None, " + "Device? device=None, bool? pin_memory=None, float " + "scale=1, int zero_point=0, MemoryFormat? " + "memory_format=contiguous_format) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::IntArrayRef, const at::TensorOptions &, + double, int64_t, + c10::optional), + &ATenMLIRTypeDefault::_empty_affine_quantized>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_empty_per_channel_affine_quantized_like(Tensor " + "self, Tensor zero_points, int[] size, int[] axis, *, " + "ScalarType? dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None, MemoryFormat? " + "memory_format=contiguous_format) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + const at::TensorOptions &, + c10::optional), + &ATenMLIRTypeDefault:: + _empty_per_channel_affine_quantized_like>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::resize_(Tensor(a!) self, int[] size) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::empty.out(int[] size, *, MemoryFormat? " + "memory_format=None, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel), + &ATenMLIRTypeDefault::empty_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::empty_like(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::empty_like.dtype(Tensor self, *, ScalarType " + "dtype, Layout layout, Device device, bool " + "pin_memory=False, MemoryFormat? " + "memory_format=contiguous_format) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::TensorOptions &, + c10::optional), + &ATenMLIRTypeDefault::empty_like>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::empty_strided(int[] size, int[] stride, *, " + "ScalarType? dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::erf(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::erf_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::erf.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::erfc(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::erfc_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::erfc.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::exp(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::exp_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::exp.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::expm1(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::expm1_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::expm1.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::expand(Tensor(a) self, int[] size, *, bool " + "implicit=False) -> Tensor(a)") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, bool), + &ATenMLIRType::expand>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::expand_as(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::eye(int n, *, ScalarType? dtype=None, Layout? " + "layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(int64_t, const at::TensorOptions &), + &ATenMLIRTypeDefault::eye>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::eye.m(int n, int m, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(int64_t, int64_t, const at::TensorOptions &), + &ATenMLIRTypeDefault::eye>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::eye.out(int n, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::eye.m_out(int n, int m, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::flatten.using_ints(Tensor self, int " + "start_dim=0, int end_dim=-1) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fill_.Scalar(Tensor(a!) self, Scalar value) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fill_.Tensor(Tensor(a!) self, Tensor value) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::floor(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::floor_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::floor.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::frac(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::frac_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::frac.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::full(int[] size, Scalar fill_value, *, " + "ScalarType? dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::full.out(int[] size, Scalar fill_value, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::IntArrayRef, at::Scalar), + &ATenMLIRTypeDefault::full_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::full_like(Tensor self, Scalar fill_value) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::full_like.dtype(Tensor self, Scalar " + "fill_value, *, ScalarType dtype, Layout layout, " + "Device device, bool pin_memory=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::from_file(str filename, bool? shared=None, int? " + "size=0, *, ScalarType? dtype=None, Layout? layout=None, " + "Device? device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel, + c10::optional, + const at::TensorOptions &), + &ATenMLIRTypeDefault::from_file>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::grid_sampler(Tensor input, Tensor grid, int " + "interpolation_mode, int padding_mode, bool " + "align_corners) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::grid_sampler_2d(Tensor input, Tensor grid, " + "int interpolation_mode, int padding_mode, bool " + "align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t, int64_t, bool), + &ATenMLIRTypeDefault::grid_sampler_2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::grid_sampler_2d_backward(Tensor grad_output, " + "Tensor input, Tensor grid, int interpolation_mode, int " + "padding_mode, bool align_corners) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + int64_t, int64_t, + bool), + &ATenMLIRTypeDefault::grid_sampler_2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::grid_sampler_3d(Tensor input, Tensor grid, " + "int interpolation_mode, int padding_mode, bool " + "align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t, int64_t, bool), + &ATenMLIRTypeDefault::grid_sampler_3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::grid_sampler_3d_backward(Tensor grad_output, " + "Tensor input, Tensor grid, int interpolation_mode, int " + "padding_mode, bool align_corners) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + int64_t, int64_t, + bool), + &ATenMLIRTypeDefault::grid_sampler_3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hann_window(int window_length, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hann_window.periodic(int window_length, bool " + "periodic, *, ScalarType? dtype=None, Layout? " + "layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::hamming_window(int window_length, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? device=None, " + "bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hamming_window.periodic(int window_length, " + "bool periodic, *, ScalarType? dtype=None, Layout? " + "layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::hamming_window.periodic_alpha(int window_length, " + "bool periodic, float alpha, *, ScalarType? dtype=None, " + "Layout? layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::hamming_window.periodic_alpha_beta(int " + "window_length, bool periodic, float alpha, float beta, " + "*, ScalarType? dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::hinge_embedding_loss(Tensor self, Tensor target, " + "float margin=1.0, int reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, double, + int64_t), + &ATenMLIRTypeDefault::hinge_embedding_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ger(Tensor self, Tensor vec2) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::ger>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ger.out(Tensor self, Tensor vec2, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::group_norm(Tensor input, int num_groups, " + "Tensor? weight=None, Tensor? bias=None, float " + "eps=1e-05, bool cudnn_enabled=True) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fft(Tensor self, int signal_ndim, bool " + "normalized=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, bool), + &ATenMLIRTypeDefault::fft>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ifft(Tensor self, int signal_ndim, bool " + "normalized=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, bool), + &ATenMLIRTypeDefault::ifft>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rfft(Tensor self, int signal_ndim, bool " + "normalized=False, bool onesided=True) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, bool, bool), + &ATenMLIRTypeDefault::rfft>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::irfft(Tensor self, int signal_ndim, bool " + "normalized=False, bool onesided=True, int[] " + "signal_sizes=[]) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_fft_with_size(Tensor self, int signal_ndim, bool " + "complex_input, bool complex_output, bool inverse, int[] " + "checked_signal_sizes, bool normalized, bool onesided, " + "int[] output_sizes) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, bool, bool, bool, + at::IntArrayRef, bool, bool, at::IntArrayRef), + &ATenMLIRTypeDefault::_fft_with_size>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cufft_get_plan_cache_size(int device_index) " + "-> int") + .impl_unboxedOnlyKernel< + int64_t(int64_t), + &ATenMLIRTypeDefault::_cufft_get_plan_cache_size>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cufft_get_plan_cache_max_size(int " + "device_index) -> int") + .impl_unboxedOnlyKernel< + int64_t(int64_t), + &ATenMLIRTypeDefault::_cufft_get_plan_cache_max_size>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cufft_set_plan_cache_max_size(int " + "device_index, int max_size) -> void") + .impl_unboxedOnlyKernel< + void(int64_t, int64_t), + &ATenMLIRTypeDefault::_cufft_set_plan_cache_max_size>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_cufft_clear_plan_cache(int device_index) -> void") + .impl_unboxedOnlyKernel< + void(int64_t), + &ATenMLIRTypeDefault::_cufft_clear_plan_cache>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index.Tensor(Tensor self, Tensor?[] indices) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_copy_(Tensor(a!) self, int dim, Tensor " + "index, Tensor source) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_copy(Tensor self, int dim, Tensor " + "index, Tensor source) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::index_copy>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::index_put_(Tensor(a!) self, Tensor?[] indices, " + "Tensor values, bool accumulate=False) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_put(Tensor self, Tensor?[] indices, " + "Tensor values, bool accumulate=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_index_put_impl_(Tensor(a!) self, Tensor?[] " + "indices, Tensor values, bool accumulate=False, bool " + "unsafe=False) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::TensorList, + const at::Tensor &, bool, bool), + &ATenMLIRTypeDefault::_index_put_impl_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::instance_norm(Tensor input, Tensor? weight, " + "Tensor? bias, Tensor? running_mean, Tensor? " + "running_var, bool use_input_stats, float momentum, " + "float eps, bool cudnn_enabled) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, bool, double, double, + bool), + &ATenMLIRTypeDefault::instance_norm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::inverse(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::inverse.out(Tensor self, *, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_inverse_helper(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &), + &ATenMLIRTypeDefault::_inverse_helper>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::isclose(Tensor self, Tensor other, float " + "rtol=1e-05, float atol=1e-08, bool equal_nan=False) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::isnan(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::is_distributed(Tensor self) -> bool") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::is_floating_point(Tensor self) -> bool") + .impl_unboxedOnlyKernel< + bool(const at::Tensor &), + &ATenMLIRTypeDefault::is_floating_point>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::is_complex(Tensor self) -> bool") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::is_nonzero(Tensor self) -> bool") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::is_same_size(Tensor self, Tensor other) -> bool") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::is_signed(Tensor self) -> bool") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::kl_div(Tensor self, Tensor target, int " + "reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::kl_div_backward(Tensor grad_output, Tensor " + "self, Tensor target, int reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::kl_div_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::kthvalue(Tensor self, int k, int dim=-1, bool " + "keepdim=False) -> (Tensor values, Tensor indices)") + .impl_unboxedOnlyKernel( + const at::Tensor &, int64_t, + int64_t, bool), + &ATenMLIRTypeDefault::kthvalue>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::kthvalue.values(Tensor self, int k, int dim=-1, " + "bool keepdim=False, *, Tensor(a!) values, Tensor(b!) " + "indices) -> (Tensor(a!) values, Tensor(b!) indices)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + int64_t, int64_t, bool), + &ATenMLIRTypeDefault::kthvalue_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::layer_norm(Tensor input, int[] normalized_shape, " + "Tensor? weight=None, Tensor? bias=None, float " + "eps=1e-05, bool cudnn_enable=True) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + const at::Tensor &, const at::Tensor &, double, + bool), + &ATenMLIRTypeDefault::layer_norm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::native_layer_norm(Tensor input, Tensor? " + "weight, Tensor? bias, int M, int N, float eps) -> " + "(Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t, int64_t, double), + &ATenMLIRTypeDefault::native_layer_norm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::native_layer_norm_backward(Tensor grad_out, " + "Tensor input, Tensor mean, Tensor rstd, Tensor? " + "weight, int M, int N, bool[3] output_mask) -> " + "(Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t, int64_t, + std::array), + &ATenMLIRTypeDefault::native_layer_norm_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::native_layer_norm_double_backward(Tensor? ggI, " + "Tensor? ggW, Tensor? ggb, Tensor gO, Tensor input, " + "Tensor mean, Tensor rstd, Tensor? weight, int M, int N, " + "bool[3] output_mask) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, int64_t, + int64_t, std::array), + &ATenMLIRTypeDefault::native_layer_norm_double_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::linear(Tensor input, Tensor weight, Tensor? " + "bias=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mkldnn_linear(Tensor input, Tensor weight, " + "Tensor? bias=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fbgemm_linear_int8_weight_fp32_activation(" + "Tensor input, Tensor weight, Tensor packed, Tensor " + "col_offsets, Scalar weight_scale, Scalar " + "weight_zero_point, Tensor bias) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar, const at::Tensor &), + &ATenMLIRTypeDefault:: + fbgemm_linear_int8_weight_fp32_activation>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fbgemm_linear_int8_weight(Tensor input, " + "Tensor weight, Tensor packed, Tensor col_offsets, " + "Scalar weight_scale, Scalar weight_zero_point, " + "Tensor bias) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar, const at::Tensor &), + &ATenMLIRTypeDefault::fbgemm_linear_int8_weight>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fbgemm_linear_quantize_weight(Tensor input) " + "-> (Tensor, Tensor, float, int)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &), + &ATenMLIRTypeDefault::fbgemm_linear_quantize_weight>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fbgemm_pack_gemm_matrix_fp16(Tensor input) -> " + "Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &), + &ATenMLIRTypeDefault::fbgemm_pack_gemm_matrix_fp16>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::fbgemm_linear_fp16_weight_fp32_activation(Tensor " + "input, Tensor packed_weight, Tensor bias) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault:: + fbgemm_linear_fp16_weight_fp32_activation>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fbgemm_linear_fp16_weight(Tensor input, " + "Tensor packed_weight, Tensor bias) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::fbgemm_linear_fp16_weight>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fbgemm_pack_quantized_matrix(Tensor input) -> " + "Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &), + &ATenMLIRTypeDefault::fbgemm_pack_quantized_matrix>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fbgemm_pack_quantized_matrix.KN(Tensor input, " + "int K, int N) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, int64_t), + &ATenMLIRTypeDefault::fbgemm_pack_quantized_matrix>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::linspace(Scalar start, Scalar end, int steps=100, " + "*, ScalarType? dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::linspace.out(Scalar start, Scalar end, int " + "steps=100, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log10(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log10_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log10.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log1p(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log1p_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log1p.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log2(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log2_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log2.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::logdet(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::logspace(Scalar start, Scalar end, int " + "steps=100, float base=10.0, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::logspace.out(Scalar start, Scalar end, int " + "steps=100, float base=10.0, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log_softmax(Tensor self, int dim, ScalarType? " + "dtype=None) -> Tensor") + .impl_unboxedOnlyKernel), + &ATenMLIRTypeDefault::log_softmax>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_log_softmax(Tensor self, int dim, bool " + "half_to_float) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_log_softmax_backward_data(Tensor grad_output, " + "Tensor output, int dim, Tensor self) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t, const at::Tensor &), + &ATenMLIRType::_log_softmax_backward_data>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::logsumexp(Tensor self, int[1] dim, bool " + "keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::logsumexp.out(Tensor self, int[1] dim, bool " + "keepdim=False, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::margin_ranking_loss(Tensor input1, Tensor " + "input2, Tensor target, float margin=0.0, int " + "reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, double, int64_t), + &ATenMLIRTypeDefault::margin_ranking_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::matmul(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::matmul.out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::matrix_rank.tol(Tensor self, float tol, bool " + "symmetric=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::matrix_rank(Tensor self, bool " + "symmetric=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::matrix_power(Tensor self, int n) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max.dim(Tensor self, int dim, bool " + "keepdim=False) -> (Tensor values, Tensor indices)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::max>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max.dim_max(Tensor self, int dim, bool " + "keepdim=False, *, Tensor(a!) max, Tensor(b!) " + "max_values) -> (Tensor(a!) values, Tensor(b!) indices)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::max_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_values(Tensor self, int[1] dim, bool " + "keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max_pool1d_with_indices(Tensor self, int[1] " + "kernel_size, int[1] stride=[], int[1] padding=0, int[1] " + "dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::max_pool1d_with_indices>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_pool1d(Tensor self, int[1] kernel_size, " + "int[1] stride=[], int[1] padding=0, int[1] " + "dilation=1, bool ceil_mode=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::max_pool1d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_pool2d(Tensor self, int[2] kernel_size, " + "int[2] stride=[], int[2] padding=0, int[2] " + "dilation=1, bool ceil_mode=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::max_pool2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mkldnn_max_pool2d(Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, " + "int[2] dilation=1, bool ceil_mode=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::mkldnn_max_pool2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::quantized_max_pool2d(Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, " + "int[2] dilation=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::quantized_max_pool2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_pool3d(Tensor self, int[3] kernel_size, " + "int[3] stride=[], int[3] padding=0, int[3] " + "dilation=1, bool ceil_mode=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::max_pool3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mean(Tensor self, *, ScalarType? dtype=None) " + "-> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, + c10::optional), + &ATenMLIRType::mean>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mean.dim(Tensor self, int[1] dim, bool " + "keepdim=False, *, ScalarType? dtype=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, bool, + c10::optional), + &ATenMLIRType::mean>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mean.out(Tensor self, int[1] dim, bool " + "keepdim=False, *, ScalarType? dtype=None, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel), + &ATenMLIRTypeDefault::mean_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::median.dim(Tensor self, int dim, bool " + "keepdim=False) -> (Tensor values, Tensor indices)") + .impl_unboxedOnlyKernel( + const at::Tensor &, int64_t, + bool), + &ATenMLIRTypeDefault::median>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::median.dim_values(Tensor self, int dim, bool " + "keepdim=False, *, Tensor(a!) values, Tensor(b!) " + "indices) -> (Tensor(a!) values, Tensor(b!) indices)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::median_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::min.dim(Tensor self, int dim, bool " + "keepdim=False) -> (Tensor values, Tensor indices)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::min>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::min.dim_min(Tensor self, int dim, bool " + "keepdim=False, *, Tensor(a!) min, Tensor(b!) " + "min_indices) -> (Tensor(a!) values, Tensor(b!) indices)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::min_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::min_values(Tensor self, int[1] dim, bool " + "keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mkldnn_convolution(Tensor self, Tensor " + "weight, Tensor? bias, int[] padding, int[] stride, " + "int[] dilation, int groups) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t), + &ATenMLIRTypeDefault::mkldnn_convolution>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mkldnn_convolution_backward_input(int[] " + "self_size, Tensor grad_output, Tensor weight, int[] " + "padding, int[] stride, int[] dilation, int groups, " + "bool bias_defined) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::IntArrayRef, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + bool), + &ATenMLIRTypeDefault::mkldnn_convolution_backward_input>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mkldnn_convolution_backward_weights(int[] " + "weight_size, Tensor grad_output, Tensor self, int[] " + "padding, int[] stride, int[] dilation, int groups, " + "bool bias_defined) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + at::IntArrayRef, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, bool), + &ATenMLIRTypeDefault:: + mkldnn_convolution_backward_weights>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mkldnn_convolution_backward(Tensor self, " + "Tensor grad_output, Tensor weight, int[] padding, " + "int[] stride, int[] dilation, int groups, bool[3] " + "output_mask) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, std::array), + &ATenMLIRTypeDefault::mkldnn_convolution_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_batch_norm(Tensor input, Tensor " + "weight, Tensor? bias, Tensor? running_mean, Tensor? " + "running_var, bool training, float " + "exponential_average_factor, float epsilon) -> " + "(Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, bool, double, double), + &ATenMLIRTypeDefault::miopen_batch_norm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::miopen_batch_norm_backward(Tensor input, Tensor " + "grad_output, Tensor weight, Tensor? running_mean, " + "Tensor? running_var, Tensor? save_mean, Tensor? " + "save_var, float epsilon) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, double), + &ATenMLIRTypeDefault::miopen_batch_norm_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_convolution(Tensor self, Tensor " + "weight, Tensor? bias, int[] padding, int[] stride, " + "int[] dilation, int groups, bool benchmark, bool " + "deterministic) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + bool, bool), + &ATenMLIRTypeDefault::miopen_convolution>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_convolution_backward_input(int[] " + "self_size, Tensor grad_output, Tensor weight, int[] " + "padding, int[] stride, int[] dilation, int groups, " + "bool benchmark, bool deterministic) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::IntArrayRef, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + bool, bool), + &ATenMLIRTypeDefault::miopen_convolution_backward_input>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_convolution_backward(Tensor self, " + "Tensor grad_output, Tensor weight, int[] padding, " + "int[] stride, int[] dilation, int groups, bool " + "benchmark, bool deterministic, bool[3] output_mask) " + "-> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, bool, bool, + std::array), + &ATenMLIRTypeDefault::miopen_convolution_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_convolution_backward_bias(Tensor " + "grad_output) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &), + &ATenMLIRTypeDefault::miopen_convolution_backward_bias>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_convolution_backward_weight(int[] " + "weight_size, Tensor grad_output, Tensor self, int[] " + "padding, int[] stride, int[] dilation, int groups, " + "bool benchmark, bool deterministic) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::IntArrayRef, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + bool, bool), + &ATenMLIRTypeDefault::miopen_convolution_backward_weight>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::miopen_convolution_transpose(Tensor self, Tensor " + "weight, Tensor? bias, int[] padding, int[] " + "output_padding, int[] stride, int[] dilation, int " + "groups, bool benchmark, bool deterministic) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, bool, bool), + &ATenMLIRTypeDefault::miopen_convolution_transpose>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::miopen_convolution_transpose_backward(Tensor " + "self, Tensor grad_output, Tensor weight, int[] padding, " + "int[] output_padding, int[] stride, int[] dilation, int " + "groups, bool benchmark, bool deterministic, bool[3] " + "output_mask) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, + std::array), + &ATenMLIRTypeDefault:: + miopen_convolution_transpose_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_convolution_transpose_backward_input(" + "Tensor grad_output, Tensor weight, int[] padding, " + "int[] stride, int[] dilation, int groups, bool " + "benchmark, bool deterministic) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, bool, bool), + &ATenMLIRTypeDefault:: + miopen_convolution_transpose_backward_input>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::miopen_convolution_transpose_backward_weight(int[]" + " weight_size, Tensor grad_output, Tensor self, int[] " + "padding, int[] stride, int[] dilation, int groups, bool " + "benchmark, bool deterministic) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::IntArrayRef, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + bool, bool), + &ATenMLIRTypeDefault:: + miopen_convolution_transpose_backward_weight>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_depthwise_convolution(Tensor self, " + "Tensor weight, Tensor? bias, int[] padding, int[] " + "stride, int[] dilation, int groups, bool benchmark, " + "bool deterministic) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + bool, bool), + &ATenMLIRTypeDefault::miopen_depthwise_convolution>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::miopen_depthwise_convolution_backward_input(int[] " + "self_size, Tensor grad_output, Tensor weight, int[] " + "padding, int[] stride, int[] dilation, int groups, bool " + "benchmark, bool deterministic) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::IntArrayRef, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + bool, bool), + &ATenMLIRTypeDefault:: + miopen_depthwise_convolution_backward_input>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_depthwise_convolution_backward(Tensor " + "self, Tensor grad_output, Tensor weight, int[] " + "padding, int[] stride, int[] dilation, int groups, " + "bool benchmark, bool deterministic, bool[3] " + "output_mask) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, bool, bool, + std::array), + &ATenMLIRTypeDefault:: + miopen_depthwise_convolution_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::miopen_depthwise_convolution_backward_weight(int[]" + " weight_size, Tensor grad_output, Tensor self, int[] " + "padding, int[] stride, int[] dilation, int groups, bool " + "benchmark, bool deterministic) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::IntArrayRef, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, + bool, bool), + &ATenMLIRTypeDefault:: + miopen_depthwise_convolution_backward_weight>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_rnn(Tensor input, Tensor[] weight, int " + "weight_stride0, Tensor hx, Tensor? cx, int mode, " + "int hidden_size, int num_layers, bool batch_first, " + "float dropout, bool train, bool bidirectional, " + "int[] batch_sizes, Tensor? dropout_state) -> " + "(Tensor, Tensor, Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, at::TensorList, int64_t, + const at::Tensor &, const at::Tensor &, int64_t, + int64_t, int64_t, bool, double, bool, bool, + at::IntArrayRef, const at::Tensor &), + &ATenMLIRTypeDefault::miopen_rnn>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::miopen_rnn_backward(Tensor input, Tensor[] " + "weight, int weight_stride0, Tensor weight_buf, " + "Tensor hx, Tensor? cx, Tensor output, Tensor? " + "grad_output, Tensor? grad_hy, Tensor? grad_cy, int " + "mode, int hidden_size, int num_layers, bool " + "batch_first, float dropout, bool train, bool " + "bidirectional, int[] batch_sizes, Tensor? " + "dropout_state, Tensor reserve, bool[4] output_mask) " + "-> (Tensor, Tensor, Tensor, Tensor[])") + .impl_unboxedOnlyKernel< + std::tuple>( + const at::Tensor &, at::TensorList, int64_t, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t, int64_t, int64_t, bool, + double, bool, bool, at::IntArrayRef, + const at::Tensor &, const at::Tensor &, + std::array), + &ATenMLIRTypeDefault::miopen_rnn_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mm(Tensor self, Tensor mat2) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRType::mm>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mm.out(Tensor self, Tensor mat2, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_sparse_mm(Tensor sparse, Tensor dense) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mode(Tensor self, int dim=-1, bool " + "keepdim=False) -> (Tensor values, Tensor indices)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::mode>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mode.values(Tensor self, int dim=-1, bool " + "keepdim=False, *, Tensor(a!) values, Tensor(b!) " + "indices) -> (Tensor(a!) values, Tensor(b!) indices)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::mode_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRType::mul>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRType::mul_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mul.out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::mul.Scalar(Tensor self, Scalar other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::mul>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::mul_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mv(Tensor self, Tensor vec) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::mv>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mv.out(Tensor self, Tensor vec, *, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mvlgamma(Tensor self, int p) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::narrow_copy(Tensor self, int dim, int start, " + "int length) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::narrow(Tensor(a) self, int dim, int start, " + "int length) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::native_batch_norm(Tensor input, Tensor? " + "weight, Tensor? bias, Tensor? running_mean, Tensor? " + "running_var, bool training, float momentum, float " + "eps) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, bool, double, double), + &ATenMLIRType::native_batch_norm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::batch_norm_stats(Tensor input, float eps) -> " + "(Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + double), + &ATenMLIRTypeDefault::batch_norm_stats>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::batch_norm_elemt(Tensor input, Tensor? " + "weight, Tensor? bias, Tensor mean, Tensor invstd, " + "float eps) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, double), + &ATenMLIRTypeDefault::batch_norm_elemt>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::batch_norm_gather_stats(Tensor input, Tensor " + "mean, Tensor invstd, Tensor? running_mean, Tensor? " + "running_var, float momentum, float eps, int count) " + "-> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + double, double, + int64_t), + &ATenMLIRTypeDefault::batch_norm_gather_stats>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::batch_norm_gather_stats_with_counts(Tensor " + "input, Tensor mean, Tensor invstd, Tensor? " + "running_mean, Tensor? running_var, float momentum, " + "float eps, int[] counts) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + double, double, + at::IntArrayRef), + &ATenMLIRTypeDefault:: + batch_norm_gather_stats_with_counts>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::native_batch_norm_backward(Tensor grad_out, " + "Tensor input, Tensor? weight, Tensor? running_mean, " + "Tensor? running_var, Tensor? save_mean, Tensor? " + "save_invstd, bool train, float eps, bool[3] " + "output_mask) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, bool, double, + std::array), + &ATenMLIRType::native_batch_norm_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::batch_norm_backward_reduce(Tensor grad_out, " + "Tensor input, Tensor mean, Tensor invstd, Tensor? " + "weight, bool input_g, bool weight_g, bool bias_g) " + "-> (Tensor, Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std:: + tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, bool, bool, bool), + &ATenMLIRTypeDefault::batch_norm_backward_reduce>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::batch_norm_backward_elemt(Tensor grad_out, Tensor " + "input, Tensor mean, Tensor invstd, Tensor? weight, " + "Tensor mean_dy, Tensor mean_dy_xmu) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::batch_norm_backward_elemt>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::batch_norm_update_stats(Tensor input, Tensor? " + "running_mean, Tensor? running_var, float momentum) " + "-> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + double), + &ATenMLIRTypeDefault::batch_norm_update_stats>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_nnpack_spatial_convolution(Tensor input, Tensor " + "weight, Tensor? bias, int[2] padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::_nnpack_spatial_convolution>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_nnpack_spatial_convolution_backward(Tensor " + "input, Tensor grad_output, Tensor weight, int[2] " + "padding, bool[3] output_mask) -> (Tensor, Tensor, " + "Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + std::array), + &ATenMLIRTypeDefault:: + _nnpack_spatial_convolution_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_nnpack_spatial_convolution_backward_input(" + "Tensor input, Tensor grad_output, Tensor weight, " + "int[2] padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault:: + _nnpack_spatial_convolution_backward_input>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_nnpack_spatial_convolution_backward_weight(" + "Tensor input, int[] weightsize, Tensor grad_output, " + "int[2] padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault:: + _nnpack_spatial_convolution_backward_weight>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ones.out(int[] size, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pairwise_distance(Tensor x1, Tensor x2, float " + "p=2, float eps=1e-06, bool keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, double, + double, bool), + &ATenMLIRTypeDefault::pairwise_distance>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::cdist(Tensor x1, Tensor x2, float p=2) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cdist_backward(Tensor grad, Tensor x1, " + "Tensor x2, float p, Tensor cdist) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, double, + const at::Tensor &), + &ATenMLIRTypeDefault::_cdist_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pdist(Tensor self, float p=2) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_pdist_forward(Tensor self, float p=2) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_pdist_backward(Tensor grad, Tensor self, " + "float p, Tensor pdist) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, double, + const at::Tensor &), + &ATenMLIRTypeDefault::_pdist_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cosine_similarity(Tensor x1, Tensor x2, int " + "dim=1, float eps=1e-08) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t, double), + &ATenMLIRTypeDefault::cosine_similarity>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::numpy_T(Tensor(a) self) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pixel_shuffle(Tensor self, int " + "upscale_factor) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::is_pinned(Tensor self) -> bool") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pin_memory(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pinverse(Tensor self, float rcond=1e-15) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::poisson_nll_loss(Tensor input, Tensor target, " + "bool log_input, bool full, float eps, int " + "reduction) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, bool, + bool, double, int64_t), + &ATenMLIRTypeDefault::poisson_nll_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::scalar_tensor(Scalar s, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rand(int[] size, *, ScalarType? dtype=None, " + "Layout? layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::IntArrayRef, const at::TensorOptions &), + &ATenMLIRTypeDefault::rand>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::rand.generator(int[] size, *, Generator? " + "generator, ScalarType? dtype=None, Layout? layout=None, " + "Device? device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rand.out(int[] size, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rand.generator_out(int[] size, *, Generator? " + "generator, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rand_like(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rand_like.dtype(Tensor self, *, ScalarType " + "dtype, Layout layout, Device device, bool " + "pin_memory=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randint(int high, int[] size, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randint.generator(int high, int[] size, *, " + "Generator? generator, ScalarType? dtype=None, " + "Layout? layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::randint.low(int low, int high, int[] size, *, " + "ScalarType? dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randint.low_generator(int low, int high, " + "int[] size, *, Generator? generator, ScalarType? " + "dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(int64_t, int64_t, at::IntArrayRef, + at::Generator *, const at::TensorOptions &), + &ATenMLIRTypeDefault::randint>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randint.out(int high, int[] size, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::randint.generator_out(int high, int[] size, *, " + "Generator? generator, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randint.low_out(int low, int high, int[] " + "size, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randint.low_generator_out(int low, int high, " + "int[] size, *, Generator? generator, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randint_like(Tensor self, int high) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randint_like.low(Tensor self, int low, int " + "high) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randint_like.dtype(Tensor self, int high, *, " + "ScalarType dtype, Layout layout, Device device, " + "bool pin_memory=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randint_like.low_dtype(Tensor self, int low, " + "int high, *, ScalarType dtype, Layout layout, " + "Device device, bool pin_memory=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randn(int[] size, *, ScalarType? dtype=None, " + "Layout? layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::randn.generator(int[] size, *, Generator? " + "generator, ScalarType? dtype=None, Layout? layout=None, " + "Device? device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randn.out(int[] size, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randn.generator_out(int[] size, *, Generator? " + "generator, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randn_like(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randn_like.dtype(Tensor self, *, ScalarType " + "dtype, Layout layout, Device device, bool " + "pin_memory=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randperm(int n, *, ScalarType? dtype=None, " + "Layout? layout=None, Device? device=None, bool? " + "pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::randperm.generator(int n, *, Generator? " + "generator, ScalarType? dtype=None, Layout? layout=None, " + "Device? device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randperm.out(int n, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::randperm.generator_out(int n, *, Generator? " + "generator, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::range.step(Scalar start, Scalar end, Scalar " + "step=1, *, ScalarType? dtype=None, Layout? layout=None, " + "Device? device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::range(Scalar start, Scalar end, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? device=None, " + "bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::range.out(Scalar start, Scalar end, Scalar " + "step=1, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reciprocal(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reciprocal_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reciprocal.out(Tensor self, *, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::neg(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::neg_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::neg.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::repeat(Tensor self, int[] repeats) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::repeat_interleave.Tensor(Tensor repeats) -> " + "Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &), + &ATenMLIRTypeDefault::repeat_interleave>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::repeat_interleave.self_Tensor(Tensor self, " + "Tensor repeats, int? dim=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + c10::optional), + &ATenMLIRTypeDefault::repeat_interleave>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::repeat_interleave.self_int(Tensor self, int " + "repeats, int? dim=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, + c10::optional), + &ATenMLIRTypeDefault::repeat_interleave>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reshape(Tensor self, int[] shape) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_mkldnn_reshape(Tensor self, int[] shape) -> " + "Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::_mkldnn_reshape>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::reshape_as(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::round(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::round_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::round.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rrelu(Tensor self, Scalar lower=0.125, Scalar " + "upper=0.3333333333333333, bool training=False, " + "Generator? generator=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::rrelu_(Tensor(a!) self, Scalar lower=0.125, " + "Scalar upper=0.3333333333333333, bool training=False, " + "Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::relu(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::relu_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::prelu(Tensor self, Tensor weight) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::prelu_backward(Tensor grad_output, Tensor " + "self, Tensor weight) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::prelu_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::gelu(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::gelu_backward(Tensor grad, Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hardshrink(Tensor self, Scalar lambd=0.5) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hardshrink_backward(Tensor grad_out, Tensor " + "self, Scalar lambd) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::Scalar), + &ATenMLIRTypeDefault::hardshrink_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rsqrt(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rsqrt_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rsqrt.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::select.int(Tensor(a) self, int dim, int " + "index) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::selu(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::selu_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::celu(Tensor self, Scalar alpha=1.0) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::celu>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::celu_(Tensor(a!) self, Scalar alpha=1.0) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sigmoid(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sigmoid_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sigmoid.out(Tensor self, *, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sin(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sin_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sin.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sinh(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sinh_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sinh.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::detach(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::detach_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::size.int(Tensor self, int dim) -> int") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::slice.Tensor(Tensor(a) self, int dim=0, int " + "start=0, int end=9223372036854775807, int step=1) " + "-> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::slogdet(Tensor self) -> (Tensor sign, Tensor " + "logabsdet)") + .impl_unboxedOnlyKernel( + const at::Tensor &), + &ATenMLIRTypeDefault::slogdet>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::smm(Tensor self, Tensor mat2) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::smm>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::softmax(Tensor self, int dim, ScalarType? " + "dtype=None) -> Tensor") + .impl_unboxedOnlyKernel), + &ATenMLIRTypeDefault::softmax>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_softmax(Tensor self, int dim, bool " + "half_to_float) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_softmax_backward_data(Tensor grad_output, " + "Tensor output, int dim, Tensor self) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t, const at::Tensor &), + &ATenMLIRTypeDefault::_softmax_backward_data>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_add.out(Tensor self, Tensor other, *, " + "Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::_sparse_add_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_sparse_dense_add.out(Tensor self, Tensor other, " + "*, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::_sparse_dense_add_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_div_zerodim.out(Tensor self, Tensor " + "other, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::_sparse_div_zerodim_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_div_scalar.out(Tensor self, Scalar " + "other, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::Scalar), + &ATenMLIRTypeDefault::_sparse_div_scalar_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_mul.out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::_sparse_mul_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_mul_zerodim.out(Tensor self, Tensor " + "other, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::_sparse_mul_zerodim_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_mul_scalar.out(Tensor self, Scalar " + "other, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::Scalar), + &ATenMLIRTypeDefault::_sparse_mul_scalar_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::split.Tensor(Tensor(a) self, int split_size, " + "int dim=0) -> Tensor(a)[]") + .impl_unboxedOnlyKernel( + const at::Tensor &, int64_t, + int64_t), + &ATenMLIRTypeDefault::split>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::split_with_sizes(Tensor self, int[] " + "split_sizes, int dim=0) -> Tensor[]") + .impl_unboxedOnlyKernel< + std::vector( + const at::Tensor &, + at::IntArrayRef, int64_t), + &ATenMLIRTypeDefault::split_with_sizes>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::squeeze(Tensor(a) self) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t), + &ATenMLIRType::squeeze>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::squeeze_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::squeeze_.dim(Tensor(a!) self, int dim) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sspaddmm(Tensor self, Tensor mat1, Tensor " + "mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::sspaddmm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sspaddmm.out(Tensor self, Tensor mat1, Tensor " + "mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::sspaddmm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::stack(Tensor[] tensors, int dim=0) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::stack.out(Tensor[] tensors, int dim=0, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::stft(Tensor self, int n_fft, int? " + "hop_length=None, int? win_length=None, Tensor? " + "window=None, bool normalized=False, bool " + "onesided=True) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, + c10::optional, c10::optional, + const at::Tensor &, bool, bool), + &ATenMLIRTypeDefault::stft>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::stride.int(Tensor self, int dim) -> int") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sum(Tensor self, *, ScalarType? dtype=None) " + "-> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, + c10::optional), + &ATenMLIRTypeDefault::sum>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sum.dim_IntList(Tensor self, int[1] dim, bool " + "keepdim=False, *, ScalarType? dtype=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, bool, + c10::optional), + &ATenMLIRType::sum>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sum.IntList_out(Tensor self, int[1] dim, bool " + "keepdim=False, *, ScalarType? dtype=None, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel), + &ATenMLIRTypeDefault::sum_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::sum_to_size(Tensor self, int[] size) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sqrt(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sqrt_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::std(Tensor self, bool unbiased=True) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::std.dim(Tensor self, int[1] dim, bool " + "unbiased=True, bool keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::std_mean(Tensor self, bool unbiased=True) -> " + "(Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, bool), + &ATenMLIRTypeDefault::std_mean>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::std_mean.dim(Tensor self, int[1] dim, bool " + "unbiased=True, bool keepdim=False) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + at::IntArrayRef, bool, bool), + &ATenMLIRTypeDefault::std_mean>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::std.out(Tensor self, int[1] dim, bool " + "unbiased=True, bool keepdim=False, *, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::prod(Tensor self, *, ScalarType? dtype=None) " + "-> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, + c10::optional), + &ATenMLIRTypeDefault::prod>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::prod.dim_int(Tensor self, int dim, bool " + "keepdim=False, *, ScalarType? dtype=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, bool, + c10::optional), + &ATenMLIRTypeDefault::prod>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::prod.int_out(Tensor self, int dim, bool " + "keepdim=False, *, ScalarType? dtype=None, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t, + bool, c10::optional), + &ATenMLIRTypeDefault::prod_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::t(Tensor(a) self) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::t_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tan(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tan_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tan.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tanh(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tanh_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tanh.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tensordot(Tensor self, Tensor other, int[] " + "dims_self, int[] dims_other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::tensordot>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::threshold(Tensor self, Scalar threshold, " + "Scalar value) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::threshold_(Tensor(a!) self, Scalar threshold, " + "Scalar value) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::threshold.out(Tensor self, Scalar threshold, " + "Scalar value, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::threshold_backward(Tensor grad_output, Tensor " + "self, Scalar threshold) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::transpose.int(Tensor(a) self, int dim0, int " + "dim1) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_mkldnn_transpose(Tensor self, int dim0, int " + "dim1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, int64_t), + &ATenMLIRTypeDefault::_mkldnn_transpose>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::transpose_(Tensor(a!) self, int dim0, int " + "dim1) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_mkldnn_transpose_(Tensor(a!) self, int dim0, " + "int dim1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, int64_t, int64_t), + &ATenMLIRTypeDefault::_mkldnn_transpose_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::one_hot(Tensor self, int num_classes=-1) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::flip(Tensor self, int[] dims) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::flip>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::roll(Tensor self, int[1] shifts, int[1] " + "dims=[]) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::trapz.x(Tensor y, Tensor x, *, int dim=-1) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::trapz.dx(Tensor y, *, float dx=1, int dim=-1) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_trilinear(Tensor i1, Tensor i2, Tensor i3, " + "int[] expand1, int[] expand2, int[] expand3, int[] " + "sumdim, int unroll_dim=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t), + &ATenMLIRTypeDefault::_trilinear>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::triplet_margin_loss(Tensor anchor, Tensor " + "positive, Tensor negative, float margin=1.0, float " + "p=2, float eps=1e-06, bool swap=False, int " + "reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, double, double, double, + bool, int64_t), + &ATenMLIRTypeDefault::triplet_margin_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::trunc(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::trunc_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::trunc.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::type_as(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_has_compatible_shallow_copy_type(Tensor " + "self, Tensor from) -> bool") + .impl_unboxedOnlyKernel< + bool(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::_has_compatible_shallow_copy_type>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_unique(Tensor self, bool sorted=True, bool " + "return_inverse=False) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, bool, bool), + &ATenMLIRTypeDefault::_unique>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::unique_dim(Tensor self, int dim, bool " + "sorted=True, bool return_inverse=False, bool " + "return_counts=False) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, int64_t, bool, bool, bool), + &ATenMLIRTypeDefault::unique_dim>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::unique_consecutive(Tensor self, bool " + "return_inverse=False, bool return_counts=False, " + "int? dim=None) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, bool, bool, + c10::optional), + &ATenMLIRTypeDefault::unique_consecutive>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::unique_dim_consecutive(Tensor self, int dim, " + "bool return_inverse=False, bool " + "return_counts=False) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, int64_t, bool, bool), + &ATenMLIRTypeDefault::unique_dim_consecutive>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_unique2(Tensor self, bool sorted=True, bool " + "return_inverse=False, bool return_counts=False) -> " + "(Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, bool, bool, bool), + &ATenMLIRTypeDefault::_unique2>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_unsafe_view(Tensor self, int[] size) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t), + &ATenMLIRType::unsqueeze>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::unsqueeze_(Tensor(a!) self, int dim) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::var(Tensor self, bool unbiased=True) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::var.dim(Tensor self, int[1] dim, bool " + "unbiased=True, bool keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::var.out(Tensor self, int[1] dim, bool " + "unbiased=True, bool keepdim=False, *, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::var_mean(Tensor self, bool unbiased=True) -> " + "(Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, bool), + &ATenMLIRTypeDefault::var_mean>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::var_mean.dim(Tensor self, int[1] dim, bool " + "unbiased=True, bool keepdim=False) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + at::IntArrayRef, bool, bool), + &ATenMLIRTypeDefault::var_mean>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::view_as(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::where.self(Tensor condition, Tensor self, " + "Tensor other) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::where(Tensor condition) -> Tensor[]") + .impl_unboxedOnlyKernel( + const at::Tensor &), + &ATenMLIRTypeDefault::where>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_s_where(Tensor condition, Tensor self, " + "Tensor other) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::norm_except_dim(Tensor v, int pow=2, int " + "dim=0) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, int64_t), + &ATenMLIRTypeDefault::norm_except_dim>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_weight_norm(Tensor v, Tensor g, int dim=0) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_weight_norm_cuda_interface(Tensor v, Tensor " + "g, int dim=0) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::_weight_norm_cuda_interface>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_weight_norm_cuda_interface_backward(Tensor " + "grad_w, Tensor saved_v, Tensor saved_g, Tensor " + "saved_norms, int dim) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault:: + _weight_norm_cuda_interface_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_weight_norm_differentiable_backward(Tensor " + "grad_w, Tensor saved_v, Tensor saved_g, Tensor " + "saved_norms, int dim) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault:: + _weight_norm_differentiable_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::zeros.out(int[] size, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_standard_gamma_grad(Tensor self, Tensor " + "output) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::_standard_gamma_grad>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_standard_gamma(Tensor self, Generator? " + "generator=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Generator *), + &ATenMLIRTypeDefault::_standard_gamma>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_dirichlet_grad(Tensor x, Tensor alpha, " + "Tensor total) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::_dirichlet_grad>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sample_dirichlet(Tensor self, Generator? " + "generator=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Generator *), + &ATenMLIRTypeDefault::_sample_dirichlet>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::poisson(Tensor self, Generator? " + "generator=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::native_norm(Tensor self, Scalar p=2) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_sum(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_sum.dtype(Tensor self, *, ScalarType " + "dtype) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_sum.dim(Tensor self, int[1] dim) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_sum.dim_dtype(Tensor self, int[1] " + "dim, *, ScalarType dtype) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_sum_backward(Tensor grad, Tensor " + "self, int[] dim) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::_sparse_sum_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, " + "*, ScalarType dtype) -> Tensor") + .impl_unboxedOnlyKernel, + at::ScalarType), + &ATenMLIRTypeDefault::norm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::norm>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? " + "p, int[1] dim, bool keepdim, *, ScalarType dtype) " + "-> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, c10::optional, + at::IntArrayRef, bool, at::ScalarType), + &ATenMLIRTypeDefault::norm>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, " + "int[1] dim, bool keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::norm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::norm.dtype_out(Tensor self, Scalar? p, int[1] " + "dim, bool keepdim, *, ScalarType dtype, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + c10::optional, at::IntArrayRef, + bool, at::ScalarType), + &ATenMLIRTypeDefault::norm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::norm.out(Tensor self, Scalar? p, int[1] dim, bool " + "keepdim=False, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::norm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::frobenius_norm(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::frobenius_norm.dim(Tensor self, int[1] dim, " + "bool keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::frobenius_norm.out(Tensor self, int[1] dim, bool " + "keepdim=False, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::frobenius_norm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nuclear_norm(Tensor self, bool keepdim=False) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nuclear_norm.out(Tensor self, bool " + "keepdim=False, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, bool), + &ATenMLIRTypeDefault::nuclear_norm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nuclear_norm.dim(Tensor self, int[2] dim, " + "bool keepdim=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::nuclear_norm.dim_out(Tensor self, int[2] dim, " + "bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::nuclear_norm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::clone(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::resize_as_(Tensor(a!) self, Tensor " + "the_template) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pow.Tensor_Scalar_out(Tensor self, Scalar " + "exponent, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pow.Tensor_Scalar(Tensor self, Scalar " + "exponent) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::pow>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::zero_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sub.out(Tensor self, Tensor other, *, Scalar " + "alpha=1, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sub.Tensor(Tensor self, Tensor other, *, " + "Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, " + "Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sub.Scalar(Tensor self, Scalar other, Scalar " + "alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::sub>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sub_.Scalar(Tensor(a!) self, Scalar other, " + "Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::sub_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rsub.Tensor(Tensor self, Tensor other, *, " + "Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rsub.Scalar(Tensor self, Scalar other, Scalar " + "alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::rsub>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::s_native_addmm.out(Tensor self, Tensor mat1, " + "Tensor mat2, *, Scalar beta=1, Scalar alpha=1, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::s_native_addmm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::s_native_addmm(Tensor self, Tensor mat1, Tensor " + "mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::s_native_addmm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::s_native_addmm_(Tensor(a!) self, Tensor mat1, " + "Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::s_native_addmm_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_sparse_addmm(Tensor self, Tensor sparse, Tensor " + "dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::_sparse_addmm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addmm.out(Tensor self, Tensor mat1, Tensor " + "mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addmm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, " + "*, Scalar beta=1, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRType::addmm>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, " + "*, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addmm_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sparse_coo_tensor.size(int[] size, *, " + "ScalarType dtype, Layout layout, Device device, " + "bool pin_memory=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::IntArrayRef, const at::TensorOptions &), + &ATenMLIRTypeDefault::sparse_coo_tensor>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::sparse_coo_tensor.indices(Tensor indices, Tensor " + "values, *, ScalarType? dtype=None, Layout? layout=None, " + "Device? device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::TensorOptions &), + &ATenMLIRTypeDefault::sparse_coo_tensor>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sparse_coo_tensor.indices_size(Tensor " + "indices, Tensor values, int[] size, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::TensorOptions &), + &ATenMLIRTypeDefault::sparse_coo_tensor>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_sparse_coo_tensor_unsafe(Tensor indices, " + "Tensor values, int[] size, *, ScalarType? " + "dtype=None, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::TensorOptions &), + &ATenMLIRTypeDefault::_sparse_coo_tensor_unsafe>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_sparse_coo_tensor_with_dims(int sparse_dim, int " + "dense_dim, int[] size, *, ScalarType dtype, Layout " + "layout, Device device, bool pin_memory=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(int64_t, int64_t, at::IntArrayRef, + const at::TensorOptions &), + &ATenMLIRTypeDefault::_sparse_coo_tensor_with_dims>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_sparse_coo_tensor_with_dims_and_tensors(int " + "sparse_dim, int dense_dim, int[] size, Tensor indices, " + "Tensor values, *, ScalarType dtype, Layout layout, " + "Device device, bool pin_memory=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(int64_t, int64_t, at::IntArrayRef, + const at::Tensor &, const at::Tensor &, + const at::TensorOptions &), + &ATenMLIRTypeDefault:: + _sparse_coo_tensor_with_dims_and_tensors>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sparse_resize_(Tensor(a!) self, int[] size, " + "int sparse_dim, int dense_dim) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::sparse_resize_and_clear_(Tensor(a!) self, int[] " + "size, int sparse_dim, int dense_dim) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::IntArrayRef, int64_t, + int64_t), + &ATenMLIRTypeDefault::sparse_resize_and_clear_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::sparse_mask(Tensor self, Tensor mask) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::to_dense(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::to_dense_backward(Tensor grad, Tensor input) " + "-> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::to_dense_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sparse_dim(Tensor self) -> int") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_dimI(Tensor self) -> int") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::dense_dim(Tensor self) -> int") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_dimV(Tensor self) -> int") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_nnz(Tensor self) -> int") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::coalesce(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::is_coalesced(Tensor self) -> bool") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_indices(Tensor(a) self) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_values(Tensor(a) self) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_coalesced_(Tensor(a!) self, bool coalesced) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::indices(Tensor(a) self) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::values(Tensor(a) self) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hspmm.out(Tensor mat1, Tensor mat2, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hspmm(Tensor mat1, Tensor mat2) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::copy_sparse_to_sparse_(Tensor(a!) self, " + "Tensor src, bool non_blocking=False) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, bool), + &ATenMLIRTypeDefault::copy_sparse_to_sparse_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::unbind.int(Tensor(a) self, int dim=0) -> " + "Tensor(a)[]") + .impl_unboxedOnlyKernel( + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::unbind>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::to_sparse.sparse_dim(Tensor self, int " + "sparse_dim) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::to_sparse(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::to_mkldnn(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mkldnn_reorder_conv2d_weight(Tensor self, " + "int[2] padding=0, int[2] stride=1, int[2] " + "dilation=1, int groups=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t), + &ATenMLIRTypeDefault::mkldnn_reorder_conv2d_weight>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::to_mkldnn_backward(Tensor grad, Tensor input) " + "-> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::to_mkldnn_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::quantize_linear(Tensor self, float scale, int " + "zero_point, ScalarType dtype) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, double, int64_t, + at::ScalarType), + &ATenMLIRTypeDefault::quantize_linear>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::quantize_linear_per_channel(Tensor self, " + "Tensor scales, Tensor zero_points, int[] axis, " + "ScalarType dtype) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::ScalarType), + &ATenMLIRTypeDefault::quantize_linear_per_channel>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::dequantize(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_dequantize_linear(Tensor self, float scale, " + "int zero_point, ScalarType dtype) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, double, int64_t, + at::ScalarType), + &ATenMLIRTypeDefault::_dequantize_linear>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::q_scale(Tensor self) -> float") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::q_zero_point(Tensor self) -> int") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::q_per_channel_scales(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &), + &ATenMLIRTypeDefault::q_per_channel_scales>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::q_per_channel_zero_points(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &), + &ATenMLIRTypeDefault::q_per_channel_zero_points>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::int_repr(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_per_tensor_affine_qtensor(Tensor self, float " + "scale, int zero_point) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, double, int64_t), + &ATenMLIRTypeDefault::_per_tensor_affine_qtensor>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_per_channel_affine_qtensor(Tensor self, Tensor " + "scale, Tensor zero_point, int[] axis) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::_per_channel_affine_qtensor>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::qscheme(Tensor self) -> QScheme") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fake_quantize_per_tensor_affine(Tensor self, " + "float scale, int zero_point, int quant_min, int " + "quant_max) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, double, int64_t, int64_t, + int64_t), + &ATenMLIRTypeDefault::fake_quantize_per_tensor_affine>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fake_quantize_per_tensor_affine_backward(" + "Tensor grad, Tensor self, float scale, int " + "zero_point, int quant_min, int quant_max) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, double, + int64_t, int64_t, int64_t), + &ATenMLIRTypeDefault:: + fake_quantize_per_tensor_affine_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::to.dtype_layout(Tensor self, *, ScalarType dtype, " + "Layout layout, Device device, bool pin_memory=False, " + "bool non_blocking=False, bool copy=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::to.device(Tensor self, Device device, " + "ScalarType dtype, bool non_blocking=False, bool " + "copy=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, c10::Device, + at::ScalarType, bool, bool), + &ATenMLIRType::to>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::to.dtype(Tensor self, ScalarType dtype, bool " + "non_blocking=False, bool copy=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::to.other(Tensor self, Tensor other, bool " + "non_blocking=False, bool copy=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::meshgrid(Tensor[] tensors) -> Tensor[]") + .impl_unboxedOnlyKernel( + at::TensorList), + &ATenMLIRTypeDefault::meshgrid>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cartesian_prod(Tensor[] tensors) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::combinations(Tensor self, int r=2, bool " + "with_replacement=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::item(Tensor self) -> Scalar") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_local_scalar_dense(Tensor self) -> Scalar") + .impl_unboxedOnlyKernel< + at::Scalar(const at::Tensor &), + &ATenMLIRTypeDefault::_local_scalar_dense>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_thnn_fused_lstm_cell(Tensor input_gates, Tensor " + "hidden_gates, Tensor cx, Tensor? input_bias=None, " + "Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::_thnn_fused_lstm_cell>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_thnn_fused_lstm_cell_backward(Tensor? " + "grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, " + "Tensor workspace, bool has_bias) -> (Tensor, " + "Tensor, Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, bool), + &ATenMLIRTypeDefault::_thnn_fused_lstm_cell_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_thnn_fused_gru_cell(Tensor input_gates, Tensor " + "hidden_gates, Tensor hx, Tensor? input_bias=None, " + "Tensor? hidden_bias=None) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::_thnn_fused_gru_cell>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_thnn_fused_gru_cell_backward(Tensor grad_hy, " + "Tensor workspace, bool has_bias) -> (Tensor, " + "Tensor, Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, bool), + &ATenMLIRTypeDefault::_thnn_fused_gru_cell_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lstm.input(Tensor input, Tensor[] hx, " + "Tensor[] params, bool has_biases, int num_layers, " + "float dropout, bool train, bool bidirectional, bool " + "batch_first) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, at::TensorList, at::TensorList, + bool, int64_t, double, bool, bool, bool), + &ATenMLIRTypeDefault::lstm>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lstm.data(Tensor data, Tensor batch_sizes, " + "Tensor[] hx, Tensor[] params, bool has_biases, int " + "num_layers, float dropout, bool train, bool " + "bidirectional) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + at::TensorList, at::TensorList, bool, int64_t, double, + bool, bool), + &ATenMLIRTypeDefault::lstm>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::gru.input(Tensor input, Tensor hx, Tensor[] " + "params, bool has_biases, int num_layers, float " + "dropout, bool train, bool bidirectional, bool " + "batch_first) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + at::TensorList, bool, int64_t, + double, bool, bool, bool), + &ATenMLIRTypeDefault::gru>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::gru.data(Tensor data, Tensor batch_sizes, " + "Tensor hx, Tensor[] params, bool has_biases, int " + "num_layers, float dropout, bool train, bool " + "bidirectional) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + at::TensorList, bool, int64_t, + double, bool, bool), + &ATenMLIRTypeDefault::gru>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rnn_tanh.input(Tensor input, Tensor hx, " + "Tensor[] params, bool has_biases, int num_layers, " + "float dropout, bool train, bool bidirectional, bool " + "batch_first) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + at::TensorList, bool, int64_t, + double, bool, bool, bool), + &ATenMLIRTypeDefault::rnn_tanh>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rnn_tanh.data(Tensor data, Tensor " + "batch_sizes, Tensor hx, Tensor[] params, bool " + "has_biases, int num_layers, float dropout, bool " + "train, bool bidirectional) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + at::TensorList, bool, int64_t, + double, bool, bool), + &ATenMLIRTypeDefault::rnn_tanh>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rnn_relu.input(Tensor input, Tensor hx, " + "Tensor[] params, bool has_biases, int num_layers, " + "float dropout, bool train, bool bidirectional, bool " + "batch_first) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + at::TensorList, bool, int64_t, + double, bool, bool, bool), + &ATenMLIRTypeDefault::rnn_relu>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rnn_relu.data(Tensor data, Tensor " + "batch_sizes, Tensor hx, Tensor[] params, bool " + "has_biases, int num_layers, float dropout, bool " + "train, bool bidirectional) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + at::TensorList, bool, int64_t, + double, bool, bool), + &ATenMLIRTypeDefault::rnn_relu>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lstm_cell(Tensor input, Tensor[] hx, Tensor " + "w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? " + "b_hh=None) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, + at::TensorList, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::lstm_cell>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::gru_cell(Tensor input, Tensor hx, Tensor " + "w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? " + "b_hh=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::gru_cell>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rnn_tanh_cell(Tensor input, Tensor hx, Tensor " + "w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? " + "b_hh=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::rnn_tanh_cell>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rnn_relu_cell(Tensor input, Tensor hx, Tensor " + "w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? " + "b_hh=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::rnn_relu_cell>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::quantized_lstm(Tensor input, Tensor[] hx, " + "Tensor[] params, bool has_biases, int num_layers, " + "float dropout, bool train, bool bidirectional, bool " + "batch_first, *, ScalarType? dtype=None, bool " + "use_dynamic=False) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, at::TensorList, at::TensorList, + bool, int64_t, double, bool, bool, bool, + c10::optional, bool), + &ATenMLIRTypeDefault::quantized_lstm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::quantized_gru.input(Tensor input, Tensor hx, " + "Tensor[] params, bool has_biases, int num_layers, " + "float dropout, bool train, bool bidirectional, bool " + "batch_first) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + at::TensorList, bool, int64_t, + double, bool, bool, bool), + &ATenMLIRTypeDefault::quantized_gru>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::quantized_gru.data(Tensor data, Tensor " + "batch_sizes, Tensor hx, Tensor[] params, bool " + "has_biases, int num_layers, float dropout, bool " + "train, bool bidirectional) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + at::TensorList, bool, int64_t, + double, bool, bool), + &ATenMLIRTypeDefault::quantized_gru>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::quantized_lstm_cell(Tensor input, Tensor[] " + "hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor " + "b_hh, Tensor packed_ih, Tensor packed_hh, Tensor " + "col_offsets_ih, Tensor col_offsets_hh, Scalar " + "scale_ih, Scalar scale_hh, Scalar zero_point_ih, " + "Scalar zero_point_hh) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, + at::TensorList, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::quantized_lstm_cell>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::quantized_gru_cell(Tensor input, Tensor hx, " + "Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, " + "Tensor packed_ih, Tensor packed_hh, Tensor " + "col_offsets_ih, Tensor col_offsets_hh, Scalar " + "scale_ih, Scalar scale_hh, Scalar zero_point_ih, " + "Scalar zero_point_hh) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar, at::Scalar, + at::Scalar), + &ATenMLIRTypeDefault::quantized_gru_cell>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::quantized_rnn_relu_cell(Tensor input, Tensor " + "hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor " + "b_hh, Tensor packed_ih, Tensor packed_hh, Tensor " + "col_offsets_ih, Tensor col_offsets_hh, Scalar " + "scale_ih, Scalar scale_hh, Scalar zero_point_ih, " + "Scalar zero_point_hh) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar, at::Scalar, + at::Scalar), + &ATenMLIRTypeDefault::quantized_rnn_relu_cell>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::quantized_rnn_tanh_cell(Tensor input, Tensor " + "hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor " + "b_hh, Tensor packed_ih, Tensor packed_hh, Tensor " + "col_offsets_ih, Tensor col_offsets_hh, Scalar " + "scale_ih, Scalar scale_hh, Scalar zero_point_ih, " + "Scalar zero_point_hh) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar, at::Scalar, + at::Scalar), + &ATenMLIRTypeDefault::quantized_rnn_tanh_cell>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_pack_padded_sequence(Tensor input, Tensor " + "lengths, bool batch_first) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + bool), + &ATenMLIRTypeDefault::_pack_padded_sequence>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_pack_padded_sequence_backward(Tensor grad, " + "int[] input_size, Tensor batch_sizes, bool " + "batch_first) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + const at::Tensor &, bool), + &ATenMLIRTypeDefault::_pack_padded_sequence_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_pad_packed_sequence(Tensor data, Tensor " + "batch_sizes, bool batch_first, Scalar padding_value, " + "int total_length) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + bool, at::Scalar, + int64_t), + &ATenMLIRTypeDefault::_pad_packed_sequence>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::set_.source_Storage(Tensor(a!) self, Storage " + "source) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Storage), + &ATenMLIRTypeDefault::set_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::set_.source_Storage_storage_offset(Tensor(a!) " + "self, Storage source, int storage_offset, int[] " + "size, int[] stride=[]) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Storage, int64_t, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::set_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::set_.source_Tensor(Tensor(a!) self, Tensor " + "source) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::set_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::set_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::set_quantizer_(Tensor(a!) self, " + "ConstQuantizerPtr quantizer) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::is_set_to(Tensor self, Tensor tensor) -> bool") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::masked_fill_.Scalar(Tensor(a!) self, Tensor " + "mask, Scalar value) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::masked_fill.Scalar(Tensor self, Tensor mask, " + "Scalar value) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::masked_fill_.Tensor(Tensor(a!) self, Tensor " + "mask, Tensor value) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::masked_fill.Tensor(Tensor self, Tensor mask, " + "Tensor value) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::masked_scatter_(Tensor(a!) self, Tensor mask, " + "Tensor source) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::masked_scatter_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::masked_scatter(Tensor self, Tensor mask, " + "Tensor source) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::view(Tensor(a) self, int[] size) -> Tensor(a)") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRType::view>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::put_(Tensor(a!) self, Tensor index, Tensor " + "source, bool accumulate=False) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, bool), + &ATenMLIRTypeDefault::put_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_add_(Tensor(a!) self, int dim, Tensor " + "index, Tensor source) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_add(Tensor self, int dim, Tensor index, " + "Tensor source) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::index_add>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_fill_.Scalar(Tensor(a!) self, int dim, " + "Tensor index, Scalar value) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_fill.Scalar(Tensor self, int dim, " + "Tensor index, Scalar value) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_fill_.Tensor(Tensor(a!) self, int dim, " + "Tensor index, Tensor value) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_fill.Tensor(Tensor self, int dim, " + "Tensor index, Tensor value) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::index_fill>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::scatter_.src(Tensor(a!) self, int dim, Tensor " + "index, Tensor src) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::scatter.src(Tensor self, int dim, Tensor " + "index, Tensor src) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::scatter>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::scatter_.value(Tensor(a!) self, int dim, " + "Tensor index, Scalar value) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::scatter.value(Tensor self, int dim, Tensor " + "index, Scalar value) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::scatter_add_(Tensor(a!) self, int dim, Tensor " + "index, Tensor src) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::scatter_add(Tensor self, int dim, Tensor " + "index, Tensor src) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::scatter_add>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::lt_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::lt_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::gt_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::gt_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::le_.Scalar(Tensor(a!) self, Scalar other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::le_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::le_.Tensor(Tensor(a!) self, Tensor other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::le_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::ge_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::ge_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::eq_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::eq_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::ne_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::ne_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__and__.Scalar(Tensor self, Scalar other) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__and__.Tensor(Tensor self, Tensor other) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__iand__.Scalar(Tensor(a!) self, Scalar " + "other) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__iand__.Tensor(Tensor(a!) self, Tensor " + "other) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__or__.Scalar(Tensor self, Scalar other) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__or__.Tensor(Tensor self, Tensor other) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__ior__.Scalar(Tensor(a!) self, Scalar other) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__ior__.Tensor(Tensor(a!) self, Tensor other) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__xor__.Scalar(Tensor self, Scalar other) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__xor__.Tensor(Tensor self, Tensor other) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__ixor__.Scalar(Tensor(a!) self, Scalar " + "other) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__ixor__.Tensor(Tensor(a!) self, Tensor " + "other) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__lshift__.Scalar(Tensor self, Scalar other) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__lshift__.Tensor(Tensor self, Tensor other) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__ilshift__.Scalar(Tensor(a!) self, Scalar " + "other) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__ilshift__.Tensor(Tensor(a!) self, Tensor " + "other) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__rshift__.Scalar(Tensor self, Scalar other) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__rshift__.Tensor(Tensor self, Tensor other) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__irshift__.Scalar(Tensor(a!) self, Scalar " + "other) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::__irshift__.Tensor(Tensor(a!) self, Tensor " + "other) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lgamma_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::atan2_(Tensor(a!) self, Tensor other) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tril_(Tensor(a!) self, int diagonal=0) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::triu_(Tensor(a!) self, int diagonal=0) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::digamma_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::renorm_(Tensor(a!) self, Scalar p, int dim, " + "Scalar maxnorm) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::pow_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::pow_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lerp_.Scalar(Tensor(a!) self, Tensor end, " + "Scalar weight) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lerp_.Tensor(Tensor(a!) self, Tensor end, " + "Tensor weight) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fmod_.Scalar(Tensor(a!) self, Scalar other) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fmod_.Tensor(Tensor(a!) self, Tensor other) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::remainder_.Scalar(Tensor(a!) self, Scalar " + "other) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::remainder_.Tensor(Tensor(a!) self, Tensor " + "other) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor " + "batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addbmm_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addbmm.out(Tensor self, Tensor batch1, Tensor " + "batch2, *, Scalar beta=1, Scalar alpha=1, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addbmm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addbmm(Tensor self, Tensor batch1, Tensor " + "batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::addbmm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addcdiv_(Tensor(a!) self, Tensor tensor1, " + "Tensor tensor2, *, Scalar value=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::random_.from(Tensor(a!) self, int from, int " + "to, *, Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::random_.to(Tensor(a!) self, int to, *, " + "Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::random_(Tensor(a!) self, *, Generator? " + "generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::uniform_(Tensor(a!) self, float from=0, float " + "to=1, *, Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::normal_(Tensor(a!) self, float mean=0, float " + "std=1, *, Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::cauchy_(Tensor(a!) self, float median=0, float " + "sigma=1, *, Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::log_normal_(Tensor(a!) self, float mean=1, float " + "std=2, *, Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::exponential_(Tensor(a!) self, float lambd=1, " + "*, Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::geometric_(Tensor(a!) self, float p, *, " + "Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::diag.out(Tensor self, int diagonal=0, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::diag_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::diag(Tensor self, int diagonal=0) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::diag>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cross.out(Tensor self, Tensor other, int? " + "dim=None, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, c10::optional), + &ATenMLIRTypeDefault::cross_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cross(Tensor self, Tensor other, int? " + "dim=None) -> Tensor") + .impl_unboxedOnlyKernel), + &ATenMLIRTypeDefault::cross>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::triu.out(Tensor self, int diagonal=0, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::triu_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::triu(Tensor self, int diagonal=0) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::triu>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tril.out(Tensor self, int diagonal=0, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::tril_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tril(Tensor self, int diagonal=0) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::tril>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::tril_indices(int row, int col, int offset=0, *, " + "ScalarType? dtype=long, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::triu_indices(int row, int col, int offset=0, *, " + "ScalarType? dtype=long, Layout? layout=None, Device? " + "device=None, bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::trace(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ne.Scalar_out(Tensor self, Scalar other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::ne.Scalar(Tensor self, Scalar other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::ne>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ne.Tensor_out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::ne.Tensor(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::ne>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::eq.Scalar_out(Tensor self, Scalar other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::eq.Scalar(Tensor self, Scalar other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::eq>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::eq.Tensor_out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::eq.Tensor(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::eq>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ge.Scalar_out(Tensor self, Scalar other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::ge.Scalar(Tensor self, Scalar other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::ge>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ge.Tensor_out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::ge.Tensor(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::ge>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::le.Scalar_out(Tensor self, Scalar other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::le.Scalar(Tensor self, Scalar other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::le>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::le.Tensor_out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::le.Tensor(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::le>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::gt.Scalar_out(Tensor self, Scalar other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::gt.Scalar(Tensor self, Scalar other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::gt>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::gt.Tensor_out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::gt.Tensor(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::gt>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lt.Scalar_out(Tensor self, Scalar other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::lt.Scalar(Tensor self, Scalar other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::lt>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lt.Tensor_out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::lt.Tensor(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::lt>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::take.out(Tensor self, Tensor index, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::take(Tensor self, Tensor index) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::take>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_select.out(Tensor self, int dim, Tensor " + "index, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t, + const at::Tensor &), + &ATenMLIRTypeDefault::index_select_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::index_select(Tensor self, int dim, Tensor " + "index) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::masked_select.out(Tensor self, Tensor mask, " + "*, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::masked_select_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::masked_select(Tensor self, Tensor mask) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nonzero.out(Tensor self, *, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nonzero(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nonzero_numpy(Tensor self) -> Tensor[]") + .impl_unboxedOnlyKernel( + const at::Tensor &), + &ATenMLIRTypeDefault::nonzero_numpy>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::gather.out(Tensor self, int dim, Tensor index, *, " + "bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t, + const at::Tensor &, bool), + &ATenMLIRTypeDefault::gather_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::gather(Tensor self, int dim, Tensor index, *, " + "bool sparse_grad=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_gather_sparse_backward(Tensor self, int dim, " + "Tensor index, Tensor grad) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::_gather_sparse_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addcmul.out(Tensor self, Tensor tensor1, " + "Tensor tensor2, *, Scalar value=1, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addcmul(Tensor self, Tensor tensor1, Tensor " + "tensor2, *, Scalar value=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::addcmul>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addcmul_(Tensor(a!) self, Tensor tensor1, " + "Tensor tensor2, *, Scalar value=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addcdiv.out(Tensor self, Tensor tensor1, " + "Tensor tensor2, *, Scalar value=1, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::addcdiv(Tensor self, Tensor tensor1, Tensor " + "tensor2, *, Scalar value=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::addcdiv>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, " + "Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::lstsq_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lstsq(Tensor self, Tensor A) -> (Tensor " + "solution, Tensor QR)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::lstsq>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::triangular_solve.X(Tensor self, Tensor A, bool " + "upper=True, bool transpose=False, bool " + "unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> " + "(Tensor(a!) solution, Tensor(b!) cloned_coefficient)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + const at::Tensor &, bool, bool, bool), + &ATenMLIRTypeDefault::triangular_solve_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::triangular_solve(Tensor self, Tensor A, bool " + "upper=True, bool transpose=False, bool " + "unitriangular=False) -> (Tensor solution, Tensor " + "cloned_coefficient)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + bool, bool, bool), + &ATenMLIRTypeDefault::triangular_solve>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_triangular_solve_helper(Tensor self, Tensor " + "A, bool upper, bool transpose, bool unitriangular) " + "-> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + bool, bool, bool), + &ATenMLIRTypeDefault::_triangular_solve_helper>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::symeig.e(Tensor self, bool eigenvectors=False, " + "bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> " + "(Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, bool, + bool), + &ATenMLIRTypeDefault::symeig_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::symeig(Tensor self, bool eigenvectors=False, " + "bool upper=True) -> (Tensor eigenvalues, Tensor " + "eigenvectors)") + .impl_unboxedOnlyKernel( + const at::Tensor &, bool, bool), + &ATenMLIRTypeDefault::symeig>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_symeig_helper(Tensor self, bool " + "eigenvectors, bool upper) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, bool, bool), + &ATenMLIRTypeDefault::_symeig_helper>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::eig.e(Tensor self, bool eigenvectors=False, " + "*, Tensor(a!) e, Tensor(b!) v) -> (Tensor(a!) " + "eigenvalues, Tensor(b!) eigenvectors)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, bool), + &ATenMLIRTypeDefault::eig_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::eig(Tensor self, bool eigenvectors=False) -> " + "(Tensor eigenvalues, Tensor eigenvectors)") + .impl_unboxedOnlyKernel( + const at::Tensor &, bool), + &ATenMLIRTypeDefault::eig>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::svd.U(Tensor self, bool some=True, bool " + "compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, " + "Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, " + "Tensor(c!) V)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, at::Tensor &, + const at::Tensor &, bool, bool), + &ATenMLIRTypeDefault::svd_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::svd(Tensor self, bool some=True, bool " + "compute_uv=True) -> (Tensor U, Tensor S, Tensor V)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, bool, bool), + &ATenMLIRTypeDefault::svd>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_svd_helper(Tensor self, bool some, bool " + "compute_uv) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, bool, bool), + &ATenMLIRTypeDefault::_svd_helper>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cholesky.out(Tensor self, bool upper=False, " + "*, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, bool), + &ATenMLIRTypeDefault::cholesky_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::cholesky(Tensor self, bool upper=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cholesky_helper(Tensor self, bool upper) -> " + "Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, bool), + &ATenMLIRTypeDefault::_cholesky_helper>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::cholesky_solve.out(Tensor self, Tensor input2, " + "bool upper=False, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, bool), + &ATenMLIRTypeDefault::cholesky_solve_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cholesky_solve(Tensor self, Tensor input2, " + "bool upper=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cholesky_solve_helper(Tensor self, Tensor A, " + "bool upper) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, bool), + &ATenMLIRTypeDefault::_cholesky_solve_helper>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::solve(Tensor self, Tensor A) -> (Tensor " + "solution, Tensor LU)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::solve>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::solve.solution(Tensor self, Tensor A, *, " + "Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!) " + "solution, Tensor(b!) LU)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::solve_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_solve_helper(Tensor self, Tensor A) -> " + "(Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::_solve_helper>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cholesky_inverse.out(Tensor self, bool " + "upper=False, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, bool), + &ATenMLIRTypeDefault::cholesky_inverse_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::cholesky_inverse(Tensor self, bool " + "upper=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, bool), + &ATenMLIRTypeDefault::cholesky_inverse>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::qr.Q(Tensor self, bool some=True, *, Tensor(a!) " + "Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, bool), + &ATenMLIRTypeDefault::qr_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::qr(Tensor self, bool some=True) -> (Tensor Q, " + "Tensor R)") + .impl_unboxedOnlyKernel( + const at::Tensor &, bool), + &ATenMLIRTypeDefault::qr>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_qr_helper(Tensor self, bool some) -> " + "(Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, bool), + &ATenMLIRTypeDefault::_qr_helper>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::geqrf.a(Tensor self, *, Tensor(a!) a, " + "Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::geqrf_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::geqrf(Tensor self) -> (Tensor a, Tensor tau)") + .impl_unboxedOnlyKernel( + const at::Tensor &), + &ATenMLIRTypeDefault::geqrf>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::orgqr.out(Tensor self, Tensor input2, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::orgqr(Tensor self, Tensor input2) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::ormqr.out(Tensor self, Tensor input2, Tensor " + "input3, bool left=True, bool transpose=False, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::ormqr(Tensor self, Tensor input2, Tensor input3, " + "bool left=True, bool transpose=False) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, bool, bool), + &ATenMLIRTypeDefault::ormqr>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_lu_with_info(Tensor self, bool pivot=True, " + "bool check_errors=True) -> (Tensor, Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, bool, bool), + &ATenMLIRTypeDefault::_lu_with_info>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lu_solve.out(Tensor self, Tensor LU_data, " + "Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::lu_solve_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lu_solve(Tensor self, Tensor LU_data, Tensor " + "LU_pivots) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_lu_solve_helper(Tensor self, Tensor LU_data, " + "Tensor LU_pivots) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::_lu_solve_helper>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::multinomial.out(Tensor self, int num_samples, " + "bool replacement=False, *, Generator? " + "generator=None, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t, + bool, at::Generator *), + &ATenMLIRTypeDefault::multinomial_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::multinomial(Tensor self, int num_samples, " + "bool replacement=False, *, Generator? " + "generator=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_multinomial_alias_setup(Tensor probs) -> " + "(Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &), + &ATenMLIRTypeDefault::_multinomial_alias_setup>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_multinomial_alias_draw(Tensor J, Tensor q, int " + "num_samples, *, Generator? generator=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t, at::Generator *), + &ATenMLIRTypeDefault::_multinomial_alias_draw>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lgamma.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lgamma(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::digamma.out(Tensor self, *, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::digamma(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::polygamma.out(int n, Tensor self, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::polygamma(int n, Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::erfinv(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::erfinv_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::erfinv.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sign(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sign_(Tensor(a!) self) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sign.out(Tensor self, *, Tensor(a!) out) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::dist(Tensor self, Tensor other, Scalar p=2) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::atan2.out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::atan2(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lerp.Scalar_out(Tensor self, Tensor end, " + "Scalar weight, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lerp.Tensor_out(Tensor self, Tensor end, " + "Tensor weight, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::lerp_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lerp.Scalar(Tensor self, Tensor end, Scalar " + "weight) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::lerp.Tensor(Tensor self, Tensor end, Tensor " + "weight) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::histc.out(Tensor self, int bins=100, Scalar " + "min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::histc(Tensor self, int bins=100, Scalar " + "min=0, Scalar max=0) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fmod.Scalar_out(Tensor self, Scalar other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::fmod>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fmod.Tensor_out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::fmod>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::remainder.Scalar_out(Tensor self, Scalar " + "other, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::remainder.Scalar(Tensor self, Scalar other) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::remainder.Tensor_out(Tensor self, Tensor " + "other, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::remainder.Tensor(Tensor self, Tensor other) " + "-> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::min.out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::min.other(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::min>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::min(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max.out(Tensor self, Tensor other, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max.other(Tensor self, Tensor other) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::max>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::median(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sort.values(Tensor self, int dim=-1, bool " + "descending=False, *, Tensor(a!) values, Tensor(b!) " + "indices) -> (Tensor(a!) values, Tensor(b!) indices)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::sort_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::sort(Tensor self, int dim=-1, bool " + "descending=False) -> (Tensor values, Tensor indices)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::sort>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::argsort(Tensor self, int dim=-1, bool " + "descending=False) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::topk.values(Tensor self, int k, int dim=-1, " + "bool largest=True, bool sorted=True, *, Tensor(a!) " + "values, Tensor(b!) indices) ->(Tensor(a!) values, " + "Tensor(b!) indices)") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + int64_t, int64_t, bool, bool), + &ATenMLIRTypeDefault::topk_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::topk(Tensor self, int k, int dim=-1, bool " + "largest=True, bool sorted=True) -> (Tensor values, " + "Tensor indices)") + .impl_unboxedOnlyKernel( + const at::Tensor &, int64_t, + int64_t, bool, bool), + &ATenMLIRTypeDefault::topk>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::all(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::any(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::renorm.out(Tensor self, Scalar p, int dim, " + "Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::renorm(Tensor self, Scalar p, int dim, Scalar " + "maxnorm) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::unfold(Tensor(a) self, int dimension, int " + "size, int step) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::equal(Tensor self, Tensor other) -> bool") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pow.Tensor_Tensor_out(Tensor self, Tensor " + "exponent, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pow.Tensor_Tensor(Tensor self, Tensor " + "exponent) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::pow>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pow.Scalar_out(Scalar self, Tensor exponent, " + "*, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::pow.Scalar(Scalar self, Tensor exponent) -> " + "Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(at::Scalar, const at::Tensor &), + &ATenMLIRTypeDefault::pow>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::normal.Tensor_float_out(Tensor mean, float " + "std=1, *, Generator? generator=None, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::normal.Tensor_float(Tensor mean, float std=1, " + "*, Generator? generator=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::normal.float_Tensor_out(float mean, Tensor " + "std, *, Generator? generator=None, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::normal.float_Tensor(float mean, Tensor std, " + "*, Generator? generator=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::normal.Tensor_Tensor_out(Tensor mean, Tensor " + "std, *, Generator? generator=None, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Generator *), + &ATenMLIRTypeDefault::normal_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::normal.Tensor_Tensor(Tensor mean, Tensor std, " + "*, Generator? generator=None) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::normal.float_float(float mean, float std, int[] " + "size, *, Generator? generator=None, ScalarType? " + "dtype=None, Layout? layout=None, Device? device=None, " + "bool? pin_memory=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(double, double, at::IntArrayRef, + at::Generator *, const at::TensorOptions &), + &ATenMLIRTypeDefault::normal>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::normal.float_float_out(float mean, float std, " + "int[] size, *, Generator? generator=None, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::alias(Tensor(a) self) -> Tensor(a)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_addr(Tensor self, Tensor vec1, Tensor vec2, " + "*, Scalar beta=1, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::_addr>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, " + "*, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::_addr_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_addr.out(Tensor self, Tensor vec1, Tensor " + "vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::_addr_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_index_copy_(Tensor(a!) self, int dim, Tensor " + "index, Tensor source) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cumsum(Tensor self, int dim) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cumsum.out(Tensor self, int dim, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::_cumsum_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cumprod(Tensor self, int dim) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cumprod.out(Tensor self, int dim, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::_cumprod_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_var(Tensor self, bool unbiased=True) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_std(Tensor self, bool unbiased=True) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_addmm.out(Tensor self, Tensor mat1, Tensor " + "mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::_addmm_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_addmm(Tensor self, Tensor mat1, Tensor mat2, " + "*, Scalar beta=1, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::_addmm>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::_addmm_(Tensor(a!) self, Tensor mat1, Tensor " + "mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::_addmm_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cat(Tensor[] tensors, int dim=0) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_cat.out(Tensor[] tensors, int dim=0, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_mode(Tensor self, int dim=-1, bool " + "keepdim=False) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel( + const at::Tensor &, int64_t, + bool), + &ATenMLIRTypeDefault::_mode>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_mode.values(Tensor self, int dim=-1, bool " + "keepdim=False, *, Tensor(a!) values, Tensor(b!) " + "indices) -> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::_mode_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_max(Tensor self, int dim, bool " + "keepdim=False) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::_max>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_max.max(Tensor self, int dim, bool " + "keepdim=False, *, Tensor(a!) max, Tensor(b!) " + "max_indices) -> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::_max_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_min(Tensor self, int dim, bool " + "keepdim=False) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::_min>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_min.min(Tensor self, int dim, bool " + "keepdim=False, *, Tensor(a!) min, Tensor(b!) " + "min_indices) -> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + int64_t, bool), + &ATenMLIRTypeDefault::_min_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::binary_cross_entropy.out(Tensor self, Tensor " + "target, Tensor? weight=None, int reduction=Mean, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::binary_cross_entropy_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::binary_cross_entropy(Tensor self, Tensor target, " + "Tensor? weight=None, int reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::binary_cross_entropy>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::binary_cross_entropy_backward.grad_input(" + "Tensor grad_output, Tensor self, Tensor target, " + "Tensor? weight=None, int reduction=Mean, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::binary_cross_entropy_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::binary_cross_entropy_backward(Tensor " + "grad_output, Tensor self, Tensor target, Tensor? " + "weight=None, int reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::binary_cross_entropy_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mse_loss.out(Tensor self, Tensor target, int " + "reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mse_loss(Tensor self, Tensor target, int " + "reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mse_loss_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor target, int " + "reduction, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::mse_loss_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mse_loss_backward(Tensor grad_output, Tensor " + "self, Tensor target, int reduction) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::mse_loss_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::l1_loss.out(Tensor self, Tensor target, int " + "reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::l1_loss(Tensor self, Tensor target, int " + "reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::l1_loss_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor target, int " + "reduction, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::l1_loss_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::l1_loss_backward(Tensor grad_output, Tensor " + "self, Tensor target, int reduction) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::l1_loss_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::multi_margin_loss.out(Tensor self, Tensor target, " + "Scalar p=1, Scalar margin=1, Tensor? weight=None, int " + "reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::multi_margin_loss_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::multi_margin_loss(Tensor self, Tensor target, " + "Scalar p=1, Scalar margin=1, Tensor? weight=None, " + "int reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::multi_margin_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::multi_margin_loss_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor target, Scalar p, " + "Scalar margin, Tensor? weight=None, int reduction=Mean, " + "*, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::multi_margin_loss_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::multi_margin_loss_backward(Tensor grad_output, " + "Tensor self, Tensor target, Scalar p, Scalar margin, " + "Tensor? weight=None, int reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::multi_margin_loss_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::multilabel_margin_loss.out(Tensor self, " + "Tensor target, int reduction=Mean, *, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::multilabel_margin_loss_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::multilabel_margin_loss(Tensor self, Tensor " + "target, int reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::multilabel_margin_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::multilabel_margin_loss_forward.output(Tensor " + "self, Tensor target, int reduction, *, Tensor(a!) " + "output, Tensor(b!) is_target) -> (Tensor(a!), " + "Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::multilabel_margin_loss_forward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::multilabel_margin_loss_forward(Tensor self, " + "Tensor target, int reduction) -> (Tensor output, " + "Tensor is_target)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::multilabel_margin_loss_forward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::multilabel_margin_loss_backward.grad_input(" + "Tensor grad_output, Tensor self, Tensor target, int " + "reduction, Tensor is_target, *, Tensor(a!) " + "grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t, const at::Tensor &), + &ATenMLIRTypeDefault:: + multilabel_margin_loss_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::multilabel_margin_loss_backward(Tensor " + "grad_output, Tensor self, Tensor target, int " + "reduction, Tensor is_target) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t, + const at::Tensor &), + &ATenMLIRTypeDefault::multilabel_margin_loss_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nll_loss.out(Tensor self, Tensor target, " + "Tensor? weight=None, int reduction=Mean, int " + "ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t, int64_t), + &ATenMLIRTypeDefault::nll_loss_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nll_loss(Tensor self, Tensor target, Tensor? " + "weight=None, int reduction=Mean, int " + "ignore_index=-100) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t, int64_t), + &ATenMLIRTypeDefault::nll_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nll_loss_forward.output(Tensor self, Tensor " + "target, Tensor? weight, int reduction, int " + "ignore_index, *, Tensor(a!) output, Tensor(b!) " + "total_weight) -> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, int64_t, + int64_t), + &ATenMLIRTypeDefault::nll_loss_forward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nll_loss_forward(Tensor self, Tensor target, " + "Tensor? weight, int reduction, int ignore_index) -> " + "(Tensor output, Tensor total_weight)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, int64_t, + int64_t), + &ATenMLIRType::nll_loss_forward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::nll_loss_backward.grad_input(Tensor grad_output, " + "Tensor self, Tensor target, Tensor? weight, int " + "reduction, int ignore_index, Tensor total_weight, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t, int64_t, + const at::Tensor &), + &ATenMLIRTypeDefault::nll_loss_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nll_loss_backward(Tensor grad_output, Tensor " + "self, Tensor target, Tensor? weight, int reduction, " + "int ignore_index, Tensor total_weight) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t, int64_t, const at::Tensor &), + &ATenMLIRType::nll_loss_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nll_loss2d.out(Tensor self, Tensor target, " + "Tensor? weight=None, int reduction=Mean, int " + "ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t, int64_t), + &ATenMLIRTypeDefault::nll_loss2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nll_loss2d(Tensor self, Tensor target, " + "Tensor? weight=None, int reduction=Mean, int " + "ignore_index=-100) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t, int64_t), + &ATenMLIRTypeDefault::nll_loss2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::nll_loss2d_forward.output(Tensor self, Tensor " + "target, Tensor? weight, int reduction, int " + "ignore_index, *, Tensor(a!) output, Tensor(b!) " + "total_weight) -> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, int64_t, + int64_t), + &ATenMLIRTypeDefault::nll_loss2d_forward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::nll_loss2d_forward(Tensor self, Tensor target, " + "Tensor? weight, int reduction, int ignore_index) -> " + "(Tensor output, Tensor total_weight)") + .impl_unboxedOnlyKernel( + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, int64_t, + int64_t), + &ATenMLIRType::nll_loss2d_forward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::nll_loss2d_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor target, Tensor? " + "weight, int reduction, int ignore_index, Tensor " + "total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t, int64_t, + const at::Tensor &), + &ATenMLIRTypeDefault::nll_loss2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::nll_loss2d_backward(Tensor grad_output, Tensor " + "self, Tensor target, Tensor? weight, int reduction, int " + "ignore_index, Tensor total_weight) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t, int64_t, const at::Tensor &), + &ATenMLIRType::nll_loss2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::smooth_l1_loss.out(Tensor self, Tensor target, " + "int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::smooth_l1_loss_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::smooth_l1_loss(Tensor self, Tensor target, " + "int reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::smooth_l1_loss_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor target, int " + "reduction, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::smooth_l1_loss_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::smooth_l1_loss_backward(Tensor grad_output, " + "Tensor self, Tensor target, int reduction) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::smooth_l1_loss_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::soft_margin_loss.out(Tensor self, Tensor target, " + "int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::soft_margin_loss_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::soft_margin_loss(Tensor self, Tensor target, " + "int reduction=Mean) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::soft_margin_loss>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::soft_margin_loss_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor target, int " + "reduction, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + int64_t), + &ATenMLIRTypeDefault::soft_margin_loss_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::soft_margin_loss_backward(Tensor grad_output, " + "Tensor self, Tensor target, int reduction) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::soft_margin_loss_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::elu.out(Tensor self, Scalar alpha=1, Scalar " + "scale=1, Scalar input_scale=1, *, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, at::Scalar, + at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::elu_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::elu(Tensor self, Scalar alpha=1, Scalar " + "scale=1, Scalar input_scale=1) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::elu_backward.grad_input(Tensor grad_output, " + "Scalar alpha, Scalar scale, Scalar input_scale, Tensor " + "output, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, at::Scalar, + at::Scalar, at::Scalar, const at::Tensor &), + &ATenMLIRTypeDefault::elu_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::elu_backward(Tensor grad_output, Scalar " + "alpha, Scalar scale, Scalar input_scale, Tensor " + "output) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar, at::Scalar, + at::Scalar, const at::Tensor &), + &ATenMLIRTypeDefault::elu_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::elu_(Tensor(a!) self, Scalar alpha=1, Scalar " + "scale=1, Scalar input_scale=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::glu.out(Tensor self, int dim=-1, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::glu_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::glu(Tensor self, int dim=-1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::glu>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::glu_backward.grad_input(Tensor grad_output, " + "Tensor self, int dim, *, Tensor(a!) grad_input) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, int64_t), + &ATenMLIRTypeDefault::glu_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::glu_backward(Tensor grad_output, Tensor self, " + "int dim) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hardtanh.out(Tensor self, Scalar min_val=-1, " + "Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hardtanh(Tensor self, Scalar min_val=-1, " + "Scalar max_val=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRType::hardtanh>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hardtanh_backward.grad_input(Tensor " + "grad_output, Tensor self, Scalar min_val, Scalar " + "max_val, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRTypeDefault::hardtanh_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hardtanh_backward(Tensor grad_output, Tensor " + "self, Scalar min_val, Scalar max_val) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, " + "Scalar max_val=1) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, at::Scalar, at::Scalar), + &ATenMLIRType::hardtanh_>(at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::leaky_relu.out(Tensor self, Scalar " + "negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::leaky_relu(Tensor self, Scalar " + "negative_slope=0.01) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::leaky_relu_backward.grad_input(Tensor " + "grad_output, Tensor self, Scalar negative_slope, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::leaky_relu_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::leaky_relu_backward(Tensor grad_output, " + "Tensor self, Scalar negative_slope) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::Scalar), + &ATenMLIRTypeDefault::leaky_relu_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::leaky_relu_(Tensor(a!) self, Scalar " + "negative_slope=0.01) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log_sigmoid.out(Tensor self, *, Tensor(a!) " + "out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::log_sigmoid_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log_sigmoid(Tensor self) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log_sigmoid_forward.output(Tensor self, *, " + "Tensor(a!) output, Tensor(b!) buffer) -> " + "(Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::log_sigmoid_forward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log_sigmoid_forward(Tensor self) -> (Tensor " + "output, Tensor buffer)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &), + &ATenMLIRTypeDefault::log_sigmoid_forward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log_sigmoid_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor buffer, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::log_sigmoid_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::log_sigmoid_backward(Tensor grad_output, " + "Tensor self, Tensor buffer) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::log_sigmoid_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::rrelu_with_noise.out(Tensor self, Tensor noise, " + "Scalar lower=0.125, Scalar upper=0.3333333333333333, " + "bool training=False, Generator? generator=None, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar, + bool, at::Generator *), + &ATenMLIRTypeDefault::rrelu_with_noise_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rrelu_with_noise(Tensor self, Tensor noise, " + "Scalar lower=0.125, Scalar " + "upper=0.3333333333333333, bool training=False, " + "Generator? generator=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar, bool, at::Generator *), + &ATenMLIRTypeDefault::rrelu_with_noise>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rrelu_with_noise_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor noise, Scalar " + "lower, Scalar upper, bool training, *, Tensor(a!) " + "grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar, bool), + &ATenMLIRTypeDefault::rrelu_with_noise_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rrelu_with_noise_backward(Tensor grad_output, " + "Tensor self, Tensor noise, Scalar lower, Scalar " + "upper, bool training) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar, + bool), + &ATenMLIRTypeDefault::rrelu_with_noise_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::rrelu_with_noise_(Tensor(a!) self, Tensor " + "noise, Scalar lower=0.125, Scalar " + "upper=0.3333333333333333, bool training=False, " + "Generator? generator=None) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, at::Scalar, + at::Scalar, bool, at::Generator *), + &ATenMLIRTypeDefault::rrelu_with_noise_>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::softplus.out(Tensor self, Scalar beta=1, Scalar " + "threshold=20, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::softplus(Tensor self, Scalar beta=1, Scalar " + "threshold=20) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::softplus_backward.grad_input(Tensor grad_output, " + "Tensor self, Scalar beta, Scalar threshold, Tensor " + "output, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar, at::Scalar, + const at::Tensor &), + &ATenMLIRTypeDefault::softplus_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::softplus_backward(Tensor grad_output, Tensor " + "self, Scalar beta, Scalar threshold, Tensor output) " + "-> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::Scalar, at::Scalar, const at::Tensor &), + &ATenMLIRTypeDefault::softplus_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::softshrink.out(Tensor self, Scalar lambd=0.5, " + "*, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::softshrink(Tensor self, Scalar lambd=0.5) -> " + "Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::softshrink_backward.grad_input(Tensor " + "grad_output, Tensor self, Scalar lambd, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::Scalar), + &ATenMLIRTypeDefault::softshrink_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::softshrink_backward(Tensor grad_output, " + "Tensor self, Scalar lambd) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::Scalar), + &ATenMLIRTypeDefault::softshrink_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_avg_pool2d.out(Tensor self, int[2] " + "output_size, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::adaptive_avg_pool2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_avg_pool2d(Tensor self, int[2] " + "output_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::adaptive_avg_pool2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::mkldnn_adaptive_avg_pool2d(Tensor self, " + "int[2] output_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::mkldnn_adaptive_avg_pool2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_adaptive_avg_pool2d(Tensor self, int[2] " + "output_size) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::_adaptive_avg_pool2d_backward(Tensor " + "grad_output, Tensor self) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRType::_adaptive_avg_pool2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_avg_pool3d.out(Tensor self, int[3] " + "output_size, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::adaptive_avg_pool3d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_avg_pool3d(Tensor self, int[3] " + "output_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::adaptive_avg_pool3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_avg_pool3d_backward.grad_input(" + "Tensor grad_output, Tensor self, *, Tensor(a!) " + "grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::adaptive_avg_pool3d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_avg_pool3d_backward(Tensor " + "grad_output, Tensor self) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::adaptive_avg_pool3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_max_pool2d.out(Tensor self, int[2] " + "output_size, *, Tensor(a!) out, Tensor(b!) indices) " + "-> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::adaptive_max_pool2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_max_pool2d(Tensor self, int[2] " + "output_size) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::adaptive_max_pool2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_max_pool2d_backward.grad_input(" + "Tensor grad_output, Tensor self, Tensor indices, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::adaptive_max_pool2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_max_pool2d_backward(Tensor " + "grad_output, Tensor self, Tensor indices) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::adaptive_max_pool2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_max_pool3d.out(Tensor self, int[3] " + "output_size, *, Tensor(a!) out, Tensor(b!) indices) " + "-> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::adaptive_max_pool3d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_max_pool3d(Tensor self, int[3] " + "output_size) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple(const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::adaptive_max_pool3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_max_pool3d_backward.grad_input(" + "Tensor grad_output, Tensor self, Tensor indices, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::adaptive_max_pool3d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::adaptive_max_pool3d_backward(Tensor " + "grad_output, Tensor self, Tensor indices) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::adaptive_max_pool3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::avg_pool2d.out(Tensor self, int[2] kernel_size, " + "int[2] stride=[], int[2] padding=0, bool " + "ceil_mode=False, bool count_include_pad=True, int? " + "divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel), + &ATenMLIRTypeDefault::avg_pool2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::avg_pool2d(Tensor self, int[2] kernel_size, " + "int[2] stride=[], int[2] padding=0, bool " + "ceil_mode=False, bool count_include_pad=True, int? " + "divisor_override=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, bool, + c10::optional), + &ATenMLIRTypeDefault::avg_pool2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::avg_pool2d_backward.grad_input(Tensor " + "grad_output, Tensor self, int[2] kernel_size, " + "int[2] stride, int[2] padding, bool ceil_mode, bool " + "count_include_pad, int? divisor_override, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, bool, + c10::optional), + &ATenMLIRTypeDefault::avg_pool2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::avg_pool2d_backward(Tensor grad_output, " + "Tensor self, int[2] kernel_size, int[2] stride, " + "int[2] padding, bool ceil_mode, bool " + "count_include_pad, int? divisor_override) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool, bool, + c10::optional), + &ATenMLIRTypeDefault::avg_pool2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::avg_pool3d.out(Tensor self, int[3] kernel_size, " + "int[3] stride=[], int[3] padding=0, bool " + "ceil_mode=False, bool count_include_pad=True, int? " + "divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel), + &ATenMLIRTypeDefault::avg_pool3d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::avg_pool3d(Tensor self, int[3] kernel_size, " + "int[3] stride=[], int[3] padding=0, bool " + "ceil_mode=False, bool count_include_pad=True, int? " + "divisor_override=None) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, bool, + c10::optional), + &ATenMLIRTypeDefault::avg_pool3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::avg_pool3d_backward.grad_input(Tensor " + "grad_output, Tensor self, int[3] kernel_size, " + "int[3] stride, int[3] padding, bool ceil_mode, bool " + "count_include_pad, int? divisor_override, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, bool, + c10::optional), + &ATenMLIRTypeDefault::avg_pool3d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::avg_pool3d_backward(Tensor grad_output, " + "Tensor self, int[3] kernel_size, int[3] stride, " + "int[3] padding, bool ceil_mode, bool " + "count_include_pad, int? divisor_override) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool, bool, + c10::optional), + &ATenMLIRTypeDefault::avg_pool3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fractional_max_pool2d.output(Tensor self, " + "int[2] kernel_size, int[2] output_size, Tensor " + "random_samples, *, Tensor(a!) output, Tensor(b!) " + "indices) -> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, const at::Tensor &), + &ATenMLIRTypeDefault::fractional_max_pool2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fractional_max_pool2d(Tensor self, int[2] " + "kernel_size, int[2] output_size, Tensor " + "random_samples) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, const at::Tensor &), + &ATenMLIRTypeDefault::fractional_max_pool2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fractional_max_pool2d_backward.grad_input(" + "Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] output_size, Tensor indices, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, const at::Tensor &), + &ATenMLIRTypeDefault::fractional_max_pool2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fractional_max_pool2d_backward(Tensor " + "grad_output, Tensor self, int[2] kernel_size, " + "int[2] output_size, Tensor indices) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + const at::Tensor &), + &ATenMLIRTypeDefault::fractional_max_pool2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fractional_max_pool3d.output(Tensor self, " + "int[3] kernel_size, int[3] output_size, Tensor " + "random_samples, *, Tensor(a!) output, Tensor(b!) " + "indices) -> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, const at::Tensor &), + &ATenMLIRTypeDefault::fractional_max_pool3d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fractional_max_pool3d(Tensor self, int[3] " + "kernel_size, int[3] output_size, Tensor " + "random_samples) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, const at::Tensor &), + &ATenMLIRTypeDefault::fractional_max_pool3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fractional_max_pool3d_backward.grad_input(" + "Tensor grad_output, Tensor self, int[3] " + "kernel_size, int[3] output_size, Tensor indices, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, const at::Tensor &), + &ATenMLIRTypeDefault::fractional_max_pool3d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::fractional_max_pool3d_backward(Tensor " + "grad_output, Tensor self, int[3] kernel_size, " + "int[3] output_size, Tensor indices) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + const at::Tensor &), + &ATenMLIRTypeDefault::fractional_max_pool3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max_pool2d_with_indices.out(Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, int[2] " + "dilation=1, bool ceil_mode=False, *, Tensor(a!) out, " + "Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::max_pool2d_with_indices_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max_pool2d_with_indices(Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, int[2] " + "dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRType::max_pool2d_with_indices>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_pool2d_with_indices_backward.grad_input(" + "Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride, int[2] padding, int[2] " + "dilation, bool ceil_mode, Tensor indices, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool, const at::Tensor &), + &ATenMLIRTypeDefault:: + max_pool2d_with_indices_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_pool2d_with_indices_backward(Tensor " + "grad_output, Tensor self, int[2] kernel_size, " + "int[2] stride, int[2] padding, int[2] dilation, " + "bool ceil_mode, Tensor indices) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, + const at::Tensor &), + &ATenMLIRType::max_pool2d_with_indices_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max_pool3d_with_indices.out(Tensor self, int[3] " + "kernel_size, int[3] stride=[], int[3] padding=0, int[3] " + "dilation=1, bool ceil_mode=False, *, Tensor(a!) out, " + "Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::max_pool3d_with_indices_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max_pool3d_with_indices(Tensor self, int[3] " + "kernel_size, int[3] stride=[], int[3] padding=0, int[3] " + "dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::max_pool3d_with_indices>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_pool3d_with_indices_backward.grad_input(" + "Tensor grad_output, Tensor self, int[3] " + "kernel_size, int[3] stride, int[3] padding, int[3] " + "dilation, bool ceil_mode, Tensor indices, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, bool, const at::Tensor &), + &ATenMLIRTypeDefault:: + max_pool3d_with_indices_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_pool3d_with_indices_backward(Tensor " + "grad_output, Tensor self, int[3] kernel_size, " + "int[3] stride, int[3] padding, int[3] dilation, " + "bool ceil_mode, Tensor indices) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, bool, + const at::Tensor &), + &ATenMLIRTypeDefault::max_pool3d_with_indices_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max_unpool2d.out(Tensor self, Tensor indices, " + "int[2] output_size, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::max_unpool2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_unpool2d(Tensor self, Tensor indices, " + "int[2] output_size) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max_unpool2d_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor indices, int[2] " + "output_size, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::max_unpool2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max_unpool2d_backward(Tensor grad_output, Tensor " + "self, Tensor indices, int[2] output_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::max_unpool2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_unpool3d.out(Tensor self, Tensor indices, " + "int[3] output_size, int[3] stride, int[3] padding, " + "*, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::max_unpool3d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::max_unpool3d(Tensor self, Tensor indices, int[3] " + "output_size, int[3] stride, int[3] padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::max_unpool3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_unpool3d_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor indices, int[3] " + "output_size, int[3] stride, int[3] padding, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::max_unpool3d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::max_unpool3d_backward(Tensor grad_output, " + "Tensor self, Tensor indices, int[3] output_size, " + "int[3] stride, int[3] padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::max_unpool3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reflection_pad1d.out(Tensor self, int[2] " + "padding, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::reflection_pad1d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reflection_pad1d(Tensor self, int[2] padding) " + "-> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::reflection_pad1d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reflection_pad1d_backward.grad_input(Tensor " + "grad_output, Tensor self, int[2] padding, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::reflection_pad1d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reflection_pad1d_backward(Tensor grad_output, " + "Tensor self, int[2] padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::reflection_pad1d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reflection_pad2d.out(Tensor self, int[4] " + "padding, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::reflection_pad2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reflection_pad2d(Tensor self, int[4] padding) " + "-> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::reflection_pad2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reflection_pad2d_backward.grad_input(Tensor " + "grad_output, Tensor self, int[4] padding, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::reflection_pad2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::reflection_pad2d_backward(Tensor grad_output, " + "Tensor self, int[4] padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::reflection_pad2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad1d.out(Tensor self, int[2] " + "padding, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad1d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad1d(Tensor self, int[2] " + "padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad1d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad1d_backward.grad_input(Tensor " + "grad_output, Tensor self, int[2] padding, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad1d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad1d_backward(Tensor " + "grad_output, Tensor self, int[2] padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad1d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad2d.out(Tensor self, int[4] " + "padding, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad2d(Tensor self, int[4] " + "padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad2d_backward.grad_input(Tensor " + "grad_output, Tensor self, int[4] padding, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad2d_backward(Tensor " + "grad_output, Tensor self, int[4] padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad3d.out(Tensor self, int[6] " + "padding, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad3d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad3d(Tensor self, int[6] " + "padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad3d_backward.grad_input(Tensor " + "grad_output, Tensor self, int[6] padding, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad3d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::replication_pad3d_backward(Tensor " + "grad_output, Tensor self, int[6] padding) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::replication_pad3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_linear1d.out(Tensor self, int[1] " + "output_size, bool align_corners, *, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_linear1d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_linear1d(Tensor self, int[1] " + "output_size, bool align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_linear1d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_linear1d_backward.grad_input(Tensor " + "grad_output, int[1] output_size, int[3] input_size, " + "bool align_corners, *, Tensor(a!) grad_input) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_linear1d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_linear1d_backward(Tensor " + "grad_output, int[1] output_size, int[3] input_size, " + "bool align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_linear1d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_bilinear2d.out(Tensor self, int[2] " + "output_size, bool align_corners, *, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_bilinear2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_bilinear2d(Tensor self, int[2] " + "output_size, bool align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_bilinear2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_bilinear2d_backward.grad_input(" + "Tensor grad_output, int[2] output_size, int[4] " + "input_size, bool align_corners, *, Tensor(a!) " + "grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_bilinear2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_bilinear2d_backward(Tensor " + "grad_output, int[2] output_size, int[4] input_size, " + "bool align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_bilinear2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_bicubic2d.out(Tensor self, int[2] " + "output_size, bool align_corners, *, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_bicubic2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_bicubic2d(Tensor self, int[2] " + "output_size, bool align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_bicubic2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_bicubic2d_backward.grad_input(Tensor " + "grad_output, int[2] output_size, int[4] input_size, " + "bool align_corners, *, Tensor(a!) grad_input) -> " + "Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_bicubic2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_bicubic2d_backward(Tensor " + "grad_output, int[2] output_size, int[4] input_size, " + "bool align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_bicubic2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_trilinear3d.out(Tensor self, int[3] " + "output_size, bool align_corners, *, Tensor(a!) out) " + "-> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_trilinear3d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_trilinear3d(Tensor self, int[3] " + "output_size, bool align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_trilinear3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_trilinear3d_backward.grad_input(" + "Tensor grad_output, int[3] output_size, int[5] " + "input_size, bool align_corners, *, Tensor(a!) " + "grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_trilinear3d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_trilinear3d_backward(Tensor " + "grad_output, int[3] output_size, int[5] input_size, " + "bool align_corners) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, bool), + &ATenMLIRTypeDefault::upsample_trilinear3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_nearest1d.out(Tensor self, int[1] " + "output_size, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest1d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_nearest1d(Tensor self, int[1] " + "output_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest1d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_nearest1d_backward.grad_input(Tensor " + "grad_output, int[1] output_size, int[3] input_size, " + "*, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest1d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::upsample_nearest1d_backward(Tensor grad_output, " + "int[1] output_size, int[3] input_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest1d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_nearest2d.out(Tensor self, int[2] " + "output_size, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_nearest2d(Tensor self, int[2] " + "output_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_nearest2d_backward.grad_input(Tensor " + "grad_output, int[2] output_size, int[4] input_size, " + "*, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::upsample_nearest2d_backward(Tensor grad_output, " + "int[2] output_size, int[4] input_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_nearest3d.out(Tensor self, int[3] " + "output_size, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest3d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_nearest3d(Tensor self, int[3] " + "output_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::upsample_nearest3d_backward.grad_input(Tensor " + "grad_output, int[3] output_size, int[5] input_size, " + "*, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest3d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::upsample_nearest3d_backward(Tensor grad_output, " + "int[3] output_size, int[5] input_size) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::upsample_nearest3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::sigmoid_backward.grad_input(Tensor grad_output, " + "Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::sigmoid_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::sigmoid_backward(Tensor grad_output, Tensor " + "output) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::sigmoid_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::tanh_backward.grad_input(Tensor grad_output, " + "Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::tanh_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::tanh_backward(Tensor grad_output, Tensor " + "output) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::slow_conv_transpose2d.out(Tensor self, Tensor " + "weight, int[2] kernel_size, Tensor? bias=None, int[2] " + "stride=1, int[2] padding=0, int[2] output_padding=0, " + "int[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &( + at::Tensor &, const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::slow_conv_transpose2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::slow_conv_transpose2d(Tensor self, Tensor " + "weight, int[2] kernel_size, Tensor? bias=None, " + "int[2] stride=1, int[2] padding=0, int[2] " + "output_padding=0, int[2] dilation=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::slow_conv_transpose2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::slow_conv_transpose2d_backward.grad_output(" + "Tensor grad_output, Tensor self, Tensor weight, " + "int[2] kernel_size, int[2] stride, int[2] padding, " + "int[2] output_padding, int[2] dilation, Tensor " + "columns, Tensor ones, *, Tensor?(a!) grad_input, " + "Tensor?(b!) grad_weight, Tensor?(c!) grad_bias) -> " + "(Tensor(a!), Tensor(b!), Tensor(c!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::slow_conv_transpose2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::slow_conv_transpose2d_backward.output_mask(Tensor " + "grad_output, Tensor self, Tensor weight, int[2] " + "kernel_size, int[2] stride, int[2] padding, int[2] " + "output_padding, int[2] dilation, Tensor columns, Tensor " + "ones, bool[3] output_mask) -> (Tensor grad_input, " + "Tensor grad_weight, Tensor grad_bias)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + const at::Tensor &, const at::Tensor &, + std::array), + &ATenMLIRTypeDefault::slow_conv_transpose2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::slow_conv_transpose3d.out(Tensor self, Tensor " + "weight, int[3] kernel_size, Tensor? bias=None, int[3] " + "stride=1, int[3] padding=0, int[3] output_padding=0, " + "int[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &( + at::Tensor &, const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::slow_conv_transpose3d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::slow_conv_transpose3d(Tensor self, Tensor " + "weight, int[3] kernel_size, Tensor? bias=None, " + "int[3] stride=1, int[3] padding=0, int[3] " + "output_padding=0, int[3] dilation=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::slow_conv_transpose3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::slow_conv_transpose3d_backward.grad_output(" + "Tensor grad_output, Tensor self, Tensor weight, " + "int[3] kernel_size, int[3] stride, int[3] padding, " + "int[3] output_padding, int[3] dilation, Tensor " + "finput, Tensor fgrad_input, *, Tensor?(a!) " + "grad_input, Tensor?(b!) grad_weight, Tensor?(c!) " + "grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + const at::Tensor &, const at::Tensor &), + &ATenMLIRTypeDefault::slow_conv_transpose3d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::slow_conv_transpose3d_backward.output_mask(Tensor " + "grad_output, Tensor self, Tensor weight, int[3] " + "kernel_size, int[3] stride, int[3] padding, int[3] " + "output_padding, int[3] dilation, Tensor finput, Tensor " + "fgrad_input, bool[3] output_mask) -> (Tensor " + "grad_input, Tensor grad_weight, Tensor grad_bias)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + const at::Tensor &, const at::Tensor &, + std::array), + &ATenMLIRTypeDefault::slow_conv_transpose3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::thnn_conv2d.out(Tensor self, Tensor weight, " + "int[2] kernel_size, Tensor? bias=None, int[2] stride=1, " + "int[2] padding=0, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv2d(Tensor self, Tensor weight, " + "int[2] kernel_size, Tensor? bias=None, int[2] " + "stride=1, int[2] padding=0) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv2d_forward.output(Tensor self, " + "Tensor weight, int[2] kernel_size, Tensor? bias, " + "int[2] stride, int[2] padding, *, Tensor(a!) " + "output, Tensor(b!) finput, Tensor(c!) fgrad_input) " + "-> (Tensor(a!), Tensor(b!), Tensor(c!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv2d_forward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv2d_forward(Tensor self, Tensor " + "weight, int[2] kernel_size, Tensor? bias, int[2] " + "stride, int[2] padding) -> (Tensor output, Tensor " + "finput, Tensor fgrad_input)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv2d_forward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv2d_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor weight, int[2] " + "kernel_size, int[2] stride, int[2] padding, Tensor " + "finput, Tensor fgrad_input, *, Tensor?(a!) " + "grad_input, Tensor?(b!) grad_weight, Tensor?(c!) " + "grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::thnn_conv2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv2d_backward.output_mask(Tensor " + "grad_output, Tensor self, Tensor weight, int[2] " + "kernel_size, int[2] stride, int[2] padding, Tensor " + "finput, Tensor fgrad_input, bool[3] output_mask) -> " + "(Tensor grad_input, Tensor grad_weight, Tensor " + "grad_bias)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, const at::Tensor &, + const at::Tensor &, std::array), + &ATenMLIRTypeDefault::thnn_conv2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv_depthwise2d.out(Tensor self, Tensor " + "weight, int[2] kernel_size, Tensor? bias=None, " + "int[2] stride=1, int[2] padding=0, int[2] " + "dilation=1, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv_depthwise2d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::thnn_conv_depthwise2d(Tensor self, Tensor weight, " + "int[2] kernel_size, Tensor? bias=None, int[2] stride=1, " + "int[2] padding=0, int[2] dilation=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv_depthwise2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv_depthwise2d_forward.out(Tensor " + "self, Tensor weight, int[2] kernel_size, Tensor? " + "bias, int[2] stride, int[2] padding, int[2] " + "dilation, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv_depthwise2d_forward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::thnn_conv_depthwise2d_forward(Tensor self, Tensor " + "weight, int[2] kernel_size, Tensor? bias, int[2] " + "stride, int[2] padding, int[2] dilation) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv_depthwise2d_forward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::thnn_conv_depthwise2d_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor weight, int[2] " + "kernel_size, int[2] stride, int[2] padding, int[2] " + "dilation, *, Tensor?(a!) grad_input, Tensor?(b!) " + "grad_weight) -> (Tensor(a!), Tensor(b!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, const at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv_depthwise2d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv_depthwise2d_backward.output_mask(" + "Tensor grad_output, Tensor self, Tensor weight, " + "int[2] kernel_size, int[2] stride, int[2] padding, " + "int[2] dilation, bool[2] output_mask) -> (Tensor " + "grad_input, Tensor grad_weight)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, + const at::Tensor &, + const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, std::array), + &ATenMLIRTypeDefault::thnn_conv_depthwise2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::thnn_conv3d.out(Tensor self, Tensor weight, " + "int[3] kernel_size, Tensor? bias=None, int[3] stride=1, " + "int[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, + const at::Tensor &, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv3d_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv3d(Tensor self, Tensor weight, " + "int[3] kernel_size, Tensor? bias=None, int[3] " + "stride=1, int[3] padding=0) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv3d_forward.output(Tensor self, " + "Tensor weight, int[3] kernel_size, Tensor? bias, " + "int[3] stride, int[3] padding, *, Tensor(a!) " + "output, Tensor(b!) finput, Tensor(c!) fgrad_input) " + "-> (Tensor(a!), Tensor(b!), Tensor(c!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, at::Tensor &, + const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv3d_forward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv3d_forward(Tensor self, Tensor " + "weight, int[3] kernel_size, Tensor? bias, int[3] " + "stride, int[3] padding) -> (Tensor output, Tensor " + "finput, Tensor fgrad_input)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::thnn_conv3d_forward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv3d_backward.grad_input(Tensor " + "grad_output, Tensor self, Tensor weight, int[3] " + "kernel_size, int[3] stride, int[3] padding, Tensor " + "finput, Tensor fgrad_input, *, Tensor?(a!) " + "grad_input, Tensor?(b!) grad_weight, Tensor?(c!) " + "grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))") + .impl_unboxedOnlyKernel< + std::tuple( + at::Tensor &, at::Tensor &, at::Tensor &, + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, const at::Tensor &, + const at::Tensor &), + &ATenMLIRTypeDefault::thnn_conv3d_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::thnn_conv3d_backward.output_mask(Tensor " + "grad_output, Tensor self, Tensor weight, int[3] " + "kernel_size, int[3] stride, int[3] padding, Tensor " + "finput, Tensor fgrad_input, bool[3] output_mask) -> " + "(Tensor grad_input, Tensor grad_weight, Tensor " + "grad_bias)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, const at::Tensor &, + const at::Tensor &, std::array), + &ATenMLIRTypeDefault::thnn_conv3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::slow_conv_dilated2d(Tensor self, Tensor weight, " + "int[2] kernel_size, Tensor? bias=None, int[2] stride=1, " + "int[2] padding=0, int[2] dilation=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::slow_conv_dilated2d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::slow_conv_dilated2d_backward(Tensor " + "grad_output, Tensor self, Tensor weight, int[2] " + "kernel_size, int[2] stride, int[2] padding, int[2] " + "dilation, bool[3] output_mask) -> (Tensor " + "grad_input, Tensor grad_weight, Tensor grad_bias)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + std::array), + &ATenMLIRTypeDefault::slow_conv_dilated2d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::slow_conv_dilated3d(Tensor self, Tensor weight, " + "int[3] kernel_size, Tensor? bias=None, int[3] stride=1, " + "int[3] padding=0, int[3] dilation=1) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, const at::Tensor &, + at::IntArrayRef, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::slow_conv_dilated3d>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::slow_conv_dilated3d_backward(Tensor " + "grad_output, Tensor self, Tensor weight, int[3] " + "kernel_size, int[3] stride, int[3] padding, int[3] " + "dilation, bool[3] output_mask) -> (Tensor " + "grad_input, Tensor grad_weight, Tensor grad_bias)") + .impl_unboxedOnlyKernel< + std::tuple( + const at::Tensor &, const at::Tensor &, + const at::Tensor &, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + std::array), + &ATenMLIRTypeDefault::slow_conv_dilated3d_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::col2im.out(Tensor self, int[2] output_size, " + "int[2] kernel_size, int[2] dilation, int[2] padding, " + "int[2] stride, *, Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::col2im(Tensor self, int[2] output_size, " + "int[2] kernel_size, int[2] dilation, int[2] " + "padding, int[2] stride) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::col2im>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::col2im_backward.grad_input(Tensor grad_output, " + "int[2] kernel_size, int[2] dilation, int[2] padding, " + "int[2] stride, *, Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::col2im_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::col2im_backward(Tensor grad_output, int[2] " + "kernel_size, int[2] dilation, int[2] padding, " + "int[2] stride) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::col2im_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::im2col.out(Tensor self, int[2] kernel_size, " + "int[2] dilation, int[2] padding, int[2] stride, *, " + "Tensor(a!) out) -> Tensor(a!)") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema( + "aten::im2col(Tensor self, int[2] kernel_size, int[2] " + "dilation, int[2] padding, int[2] stride) -> Tensor") + .impl_unboxedOnlyKernel( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::im2col_backward.grad_input(Tensor " + "grad_output, int[2] input_size, int[2] kernel_size, " + "int[2] dilation, int[2] padding, int[2] stride, *, " + "Tensor(a!) grad_input) -> Tensor(a!)") + .impl_unboxedOnlyKernel< + at::Tensor &(at::Tensor &, const at::Tensor &, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef), + &ATenMLIRTypeDefault::im2col_backward_out>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::im2col_backward(Tensor grad_output, int[2] " + "input_size, int[2] kernel_size, int[2] dilation, " + "int[2] padding, int[2] stride) -> Tensor") + .impl_unboxedOnlyKernel< + at::Tensor(const at::Tensor &, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef), + &ATenMLIRTypeDefault::im2col_backward>( + at::TensorTypeId::XLATensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)); +} + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/aten_mlir_type_default.h b/frontends/pytorch/csrc/aten_mlir_type_default.h new file mode 100644 index 000000000..7e21306e1 --- /dev/null +++ b/frontends/pytorch/csrc/aten_mlir_type_default.h @@ -0,0 +1,2907 @@ +//===- aten_mlir_type_default.h ---------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#include + +namespace torch_mlir { + +class ATenMLIRTypeDefault { +public: + static at::Tensor _cast_Byte(const at::Tensor &self, bool non_blocking); + static at::Tensor _cast_Char(const at::Tensor &self, bool non_blocking); + static at::Tensor _cast_Double(const at::Tensor &self, bool non_blocking); + static at::Tensor _cast_Float(const at::Tensor &self, bool non_blocking); + static at::Tensor _cast_Int(const at::Tensor &self, bool non_blocking); + static at::Tensor _cast_Long(const at::Tensor &self, bool non_blocking); + static at::Tensor _cast_Short(const at::Tensor &self, bool non_blocking); + static at::Tensor _cast_Half(const at::Tensor &self, bool non_blocking); + static void backward(const at::Tensor &self, const at::Tensor &gradient, + bool keep_graph, bool create_graph); + static void set_data(const at::Tensor &self, const at::Tensor &new_data); + static at::Tensor data(const at::Tensor &self); + static int64_t _debug_has_internal_overlap(const at::Tensor &self); + static std::tuple + _fused_dropout(const at::Tensor &self, double p, at::Generator *generator); + static at::Tensor _masked_scale(const at::Tensor &self, + const at::Tensor &mask, double scale); + static std::tuple + _sobol_engine_draw(const at::Tensor &quasi, int64_t n, + const at::Tensor &sobolstate, int64_t dimension, + int64_t num_generated, + c10::optional dtype); + static at::Tensor &_sobol_engine_ff_(at::Tensor &self, int64_t n, + const at::Tensor &sobolstate, + int64_t dimension, + int64_t num_generated); + static at::Tensor &_sobol_engine_scramble_(at::Tensor &self, + const at::Tensor <m, + int64_t dimension); + static at::Tensor &_sobol_engine_initialize_state_(at::Tensor &self, + int64_t dimension); + static at::Tensor _reshape_from_tensor(const at::Tensor &self, + const at::Tensor &shape); + static at::Tensor _shape_as_tensor(const at::Tensor &self); + static at::Tensor dropout(const at::Tensor &input, double p, bool train); + static at::Tensor &dropout_(at::Tensor &self, double p, bool train); + static at::Tensor feature_dropout(const at::Tensor &input, double p, + bool train); + static at::Tensor &feature_dropout_(at::Tensor &self, double p, bool train); + static at::Tensor alpha_dropout(const at::Tensor &input, double p, + bool train); + static at::Tensor &alpha_dropout_(at::Tensor &self, double p, bool train); + static at::Tensor feature_alpha_dropout(const at::Tensor &input, double p, + bool train); + static at::Tensor &feature_alpha_dropout_(at::Tensor &self, double p, + bool train); + static at::Tensor abs(const at::Tensor &self); + static at::Tensor &abs_(at::Tensor &self); + static at::Tensor &abs_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor acos(const at::Tensor &self); + static at::Tensor &acos_(at::Tensor &self); + static at::Tensor &acos_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor avg_pool1d(const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + bool ceil_mode, bool count_include_pad); + static at::Tensor adaptive_avg_pool1d(const at::Tensor &self, + at::IntArrayRef output_size); + static std::tuple + adaptive_max_pool1d(const at::Tensor &self, at::IntArrayRef output_size); + static at::Tensor add(const at::Tensor &self, const at::Tensor &other, + at::Scalar alpha); + static at::Tensor &add_(at::Tensor &self, const at::Tensor &other, + at::Scalar alpha); + static at::Tensor &add_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other, at::Scalar alpha); + static at::Tensor add(const at::Tensor &self, at::Scalar other, + at::Scalar alpha); + static at::Tensor &add_(at::Tensor &self, at::Scalar other, at::Scalar alpha); + static at::Tensor addmv(const at::Tensor &self, const at::Tensor &mat, + const at::Tensor &vec, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &addmv_(at::Tensor &self, const at::Tensor &mat, + const at::Tensor &vec, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &addmv_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &mat, const at::Tensor &vec, + at::Scalar beta, at::Scalar alpha); + static at::Tensor addr(const at::Tensor &self, const at::Tensor &vec1, + const at::Tensor &vec2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &addr_(at::Tensor &self, const at::Tensor &vec1, + const at::Tensor &vec2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &addr_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &vec1, const at::Tensor &vec2, + at::Scalar beta, at::Scalar alpha); + static at::Tensor affine_grid_generator(const at::Tensor &theta, + at::IntArrayRef size, + bool align_corners); + static at::Tensor affine_grid_generator_backward(const at::Tensor &grad, + at::IntArrayRef size, + bool align_corners); + static at::Tensor all(const at::Tensor &self, int64_t dim, bool keepdim); + static at::Tensor &all_out(at::Tensor &out, const at::Tensor &self, + int64_t dim, bool keepdim); + static bool allclose(const at::Tensor &self, const at::Tensor &other, + double rtol, double atol, bool equal_nan); + static at::Tensor any(const at::Tensor &self, int64_t dim, bool keepdim); + static at::Tensor &any_out(at::Tensor &out, const at::Tensor &self, + int64_t dim, bool keepdim); + static at::Tensor arange(at::Scalar end, const at::TensorOptions &options); + static at::Tensor arange(at::Scalar start, at::Scalar end, + const at::TensorOptions &options); + static at::Tensor arange(at::Scalar start, at::Scalar end, at::Scalar step, + const at::TensorOptions &options); + static at::Tensor &arange_out(at::Tensor &out, at::Scalar end); + static at::Tensor &arange_out(at::Tensor &out, at::Scalar start, + at::Scalar end, at::Scalar step); + static at::Tensor _dim_arange(const at::Tensor &like, int64_t dim); + static at::Tensor argmax(const at::Tensor &self, c10::optional dim, + bool keepdim); + static at::Tensor argmin(const at::Tensor &self, c10::optional dim, + bool keepdim); + static at::Tensor as_strided(const at::Tensor &self, at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset); + static at::Tensor &as_strided_(at::Tensor &self, at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset); + static at::Tensor asin(const at::Tensor &self); + static at::Tensor &asin_(at::Tensor &self); + static at::Tensor &asin_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor atan(const at::Tensor &self); + static at::Tensor &atan_(at::Tensor &self); + static at::Tensor &atan_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor baddbmm(const at::Tensor &self, const at::Tensor &batch1, + const at::Tensor &batch2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &baddbmm_(at::Tensor &self, const at::Tensor &batch1, + const at::Tensor &batch2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &_baddbmm_mkl_(at::Tensor &self, const at::Tensor &batch1, + const at::Tensor &batch2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &baddbmm_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &batch1, + const at::Tensor &batch2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor bartlett_window(int64_t window_length, + const at::TensorOptions &options); + static at::Tensor bartlett_window(int64_t window_length, bool periodic, + const at::TensorOptions &options); + static at::Tensor batch_norm(const at::Tensor &input, + const at::Tensor &weight, const at::Tensor &bias, + const at::Tensor &running_mean, + const at::Tensor &running_var, bool training, + double momentum, double eps, bool cudnn_enabled); + static std::tuple + _batch_norm_impl_index(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, const at::Tensor &running_mean, + const at::Tensor &running_var, bool training, + double momentum, double eps, bool cudnn_enabled); + static std::tuple + _batch_norm_impl_index_backward( + int64_t impl_index, const at::Tensor &input, + const at::Tensor &grad_output, const at::Tensor &weight, + const at::Tensor &running_mean, const at::Tensor &running_var, + const at::Tensor &save_mean, const at::Tensor &save_var_transform, + bool train, double eps, std::array output_mask); + static at::Tensor bernoulli(const at::Tensor &self, at::Generator *generator); + static at::Tensor &bernoulli_out(at::Tensor &out, const at::Tensor &self, + at::Generator *generator); + static at::Tensor &bernoulli_(at::Tensor &self, const at::Tensor &p, + at::Generator *generator); + static at::Tensor &bernoulli_(at::Tensor &self, double p, + at::Generator *generator); + static at::Tensor bernoulli(const at::Tensor &self, double p, + at::Generator *generator); + static at::Tensor bilinear(const at::Tensor &input1, const at::Tensor &input2, + const at::Tensor &weight, const at::Tensor &bias); + static at::Tensor binary_cross_entropy_with_logits( + const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, const at::Tensor &pos_weight, + int64_t reduction); + static at::Tensor binary_cross_entropy_with_logits_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, const at::Tensor &weight, + const at::Tensor &pos_weight, int64_t reduction); + static at::Tensor bincount(const at::Tensor &self, const at::Tensor &weights, + int64_t minlength); + static at::Tensor bitwise_not(const at::Tensor &self); + static at::Tensor &bitwise_not_(at::Tensor &self); + static at::Tensor &bitwise_not_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor logical_not(const at::Tensor &self); + static at::Tensor &logical_not_(at::Tensor &self); + static at::Tensor &logical_not_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor logical_xor(const at::Tensor &self, + const at::Tensor &other); + static at::Tensor &logical_xor_(at::Tensor &self, const at::Tensor &other); + static at::Tensor &logical_xor_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor blackman_window(int64_t window_length, + const at::TensorOptions &options); + static at::Tensor blackman_window(int64_t window_length, bool periodic, + const at::TensorOptions &options); + static at::Tensor bmm(const at::Tensor &self, const at::Tensor &mat2); + static at::Tensor &bmm_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &mat2); + static std::vector broadcast_tensors(at::TensorList tensors); + static at::Tensor cat(at::TensorList tensors, int64_t dim); + static at::Tensor &cat_out(at::Tensor &out, at::TensorList tensors, + int64_t dim); + static at::Tensor ceil(const at::Tensor &self); + static at::Tensor &ceil_(at::Tensor &self); + static at::Tensor &ceil_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor chain_matmul(at::TensorList matrices); + static std::vector chunk(const at::Tensor &self, int64_t chunks, + int64_t dim); + static at::Tensor clamp(const at::Tensor &self, c10::optional min, + c10::optional max); + static at::Tensor &clamp_(at::Tensor &self, c10::optional min, + c10::optional max); + static at::Tensor &clamp_out(at::Tensor &out, const at::Tensor &self, + c10::optional min, + c10::optional max); + static at::Tensor clamp_max(const at::Tensor &self, at::Scalar max); + static at::Tensor &clamp_max_(at::Tensor &self, at::Scalar max); + static at::Tensor &clamp_max_out(at::Tensor &out, const at::Tensor &self, + at::Scalar max); + static at::Tensor clamp_min(const at::Tensor &self, at::Scalar min); + static at::Tensor &clamp_min_(at::Tensor &self, at::Scalar min); + static at::Tensor &clamp_min_out(at::Tensor &out, const at::Tensor &self, + at::Scalar min); + static at::Tensor constant_pad_nd(const at::Tensor &self, at::IntArrayRef pad, + at::Scalar value); + static at::Tensor contiguous(const at::Tensor &self, + at::MemoryFormat memory_format); + static at::Tensor convolution(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, + at::IntArrayRef output_padding, int64_t groups); + static at::Tensor convolution_overrideable( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups); + static std::tuple + convolution_backward_overrideable( + const at::Tensor &grad_output, const at::Tensor &input, + const at::Tensor &weight, at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, + int64_t groups, std::array output_mask); + static at::Tensor + _convolution(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups, + bool benchmark, bool deterministic, bool cudnn_enabled); + static at::Tensor + _convolution_nogroup(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding); + static std::tuple + _convolution_double_backward(const at::Tensor &ggI, const at::Tensor &ggW, + const at::Tensor &ggb, const at::Tensor &gO, + const at::Tensor &weight, const at::Tensor &self, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, + at::IntArrayRef output_padding, int64_t groups, + bool benchmark, bool deterministic, + bool cudnn_enabled, + std::array output_mask); + static at::Tensor conv1d(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, + int64_t groups); + static at::Tensor conv2d(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, + int64_t groups); + static at::Tensor conv3d(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, + int64_t groups); + static at::Tensor conv_tbc(const at::Tensor &self, const at::Tensor &weight, + const at::Tensor &bias, int64_t pad); + static std::tuple + conv_tbc_backward(const at::Tensor &self, const at::Tensor &input, + const at::Tensor &weight, const at::Tensor &bias, + int64_t pad); + static at::Tensor + conv_transpose1d(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef output_padding, + int64_t groups, at::IntArrayRef dilation); + static at::Tensor + conv_transpose2d(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef output_padding, + int64_t groups, at::IntArrayRef dilation); + static at::Tensor + conv_transpose3d(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef output_padding, + int64_t groups, at::IntArrayRef dilation); + static at::Tensor ©_(at::Tensor &self, const at::Tensor &src, + bool non_blocking); + static at::Tensor _copy_from(const at::Tensor &self, const at::Tensor &dst, + bool non_blocking); + static at::Tensor cos(const at::Tensor &self); + static at::Tensor &cos_(at::Tensor &self); + static at::Tensor &cos_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor cosh(const at::Tensor &self); + static at::Tensor &cosh_(at::Tensor &self); + static at::Tensor &cosh_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor cosine_embedding_loss(const at::Tensor &input1, + const at::Tensor &input2, + const at::Tensor &target, + double margin, int64_t reduction); + static at::Tensor cumsum(const at::Tensor &self, int64_t dim, + c10::optional dtype); + static at::Tensor &cumsum_out(at::Tensor &out, const at::Tensor &self, + int64_t dim, + c10::optional dtype); + static at::Tensor cumprod(const at::Tensor &self, int64_t dim, + c10::optional dtype); + static at::Tensor &cumprod_out(at::Tensor &out, const at::Tensor &self, + int64_t dim, + c10::optional dtype); + static at::Tensor ctc_loss(const at::Tensor &log_probs, + const at::Tensor &targets, + at::IntArrayRef input_lengths, + at::IntArrayRef target_lengths, int64_t blank, + int64_t reduction, bool zero_infinity); + static at::Tensor ctc_loss(const at::Tensor &log_probs, + const at::Tensor &targets, + const at::Tensor &input_lengths, + const at::Tensor &target_lengths, int64_t blank, + int64_t reduction, bool zero_infinity); + static std::tuple + _ctc_loss(const at::Tensor &log_probs, const at::Tensor &targets, + at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, + int64_t blank, bool zero_infinity); + static at::Tensor _ctc_loss_backward( + const at::Tensor &grad, const at::Tensor &log_probs, + const at::Tensor &targets, at::IntArrayRef input_lengths, + at::IntArrayRef target_lengths, const at::Tensor &neg_log_likelihood, + const at::Tensor &log_alpha, int64_t blank, bool zero_infinity); + static at::Tensor det(const at::Tensor &self); + static at::Tensor diag_embed(const at::Tensor &self, int64_t offset, + int64_t dim1, int64_t dim2); + static at::Tensor diagflat(const at::Tensor &self, int64_t offset); + static at::Tensor diagonal(const at::Tensor &self, int64_t offset, + int64_t dim1, int64_t dim2); + static at::Tensor &fill_diagonal_(at::Tensor &self, at::Scalar fill_value, + bool wrap); + static at::Tensor div(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &div_(at::Tensor &self, const at::Tensor &other); + static at::Tensor &div_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor div(const at::Tensor &self, at::Scalar other); + static at::Tensor &div_(at::Tensor &self, at::Scalar other); + static at::Tensor dot(const at::Tensor &self, const at::Tensor &tensor); + static at::Tensor &dot_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &tensor); + static at::Tensor einsum(std::string equation, at::TensorList tensors); + static at::Tensor embedding(const at::Tensor &weight, + const at::Tensor &indices, int64_t padding_idx, + bool scale_grad_by_freq, bool sparse); + static at::Tensor embedding_backward(const at::Tensor &grad, + const at::Tensor &indices, + int64_t num_weights, int64_t padding_idx, + bool scale_grad_by_freq, bool sparse); + static at::Tensor embedding_dense_backward(const at::Tensor &grad_output, + const at::Tensor &indices, + int64_t num_weights, + int64_t padding_idx, + bool scale_grad_by_freq); + static at::Tensor &embedding_renorm_(at::Tensor &self, + const at::Tensor &indices, + double max_norm, double norm_type); + static at::Tensor embedding_sparse_backward(const at::Tensor &grad, + const at::Tensor &indices, + int64_t num_weights, + int64_t padding_idx, + bool scale_grad_by_freq); + static std::tuple + embedding_bag(const at::Tensor &weight, const at::Tensor &indices, + const at::Tensor &offsets, bool scale_grad_by_freq, + int64_t mode, bool sparse, + const at::Tensor &per_sample_weights); + static std::tuple + _embedding_bag(const at::Tensor &weight, const at::Tensor &indices, + const at::Tensor &offsets, bool scale_grad_by_freq, + int64_t mode, bool sparse, + const at::Tensor &per_sample_weights); + static at::Tensor _embedding_bag_backward( + const at::Tensor &grad, const at::Tensor &indices, + const at::Tensor &offsets, const at::Tensor &offset2bag, + const at::Tensor &bag_size, const at::Tensor &maximum_indices, + int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, + const at::Tensor &per_sample_weights); + static at::Tensor _embedding_bag_sparse_backward( + const at::Tensor &grad, const at::Tensor &indices, + const at::Tensor &offsets, const at::Tensor &offset2bag, + const at::Tensor &bag_size, int64_t num_weights, bool scale_grad_by_freq, + int64_t mode, const at::Tensor &per_sample_weights); + static at::Tensor _embedding_bag_dense_backward( + const at::Tensor &grad, const at::Tensor &indices, + const at::Tensor &offsets, const at::Tensor &offset2bag, + const at::Tensor &bag_size, const at::Tensor &maximum_indices, + int64_t num_weights, bool scale_grad_by_freq, int64_t mode, + const at::Tensor &per_sample_weights); + static at::Tensor _embedding_bag_per_sample_weights_backward( + const at::Tensor &grad, const at::Tensor &weight, + const at::Tensor &indices, const at::Tensor &offsets, + const at::Tensor &offset2bag, int64_t mode); + static at::Tensor empty(at::IntArrayRef size, + const at::TensorOptions &options, + c10::optional memory_format); + static at::Tensor new_empty(const at::Tensor &self, at::IntArrayRef size, + const at::TensorOptions &options); + static at::Tensor new_full(const at::Tensor &self, at::IntArrayRef size, + at::Scalar fill_value, + const at::TensorOptions &options); + static at::Tensor _empty_affine_quantized( + at::IntArrayRef size, const at::TensorOptions &options, double scale, + int64_t zero_point, c10::optional memory_format); + static at::Tensor _empty_per_channel_affine_quantized_like( + const at::Tensor &self, const at::Tensor &zero_points, + at::IntArrayRef size, at::IntArrayRef axis, + const at::TensorOptions &options, + c10::optional memory_format); + static at::Tensor &resize_(at::Tensor &self, at::IntArrayRef size); + static at::Tensor &empty_out(at::Tensor &out, at::IntArrayRef size, + c10::optional memory_format); + static at::Tensor empty_like(const at::Tensor &self); + static at::Tensor empty_like(const at::Tensor &self, + const at::TensorOptions &options, + c10::optional memory_format); + static at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride, + const at::TensorOptions &options); + static at::Tensor erf(const at::Tensor &self); + static at::Tensor &erf_(at::Tensor &self); + static at::Tensor &erf_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor erfc(const at::Tensor &self); + static at::Tensor &erfc_(at::Tensor &self); + static at::Tensor &erfc_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor exp(const at::Tensor &self); + static at::Tensor &exp_(at::Tensor &self); + static at::Tensor &exp_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor expm1(const at::Tensor &self); + static at::Tensor &expm1_(at::Tensor &self); + static at::Tensor &expm1_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor expand(const at::Tensor &self, at::IntArrayRef size, + bool implicit); + static at::Tensor expand_as(const at::Tensor &self, const at::Tensor &other); + static at::Tensor eye(int64_t n, const at::TensorOptions &options); + static at::Tensor eye(int64_t n, int64_t m, const at::TensorOptions &options); + static at::Tensor &eye_out(at::Tensor &out, int64_t n); + static at::Tensor &eye_out(at::Tensor &out, int64_t n, int64_t m); + static at::Tensor flatten(const at::Tensor &self, int64_t start_dim, + int64_t end_dim); + static at::Tensor &fill_(at::Tensor &self, at::Scalar value); + static at::Tensor &fill_(at::Tensor &self, const at::Tensor &value); + static at::Tensor floor(const at::Tensor &self); + static at::Tensor &floor_(at::Tensor &self); + static at::Tensor &floor_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor frac(const at::Tensor &self); + static at::Tensor &frac_(at::Tensor &self); + static at::Tensor &frac_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor full(at::IntArrayRef size, at::Scalar fill_value, + const at::TensorOptions &options); + static at::Tensor &full_out(at::Tensor &out, at::IntArrayRef size, + at::Scalar fill_value); + static at::Tensor full_like(const at::Tensor &self, at::Scalar fill_value); + static at::Tensor full_like(const at::Tensor &self, at::Scalar fill_value, + const at::TensorOptions &options); + static at::Tensor from_file(std::string filename, c10::optional shared, + c10::optional size, + const at::TensorOptions &options); + static at::Tensor grid_sampler(const at::Tensor &input, + const at::Tensor &grid, + int64_t interpolation_mode, + int64_t padding_mode, bool align_corners); + static at::Tensor grid_sampler_2d(const at::Tensor &input, + const at::Tensor &grid, + int64_t interpolation_mode, + int64_t padding_mode, bool align_corners); + static std::tuple + grid_sampler_2d_backward(const at::Tensor &grad_output, + const at::Tensor &input, const at::Tensor &grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + static at::Tensor grid_sampler_3d(const at::Tensor &input, + const at::Tensor &grid, + int64_t interpolation_mode, + int64_t padding_mode, bool align_corners); + static std::tuple + grid_sampler_3d_backward(const at::Tensor &grad_output, + const at::Tensor &input, const at::Tensor &grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + static at::Tensor hann_window(int64_t window_length, + const at::TensorOptions &options); + static at::Tensor hann_window(int64_t window_length, bool periodic, + const at::TensorOptions &options); + static at::Tensor hamming_window(int64_t window_length, + const at::TensorOptions &options); + static at::Tensor hamming_window(int64_t window_length, bool periodic, + const at::TensorOptions &options); + static at::Tensor hamming_window(int64_t window_length, bool periodic, + double alpha, + const at::TensorOptions &options); + static at::Tensor hamming_window(int64_t window_length, bool periodic, + double alpha, double beta, + const at::TensorOptions &options); + static at::Tensor hinge_embedding_loss(const at::Tensor &self, + const at::Tensor &target, + double margin, int64_t reduction); + static at::Tensor ger(const at::Tensor &self, const at::Tensor &vec2); + static at::Tensor &ger_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &vec2); + static at::Tensor group_norm(const at::Tensor &input, int64_t num_groups, + const at::Tensor &weight, const at::Tensor &bias, + double eps, bool cudnn_enabled); + static at::Tensor fft(const at::Tensor &self, int64_t signal_ndim, + bool normalized); + static at::Tensor ifft(const at::Tensor &self, int64_t signal_ndim, + bool normalized); + static at::Tensor rfft(const at::Tensor &self, int64_t signal_ndim, + bool normalized, bool onesided); + static at::Tensor irfft(const at::Tensor &self, int64_t signal_ndim, + bool normalized, bool onesided, + at::IntArrayRef signal_sizes); + static at::Tensor _fft_with_size(const at::Tensor &self, int64_t signal_ndim, + bool complex_input, bool complex_output, + bool inverse, + at::IntArrayRef checked_signal_sizes, + bool normalized, bool onesided, + at::IntArrayRef output_sizes); + static int64_t _cufft_get_plan_cache_size(int64_t device_index); + static int64_t _cufft_get_plan_cache_max_size(int64_t device_index); + static void _cufft_set_plan_cache_max_size(int64_t device_index, + int64_t max_size); + static void _cufft_clear_plan_cache(int64_t device_index); + static at::Tensor index(const at::Tensor &self, at::TensorList indices); + static at::Tensor &index_copy_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &source); + static at::Tensor index_copy(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &source); + static at::Tensor &index_put_(at::Tensor &self, at::TensorList indices, + const at::Tensor &values, bool accumulate); + static at::Tensor index_put(const at::Tensor &self, at::TensorList indices, + const at::Tensor &values, bool accumulate); + static at::Tensor &_index_put_impl_(at::Tensor &self, at::TensorList indices, + const at::Tensor &values, bool accumulate, + bool unsafe); + static at::Tensor + instance_norm(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, const at::Tensor &running_mean, + const at::Tensor &running_var, bool use_input_stats, + double momentum, double eps, bool cudnn_enabled); + static at::Tensor inverse(const at::Tensor &self); + static at::Tensor &inverse_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor _inverse_helper(const at::Tensor &self); + static at::Tensor isclose(const at::Tensor &self, const at::Tensor &other, + double rtol, double atol, bool equal_nan); + static at::Tensor isnan(const at::Tensor &self); + static bool is_distributed(const at::Tensor &self); + static bool is_floating_point(const at::Tensor &self); + static bool is_complex(const at::Tensor &self); + static bool is_nonzero(const at::Tensor &self); + static bool is_same_size(const at::Tensor &self, const at::Tensor &other); + static bool is_signed(const at::Tensor &self); + static at::Tensor kl_div(const at::Tensor &self, const at::Tensor &target, + int64_t reduction); + static at::Tensor kl_div_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static std::tuple + kthvalue(const at::Tensor &self, int64_t k, int64_t dim, bool keepdim); + static std::tuple + kthvalue_out(at::Tensor &values, at::Tensor &indices, const at::Tensor &self, + int64_t k, int64_t dim, bool keepdim); + static at::Tensor layer_norm(const at::Tensor &input, + at::IntArrayRef normalized_shape, + const at::Tensor &weight, const at::Tensor &bias, + double eps, bool cudnn_enable); + static std::tuple + native_layer_norm(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, int64_t M, int64_t N, double eps); + static std::tuple + native_layer_norm_backward(const at::Tensor &grad_out, + const at::Tensor &input, const at::Tensor &mean, + const at::Tensor &rstd, const at::Tensor &weight, + int64_t M, int64_t N, + std::array output_mask); + static std::tuple + native_layer_norm_double_backward( + const at::Tensor &ggI, const at::Tensor &ggW, const at::Tensor &ggb, + const at::Tensor &gO, const at::Tensor &input, const at::Tensor &mean, + const at::Tensor &rstd, const at::Tensor &weight, int64_t M, int64_t N, + std::array output_mask); + static at::Tensor linear(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias); + static at::Tensor mkldnn_linear(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias); + static at::Tensor fbgemm_linear_int8_weight_fp32_activation( + const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &packed, const at::Tensor &col_offsets, + at::Scalar weight_scale, at::Scalar weight_zero_point, + const at::Tensor &bias); + static at::Tensor fbgemm_linear_int8_weight(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &packed, + const at::Tensor &col_offsets, + at::Scalar weight_scale, + at::Scalar weight_zero_point, + const at::Tensor &bias); + static std::tuple + fbgemm_linear_quantize_weight(const at::Tensor &input); + static at::Tensor fbgemm_pack_gemm_matrix_fp16(const at::Tensor &input); + static at::Tensor + fbgemm_linear_fp16_weight_fp32_activation(const at::Tensor &input, + const at::Tensor &packed_weight, + const at::Tensor &bias); + static at::Tensor fbgemm_linear_fp16_weight(const at::Tensor &input, + const at::Tensor &packed_weight, + const at::Tensor &bias); + static at::Tensor fbgemm_pack_quantized_matrix(const at::Tensor &input); + static at::Tensor fbgemm_pack_quantized_matrix(const at::Tensor &input, + int64_t K, int64_t N); + static at::Tensor linspace(at::Scalar start, at::Scalar end, int64_t steps, + const at::TensorOptions &options); + static at::Tensor &linspace_out(at::Tensor &out, at::Scalar start, + at::Scalar end, int64_t steps); + static at::Tensor log(const at::Tensor &self); + static at::Tensor &log_(at::Tensor &self); + static at::Tensor &log_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor log10(const at::Tensor &self); + static at::Tensor &log10_(at::Tensor &self); + static at::Tensor &log10_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor log1p(const at::Tensor &self); + static at::Tensor &log1p_(at::Tensor &self); + static at::Tensor &log1p_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor log2(const at::Tensor &self); + static at::Tensor &log2_(at::Tensor &self); + static at::Tensor &log2_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor logdet(const at::Tensor &self); + static at::Tensor logspace(at::Scalar start, at::Scalar end, int64_t steps, + double base, const at::TensorOptions &options); + static at::Tensor &logspace_out(at::Tensor &out, at::Scalar start, + at::Scalar end, int64_t steps, double base); + static at::Tensor log_softmax(const at::Tensor &self, int64_t dim, + c10::optional dtype); + static at::Tensor _log_softmax(const at::Tensor &self, int64_t dim, + bool half_to_float); + static at::Tensor _log_softmax_backward_data(const at::Tensor &grad_output, + const at::Tensor &output, + int64_t dim, + const at::Tensor &self); + static at::Tensor logsumexp(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim); + static at::Tensor &logsumexp_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef dim, bool keepdim); + static at::Tensor margin_ranking_loss(const at::Tensor &input1, + const at::Tensor &input2, + const at::Tensor &target, double margin, + int64_t reduction); + static at::Tensor matmul(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &matmul_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor matrix_rank(const at::Tensor &self, double tol, + bool symmetric); + static at::Tensor matrix_rank(const at::Tensor &self, bool symmetric); + static at::Tensor matrix_power(const at::Tensor &self, int64_t n); + static std::tuple max(const at::Tensor &self, + int64_t dim, bool keepdim); + static std::tuple + max_out(at::Tensor &max, at::Tensor &max_values, const at::Tensor &self, + int64_t dim, bool keepdim); + static at::Tensor max_values(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim); + static std::tuple + max_pool1d_with_indices(const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode); + static at::Tensor max_pool1d(const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode); + static at::Tensor max_pool2d(const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode); + static at::Tensor mkldnn_max_pool2d(const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode); + static at::Tensor quantized_max_pool2d(const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation); + static at::Tensor max_pool3d(const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode); + static at::Tensor mean(const at::Tensor &self, + c10::optional dtype); + static at::Tensor mean(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim, c10::optional dtype); + static at::Tensor &mean_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef dim, bool keepdim, + c10::optional dtype); + static std::tuple median(const at::Tensor &self, + int64_t dim, bool keepdim); + static std::tuple + median_out(at::Tensor &values, at::Tensor &indices, const at::Tensor &self, + int64_t dim, bool keepdim); + static std::tuple min(const at::Tensor &self, + int64_t dim, bool keepdim); + static std::tuple + min_out(at::Tensor &min, at::Tensor &min_indices, const at::Tensor &self, + int64_t dim, bool keepdim); + static at::Tensor min_values(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim); + static at::Tensor + mkldnn_convolution(const at::Tensor &self, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef padding, + at::IntArrayRef stride, at::IntArrayRef dilation, + int64_t groups); + static at::Tensor mkldnn_convolution_backward_input( + at::IntArrayRef self_size, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool bias_defined); + static std::tuple mkldnn_convolution_backward_weights( + at::IntArrayRef weight_size, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool bias_defined); + static std::tuple + mkldnn_convolution_backward(const at::Tensor &self, + const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, + at::IntArrayRef stride, at::IntArrayRef dilation, + int64_t groups, std::array output_mask); + static std::tuple + miopen_batch_norm(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, const at::Tensor &running_mean, + const at::Tensor &running_var, bool training, + double exponential_average_factor, double epsilon); + static std::tuple + miopen_batch_norm_backward(const at::Tensor &input, + const at::Tensor &grad_output, + const at::Tensor &weight, + const at::Tensor &running_mean, + const at::Tensor &running_var, + const at::Tensor &save_mean, + const at::Tensor &save_var, double epsilon); + static at::Tensor + miopen_convolution(const at::Tensor &self, const at::Tensor &weight, + const at::Tensor &bias, at::IntArrayRef padding, + at::IntArrayRef stride, at::IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic); + static at::Tensor miopen_convolution_backward_input( + at::IntArrayRef self_size, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic); + static std::tuple + miopen_convolution_backward(const at::Tensor &self, + const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, + at::IntArrayRef stride, at::IntArrayRef dilation, + int64_t groups, bool benchmark, + bool deterministic, + std::array output_mask); + static at::Tensor + miopen_convolution_backward_bias(const at::Tensor &grad_output); + static at::Tensor miopen_convolution_backward_weight( + at::IntArrayRef weight_size, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic); + static at::Tensor miopen_convolution_transpose( + const at::Tensor &self, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef padding, at::IntArrayRef output_padding, + at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic); + static std::tuple + miopen_convolution_transpose_backward( + const at::Tensor &self, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic, std::array output_mask); + static at::Tensor miopen_convolution_transpose_backward_input( + const at::Tensor &grad_output, const at::Tensor &weight, + at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic); + static at::Tensor miopen_convolution_transpose_backward_weight( + at::IntArrayRef weight_size, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic); + static at::Tensor miopen_depthwise_convolution( + const at::Tensor &self, const at::Tensor &weight, const at::Tensor &bias, + at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic); + static at::Tensor miopen_depthwise_convolution_backward_input( + at::IntArrayRef self_size, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic); + static std::tuple + miopen_depthwise_convolution_backward( + const at::Tensor &self, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic, std::array output_mask); + static at::Tensor miopen_depthwise_convolution_backward_weight( + at::IntArrayRef weight_size, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding, at::IntArrayRef stride, + at::IntArrayRef dilation, int64_t groups, bool benchmark, + bool deterministic); + static std::tuple + miopen_rnn(const at::Tensor &input, at::TensorList weight, + int64_t weight_stride0, const at::Tensor &hx, const at::Tensor &cx, + int64_t mode, int64_t hidden_size, int64_t num_layers, + bool batch_first, double dropout, bool train, bool bidirectional, + at::IntArrayRef batch_sizes, const at::Tensor &dropout_state); + static std::tuple> + miopen_rnn_backward(const at::Tensor &input, at::TensorList weight, + int64_t weight_stride0, const at::Tensor &weight_buf, + const at::Tensor &hx, const at::Tensor &cx, + const at::Tensor &output, const at::Tensor &grad_output, + const at::Tensor &grad_hy, const at::Tensor &grad_cy, + int64_t mode, int64_t hidden_size, int64_t num_layers, + bool batch_first, double dropout, bool train, + bool bidirectional, at::IntArrayRef batch_sizes, + const at::Tensor &dropout_state, + const at::Tensor &reserve, + std::array output_mask); + static at::Tensor mm(const at::Tensor &self, const at::Tensor &mat2); + static at::Tensor &mm_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &mat2); + static at::Tensor _sparse_mm(const at::Tensor &sparse, + const at::Tensor &dense); + static std::tuple mode(const at::Tensor &self, + int64_t dim, bool keepdim); + static std::tuple + mode_out(at::Tensor &values, at::Tensor &indices, const at::Tensor &self, + int64_t dim, bool keepdim); + static at::Tensor mul(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &mul_(at::Tensor &self, const at::Tensor &other); + static at::Tensor &mul_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor mul(const at::Tensor &self, at::Scalar other); + static at::Tensor &mul_(at::Tensor &self, at::Scalar other); + static at::Tensor mv(const at::Tensor &self, const at::Tensor &vec); + static at::Tensor &mv_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &vec); + static at::Tensor mvlgamma(const at::Tensor &self, int64_t p); + static at::Tensor &mvlgamma_(at::Tensor &self, int64_t p); + static at::Tensor narrow_copy(const at::Tensor &self, int64_t dim, + int64_t start, int64_t length); + static at::Tensor narrow(const at::Tensor &self, int64_t dim, int64_t start, + int64_t length); + static std::tuple + native_batch_norm(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, const at::Tensor &running_mean, + const at::Tensor &running_var, bool training, + double momentum, double eps); + static std::tuple + batch_norm_stats(const at::Tensor &input, double eps); + static at::Tensor batch_norm_elemt(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &mean, + const at::Tensor &invstd, double eps); + static std::tuple batch_norm_gather_stats( + const at::Tensor &input, const at::Tensor &mean, const at::Tensor &invstd, + const at::Tensor &running_mean, const at::Tensor &running_var, + double momentum, double eps, int64_t count); + static std::tuple batch_norm_gather_stats_with_counts( + const at::Tensor &input, const at::Tensor &mean, const at::Tensor &invstd, + const at::Tensor &running_mean, const at::Tensor &running_var, + double momentum, double eps, at::IntArrayRef counts); + static std::tuple + native_batch_norm_backward(const at::Tensor &grad_out, + const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &running_mean, + const at::Tensor &running_var, + const at::Tensor &save_mean, + const at::Tensor &save_invstd, bool train, + double eps, std::array output_mask); + static std::tuple + batch_norm_backward_reduce(const at::Tensor &grad_out, + const at::Tensor &input, const at::Tensor &mean, + const at::Tensor &invstd, const at::Tensor &weight, + bool input_g, bool weight_g, bool bias_g); + static at::Tensor + batch_norm_backward_elemt(const at::Tensor &grad_out, const at::Tensor &input, + const at::Tensor &mean, const at::Tensor &invstd, + const at::Tensor &weight, const at::Tensor &mean_dy, + const at::Tensor &mean_dy_xmu); + static std::tuple + batch_norm_update_stats(const at::Tensor &input, + const at::Tensor &running_mean, + const at::Tensor &running_var, double momentum); + static at::Tensor _nnpack_spatial_convolution(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + at::IntArrayRef padding); + static std::tuple + _nnpack_spatial_convolution_backward(const at::Tensor &input, + const at::Tensor &grad_output, + const at::Tensor &weight, + at::IntArrayRef padding, + std::array output_mask); + static at::Tensor _nnpack_spatial_convolution_backward_input( + const at::Tensor &input, const at::Tensor &grad_output, + const at::Tensor &weight, at::IntArrayRef padding); + static at::Tensor _nnpack_spatial_convolution_backward_weight( + const at::Tensor &input, at::IntArrayRef weightsize, + const at::Tensor &grad_output, at::IntArrayRef padding); + static at::Tensor &ones_out(at::Tensor &out, at::IntArrayRef size); + static at::Tensor pairwise_distance(const at::Tensor &x1, + const at::Tensor &x2, double p, + double eps, bool keepdim); + static at::Tensor cdist(const at::Tensor &x1, const at::Tensor &x2, double p); + static at::Tensor _cdist_backward(const at::Tensor &grad, + const at::Tensor &x1, const at::Tensor &x2, + double p, const at::Tensor &cdist); + static at::Tensor pdist(const at::Tensor &self, double p); + static at::Tensor _pdist_forward(const at::Tensor &self, double p); + static at::Tensor _pdist_backward(const at::Tensor &grad, + const at::Tensor &self, double p, + const at::Tensor &pdist); + static at::Tensor cosine_similarity(const at::Tensor &x1, + const at::Tensor &x2, int64_t dim, + double eps); + static at::Tensor permute(const at::Tensor &self, at::IntArrayRef dims); + static at::Tensor numpy_T(const at::Tensor &self); + static at::Tensor pixel_shuffle(const at::Tensor &self, + int64_t upscale_factor); + static bool is_pinned(const at::Tensor &self); + static at::Tensor pin_memory(const at::Tensor &self); + static at::Tensor pinverse(const at::Tensor &self, double rcond); + static at::Tensor poisson_nll_loss(const at::Tensor &input, + const at::Tensor &target, bool log_input, + bool full, double eps, int64_t reduction); + static at::Tensor scalar_tensor(at::Scalar s, + const at::TensorOptions &options); + static at::Tensor rand(at::IntArrayRef size, + const at::TensorOptions &options); + static at::Tensor rand(at::IntArrayRef size, at::Generator *generator, + const at::TensorOptions &options); + static at::Tensor &rand_out(at::Tensor &out, at::IntArrayRef size); + static at::Tensor &rand_out(at::Tensor &out, at::IntArrayRef size, + at::Generator *generator); + static at::Tensor rand_like(const at::Tensor &self); + static at::Tensor rand_like(const at::Tensor &self, + const at::TensorOptions &options); + static at::Tensor randint(int64_t high, at::IntArrayRef size, + const at::TensorOptions &options); + static at::Tensor randint(int64_t high, at::IntArrayRef size, + at::Generator *generator, + const at::TensorOptions &options); + static at::Tensor randint(int64_t low, int64_t high, at::IntArrayRef size, + const at::TensorOptions &options); + static at::Tensor randint(int64_t low, int64_t high, at::IntArrayRef size, + at::Generator *generator, + const at::TensorOptions &options); + static at::Tensor &randint_out(at::Tensor &out, int64_t high, + at::IntArrayRef size); + static at::Tensor &randint_out(at::Tensor &out, int64_t high, + at::IntArrayRef size, + at::Generator *generator); + static at::Tensor &randint_out(at::Tensor &out, int64_t low, int64_t high, + at::IntArrayRef size); + static at::Tensor &randint_out(at::Tensor &out, int64_t low, int64_t high, + at::IntArrayRef size, + at::Generator *generator); + static at::Tensor randint_like(const at::Tensor &self, int64_t high); + static at::Tensor randint_like(const at::Tensor &self, int64_t low, + int64_t high); + static at::Tensor randint_like(const at::Tensor &self, int64_t high, + const at::TensorOptions &options); + static at::Tensor randint_like(const at::Tensor &self, int64_t low, + int64_t high, + const at::TensorOptions &options); + static at::Tensor randn(at::IntArrayRef size, + const at::TensorOptions &options); + static at::Tensor randn(at::IntArrayRef size, at::Generator *generator, + const at::TensorOptions &options); + static at::Tensor &randn_out(at::Tensor &out, at::IntArrayRef size); + static at::Tensor &randn_out(at::Tensor &out, at::IntArrayRef size, + at::Generator *generator); + static at::Tensor randn_like(const at::Tensor &self); + static at::Tensor randn_like(const at::Tensor &self, + const at::TensorOptions &options); + static at::Tensor randperm(int64_t n, const at::TensorOptions &options); + static at::Tensor randperm(int64_t n, at::Generator *generator, + const at::TensorOptions &options); + static at::Tensor &randperm_out(at::Tensor &out, int64_t n); + static at::Tensor &randperm_out(at::Tensor &out, int64_t n, + at::Generator *generator); + static at::Tensor range(at::Scalar start, at::Scalar end, at::Scalar step, + const at::TensorOptions &options); + static at::Tensor range(at::Scalar start, at::Scalar end, + const at::TensorOptions &options); + static at::Tensor &range_out(at::Tensor &out, at::Scalar start, + at::Scalar end, at::Scalar step); + static at::Tensor reciprocal(const at::Tensor &self); + static at::Tensor &reciprocal_(at::Tensor &self); + static at::Tensor &reciprocal_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor neg(const at::Tensor &self); + static at::Tensor &neg_(at::Tensor &self); + static at::Tensor &neg_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor repeat(const at::Tensor &self, at::IntArrayRef repeats); + static at::Tensor repeat_interleave(const at::Tensor &repeats); + static at::Tensor repeat_interleave(const at::Tensor &self, + const at::Tensor &repeats, + c10::optional dim); + static at::Tensor repeat_interleave(const at::Tensor &self, int64_t repeats, + c10::optional dim); + static at::Tensor reshape(const at::Tensor &self, at::IntArrayRef shape); + static at::Tensor _mkldnn_reshape(const at::Tensor &self, + at::IntArrayRef shape); + static at::Tensor reshape_as(const at::Tensor &self, const at::Tensor &other); + static at::Tensor round(const at::Tensor &self); + static at::Tensor &round_(at::Tensor &self); + static at::Tensor &round_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor rrelu(const at::Tensor &self, at::Scalar lower, + at::Scalar upper, bool training, + at::Generator *generator); + static at::Tensor &rrelu_(at::Tensor &self, at::Scalar lower, + at::Scalar upper, bool training, + at::Generator *generator); + static at::Tensor relu(const at::Tensor &self); + static at::Tensor &relu_(at::Tensor &self); + static at::Tensor prelu(const at::Tensor &self, const at::Tensor &weight); + static std::tuple + prelu_backward(const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight); + static at::Tensor gelu(const at::Tensor &self); + static at::Tensor gelu_backward(const at::Tensor &grad, + const at::Tensor &self); + static at::Tensor hardshrink(const at::Tensor &self, at::Scalar lambd); + static at::Tensor hardshrink_backward(const at::Tensor &grad_out, + const at::Tensor &self, + at::Scalar lambd); + static at::Tensor rsqrt(const at::Tensor &self); + static at::Tensor &rsqrt_(at::Tensor &self); + static at::Tensor &rsqrt_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor select(const at::Tensor &self, int64_t dim, int64_t index); + static at::Tensor selu(const at::Tensor &self); + static at::Tensor &selu_(at::Tensor &self); + static at::Tensor celu(const at::Tensor &self, at::Scalar alpha); + static at::Tensor &celu_(at::Tensor &self, at::Scalar alpha); + static at::Tensor sigmoid(const at::Tensor &self); + static at::Tensor &sigmoid_(at::Tensor &self); + static at::Tensor &sigmoid_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor sin(const at::Tensor &self); + static at::Tensor &sin_(at::Tensor &self); + static at::Tensor &sin_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor sinh(const at::Tensor &self); + static at::Tensor &sinh_(at::Tensor &self); + static at::Tensor &sinh_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor detach(const at::Tensor &self); + static at::Tensor &detach_(at::Tensor &self); + static int64_t size(const at::Tensor &self, int64_t dim); + static at::Tensor slice(const at::Tensor &self, int64_t dim, int64_t start, + int64_t end, int64_t step); + static std::tuple slogdet(const at::Tensor &self); + static at::Tensor smm(const at::Tensor &self, const at::Tensor &mat2); + static at::Tensor softmax(const at::Tensor &self, int64_t dim, + c10::optional dtype); + static at::Tensor _softmax(const at::Tensor &self, int64_t dim, + bool half_to_float); + static at::Tensor _softmax_backward_data(const at::Tensor &grad_output, + const at::Tensor &output, + int64_t dim, const at::Tensor &self); + static at::Tensor &_sparse_add_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other, at::Scalar alpha); + static at::Tensor &_sparse_dense_add_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other, + at::Scalar alpha); + static at::Tensor &_sparse_div_zerodim_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other); + static at::Tensor &_sparse_div_scalar_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar other); + static at::Tensor &_sparse_mul_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor &_sparse_mul_zerodim_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &other); + static at::Tensor &_sparse_mul_scalar_out(at::Tensor &out, + const at::Tensor &self, + at::Scalar other); + static std::vector split(const at::Tensor &self, + int64_t split_size, int64_t dim); + static std::vector split_with_sizes(const at::Tensor &self, + at::IntArrayRef split_sizes, + int64_t dim); + static at::Tensor squeeze(const at::Tensor &self); + static at::Tensor squeeze(const at::Tensor &self, int64_t dim); + static at::Tensor &squeeze_(at::Tensor &self); + static at::Tensor &squeeze_(at::Tensor &self, int64_t dim); + static at::Tensor sspaddmm(const at::Tensor &self, const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &sspaddmm_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor stack(at::TensorList tensors, int64_t dim); + static at::Tensor &stack_out(at::Tensor &out, at::TensorList tensors, + int64_t dim); + static at::Tensor stft(const at::Tensor &self, int64_t n_fft, + c10::optional hop_length, + c10::optional win_length, + const at::Tensor &window, bool normalized, + bool onesided); + static int64_t stride(const at::Tensor &self, int64_t dim); + static at::Tensor sum(const at::Tensor &self, + c10::optional dtype); + static at::Tensor sum(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim, c10::optional dtype); + static at::Tensor &sum_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef dim, bool keepdim, + c10::optional dtype); + static at::Tensor sum_to_size(const at::Tensor &self, at::IntArrayRef size); + static at::Tensor sqrt(const at::Tensor &self); + static at::Tensor &sqrt_(at::Tensor &self); + static at::Tensor &sqrt_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor std(const at::Tensor &self, bool unbiased); + static at::Tensor std(const at::Tensor &self, at::IntArrayRef dim, + bool unbiased, bool keepdim); + static std::tuple std_mean(const at::Tensor &self, + bool unbiased); + static std::tuple std_mean(const at::Tensor &self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim); + static at::Tensor &std_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef dim, bool unbiased, bool keepdim); + static at::Tensor prod(const at::Tensor &self, + c10::optional dtype); + static at::Tensor prod(const at::Tensor &self, int64_t dim, bool keepdim, + c10::optional dtype); + static at::Tensor &prod_out(at::Tensor &out, const at::Tensor &self, + int64_t dim, bool keepdim, + c10::optional dtype); + static at::Tensor t(const at::Tensor &self); + static at::Tensor &t_(at::Tensor &self); + static at::Tensor tan(const at::Tensor &self); + static at::Tensor &tan_(at::Tensor &self); + static at::Tensor &tan_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor tanh(const at::Tensor &self); + static at::Tensor &tanh_(at::Tensor &self); + static at::Tensor &tanh_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor tensordot(const at::Tensor &self, const at::Tensor &other, + at::IntArrayRef dims_self, + at::IntArrayRef dims_other); + static at::Tensor threshold(const at::Tensor &self, at::Scalar threshold, + at::Scalar value); + static at::Tensor &threshold_(at::Tensor &self, at::Scalar threshold, + at::Scalar value); + static at::Tensor &threshold_out(at::Tensor &out, const at::Tensor &self, + at::Scalar threshold, at::Scalar value); + static at::Tensor threshold_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar threshold); + static at::Tensor transpose(const at::Tensor &self, int64_t dim0, + int64_t dim1); + static at::Tensor _mkldnn_transpose(const at::Tensor &self, int64_t dim0, + int64_t dim1); + static at::Tensor &transpose_(at::Tensor &self, int64_t dim0, int64_t dim1); + static at::Tensor &_mkldnn_transpose_(at::Tensor &self, int64_t dim0, + int64_t dim1); + static at::Tensor one_hot(const at::Tensor &self, int64_t num_classes); + static at::Tensor flip(const at::Tensor &self, at::IntArrayRef dims); + static at::Tensor roll(const at::Tensor &self, at::IntArrayRef shifts, + at::IntArrayRef dims); + static at::Tensor rot90(const at::Tensor &self, int64_t k, + at::IntArrayRef dims); + static at::Tensor trapz(const at::Tensor &y, const at::Tensor &x, + int64_t dim); + static at::Tensor trapz(const at::Tensor &y, double dx, int64_t dim); + static at::Tensor _trilinear(const at::Tensor &i1, const at::Tensor &i2, + const at::Tensor &i3, at::IntArrayRef expand1, + at::IntArrayRef expand2, at::IntArrayRef expand3, + at::IntArrayRef sumdim, int64_t unroll_dim); + static at::Tensor triplet_margin_loss(const at::Tensor &anchor, + const at::Tensor &positive, + const at::Tensor &negative, + double margin, double p, double eps, + bool swap, int64_t reduction); + static at::Tensor trunc(const at::Tensor &self); + static at::Tensor &trunc_(at::Tensor &self); + static at::Tensor &trunc_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor type_as(const at::Tensor &self, const at::Tensor &other); + static bool _has_compatible_shallow_copy_type(const at::Tensor &self, + const at::Tensor &from); + static std::tuple + _unique(const at::Tensor &self, bool sorted, bool return_inverse); + static std::tuple + unique_dim(const at::Tensor &self, int64_t dim, bool sorted, + bool return_inverse, bool return_counts); + static std::tuple + unique_consecutive(const at::Tensor &self, bool return_inverse, + bool return_counts, c10::optional dim); + static std::tuple + unique_dim_consecutive(const at::Tensor &self, int64_t dim, + bool return_inverse, bool return_counts); + static std::tuple + _unique2(const at::Tensor &self, bool sorted, bool return_inverse, + bool return_counts); + static at::Tensor _unsafe_view(const at::Tensor &self, at::IntArrayRef size); + static at::Tensor unsqueeze(const at::Tensor &self, int64_t dim); + static at::Tensor &unsqueeze_(at::Tensor &self, int64_t dim); + static at::Tensor var(const at::Tensor &self, bool unbiased); + static at::Tensor var(const at::Tensor &self, at::IntArrayRef dim, + bool unbiased, bool keepdim); + static at::Tensor &var_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef dim, bool unbiased, bool keepdim); + static std::tuple var_mean(const at::Tensor &self, + bool unbiased); + static std::tuple var_mean(const at::Tensor &self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim); + static at::Tensor view_as(const at::Tensor &self, const at::Tensor &other); + static at::Tensor where(const at::Tensor &condition, const at::Tensor &self, + const at::Tensor &other); + static std::vector where(const at::Tensor &condition); + static at::Tensor _s_where(const at::Tensor &condition, + const at::Tensor &self, const at::Tensor &other); + static at::Tensor norm_except_dim(const at::Tensor &v, int64_t pow, + int64_t dim); + static at::Tensor _weight_norm(const at::Tensor &v, const at::Tensor &g, + int64_t dim); + static std::tuple + _weight_norm_cuda_interface(const at::Tensor &v, const at::Tensor &g, + int64_t dim); + static std::tuple + _weight_norm_cuda_interface_backward(const at::Tensor &grad_w, + const at::Tensor &saved_v, + const at::Tensor &saved_g, + const at::Tensor &saved_norms, + int64_t dim); + static std::tuple + _weight_norm_differentiable_backward(const at::Tensor &grad_w, + const at::Tensor &saved_v, + const at::Tensor &saved_g, + const at::Tensor &saved_norms, + int64_t dim); + static at::Tensor &zeros_out(at::Tensor &out, at::IntArrayRef size); + static at::Tensor _standard_gamma_grad(const at::Tensor &self, + const at::Tensor &output); + static at::Tensor _standard_gamma(const at::Tensor &self, + at::Generator *generator); + static at::Tensor _dirichlet_grad(const at::Tensor &x, + const at::Tensor &alpha, + const at::Tensor &total); + static at::Tensor _sample_dirichlet(const at::Tensor &self, + at::Generator *generator); + static at::Tensor poisson(const at::Tensor &self, at::Generator *generator); + static at::Tensor native_norm(const at::Tensor &self, at::Scalar p); + static at::Tensor _sparse_sum(const at::Tensor &self); + static at::Tensor _sparse_sum(const at::Tensor &self, at::ScalarType dtype); + static at::Tensor _sparse_sum(const at::Tensor &self, at::IntArrayRef dim); + static at::Tensor _sparse_sum(const at::Tensor &self, at::IntArrayRef dim, + at::ScalarType dtype); + static at::Tensor _sparse_sum_backward(const at::Tensor &grad, + const at::Tensor &self, + at::IntArrayRef dim); + static at::Tensor norm(const at::Tensor &self, c10::optional p, + at::ScalarType dtype); + static at::Tensor norm(const at::Tensor &self, at::Scalar p); + static at::Tensor norm(const at::Tensor &self, c10::optional p, + at::IntArrayRef dim, bool keepdim, + at::ScalarType dtype); + static at::Tensor norm(const at::Tensor &self, c10::optional p, + at::IntArrayRef dim, bool keepdim); + static at::Tensor &norm_out(at::Tensor &out, const at::Tensor &self, + c10::optional p, at::IntArrayRef dim, + bool keepdim, at::ScalarType dtype); + static at::Tensor &norm_out(at::Tensor &out, const at::Tensor &self, + c10::optional p, at::IntArrayRef dim, + bool keepdim); + static at::Tensor frobenius_norm(const at::Tensor &self); + static at::Tensor frobenius_norm(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim); + static at::Tensor &frobenius_norm_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef dim, bool keepdim); + static at::Tensor nuclear_norm(const at::Tensor &self, bool keepdim); + static at::Tensor &nuclear_norm_out(at::Tensor &out, const at::Tensor &self, + bool keepdim); + static at::Tensor nuclear_norm(const at::Tensor &self, at::IntArrayRef dim, + bool keepdim); + static at::Tensor &nuclear_norm_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef dim, bool keepdim); + static at::Tensor clone(const at::Tensor &self); + static at::Tensor &resize_as_(at::Tensor &self, + const at::Tensor &the_template); + static at::Tensor &pow_out(at::Tensor &out, const at::Tensor &self, + at::Scalar exponent); + static at::Tensor pow(const at::Tensor &self, at::Scalar exponent); + static at::Tensor &zero_(at::Tensor &self); + static at::Tensor &sub_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other, at::Scalar alpha); + static at::Tensor sub(const at::Tensor &self, const at::Tensor &other, + at::Scalar alpha); + static at::Tensor &sub_(at::Tensor &self, const at::Tensor &other, + at::Scalar alpha); + static at::Tensor sub(const at::Tensor &self, at::Scalar other, + at::Scalar alpha); + static at::Tensor &sub_(at::Tensor &self, at::Scalar other, at::Scalar alpha); + static at::Tensor rsub(const at::Tensor &self, const at::Tensor &other, + at::Scalar alpha); + static at::Tensor rsub(const at::Tensor &self, at::Scalar other, + at::Scalar alpha); + static at::Tensor &s_native_addmm_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor s_native_addmm(const at::Tensor &self, + const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &s_native_addmm_(at::Tensor &self, const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor _sparse_addmm(const at::Tensor &self, + const at::Tensor &sparse, + const at::Tensor &dense, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &addmm_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &mat1, const at::Tensor &mat2, + at::Scalar beta, at::Scalar alpha); + static at::Tensor addmm(const at::Tensor &self, const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &addmm_(at::Tensor &self, const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor sparse_coo_tensor(at::IntArrayRef size, + const at::TensorOptions &options); + static at::Tensor sparse_coo_tensor(const at::Tensor &indices, + const at::Tensor &values, + const at::TensorOptions &options); + static at::Tensor sparse_coo_tensor(const at::Tensor &indices, + const at::Tensor &values, + at::IntArrayRef size, + const at::TensorOptions &options); + static at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor &indices, + const at::Tensor &values, + at::IntArrayRef size, + const at::TensorOptions &options); + static at::Tensor + _sparse_coo_tensor_with_dims(int64_t sparse_dim, int64_t dense_dim, + at::IntArrayRef size, + const at::TensorOptions &options); + static at::Tensor _sparse_coo_tensor_with_dims_and_tensors( + int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, + const at::Tensor &indices, const at::Tensor &values, + const at::TensorOptions &options); + static at::Tensor &sparse_resize_(at::Tensor &self, at::IntArrayRef size, + int64_t sparse_dim, int64_t dense_dim); + static at::Tensor &sparse_resize_and_clear_(at::Tensor &self, + at::IntArrayRef size, + int64_t sparse_dim, + int64_t dense_dim); + static at::Tensor sparse_mask(const at::Tensor &self, const at::Tensor &mask); + static at::Tensor to_dense(const at::Tensor &self); + static at::Tensor to_dense_backward(const at::Tensor &grad, + const at::Tensor &input); + static int64_t sparse_dim(const at::Tensor &self); + static int64_t _dimI(const at::Tensor &self); + static int64_t dense_dim(const at::Tensor &self); + static int64_t _dimV(const at::Tensor &self); + static int64_t _nnz(const at::Tensor &self); + static at::Tensor coalesce(const at::Tensor &self); + static bool is_coalesced(const at::Tensor &self); + static at::Tensor _indices(const at::Tensor &self); + static at::Tensor _values(const at::Tensor &self); + static at::Tensor &_coalesced_(at::Tensor &self, bool coalesced); + static at::Tensor indices(const at::Tensor &self); + static at::Tensor values(const at::Tensor &self); + static at::Tensor &hspmm_out(at::Tensor &out, const at::Tensor &mat1, + const at::Tensor &mat2); + static at::Tensor hspmm(const at::Tensor &mat1, const at::Tensor &mat2); + static at::Tensor ©_sparse_to_sparse_(at::Tensor &self, + const at::Tensor &src, + bool non_blocking); + static std::vector unbind(const at::Tensor &self, int64_t dim); + static at::Tensor to_sparse(const at::Tensor &self, int64_t sparse_dim); + static at::Tensor to_sparse(const at::Tensor &self); + static at::Tensor to_mkldnn(const at::Tensor &self); + static at::Tensor mkldnn_reorder_conv2d_weight(const at::Tensor &self, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups); + static at::Tensor to_mkldnn_backward(const at::Tensor &grad, + const at::Tensor &input); + static at::Tensor quantize_linear(const at::Tensor &self, double scale, + int64_t zero_point, at::ScalarType dtype); + static at::Tensor quantize_linear_per_channel(const at::Tensor &self, + const at::Tensor &scales, + const at::Tensor &zero_points, + at::IntArrayRef axis, + at::ScalarType dtype); + static at::Tensor dequantize(const at::Tensor &self); + static at::Tensor _dequantize_linear(const at::Tensor &self, double scale, + int64_t zero_point, + at::ScalarType dtype); + static double q_scale(const at::Tensor &self); + static int64_t q_zero_point(const at::Tensor &self); + static at::Tensor q_per_channel_scales(const at::Tensor &self); + static at::Tensor q_per_channel_zero_points(const at::Tensor &self); + static at::Tensor int_repr(const at::Tensor &self); + static at::Tensor _per_tensor_affine_qtensor(const at::Tensor &self, + double scale, + int64_t zero_point); + static at::Tensor _per_channel_affine_qtensor(const at::Tensor &self, + const at::Tensor &scale, + const at::Tensor &zero_point, + at::IntArrayRef axis); + static at::QScheme qscheme(const at::Tensor &self); + static at::Tensor fake_quantize_per_tensor_affine(const at::Tensor &self, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max); + static at::Tensor fake_quantize_per_tensor_affine_backward( + const at::Tensor &grad, const at::Tensor &self, double scale, + int64_t zero_point, int64_t quant_min, int64_t quant_max); + static at::Tensor to(const at::Tensor &self, const at::TensorOptions &options, + bool non_blocking, bool copy); + static at::Tensor to(const at::Tensor &self, c10::Device device, + at::ScalarType dtype, bool non_blocking, bool copy); + static at::Tensor to(const at::Tensor &self, at::ScalarType dtype, + bool non_blocking, bool copy); + static at::Tensor to(const at::Tensor &self, const at::Tensor &other, + bool non_blocking, bool copy); + static std::vector meshgrid(at::TensorList tensors); + static at::Tensor cartesian_prod(at::TensorList tensors); + static at::Tensor combinations(const at::Tensor &self, int64_t r, + bool with_replacement); + static at::Scalar item(const at::Tensor &self); + static at::Scalar _local_scalar_dense(const at::Tensor &self); + static std::tuple + _thnn_fused_lstm_cell(const at::Tensor &input_gates, + const at::Tensor &hidden_gates, const at::Tensor &cx, + const at::Tensor &input_bias, + const at::Tensor &hidden_bias); + static std::tuple + _thnn_fused_lstm_cell_backward(const at::Tensor &grad_hy, + const at::Tensor &grad_cy, + const at::Tensor &cx, const at::Tensor &cy, + const at::Tensor &workspace, bool has_bias); + static std::tuple + _thnn_fused_gru_cell(const at::Tensor &input_gates, + const at::Tensor &hidden_gates, const at::Tensor &hx, + const at::Tensor &input_bias, + const at::Tensor &hidden_bias); + static std::tuple + _thnn_fused_gru_cell_backward(const at::Tensor &grad_hy, + const at::Tensor &workspace, bool has_bias); + static std::tuple + lstm(const at::Tensor &input, at::TensorList hx, at::TensorList params, + bool has_biases, int64_t num_layers, double dropout, bool train, + bool bidirectional, bool batch_first); + static std::tuple + lstm(const at::Tensor &data, const at::Tensor &batch_sizes, at::TensorList hx, + at::TensorList params, bool has_biases, int64_t num_layers, + double dropout, bool train, bool bidirectional); + static std::tuple + gru(const at::Tensor &input, const at::Tensor &hx, at::TensorList params, + bool has_biases, int64_t num_layers, double dropout, bool train, + bool bidirectional, bool batch_first); + static std::tuple + gru(const at::Tensor &data, const at::Tensor &batch_sizes, + const at::Tensor &hx, at::TensorList params, bool has_biases, + int64_t num_layers, double dropout, bool train, bool bidirectional); + static std::tuple + rnn_tanh(const at::Tensor &input, const at::Tensor &hx, at::TensorList params, + bool has_biases, int64_t num_layers, double dropout, bool train, + bool bidirectional, bool batch_first); + static std::tuple + rnn_tanh(const at::Tensor &data, const at::Tensor &batch_sizes, + const at::Tensor &hx, at::TensorList params, bool has_biases, + int64_t num_layers, double dropout, bool train, bool bidirectional); + static std::tuple + rnn_relu(const at::Tensor &input, const at::Tensor &hx, at::TensorList params, + bool has_biases, int64_t num_layers, double dropout, bool train, + bool bidirectional, bool batch_first); + static std::tuple + rnn_relu(const at::Tensor &data, const at::Tensor &batch_sizes, + const at::Tensor &hx, at::TensorList params, bool has_biases, + int64_t num_layers, double dropout, bool train, bool bidirectional); + static std::tuple + lstm_cell(const at::Tensor &input, at::TensorList hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, + const at::Tensor &b_hh); + static at::Tensor gru_cell(const at::Tensor &input, const at::Tensor &hx, + const at::Tensor &w_ih, const at::Tensor &w_hh, + const at::Tensor &b_ih, const at::Tensor &b_hh); + static at::Tensor rnn_tanh_cell(const at::Tensor &input, const at::Tensor &hx, + const at::Tensor &w_ih, + const at::Tensor &w_hh, + const at::Tensor &b_ih, + const at::Tensor &b_hh); + static at::Tensor rnn_relu_cell(const at::Tensor &input, const at::Tensor &hx, + const at::Tensor &w_ih, + const at::Tensor &w_hh, + const at::Tensor &b_ih, + const at::Tensor &b_hh); + static std::tuple + quantized_lstm(const at::Tensor &input, at::TensorList hx, + at::TensorList params, bool has_biases, int64_t num_layers, + double dropout, bool train, bool bidirectional, + bool batch_first, c10::optional dtype, + bool use_dynamic); + static std::tuple + quantized_gru(const at::Tensor &input, const at::Tensor &hx, + at::TensorList params, bool has_biases, int64_t num_layers, + double dropout, bool train, bool bidirectional, + bool batch_first); + static std::tuple + quantized_gru(const at::Tensor &data, const at::Tensor &batch_sizes, + const at::Tensor &hx, at::TensorList params, bool has_biases, + int64_t num_layers, double dropout, bool train, + bool bidirectional); + static std::tuple quantized_lstm_cell( + const at::Tensor &input, at::TensorList hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, const at::Tensor &b_hh, + const at::Tensor &packed_ih, const at::Tensor &packed_hh, + const at::Tensor &col_offsets_ih, const at::Tensor &col_offsets_hh, + at::Scalar scale_ih, at::Scalar scale_hh, at::Scalar zero_point_ih, + at::Scalar zero_point_hh); + static at::Tensor quantized_gru_cell( + const at::Tensor &input, const at::Tensor &hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, const at::Tensor &b_hh, + const at::Tensor &packed_ih, const at::Tensor &packed_hh, + const at::Tensor &col_offsets_ih, const at::Tensor &col_offsets_hh, + at::Scalar scale_ih, at::Scalar scale_hh, at::Scalar zero_point_ih, + at::Scalar zero_point_hh); + static at::Tensor quantized_rnn_relu_cell( + const at::Tensor &input, const at::Tensor &hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, const at::Tensor &b_hh, + const at::Tensor &packed_ih, const at::Tensor &packed_hh, + const at::Tensor &col_offsets_ih, const at::Tensor &col_offsets_hh, + at::Scalar scale_ih, at::Scalar scale_hh, at::Scalar zero_point_ih, + at::Scalar zero_point_hh); + static at::Tensor quantized_rnn_tanh_cell( + const at::Tensor &input, const at::Tensor &hx, const at::Tensor &w_ih, + const at::Tensor &w_hh, const at::Tensor &b_ih, const at::Tensor &b_hh, + const at::Tensor &packed_ih, const at::Tensor &packed_hh, + const at::Tensor &col_offsets_ih, const at::Tensor &col_offsets_hh, + at::Scalar scale_ih, at::Scalar scale_hh, at::Scalar zero_point_ih, + at::Scalar zero_point_hh); + static std::tuple + _pack_padded_sequence(const at::Tensor &input, const at::Tensor &lengths, + bool batch_first); + static at::Tensor _pack_padded_sequence_backward( + const at::Tensor &grad, at::IntArrayRef input_size, + const at::Tensor &batch_sizes, bool batch_first); + static std::tuple + _pad_packed_sequence(const at::Tensor &data, const at::Tensor &batch_sizes, + bool batch_first, at::Scalar padding_value, + int64_t total_length); + static at::Tensor &set_(at::Tensor &self, at::Storage source); + static at::Tensor &set_(at::Tensor &self, at::Storage source, + int64_t storage_offset, at::IntArrayRef size, + at::IntArrayRef stride); + static at::Tensor &set_(at::Tensor &self, const at::Tensor &source); + static at::Tensor &set_(at::Tensor &self); + static at::Tensor &set_quantizer_(at::Tensor &self, + at::ConstQuantizerPtr quantizer); + static bool is_set_to(const at::Tensor &self, const at::Tensor &tensor); + static at::Tensor &masked_fill_(at::Tensor &self, const at::Tensor &mask, + at::Scalar value); + static at::Tensor masked_fill(const at::Tensor &self, const at::Tensor &mask, + at::Scalar value); + static at::Tensor &masked_fill_(at::Tensor &self, const at::Tensor &mask, + const at::Tensor &value); + static at::Tensor masked_fill(const at::Tensor &self, const at::Tensor &mask, + const at::Tensor &value); + static at::Tensor &masked_scatter_(at::Tensor &self, const at::Tensor &mask, + const at::Tensor &source); + static at::Tensor masked_scatter(const at::Tensor &self, + const at::Tensor &mask, + const at::Tensor &source); + static at::Tensor view(const at::Tensor &self, at::IntArrayRef size); + static at::Tensor &put_(at::Tensor &self, const at::Tensor &index, + const at::Tensor &source, bool accumulate); + static at::Tensor &index_add_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &source); + static at::Tensor index_add(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &source); + static at::Tensor &index_fill_(at::Tensor &self, int64_t dim, + const at::Tensor &index, at::Scalar value); + static at::Tensor index_fill(const at::Tensor &self, int64_t dim, + const at::Tensor &index, at::Scalar value); + static at::Tensor &index_fill_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &value); + static at::Tensor index_fill(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &value); + static at::Tensor &scatter_(at::Tensor &self, int64_t dim, + const at::Tensor &index, const at::Tensor &src); + static at::Tensor scatter(const at::Tensor &self, int64_t dim, + const at::Tensor &index, const at::Tensor &src); + static at::Tensor &scatter_(at::Tensor &self, int64_t dim, + const at::Tensor &index, at::Scalar value); + static at::Tensor scatter(const at::Tensor &self, int64_t dim, + const at::Tensor &index, at::Scalar value); + static at::Tensor &scatter_add_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &src); + static at::Tensor scatter_add(const at::Tensor &self, int64_t dim, + const at::Tensor &index, const at::Tensor &src); + static at::Tensor <_(at::Tensor &self, at::Scalar other); + static at::Tensor <_(at::Tensor &self, const at::Tensor &other); + static at::Tensor >_(at::Tensor &self, at::Scalar other); + static at::Tensor >_(at::Tensor &self, const at::Tensor &other); + static at::Tensor &le_(at::Tensor &self, at::Scalar other); + static at::Tensor &le_(at::Tensor &self, const at::Tensor &other); + static at::Tensor &ge_(at::Tensor &self, at::Scalar other); + static at::Tensor &ge_(at::Tensor &self, const at::Tensor &other); + static at::Tensor &eq_(at::Tensor &self, at::Scalar other); + static at::Tensor &eq_(at::Tensor &self, const at::Tensor &other); + static at::Tensor &ne_(at::Tensor &self, at::Scalar other); + static at::Tensor &ne_(at::Tensor &self, const at::Tensor &other); + static at::Tensor __and__(const at::Tensor &self, at::Scalar other); + static at::Tensor __and__(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &__iand__(at::Tensor &self, at::Scalar other); + static at::Tensor &__iand__(at::Tensor &self, const at::Tensor &other); + static at::Tensor __or__(const at::Tensor &self, at::Scalar other); + static at::Tensor __or__(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &__ior__(at::Tensor &self, at::Scalar other); + static at::Tensor &__ior__(at::Tensor &self, const at::Tensor &other); + static at::Tensor __xor__(const at::Tensor &self, at::Scalar other); + static at::Tensor __xor__(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &__ixor__(at::Tensor &self, at::Scalar other); + static at::Tensor &__ixor__(at::Tensor &self, const at::Tensor &other); + static at::Tensor __lshift__(const at::Tensor &self, at::Scalar other); + static at::Tensor __lshift__(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &__ilshift__(at::Tensor &self, at::Scalar other); + static at::Tensor &__ilshift__(at::Tensor &self, const at::Tensor &other); + static at::Tensor __rshift__(const at::Tensor &self, at::Scalar other); + static at::Tensor __rshift__(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &__irshift__(at::Tensor &self, at::Scalar other); + static at::Tensor &__irshift__(at::Tensor &self, const at::Tensor &other); + static at::Tensor &lgamma_(at::Tensor &self); + static at::Tensor &atan2_(at::Tensor &self, const at::Tensor &other); + static at::Tensor &tril_(at::Tensor &self, int64_t diagonal); + static at::Tensor &triu_(at::Tensor &self, int64_t diagonal); + static at::Tensor &digamma_(at::Tensor &self); + static at::Tensor &polygamma_(at::Tensor &self, int64_t n); + static at::Tensor &renorm_(at::Tensor &self, at::Scalar p, int64_t dim, + at::Scalar maxnorm); + static at::Tensor &pow_(at::Tensor &self, at::Scalar exponent); + static at::Tensor &pow_(at::Tensor &self, const at::Tensor &exponent); + static at::Tensor &lerp_(at::Tensor &self, const at::Tensor &end, + at::Scalar weight); + static at::Tensor &lerp_(at::Tensor &self, const at::Tensor &end, + const at::Tensor &weight); + static at::Tensor &fmod_(at::Tensor &self, at::Scalar other); + static at::Tensor &fmod_(at::Tensor &self, const at::Tensor &other); + static at::Tensor &remainder_(at::Tensor &self, at::Scalar other); + static at::Tensor &remainder_(at::Tensor &self, const at::Tensor &other); + static at::Tensor &addbmm_(at::Tensor &self, const at::Tensor &batch1, + const at::Tensor &batch2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &addbmm_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &batch1, + const at::Tensor &batch2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor addbmm(const at::Tensor &self, const at::Tensor &batch1, + const at::Tensor &batch2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &addcdiv_(at::Tensor &self, const at::Tensor &tensor1, + const at::Tensor &tensor2, at::Scalar value); + static at::Tensor &random_(at::Tensor &self, int64_t from, int64_t to, + at::Generator *generator); + static at::Tensor &random_(at::Tensor &self, int64_t to, + at::Generator *generator); + static at::Tensor &random_(at::Tensor &self, at::Generator *generator); + static at::Tensor &uniform_(at::Tensor &self, double from, double to, + at::Generator *generator); + static at::Tensor &normal_(at::Tensor &self, double mean, double std, + at::Generator *generator); + static at::Tensor &cauchy_(at::Tensor &self, double median, double sigma, + at::Generator *generator); + static at::Tensor &log_normal_(at::Tensor &self, double mean, double std, + at::Generator *generator); + static at::Tensor &exponential_(at::Tensor &self, double lambd, + at::Generator *generator); + static at::Tensor &geometric_(at::Tensor &self, double p, + at::Generator *generator); + static at::Tensor &diag_out(at::Tensor &out, const at::Tensor &self, + int64_t diagonal); + static at::Tensor diag(const at::Tensor &self, int64_t diagonal); + static at::Tensor &cross_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other, + c10::optional dim); + static at::Tensor cross(const at::Tensor &self, const at::Tensor &other, + c10::optional dim); + static at::Tensor &triu_out(at::Tensor &out, const at::Tensor &self, + int64_t diagonal); + static at::Tensor triu(const at::Tensor &self, int64_t diagonal); + static at::Tensor &tril_out(at::Tensor &out, const at::Tensor &self, + int64_t diagonal); + static at::Tensor tril(const at::Tensor &self, int64_t diagonal); + static at::Tensor tril_indices(int64_t row, int64_t col, int64_t offset, + const at::TensorOptions &options); + static at::Tensor triu_indices(int64_t row, int64_t col, int64_t offset, + const at::TensorOptions &options); + static at::Tensor trace(const at::Tensor &self); + static at::Tensor &ne_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other); + static at::Tensor ne(const at::Tensor &self, at::Scalar other); + static at::Tensor &ne_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor ne(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &eq_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other); + static at::Tensor eq(const at::Tensor &self, at::Scalar other); + static at::Tensor &eq_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor eq(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &ge_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other); + static at::Tensor ge(const at::Tensor &self, at::Scalar other); + static at::Tensor &ge_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor ge(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &le_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other); + static at::Tensor le(const at::Tensor &self, at::Scalar other); + static at::Tensor &le_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor le(const at::Tensor &self, const at::Tensor &other); + static at::Tensor >_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other); + static at::Tensor gt(const at::Tensor &self, at::Scalar other); + static at::Tensor >_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor gt(const at::Tensor &self, const at::Tensor &other); + static at::Tensor <_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other); + static at::Tensor lt(const at::Tensor &self, at::Scalar other); + static at::Tensor <_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor lt(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &take_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &index); + static at::Tensor take(const at::Tensor &self, const at::Tensor &index); + static at::Tensor &index_select_out(at::Tensor &out, const at::Tensor &self, + int64_t dim, const at::Tensor &index); + static at::Tensor index_select(const at::Tensor &self, int64_t dim, + const at::Tensor &index); + static at::Tensor &masked_select_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &mask); + static at::Tensor masked_select(const at::Tensor &self, + const at::Tensor &mask); + static at::Tensor &nonzero_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor nonzero(const at::Tensor &self); + static std::vector nonzero_numpy(const at::Tensor &self); + static at::Tensor &gather_out(at::Tensor &out, const at::Tensor &self, + int64_t dim, const at::Tensor &index, + bool sparse_grad); + static at::Tensor gather(const at::Tensor &self, int64_t dim, + const at::Tensor &index, bool sparse_grad); + static at::Tensor _gather_sparse_backward(const at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &grad); + static at::Tensor &addcmul_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &tensor1, + const at::Tensor &tensor2, at::Scalar value); + static at::Tensor addcmul(const at::Tensor &self, const at::Tensor &tensor1, + const at::Tensor &tensor2, at::Scalar value); + static at::Tensor &addcmul_(at::Tensor &self, const at::Tensor &tensor1, + const at::Tensor &tensor2, at::Scalar value); + static at::Tensor &addcdiv_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &tensor1, + const at::Tensor &tensor2, at::Scalar value); + static at::Tensor addcdiv(const at::Tensor &self, const at::Tensor &tensor1, + const at::Tensor &tensor2, at::Scalar value); + static std::tuple + lstsq_out(at::Tensor &X, at::Tensor &qr, const at::Tensor &self, + const at::Tensor &A); + static std::tuple lstsq(const at::Tensor &self, + const at::Tensor &A); + static std::tuple + triangular_solve_out(at::Tensor &X, at::Tensor &M, const at::Tensor &self, + const at::Tensor &A, bool upper, bool transpose, + bool unitriangular); + static std::tuple + triangular_solve(const at::Tensor &self, const at::Tensor &A, bool upper, + bool transpose, bool unitriangular); + static std::tuple + _triangular_solve_helper(const at::Tensor &self, const at::Tensor &A, + bool upper, bool transpose, bool unitriangular); + static std::tuple + symeig_out(at::Tensor &e, at::Tensor &V, const at::Tensor &self, + bool eigenvectors, bool upper); + static std::tuple + symeig(const at::Tensor &self, bool eigenvectors, bool upper); + static std::tuple + _symeig_helper(const at::Tensor &self, bool eigenvectors, bool upper); + static std::tuple eig_out(at::Tensor &e, + at::Tensor &v, + const at::Tensor &self, + bool eigenvectors); + static std::tuple eig(const at::Tensor &self, + bool eigenvectors); + static std::tuple + svd_out(at::Tensor &U, at::Tensor &S, at::Tensor &V, const at::Tensor &self, + bool some, bool compute_uv); + static std::tuple + svd(const at::Tensor &self, bool some, bool compute_uv); + static std::tuple + _svd_helper(const at::Tensor &self, bool some, bool compute_uv); + static at::Tensor &cholesky_out(at::Tensor &out, const at::Tensor &self, + bool upper); + static at::Tensor cholesky(const at::Tensor &self, bool upper); + static at::Tensor _cholesky_helper(const at::Tensor &self, bool upper); + static at::Tensor &cholesky_solve_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &input2, bool upper); + static at::Tensor cholesky_solve(const at::Tensor &self, + const at::Tensor &input2, bool upper); + static at::Tensor _cholesky_solve_helper(const at::Tensor &self, + const at::Tensor &A, bool upper); + static std::tuple solve(const at::Tensor &self, + const at::Tensor &A); + static std::tuple + solve_out(at::Tensor &solution, at::Tensor &lu, const at::Tensor &self, + const at::Tensor &A); + static std::tuple + _solve_helper(const at::Tensor &self, const at::Tensor &A); + static at::Tensor &cholesky_inverse_out(at::Tensor &out, + const at::Tensor &self, bool upper); + static at::Tensor cholesky_inverse(const at::Tensor &self, bool upper); + static std::tuple + qr_out(at::Tensor &Q, at::Tensor &R, const at::Tensor &self, bool some); + static std::tuple qr(const at::Tensor &self, + bool some); + static std::tuple _qr_helper(const at::Tensor &self, + bool some); + static std::tuple + geqrf_out(at::Tensor &a, at::Tensor &tau, const at::Tensor &self); + static std::tuple geqrf(const at::Tensor &self); + static at::Tensor &orgqr_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &input2); + static at::Tensor orgqr(const at::Tensor &self, const at::Tensor &input2); + static at::Tensor &ormqr_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &input2, + const at::Tensor &input3, bool left, + bool transpose); + static at::Tensor ormqr(const at::Tensor &self, const at::Tensor &input2, + const at::Tensor &input3, bool left, bool transpose); + static std::tuple + _lu_with_info(const at::Tensor &self, bool pivot, bool check_errors); + static at::Tensor &lu_solve_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &LU_data, + const at::Tensor &LU_pivots); + static at::Tensor lu_solve(const at::Tensor &self, const at::Tensor &LU_data, + const at::Tensor &LU_pivots); + static at::Tensor _lu_solve_helper(const at::Tensor &self, + const at::Tensor &LU_data, + const at::Tensor &LU_pivots); + static at::Tensor &multinomial_out(at::Tensor &out, const at::Tensor &self, + int64_t num_samples, bool replacement, + at::Generator *generator); + static at::Tensor multinomial(const at::Tensor &self, int64_t num_samples, + bool replacement, at::Generator *generator); + static std::tuple + _multinomial_alias_setup(const at::Tensor &probs); + static at::Tensor _multinomial_alias_draw(const at::Tensor &J, + const at::Tensor &q, + int64_t num_samples, + at::Generator *generator); + static at::Tensor &lgamma_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor lgamma(const at::Tensor &self); + static at::Tensor &digamma_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor digamma(const at::Tensor &self); + static at::Tensor &polygamma_out(at::Tensor &out, int64_t n, + const at::Tensor &self); + static at::Tensor polygamma(int64_t n, const at::Tensor &self); + static at::Tensor erfinv(const at::Tensor &self); + static at::Tensor &erfinv_(at::Tensor &self); + static at::Tensor &erfinv_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor sign(const at::Tensor &self); + static at::Tensor &sign_(at::Tensor &self); + static at::Tensor &sign_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor dist(const at::Tensor &self, const at::Tensor &other, + at::Scalar p); + static at::Tensor &atan2_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor atan2(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &lerp_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &end, at::Scalar weight); + static at::Tensor &lerp_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &end, const at::Tensor &weight); + static at::Tensor lerp(const at::Tensor &self, const at::Tensor &end, + at::Scalar weight); + static at::Tensor lerp(const at::Tensor &self, const at::Tensor &end, + const at::Tensor &weight); + static at::Tensor &histc_out(at::Tensor &out, const at::Tensor &self, + int64_t bins, at::Scalar min, at::Scalar max); + static at::Tensor histc(const at::Tensor &self, int64_t bins, at::Scalar min, + at::Scalar max); + static at::Tensor &fmod_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other); + static at::Tensor fmod(const at::Tensor &self, at::Scalar other); + static at::Tensor &fmod_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor fmod(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &remainder_out(at::Tensor &out, const at::Tensor &self, + at::Scalar other); + static at::Tensor remainder(const at::Tensor &self, at::Scalar other); + static at::Tensor &remainder_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor remainder(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &min_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor min(const at::Tensor &self, const at::Tensor &other); + static at::Tensor min(const at::Tensor &self); + static at::Tensor &max_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &other); + static at::Tensor max(const at::Tensor &self, const at::Tensor &other); + static at::Tensor max(const at::Tensor &self); + static at::Tensor median(const at::Tensor &self); + static std::tuple + sort_out(at::Tensor &values, at::Tensor &indices, const at::Tensor &self, + int64_t dim, bool descending); + static std::tuple sort(const at::Tensor &self, + int64_t dim, bool descending); + static at::Tensor argsort(const at::Tensor &self, int64_t dim, + bool descending); + static std::tuple + topk_out(at::Tensor &values, at::Tensor &indices, const at::Tensor &self, + int64_t k, int64_t dim, bool largest, bool sorted); + static std::tuple topk(const at::Tensor &self, + int64_t k, int64_t dim, + bool largest, bool sorted); + static at::Tensor all(const at::Tensor &self); + static at::Tensor any(const at::Tensor &self); + static at::Tensor &renorm_out(at::Tensor &out, const at::Tensor &self, + at::Scalar p, int64_t dim, at::Scalar maxnorm); + static at::Tensor renorm(const at::Tensor &self, at::Scalar p, int64_t dim, + at::Scalar maxnorm); + static at::Tensor unfold(const at::Tensor &self, int64_t dimension, + int64_t size, int64_t step); + static bool equal(const at::Tensor &self, const at::Tensor &other); + static at::Tensor &pow_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &exponent); + static at::Tensor pow(const at::Tensor &self, const at::Tensor &exponent); + static at::Tensor &pow_out(at::Tensor &out, at::Scalar self, + const at::Tensor &exponent); + static at::Tensor pow(at::Scalar self, const at::Tensor &exponent); + static at::Tensor &normal_out(at::Tensor &out, const at::Tensor &mean, + double std, at::Generator *generator); + static at::Tensor normal(const at::Tensor &mean, double std, + at::Generator *generator); + static at::Tensor &normal_out(at::Tensor &out, double mean, + const at::Tensor &std, + at::Generator *generator); + static at::Tensor normal(double mean, const at::Tensor &std, + at::Generator *generator); + static at::Tensor &normal_out(at::Tensor &out, const at::Tensor &mean, + const at::Tensor &std, + at::Generator *generator); + static at::Tensor normal(const at::Tensor &mean, const at::Tensor &std, + at::Generator *generator); + static at::Tensor normal(double mean, double std, at::IntArrayRef size, + at::Generator *generator, + const at::TensorOptions &options); + static at::Tensor &normal_out(at::Tensor &out, double mean, double std, + at::IntArrayRef size, at::Generator *generator); + static at::Tensor alias(const at::Tensor &self); + static at::Tensor _addr(const at::Tensor &self, const at::Tensor &vec1, + const at::Tensor &vec2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &_addr_(at::Tensor &self, const at::Tensor &vec1, + const at::Tensor &vec2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &_addr_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &vec1, const at::Tensor &vec2, + at::Scalar beta, at::Scalar alpha); + static at::Tensor &_index_copy_(at::Tensor &self, int64_t dim, + const at::Tensor &index, + const at::Tensor &source); + static at::Tensor _cumsum(const at::Tensor &self, int64_t dim); + static at::Tensor &_cumsum_out(at::Tensor &out, const at::Tensor &self, + int64_t dim); + static at::Tensor _cumprod(const at::Tensor &self, int64_t dim); + static at::Tensor &_cumprod_out(at::Tensor &out, const at::Tensor &self, + int64_t dim); + static at::Tensor _var(const at::Tensor &self, bool unbiased); + static at::Tensor _std(const at::Tensor &self, bool unbiased); + static at::Tensor &_addmm_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &mat1, const at::Tensor &mat2, + at::Scalar beta, at::Scalar alpha); + static at::Tensor _addmm(const at::Tensor &self, const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor &_addmm_(at::Tensor &self, const at::Tensor &mat1, + const at::Tensor &mat2, at::Scalar beta, + at::Scalar alpha); + static at::Tensor _cat(at::TensorList tensors, int64_t dim); + static at::Tensor &_cat_out(at::Tensor &out, at::TensorList tensors, + int64_t dim); + static std::tuple _mode(const at::Tensor &self, + int64_t dim, bool keepdim); + static std::tuple + _mode_out(at::Tensor &values, at::Tensor &indices, const at::Tensor &self, + int64_t dim, bool keepdim); + static std::tuple _max(const at::Tensor &self, + int64_t dim, bool keepdim); + static std::tuple + _max_out(at::Tensor &max, at::Tensor &max_indices, const at::Tensor &self, + int64_t dim, bool keepdim); + static std::tuple _min(const at::Tensor &self, + int64_t dim, bool keepdim); + static std::tuple + _min_out(at::Tensor &min, at::Tensor &min_indices, const at::Tensor &self, + int64_t dim, bool keepdim); + static at::Tensor &binary_cross_entropy_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, + int64_t reduction); + static at::Tensor binary_cross_entropy(const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, + int64_t reduction); + static at::Tensor &binary_cross_entropy_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction); + static at::Tensor binary_cross_entropy_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, + int64_t reduction); + static at::Tensor &mse_loss_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &target, int64_t reduction); + static at::Tensor mse_loss(const at::Tensor &self, const at::Tensor &target, + int64_t reduction); + static at::Tensor &mse_loss_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor mse_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor &l1_loss_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &target, int64_t reduction); + static at::Tensor l1_loss(const at::Tensor &self, const at::Tensor &target, + int64_t reduction); + static at::Tensor &l1_loss_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor l1_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor &multi_margin_loss_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &target, + at::Scalar p, at::Scalar margin, + const at::Tensor &weight, + int64_t reduction); + static at::Tensor multi_margin_loss(const at::Tensor &self, + const at::Tensor &target, at::Scalar p, + at::Scalar margin, + const at::Tensor &weight, + int64_t reduction); + static at::Tensor &multi_margin_loss_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, at::Scalar p, + at::Scalar margin, const at::Tensor &weight, int64_t reduction); + static at::Tensor multi_margin_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + at::Scalar p, at::Scalar margin, + const at::Tensor &weight, + int64_t reduction); + static at::Tensor &multilabel_margin_loss_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor multilabel_margin_loss(const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static std::tuple + multilabel_margin_loss_forward_out(at::Tensor &output, at::Tensor &is_target, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static std::tuple + multilabel_margin_loss_forward(const at::Tensor &self, + const at::Tensor &target, int64_t reduction); + static at::Tensor &multilabel_margin_loss_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, int64_t reduction, + const at::Tensor &is_target); + static at::Tensor multilabel_margin_loss_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &target, int64_t reduction, const at::Tensor &is_target); + static at::Tensor &nll_loss_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index); + static at::Tensor nll_loss(const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index); + static std::tuple + nll_loss_forward_out(at::Tensor &output, at::Tensor &total_weight, + const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index); + static std::tuple + nll_loss_forward(const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index); + static at::Tensor & + nll_loss_backward_out(at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index, const at::Tensor &total_weight); + static at::Tensor nll_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, + int64_t reduction, int64_t ignore_index, + const at::Tensor &total_weight); + static at::Tensor &nll_loss2d_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index); + static at::Tensor nll_loss2d(const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index); + static std::tuple + nll_loss2d_forward_out(at::Tensor &output, at::Tensor &total_weight, + const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index); + static std::tuple + nll_loss2d_forward(const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index); + static at::Tensor & + nll_loss2d_backward_out(at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, + const at::Tensor &weight, int64_t reduction, + int64_t ignore_index, const at::Tensor &total_weight); + static at::Tensor nll_loss2d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + const at::Tensor &weight, + int64_t reduction, int64_t ignore_index, + const at::Tensor &total_weight); + static at::Tensor &smooth_l1_loss_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor smooth_l1_loss(const at::Tensor &self, + const at::Tensor &target, int64_t reduction); + static at::Tensor &smooth_l1_loss_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor smooth_l1_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor &soft_margin_loss_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor soft_margin_loss(const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor &soft_margin_loss_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &target, int64_t reduction); + static at::Tensor soft_margin_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, + int64_t reduction); + static at::Tensor &elu_out(at::Tensor &out, const at::Tensor &self, + at::Scalar alpha, at::Scalar scale, + at::Scalar input_scale); + static at::Tensor elu(const at::Tensor &self, at::Scalar alpha, + at::Scalar scale, at::Scalar input_scale); + static at::Tensor &elu_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + at::Scalar alpha, at::Scalar scale, + at::Scalar input_scale, + const at::Tensor &output); + static at::Tensor elu_backward(const at::Tensor &grad_output, + at::Scalar alpha, at::Scalar scale, + at::Scalar input_scale, + const at::Tensor &output); + static at::Tensor &elu_(at::Tensor &self, at::Scalar alpha, at::Scalar scale, + at::Scalar input_scale); + static at::Tensor &glu_out(at::Tensor &out, const at::Tensor &self, + int64_t dim); + static at::Tensor glu(const at::Tensor &self, int64_t dim); + static at::Tensor &glu_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, int64_t dim); + static at::Tensor glu_backward(const at::Tensor &grad_output, + const at::Tensor &self, int64_t dim); + static at::Tensor &hardtanh_out(at::Tensor &out, const at::Tensor &self, + at::Scalar min_val, at::Scalar max_val); + static at::Tensor hardtanh(const at::Tensor &self, at::Scalar min_val, + at::Scalar max_val); + static at::Tensor &hardtanh_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar min_val, + at::Scalar max_val); + static at::Tensor hardtanh_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar min_val, at::Scalar max_val); + static at::Tensor &hardtanh_(at::Tensor &self, at::Scalar min_val, + at::Scalar max_val); + static at::Tensor &leaky_relu_out(at::Tensor &out, const at::Tensor &self, + at::Scalar negative_slope); + static at::Tensor leaky_relu(const at::Tensor &self, + at::Scalar negative_slope); + static at::Tensor &leaky_relu_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar negative_slope); + static at::Tensor leaky_relu_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar negative_slope); + static at::Tensor &leaky_relu_(at::Tensor &self, at::Scalar negative_slope); + static at::Tensor &log_sigmoid_out(at::Tensor &out, const at::Tensor &self); + static at::Tensor log_sigmoid(const at::Tensor &self); + static std::tuple + log_sigmoid_forward_out(at::Tensor &output, at::Tensor &buffer, + const at::Tensor &self); + static std::tuple + log_sigmoid_forward(const at::Tensor &self); + static at::Tensor &log_sigmoid_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &buffer); + static at::Tensor log_sigmoid_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &buffer); + static at::Tensor &rrelu_with_noise_out(at::Tensor &out, + const at::Tensor &self, + const at::Tensor &noise, + at::Scalar lower, at::Scalar upper, + bool training, + at::Generator *generator); + static at::Tensor rrelu_with_noise(const at::Tensor &self, + const at::Tensor &noise, at::Scalar lower, + at::Scalar upper, bool training, + at::Generator *generator); + static at::Tensor &rrelu_with_noise_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &noise, at::Scalar lower, + at::Scalar upper, bool training); + static at::Tensor rrelu_with_noise_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &noise, + at::Scalar lower, + at::Scalar upper, bool training); + static at::Tensor &rrelu_with_noise_(at::Tensor &self, + const at::Tensor &noise, + at::Scalar lower, at::Scalar upper, + bool training, at::Generator *generator); + static at::Tensor &softplus_out(at::Tensor &out, const at::Tensor &self, + at::Scalar beta, at::Scalar threshold); + static at::Tensor softplus(const at::Tensor &self, at::Scalar beta, + at::Scalar threshold); + static at::Tensor & + softplus_backward_out(at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::Scalar beta, + at::Scalar threshold, const at::Tensor &output); + static at::Tensor softplus_backward(const at::Tensor &grad_output, + const at::Tensor &self, at::Scalar beta, + at::Scalar threshold, + const at::Tensor &output); + static at::Tensor &softshrink_out(at::Tensor &out, const at::Tensor &self, + at::Scalar lambd); + static at::Tensor softshrink(const at::Tensor &self, at::Scalar lambd); + static at::Tensor &softshrink_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar lambd); + static at::Tensor softshrink_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::Scalar lambd); + static at::Tensor &adaptive_avg_pool2d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor adaptive_avg_pool2d(const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor mkldnn_adaptive_avg_pool2d(const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor _adaptive_avg_pool2d(const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor _adaptive_avg_pool2d_backward(const at::Tensor &grad_output, + const at::Tensor &self); + static at::Tensor &adaptive_avg_pool3d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor adaptive_avg_pool3d(const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor & + adaptive_avg_pool3d_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self); + static at::Tensor adaptive_avg_pool3d_backward(const at::Tensor &grad_output, + const at::Tensor &self); + static std::tuple + adaptive_max_pool2d_out(at::Tensor &out, at::Tensor &indices, + const at::Tensor &self, at::IntArrayRef output_size); + static std::tuple + adaptive_max_pool2d(const at::Tensor &self, at::IntArrayRef output_size); + static at::Tensor &adaptive_max_pool2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &indices); + static at::Tensor adaptive_max_pool2d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &indices); + static std::tuple + adaptive_max_pool3d_out(at::Tensor &out, at::Tensor &indices, + const at::Tensor &self, at::IntArrayRef output_size); + static std::tuple + adaptive_max_pool3d(const at::Tensor &self, at::IntArrayRef output_size); + static at::Tensor &adaptive_max_pool3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &indices); + static at::Tensor adaptive_max_pool3d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &indices); + static at::Tensor &avg_pool2d_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override); + static at::Tensor avg_pool2d(const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + bool ceil_mode, bool count_include_pad, + c10::optional divisor_override); + static at::Tensor & + avg_pool2d_backward_out(at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + bool ceil_mode, bool count_include_pad, + c10::optional divisor_override); + static at::Tensor + avg_pool2d_backward(const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override); + static at::Tensor &avg_pool3d_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override); + static at::Tensor avg_pool3d(const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + bool ceil_mode, bool count_include_pad, + c10::optional divisor_override); + static at::Tensor & + avg_pool3d_backward_out(at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + bool ceil_mode, bool count_include_pad, + c10::optional divisor_override); + static at::Tensor + avg_pool3d_backward(const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override); + static std::tuple + fractional_max_pool2d_out(at::Tensor &output, at::Tensor &indices, + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef output_size, + const at::Tensor &random_samples); + static std::tuple + fractional_max_pool2d(const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef output_size, + const at::Tensor &random_samples); + static at::Tensor &fractional_max_pool2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef output_size, const at::Tensor &indices); + static at::Tensor fractional_max_pool2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef output_size, + const at::Tensor &indices); + static std::tuple + fractional_max_pool3d_out(at::Tensor &output, at::Tensor &indices, + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef output_size, + const at::Tensor &random_samples); + static std::tuple + fractional_max_pool3d(const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef output_size, + const at::Tensor &random_samples); + static at::Tensor &fractional_max_pool3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef output_size, const at::Tensor &indices); + static at::Tensor fractional_max_pool3d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef output_size, + const at::Tensor &indices); + static std::tuple max_pool2d_with_indices_out( + at::Tensor &out, at::Tensor &indices, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); + static std::tuple + max_pool2d_with_indices(const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode); + static at::Tensor &max_pool2d_with_indices_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool ceil_mode, const at::Tensor &indices); + static at::Tensor max_pool2d_with_indices_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const at::Tensor &indices); + static std::tuple max_pool3d_with_indices_out( + at::Tensor &out, at::Tensor &indices, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); + static std::tuple + max_pool3d_with_indices(const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode); + static at::Tensor &max_pool3d_with_indices_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool ceil_mode, const at::Tensor &indices); + static at::Tensor max_pool3d_with_indices_backward( + const at::Tensor &grad_output, const at::Tensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const at::Tensor &indices); + static at::Tensor &max_unpool2d_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size); + static at::Tensor max_unpool2d(const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size); + static at::Tensor &max_unpool2d_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size); + static at::Tensor max_unpool2d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size); + static at::Tensor &max_unpool3d_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size, + at::IntArrayRef stride, + at::IntArrayRef padding); + static at::Tensor max_unpool3d(const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size, + at::IntArrayRef stride, + at::IntArrayRef padding); + static at::Tensor &max_unpool3d_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &indices, + at::IntArrayRef output_size, + at::IntArrayRef stride, + at::IntArrayRef padding); + static at::Tensor + max_unpool3d_backward(const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &indices, at::IntArrayRef output_size, + at::IntArrayRef stride, at::IntArrayRef padding); + static at::Tensor &reflection_pad1d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor reflection_pad1d(const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor &reflection_pad1d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding); + static at::Tensor reflection_pad1d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor &reflection_pad2d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor reflection_pad2d(const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor &reflection_pad2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding); + static at::Tensor reflection_pad2d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor &replication_pad1d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor replication_pad1d(const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor &replication_pad1d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding); + static at::Tensor replication_pad1d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor &replication_pad2d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor replication_pad2d(const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor &replication_pad2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding); + static at::Tensor replication_pad2d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor &replication_pad3d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor replication_pad3d(const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor &replication_pad3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + const at::Tensor &self, at::IntArrayRef padding); + static at::Tensor replication_pad3d_backward(const at::Tensor &grad_output, + const at::Tensor &self, + at::IntArrayRef padding); + static at::Tensor &upsample_linear1d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners); + static at::Tensor upsample_linear1d(const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners); + static at::Tensor &upsample_linear1d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size, + bool align_corners); + static at::Tensor upsample_linear1d_backward(const at::Tensor &grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners); + static at::Tensor &upsample_bilinear2d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners); + static at::Tensor upsample_bilinear2d(const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners); + static at::Tensor &upsample_bilinear2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size, + bool align_corners); + static at::Tensor upsample_bilinear2d_backward(const at::Tensor &grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners); + static at::Tensor &upsample_bicubic2d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners); + static at::Tensor upsample_bicubic2d(const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners); + static at::Tensor &upsample_bicubic2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size, + bool align_corners); + static at::Tensor upsample_bicubic2d_backward(const at::Tensor &grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners); + static at::Tensor &upsample_trilinear3d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners); + static at::Tensor upsample_trilinear3d(const at::Tensor &self, + at::IntArrayRef output_size, + bool align_corners); + static at::Tensor &upsample_trilinear3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size, + bool align_corners); + static at::Tensor upsample_trilinear3d_backward(const at::Tensor &grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners); + static at::Tensor &upsample_nearest1d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor upsample_nearest1d(const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor &upsample_nearest1d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size); + static at::Tensor upsample_nearest1d_backward(const at::Tensor &grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size); + static at::Tensor &upsample_nearest2d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor upsample_nearest2d(const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor &upsample_nearest2d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size); + static at::Tensor upsample_nearest2d_backward(const at::Tensor &grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size); + static at::Tensor &upsample_nearest3d_out(at::Tensor &out, + const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor upsample_nearest3d(const at::Tensor &self, + at::IntArrayRef output_size); + static at::Tensor &upsample_nearest3d_backward_out( + at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef output_size, at::IntArrayRef input_size); + static at::Tensor upsample_nearest3d_backward(const at::Tensor &grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size); + static at::Tensor &sigmoid_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &output); + static at::Tensor sigmoid_backward(const at::Tensor &grad_output, + const at::Tensor &output); + static at::Tensor &tanh_backward_out(at::Tensor &grad_input, + const at::Tensor &grad_output, + const at::Tensor &output); + static at::Tensor tanh_backward(const at::Tensor &grad_output, + const at::Tensor &output); + static at::Tensor &slow_conv_transpose2d_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef dilation); + static at::Tensor + slow_conv_transpose2d(const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef dilation); + static std::tuple + slow_conv_transpose2d_backward_out( + at::Tensor &grad_input, at::Tensor &grad_weight, at::Tensor &grad_bias, + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef dilation, + const at::Tensor &columns, const at::Tensor &ones); + static std::tuple + slow_conv_transpose2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef dilation, + const at::Tensor &columns, const at::Tensor &ones, + std::array output_mask); + static at::Tensor &slow_conv_transpose3d_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef dilation); + static at::Tensor + slow_conv_transpose3d(const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef dilation); + static std::tuple + slow_conv_transpose3d_backward_out( + at::Tensor &grad_input, at::Tensor &grad_weight, at::Tensor &grad_bias, + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef dilation, + const at::Tensor &finput, const at::Tensor &fgrad_input); + static std::tuple + slow_conv_transpose3d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef output_padding, at::IntArrayRef dilation, + const at::Tensor &finput, const at::Tensor &fgrad_input, + std::array output_mask); + static at::Tensor &thnn_conv2d_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &weight, + at::IntArrayRef kernel_size, + const at::Tensor &bias, + at::IntArrayRef stride, + at::IntArrayRef padding); + static at::Tensor thnn_conv2d(const at::Tensor &self, + const at::Tensor &weight, + at::IntArrayRef kernel_size, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding); + static std::tuple + thnn_conv2d_forward_out(at::Tensor &output, at::Tensor &finput, + at::Tensor &fgrad_input, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding); + static std::tuple + thnn_conv2d_forward(const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding); + static std::tuple + thnn_conv2d_backward_out(at::Tensor &grad_input, at::Tensor &grad_weight, + at::Tensor &grad_bias, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, const at::Tensor &finput, + const at::Tensor &fgrad_input); + static std::tuple + thnn_conv2d_backward(const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + const at::Tensor &finput, const at::Tensor &fgrad_input, + std::array output_mask); + static at::Tensor & + thnn_conv_depthwise2d_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation); + static at::Tensor + thnn_conv_depthwise2d(const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation); + static at::Tensor &thnn_conv_depthwise2d_forward_out( + at::Tensor &out, const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation); + static at::Tensor thnn_conv_depthwise2d_forward(const at::Tensor &self, + const at::Tensor &weight, + at::IntArrayRef kernel_size, + const at::Tensor &bias, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation); + static std::tuple + thnn_conv_depthwise2d_backward_out( + at::Tensor &grad_input, at::Tensor &grad_weight, + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation); + static std::tuple thnn_conv_depthwise2d_backward( + const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + std::array output_mask); + static at::Tensor &thnn_conv3d_out(at::Tensor &out, const at::Tensor &self, + const at::Tensor &weight, + at::IntArrayRef kernel_size, + const at::Tensor &bias, + at::IntArrayRef stride, + at::IntArrayRef padding); + static at::Tensor thnn_conv3d(const at::Tensor &self, + const at::Tensor &weight, + at::IntArrayRef kernel_size, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding); + static std::tuple + thnn_conv3d_forward_out(at::Tensor &output, at::Tensor &finput, + at::Tensor &fgrad_input, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + const at::Tensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding); + static std::tuple + thnn_conv3d_forward(const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding); + static std::tuple + thnn_conv3d_backward_out(at::Tensor &grad_input, at::Tensor &grad_weight, + at::Tensor &grad_bias, const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, const at::Tensor &finput, + const at::Tensor &fgrad_input); + static std::tuple + thnn_conv3d_backward(const at::Tensor &grad_output, const at::Tensor &self, + const at::Tensor &weight, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + const at::Tensor &finput, const at::Tensor &fgrad_input, + std::array output_mask); + static at::Tensor + slow_conv_dilated2d(const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation); + static std::tuple + slow_conv_dilated2d_backward(const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, + std::array output_mask); + static at::Tensor + slow_conv_dilated3d(const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, const at::Tensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation); + static std::tuple + slow_conv_dilated3d_backward(const at::Tensor &grad_output, + const at::Tensor &self, const at::Tensor &weight, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, + std::array output_mask); + static at::Tensor &col2im_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef output_size, + at::IntArrayRef kernel_size, + at::IntArrayRef dilation, + at::IntArrayRef padding, + at::IntArrayRef stride); + static at::Tensor col2im(const at::Tensor &self, at::IntArrayRef output_size, + at::IntArrayRef kernel_size, + at::IntArrayRef dilation, at::IntArrayRef padding, + at::IntArrayRef stride); + static at::Tensor & + col2im_backward_out(at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef kernel_size, at::IntArrayRef dilation, + at::IntArrayRef padding, at::IntArrayRef stride); + static at::Tensor col2im_backward(const at::Tensor &grad_output, + at::IntArrayRef kernel_size, + at::IntArrayRef dilation, + at::IntArrayRef padding, + at::IntArrayRef stride); + static at::Tensor &im2col_out(at::Tensor &out, const at::Tensor &self, + at::IntArrayRef kernel_size, + at::IntArrayRef dilation, + at::IntArrayRef padding, + at::IntArrayRef stride); + static at::Tensor im2col(const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef dilation, at::IntArrayRef padding, + at::IntArrayRef stride); + static at::Tensor & + im2col_backward_out(at::Tensor &grad_input, const at::Tensor &grad_output, + at::IntArrayRef input_size, at::IntArrayRef kernel_size, + at::IntArrayRef dilation, at::IntArrayRef padding, + at::IntArrayRef stride); + static at::Tensor + im2col_backward(const at::Tensor &grad_output, at::IntArrayRef input_size, + at::IntArrayRef kernel_size, at::IntArrayRef dilation, + at::IntArrayRef padding, at::IntArrayRef stride); +}; + +void RegisterAtenTypeFunctions(); + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/device.cpp b/frontends/pytorch/csrc/device.cpp new file mode 100644 index 000000000..2f0afca2b --- /dev/null +++ b/frontends/pytorch/csrc/device.cpp @@ -0,0 +1,67 @@ +//===- device.cpp -----------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +// Structured similarly to code from git@github.com:pytorch/xla.git + +#include "device.h" + +namespace torch_mlir { +namespace { + +std::string DeviceTypeToString(DeviceType hw_type) { + switch (hw_type) { + case DeviceType::CPU: + return "CPU"; + case DeviceType::MLIR: + return "MLIR"; + } + return ""; +} + +void ParseDevice(const std::string &device_spec, Device *device) { + if (device_spec.empty()) { + return ParseDevice(std::string("mlir:0"), device); + } + + if (device_spec[0] == ':') { + return ParseDevice(std::string("mlir") + device_spec, device); + } + + auto pos = device_spec.find(':'); + auto devtype = device_spec.substr(0, pos); + + // TODO error check + + device->ordinal = + std::stoi(device_spec.substr(pos + 1, device_spec.size() - pos - 1)); + if (devtype == "MLIR") { + device->hw_type = DeviceType::MLIR; + } else if (devtype == "CPU") { + device->hw_type = DeviceType::CPU; + } else { + // TODO, error + device->hw_type = DeviceType::MLIR; + } +} + +} // namespace + +Device::Device(const std::string &device_spec) { + ParseDevice(device_spec, this); +} + +std::string Device::ToString() const { + return DeviceTypeToString(hw_type) + std::string(":") + + std::to_string(ordinal); +} + +const Device *GetDefaultDevice() { + static const Device *default_device = new Device(""); + return default_device; +} + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/device.h b/frontends/pytorch/csrc/device.h new file mode 100644 index 000000000..f43ae75f6 --- /dev/null +++ b/frontends/pytorch/csrc/device.h @@ -0,0 +1,59 @@ +//===- device.h -------------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +// Structured similarly to code from git@github.com:pytorch/xla.git + +#pragma once + +#include +#include + +namespace torch_mlir { + +enum class DeviceType { CPU, MLIR }; + +/// Model a pytorch device, which determines the location of a buffer in +/// pytorch. +struct Device { + Device() = default; + explicit Device(const std::string &device_spec); + Device(DeviceType hw_type, int ordinal) + : hw_type(hw_type), ordinal(ordinal) {} + + bool operator==(const Device &other) const { return compare(other) == 0; } + + bool operator!=(const Device &other) const { return compare(other) != 0; } + + bool operator<(const Device &rhs) const { return compare(rhs) < 0; } + + int compare(const Device &rhs) const { + if (hw_type != rhs.hw_type) { + return hw_type < rhs.hw_type ? -1 : +1; + } + return ordinal < rhs.ordinal ? -1 : (ordinal > rhs.ordinal ? +1 : 0); + } + + std::string ToString() const; + + friend std::ostream &operator<<(std::ostream &os, const Device &device) { + os << device.ToString(); + return os; + } + + size_t hash() const { return std::hash{}(ToString()); } + + DeviceType hw_type = DeviceType::CPU; + int ordinal = 0; +}; + +const Device *GetDefaultDevice(); + +static inline const Device &GetDeviceOrDefault(const Device *device) { + return device != nullptr ? *device : *GetDefaultDevice(); +} + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/init_python_bindings.cpp b/frontends/pytorch/csrc/init_python_bindings.cpp new file mode 100644 index 000000000..9092bc998 --- /dev/null +++ b/frontends/pytorch/csrc/init_python_bindings.cpp @@ -0,0 +1,211 @@ +//===- init_python_bindings.cpp ---------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +// This file implements Python bindings to the MLIR/NPCOMP ATen dialect. +// Roughly speaking, it enables something like this: +// +// dev = torch_mlir.mlir_device() +// t0 = torch.randn((4,4), device=dev) +// t1 = torch.randn((4,4), device=dev) +// t2 = t0 + t1 +// t2_mlir = torch_mlir.get_mlir( t2 ) +// t2_cpu = t2.to('cpu') +// +// In this case t2_cpu contains the result of the computation, and t2_mlir +// contains the mlir description of the computation. + +#include "llvm/Support/Debug.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +#include "npcomp/Dialect/ATen/ATenDialect.h" +#include "npcomp/Dialect/ATen/ATenOpReport.h" +#include "npcomp/Dialect/ATen/ATenPasses.h" +#include "npcomp/Dialect/ATen/LivenessReport.h" + +// Then ATen headers with workarounds +#include "ATen/ArrayRef.h" +namespace at { +template using ArrayRef = c10::ArrayRef; +} +#include "ATen/SmallVector.h" +namespace at { +template using SmallVector = c10::SmallVector; +} +#include + +// other headers + +#include "aten_mlir_bridge.h" +#include "aten_mlir_type.h" +#include "init_python_bindings.h" +#include "mlir_gen.h" + +#include + +using namespace mlir; + +namespace llvm { +extern bool DebugFlag; +} + +namespace torch_mlir { +namespace { + +mlir::OwningModuleRef LoadModule(mlir::MLIRContext &context, std::string mlir) { + + mlir::OwningModuleRef module; + + std::unique_ptr membuf = + llvm::MemoryBuffer::getMemBuffer(mlir); + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(membuf), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + + if (!module) { + llvm::errs() << "Error can't parse mlir module\n"; + return nullptr; + } + if (failed(mlir::verify(*module))) { + llvm::errs() << "Error verifying MLIR module\n"; + return nullptr; + } + if (!module) + return nullptr; + return module; +} + +void InitModuleBindings(py::module m) { + m.def("_initialize_aten_bindings", + []() { ATenMLIRType::InitializeAtenBindings(); }); + m.def("_set_default_device", []() {}); + + m.def("_get_mlir", [](std::vector &ts) -> std::string { + if (ts.size() == 0) + return std::string(); + + mlir::MLIRContext context; + + // gather IR for all the tensors + std::vector recorded_ir; + for (auto &t : ts) + if (c10::optional at = bridge::TryGetMLIRTensor(t)) + recorded_ir.push_back(at->GetIrValue()); + + // generate MLIR from IR + auto mlir_gen = MLIRGen(context).genModule(recorded_ir); + mlir::OwningModuleRef module = std::move(std::get<0>(mlir_gen)); + + mlir::PassManager pm(module->getContext()); + + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::NPCOMP::aten::createATenLayerNamePass()); + if (failed(pm.run(*module))) { + llvm::errs() << "ATenLayerNamePass failed"; + return ""; + } + + // dump MLIR to string and return + std::string s; + llvm::raw_string_ostream ss(s); + module->print(ss); + return ss.str(); + }); + + m.def( + "_op_report", + [](std::string mlir) -> std::string { + mlir::MLIRContext context; + auto module = LoadModule(context, mlir); + mlir::PassManager pm(module->getContext()); + + // our pass + std::string report; + pm.addPass(mlir::NPCOMP::aten::createATenLayerNamePass()); + pm.addPass(mlir::NPCOMP::aten::createATenOpReportPass(report)); + + if (failed(pm.run(*module))) { + llvm::errs() << "ATenOpReportPass failed"; + return ""; + } + return report; + }, + "run ATenOpReportPass"); + + m.def( + "_liveness_report", + [](std::string mlir) -> std::string { + mlir::MLIRContext context; + auto module = LoadModule(context, mlir); + + mlir::PassManager pm(module->getContext()); + + pm.addPass(mlir::NPCOMP::aten::createATenLayerNamePass()); + if (failed(pm.run(*module))) { + llvm::errs() << "ATen generate liveness report failed"; + return ""; + } + + auto mOp = module.get(); + auto liveness = mlir::NPCOMP::aten::LivenessReport(mOp); + std::string report = liveness.emitJSONReport(); + return report; + }, + "generate liveness report"); + + // TODO: Could this be implemented with MLIR python bindings? + m.def( + "lower_to_std", + [](std::string mlir) -> std::string { + mlir::MLIRContext context; + auto module = LoadModule(context, mlir); + + PassManager pm0(module->getContext()); + pm0.addPass(mlir::NPCOMP::aten::createATenLoweringPass()); + pm0.addPass(mlir::NPCOMP::aten::createReturnEliminationPass()); + pm0.addPass(mlir::createCSEPass()); + + if (failed(pm0.run(*module))) { + llvm::errs() << "aten to loops conversion failed "; + return ""; + } + + // dump MLIR to string and return + std::string s; + llvm::raw_string_ostream ss(s); + ss << "# Lowered to Std\n"; + module->print(ss); + return ss.str(); + }, + "lower aten to std dialect"); + + m.def( + "set_debug", + [](bool b, std::string type) -> void { + llvm::setCurrentDebugType(type.c_str()); + llvm::DebugFlag = b; + }, + "enable/disable debug messages"); +} + +} // namespace + +void InitBindings(py::module m) { InitModuleBindings(m); } + +} // namespace torch_mlir + +PYBIND11_MODULE(_torch_mlir, m) { torch_mlir::InitBindings(m); } diff --git a/frontends/pytorch/csrc/init_python_bindings.h b/frontends/pytorch/csrc/init_python_bindings.h new file mode 100644 index 000000000..7d6bfe3af --- /dev/null +++ b/frontends/pytorch/csrc/init_python_bindings.h @@ -0,0 +1,20 @@ +//===- init_python_bindings.h -----------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef INIT_PYTHON_BINDINGS_H +#define INIT_PYTHON_BINDINGS_H + +#include "torch/csrc/jit/pybind.h" + +namespace torch_mlir { + +// Initialize bindings for torch_mlir functions +void InitBindings(py::module m); + +} // namespace torch_mlir + +#endif diff --git a/frontends/pytorch/csrc/ir.cpp b/frontends/pytorch/csrc/ir.cpp new file mode 100644 index 000000000..5e4190630 --- /dev/null +++ b/frontends/pytorch/csrc/ir.cpp @@ -0,0 +1,1151 @@ +//===- ir.cpp ---------------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Dialect/ATen/ATenDialect.h" + +#include "llvm/Support/Debug.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +#include "ir.h" + +#include + +#define DEBUG_TYPE "torch_mlir" + +using namespace mlir; + +namespace torch_mlir { +namespace ir { + +void RegisterAtenIR() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); +} + +std::vector Value::sizes() const { return node->sizes(index); } + +std::vector Value::strides() const { return node->strides(index); } + +Node::Node(OpKind op) : op_(std::move(op)) {} + +Node::Node(OpKind op, OpList operands, std::vector sizes) + : op_(std::move(op)), operands_(std::move(operands)) { + for (auto &oper : operands) + operands_.push_back(oper); + sizes_[0] = sizes; +} + +Node::Node(OpKind op, OpList operands, at::IntArrayRef sizes) + : op_(std::move(op)), operands_(std::move(operands)) { + for (auto &oper : operands) + operands_.push_back(oper); + for (auto &size : sizes) + sizes_[0].push_back(size); +} + +std::vector Node::strides(std::vector sz) const { + auto dim = sz.size(); + std::vector ret(dim); + int64_t n = 1; + for (int i = dim - 1; i >= 0; i--) { + ret[i] = n; + n = n * sz[i]; + } + return ret; +} + +mlir::Operation * +Node::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + std::cout << "unsupported node type in Node::genMLIR" << op() << std::endl; + assert(0); +} + +mlir::Operation * +ConstantNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + // the type of the mlir value + mlir::Type mlirTy; + + // the attribuite attached to the mlir value + std::vector attrs; + auto typeId = mlir::Identifier::get("type", &context); + auto valueId = mlir::Identifier::get("value", &context); + + if (scalar) { + if (scalar->isIntegral(false)) { + mlirTy = mlir::IntegerType::get(32, &context); + attrs.emplace_back(typeId, mlir::StringAttr::get("i32", &context)); + attrs.emplace_back(valueId, + mlir::IntegerAttr::get(mlirTy, scalar->to())); + } else if (scalar->isFloatingPoint()) { + mlirTy = mlir::FloatType::getF32(&context); + attrs.emplace_back(typeId, mlir::StringAttr::get("f32", &context)); + attrs.emplace_back(valueId, + mlir::FloatAttr::get(mlirTy, scalar->to())); + } else if (scalar->isBoolean()) { + mlirTy = mlir::IntegerType::get(1, &context); + attrs.emplace_back(typeId, mlir::StringAttr::get("bool", &context)); + attrs.emplace_back( + valueId, mlir::IntegerAttr::get(mlirTy, (int)scalar->to())); + } else { + assert(0 && "unhandled scalar type in ir::ConstantNode"); + } + } else if (array.size() > 0) { + auto iTy = mlir::IntegerType::get(32, &context); + mlirTy = mlir::NPCOMP::aten::ATenListType::get(iTy); + auto vecTy = + mlir::VectorType::get(llvm::ArrayRef(array.size()), iTy); + attrs.emplace_back(typeId, mlir::StringAttr::get("List[i32]", &context)); + std::vector values; + for (auto a : array) + values.push_back((int32_t)a); + attrs.emplace_back( + valueId, DenseElementsAttr::get(vecTy, ArrayRef(values))); + } else if (bool_) { + mlirTy = mlir::IntegerType::get(1, &context); + attrs.emplace_back(typeId, mlir::StringAttr::get("bool", &context)); + attrs.emplace_back(valueId, mlir::IntegerAttr::get(mlirTy, (int)*bool_)); + } else if (int_) { + mlirTy = mlir::IntegerType::get(32, &context); + attrs.emplace_back(typeId, mlir::StringAttr::get("i32", &context)); + attrs.emplace_back(valueId, mlir::IntegerAttr::get(mlirTy, *int_)); + } else if (double_) { + mlirTy = mlir::FloatType::getF64(&context); + attrs.emplace_back(typeId, mlir::StringAttr::get("f64", &context)); + attrs.emplace_back(valueId, mlir::FloatAttr::get(mlirTy, *double_)); + } else if (float_) { + mlirTy = mlir::FloatType::getF32(&context); + attrs.emplace_back(typeId, mlir::StringAttr::get("f32", &context)); + attrs.emplace_back(valueId, mlir::FloatAttr::get(mlirTy, *float_)); + } else { + auto iTy = mlir::IntegerType::get(32, &context); + mlirTy = mlir::NPCOMP::aten::ATenListType::get(iTy); + auto vecTy = + mlir::VectorType::get(llvm::ArrayRef(array.size()), iTy); + attrs.emplace_back(typeId, mlir::StringAttr::get("List[i32]", &context)); + std::vector values; + for (auto a : array) + values.push_back((int32_t)a); + attrs.emplace_back( + valueId, DenseElementsAttr::get(vecTy, ArrayRef(values))); + } + // else { + // assert(0 && "unhandled type in ir::ConstantNode"); + // } + return builder->create( + loc, ArrayRef{mlirTy}, ArrayRef{}, attrs); +} + +mlir::Operation *AdaptiveAvgPool2dNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::_adaptive_avg_pool2d")); + + mlir::Value self = symbolTable[operand(0)]; + mlir::Type selfTy = self.getType(); + auto elemTy = ((mlir::ShapedType *)&selfTy)->getElementType(); + + mlir::Type mlirTy = mlir::RankedTensorType::get(sizes(), elemTy); + + return builder->create( + loc, mlirTy, self, symbolTable[operand(1)]); +} + +mlir::Operation *AdaptiveAvgPool2dBackwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::_adaptive_avg_pool2d_backward")); + + mlir::Value self = symbolTable[operand(1)]; + mlir::Type selfTy = self.getType(); + auto elemTy = ((mlir::ShapedType *)&selfTy)->getElementType(); + + mlir::Type mlirTy = mlir::RankedTensorType::get(sizes(), elemTy); + + return builder->create( + loc, mlirTy, symbolTable[operand(1)], self); +} + +mlir::Operation * +AddNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::add")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto arg2 = symbolTable[operand(2)]; + auto retTy = arg0.getType(); + + return builder->create(loc, retTy, arg0, arg1, + arg2); +} + +mlir::Operation * +AddInPlaceNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::add_")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto arg2 = symbolTable[operand(2)]; + auto retTy = arg0.getType(); + + return builder->create(loc, retTy, arg0, arg1, + arg2); +} + +mlir::Operation * +AddmmNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::addmm")); + + mlir::Type tensorTy = symbolTable[operand(0)].getType(); + auto elemTy = ((mlir::ShapedType *)&tensorTy)->getElementType(); + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto arg2 = symbolTable[operand(2)]; + auto arg3 = symbolTable[operand(3)]; + auto arg4 = symbolTable[operand(4)]; + + return builder->create(loc, retTy, arg0, arg1, + arg2, arg3, arg4); +} + +mlir::Operation * +AsStridedNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::as_strided")); + + mlir::Value self = symbolTable[operand(0)]; + mlir::Type selfTy = self.getType(); + auto elemTy = ((mlir::ShapedType *)&selfTy)->getElementType(); + + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + return builder->create( + loc, retTy, self, symbolTable[operand(1)], symbolTable[operand(2)]); +} + +std::vector AsStridedNode::sizes() const { + + auto input_size = operand(0).sizes(); + + // XXX + // std::cout << "TODO: handle stride!\n"; + + LLVM_DEBUG(llvm::dbgs() << "as strided input size: "); + for (int64_t n : input_size) + LLVM_DEBUG(llvm::dbgs() << n << " "); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + LLVM_DEBUG(llvm::dbgs() << "view size: "); + for (int64_t n : size) + LLVM_DEBUG(llvm::dbgs() << n << " "); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + std::vector output_size; + output_size.resize(size.size()); + + int64_t numel = 1; + for (int64_t n : input_size) + numel *= n; + + int64_t numel_view = 1; + for (int i = size.size() - 1; i >= 0; i--) { + int64_t n = size[i]; + if (n == -1) + n = numel / numel_view; + else if (n <= 0) + assert(n && "unhandled size in AsStridedNode::sizes()"); + output_size[i] = n; + numel_view *= n; + } + + LLVM_DEBUG(llvm::dbgs() << "output size: "); + for (int64_t n : output_size) + LLVM_DEBUG(llvm::dbgs() << n << " "); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + assert(numel == numel_view && "bad size in AsStridedNode::sizes()"); + return output_size; +} + +mlir::Operation * +BatchNormNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::native_batch_norm")); + + mlir::Type tensorTy = symbolTable[operand(0)].getType(); + auto elemTy = ((mlir::ShapedType *)&tensorTy)->getElementType(); + mlir::Type mlirTy = mlir::RankedTensorType::get(sizes(), elemTy); + + std::vector attrs; + std::vector mlirOperands; + + for (auto &op : operands()) + mlirOperands.push_back(symbolTable[op]); + + return builder->create( + loc, + ArrayRef( + std::vector{mlirTy, symbolTable[operand(2)].getType(), + symbolTable[operand(3)].getType()}), + ArrayRef(mlirOperands), attrs); +} + +mlir::Operation *BatchNormBackwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::native_batch_norm_backward")); + + mlir::TensorType tensorTy = + symbolTable[operand(0)].getType().cast(); + mlir::Type elemTy = tensorTy.getElementType(); + mlir::Type ret0Ty = mlir::RankedTensorType::get(sizes(0), elemTy); + mlir::Type ret1Ty = mlir::RankedTensorType::get(sizes(1), elemTy); + mlir::Type ret2Ty = mlir::RankedTensorType::get(sizes(2), elemTy); + + std::vector attrs; + std::vector mlirOperands; + + for (auto &op : operands()) + mlirOperands.push_back(symbolTable[op]); + + return builder->create( + loc, + ArrayRef(std::vector{ret0Ty, ret1Ty, ret2Ty}), + ArrayRef(mlirOperands), attrs); +} + +std::vector BatchNormBackwardNode::sizes(size_t i) const { + if (i == 0) + return operand(0).sizes(); + if (i == 1) + return {operand(1).sizes()[1]}; + if (i == 2) + return {operand(1).sizes()[1]}; + + assert(0 && "bad operand index"); +} + +mlir::Operation * +Conv2dNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::_convolution")); + + mlir::Type tensorTy = symbolTable[operand(0)].getType(); + auto elemTy = ((mlir::ShapedType *)&tensorTy)->getElementType(); + mlir::Type mlirTy = mlir::RankedTensorType::get(sizes(), elemTy); + + std::vector attrs; + std::vector mlirOperands; + + for (auto &op : operands()) + mlirOperands.push_back(symbolTable[op]); + + return builder->create( + loc, ArrayRef{mlirTy}, ArrayRef(mlirOperands), + attrs); +} + +std::vector Conv2dNode::sizes() const { + auto isize = operand(0).sizes(); + auto wsize = operand(1).sizes(); + int64_t osize0 = isize[0]; + int64_t osize1 = wsize[0]; + int64_t osize2 = 1 + ((isize[2] - wsize[2] + 2 * padding[0]) / stride[0]); + int64_t osize3 = 1 + ((isize[3] - wsize[3] + 2 * padding[1]) / stride[1]); + + std::vector osize{osize0, osize1, osize2, osize3}; + return osize; +} + +mlir::Operation *Conv2dBackwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::_convolution_backward")); + + mlir::Type tensorTy = symbolTable[operand(0)].getType(); + auto elemTy = ((mlir::ShapedType *)&tensorTy)->getElementType(); + mlir::Type retTy0 = mlir::RankedTensorType::get(sizes(0), elemTy); + mlir::Type retTy1 = mlir::RankedTensorType::get(sizes(1), elemTy); + mlir::Type retTy2 = mlir::RankedTensorType::get(sizes(2), elemTy); + + std::vector attrs; + std::vector mlirOperands; + + for (auto &op : operands()) + mlirOperands.push_back(symbolTable[op]); + + return builder->create( + loc, ArrayRef{retTy0, retTy1, retTy2}, + ArrayRef(mlirOperands), attrs); +} + +std::vector Conv2dBackwardNode::sizes(size_t index) const { + if (index == 0) + return operand(1).sizes(); + if (index == 1) + return operand(2).sizes(); + else if (index == 2) + return {operand(2).sizes()[0]}; + else + assert(0 && "bad index"); +} + +mlir::Operation * +DivNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::div")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto retTy = arg0.getType(); + + return builder->create(loc, retTy, arg0, arg1); +} + +mlir::Operation * +DivInPlaceNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::div_")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto retTy = arg0.getType(); + + return builder->create(loc, retTy, arg0, + arg1); +} + +mlir::Operation * +ExpandNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::expand")); + + mlir::Value input = symbolTable[operand(0)]; + mlir::Type elemTy = input.getType().cast().getElementType(); + + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + auto size = symbolTable[operand(1)]; + auto implicit = symbolTable[operand(2)]; + + return builder->create(loc, retTy, input, size, + implicit); +} + +mlir::Operation * +GatherNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::gather")); + + mlir::Value input = symbolTable[operand(0)]; + mlir::Type elemTy = input.getType().cast().getElementType(); + + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + auto dim = symbolTable[operand(1)]; + auto index = symbolTable[operand(2)]; + auto sparse_grad = symbolTable[operand(3)]; + + return builder->create(loc, retTy, input, dim, + index, sparse_grad); +} + +mlir::Operation * +HardtanhNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::hardtanh")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto arg2 = symbolTable[operand(2)]; + auto retTy = arg0.getType(); + + return builder->create(loc, retTy, arg0, arg1, + arg2); +} + +mlir::Operation *HardtanhInPlaceNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::hardtanh_")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto arg2 = symbolTable[operand(2)]; + auto retTy = arg0.getType(); + + return builder->create(loc, retTy, arg0, + arg1, arg2); +} + +mlir::Operation *HardtanhBackwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::hardtanh_backward")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto arg2 = symbolTable[operand(2)]; + auto arg3 = symbolTable[operand(3)]; + auto retTy = arg0.getType(); + + return builder->create( + loc, retTy, arg0, arg1, arg2, arg3); +} + +mlir::Operation * +LogSoftmaxNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::_log_softmax")); + + mlir::Value self = symbolTable[operand(0)]; + mlir::Type selfTy = self.getType(); + auto elemTy = ((mlir::ShapedType *)&selfTy)->getElementType(); + + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + auto dim = symbolTable[operand(1)]; + auto half_to_float = symbolTable[operand(2)]; + + return builder->create(loc, retTy, self, + dim, half_to_float); +} + +mlir::Operation *LogSoftmaxBackwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::_log_softmax_backward_data")); + + mlir::Value arg0 = symbolTable[operand(0)]; + mlir::Value arg1 = symbolTable[operand(1)]; + mlir::Value arg2 = symbolTable[operand(2)]; + mlir::Value arg3 = symbolTable[operand(3)]; + + mlir::Type retTy = arg1.getType(); + + return builder->create( + loc, retTy, arg0, arg1, arg2, arg3); +} + +mlir::Operation *MaxPool2dWithIndicesNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::max_pool2d_with_indices")); + + mlir::Type tensorTy = symbolTable[operand(0)].getType(); + auto elemTy = ((mlir::ShapedType *)&tensorTy)->getElementType(); + mlir::Type retTy = mlir::RankedTensorType::get(sizes(0), elemTy); + mlir::Type idxTy = mlir::RankedTensorType::get( + sizes(0), mlir::IntegerType::get(64, &context)); + + std::vector attrs; + std::vector mlirOperands; + + for (auto &op : operands()) + mlirOperands.push_back(symbolTable[op]); + + return builder->create( + loc, ArrayRef{retTy, idxTy}, + ArrayRef(mlirOperands), attrs); +} + +std::vector MaxPool2dWithIndicesNode::sizes(size_t index) const { + auto isize = operand(0).sizes(); + int64_t osize0 = isize[0]; + int64_t osize1 = isize[1]; + // stride can be empty. the default is kernel_size + int64_t stride0 = stride.size() == 2 ? stride[0] : kernel_size[0]; + int64_t stride1 = stride.size() == 2 ? stride[1] : kernel_size[1]; + int64_t osize2 = 1 + ((isize[2] - kernel_size[0] + 2 * padding[0]) / stride0); + int64_t osize3 = 1 + ((isize[3] - kernel_size[1] + 2 * padding[1]) / stride1); + + std::vector osize{osize0, osize1, osize2, osize3}; + return osize; +} + +mlir::Operation *MaxPool2dWithIndicesBackwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::max_pool2d_with_indices_backward")); + + mlir::Type retTy = symbolTable[operand(1)].getType(); + + std::vector attrs; + std::vector mlirOperands; + + for (auto &op : operands()) + mlirOperands.push_back(symbolTable[op]); + + return builder->create( + loc, ArrayRef{retTy}, ArrayRef(mlirOperands), + attrs); +} + +mlir::Operation * +MeanNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::mean")); + + mlir::Value self = symbolTable[operand(0)]; + mlir::Type selfTy = self.getType(); + auto elemTy = ((mlir::ShapedType *)&selfTy)->getElementType(); + + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + return builder->create(loc, retTy, self); +} + +std::vector MeanNode::sizes() const { + + std::vector input_size = operand(0).sizes(); + std::vector output_dims; + std::vector result; + + if (dim.size() == 0) + return {1}; + + for (int64_t n : input_size) { + output_dims.push_back(n); + } + for (int64_t d : dim) { + if (d < 0) + d += output_dims.size(); + + if (keepdim) + output_dims[d] = 1; + else + output_dims[d] = 0; + } + for (int64_t n : output_dims) { + if (n > 0) + result.push_back(n); + } + return result; +} + +mlir::Operation * +MMNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::mm")); + + mlir::Type tensorTy = symbolTable[operand(0)].getType(); + auto elemTy = ((mlir::ShapedType *)&tensorTy)->getElementType(); + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + + return builder->create(loc, retTy, arg0, arg1); +} + +mlir::Operation * +MulNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::mul")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto retTy = arg0.getType(); + + return builder->create(loc, retTy, arg0, arg1); +} + +mlir::Operation * +MulInPlaceNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::mul_")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto retTy = arg0.getType(); + + return builder->create(loc, retTy, arg0, + arg1); +} + +mlir::Operation * +NegNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + assert(op() == ir::OpKind::Get("aten::neg")); + + auto arg0 = symbolTable[operand(0)]; + return builder->create(loc, arg0.getType(), arg0); +} + +mlir::Operation *NllLoss2dForwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::nll_loss2d_forward")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto input = symbolTable[operand(0)]; + + mlir::TensorType tensorTy = input.getType().cast(); + mlir::Type elemTy = tensorTy.getElementType(); + mlir::Type retTy = mlir::RankedTensorType::get(1, elemTy); + + std::vector attrs; + std::vector mlirOperands; + + for (auto &op : operands()) + mlirOperands.push_back(symbolTable[op]); + + return builder->create( + loc, ArrayRef{retTy, retTy}, + ArrayRef(mlirOperands), attrs); +} + +mlir::Operation *NllLoss2dBackwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::nll_loss2d_backward")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto input = symbolTable[operand(1)]; + + mlir::Type retTy = input.getType(); + + std::vector attrs; + std::vector mlirOperands; + + for (auto &op : operands()) + mlirOperands.push_back(symbolTable[op]); + + return builder->create( + loc, ArrayRef{retTy}, ArrayRef(mlirOperands), + attrs); +} + +mlir::Operation *NllLossForwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::nll_loss_forward")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto input = symbolTable[operand(0)]; + + mlir::TensorType tensorTy = input.getType().cast(); + mlir::Type elemTy = tensorTy.getElementType(); + mlir::Type retTy = mlir::RankedTensorType::get(1, elemTy); + + std::vector attrs; + std::vector mlirOperands; + + for (auto &op : operands()) + mlirOperands.push_back(symbolTable[op]); + + return builder->create( + loc, ArrayRef{retTy, retTy}, + ArrayRef(mlirOperands), attrs); +} + +mlir::Operation *NllLossBackwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::nll_loss_backward")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto input = symbolTable[operand(1)]; + + mlir::Type retTy = input.getType(); + + std::vector attrs; + std::vector mlirOperands; + + for (auto &op : operands()) + mlirOperands.push_back(symbolTable[op]); + + return builder->create( + loc, ArrayRef{retTy}, ArrayRef(mlirOperands), + attrs); +} + +mlir::Operation * +SumNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::sum")); + + mlir::Value self = symbolTable[operand(0)]; + mlir::Type selfTy = self.getType(); + auto elemTy = ((mlir::ShapedType *)&selfTy)->getElementType(); + + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + auto dim = symbolTable[operand(1)]; + auto keepdim = symbolTable[operand(2)]; + + return builder->create(loc, retTy, self, dim, + keepdim); +} + +std::vector SumNode::sizes() const { + + std::vector input_size = operand(0).sizes(); + std::vector output_dims; + std::vector result; + + for (int64_t n : input_size) { + output_dims.push_back(n); + } + for (int64_t d : dim) { + if (d < 0) + d += output_dims.size(); + + if (keepdim) + output_dims[d] = 1; + else + output_dims[d] = 0; + } + for (int64_t n : output_dims) { + if (n > 0) + result.push_back(n); + } + + return result; +} + +mlir::Operation * +ReLUNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::relu")); + + auto input = symbolTable[operand(0)]; + return builder->create(loc, input.getType(), + input); +} + +mlir::Operation * +ReLUInPlaceNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::relu_")); + + auto input = symbolTable[operand(0)]; + return builder->create(loc, input.getType(), + input); +} + +mlir::Operation * +SizeNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::size")); + + mlir::Value self = symbolTable[operand(0)]; + mlir::Type tTy = self.getType().cast(); + mlir::Type retTy = mlir::IntegerType::get(32, &context); + std::vector attrs; + auto typeId = mlir::Identifier::get("type", &context); + auto valueId = mlir::Identifier::get("value", &context); + attrs.emplace_back(typeId, mlir::StringAttr::get("i32", &context)); + attrs.emplace_back(valueId, mlir::IntegerAttr::get(retTy, sizes()[dim])); + return builder->create( + loc, ArrayRef{retTy}, ArrayRef{}, attrs); +} + +mlir::Operation * +SqueezeNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::squeeze")); + + mlir::Value self = symbolTable[operand(0)]; + mlir::Type elemTy = self.getType().cast().getElementType(); + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + return builder->create( + loc, retTy, self, symbolTable[operand(1)]); +} + +std::vector SqueezeNode::sizes() const { + std::vector input_size = operand(0).sizes(); + std::vector output_size; + + int input_dim = input_size.size(); + int arg_dim = dim; + assert(arg_dim <= input_dim + 1); + assert(arg_dim >= -input_dim - 1); + + if (arg_dim < 0) + arg_dim = arg_dim + input_dim + 1; + + int i = 1; + for (int64_t n : input_size) { + if (i++ == dim && n == 1) + continue; + output_size.push_back(n); + } + return output_size; +} + +mlir::Operation * +SubNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::sub")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto arg2 = symbolTable[operand(2)]; + auto retTy = arg0.getType(); + + return builder->create(loc, retTy, arg0, arg1, + arg2); +} + +mlir::Operation * +SubInPlaceNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + assert(op() == ir::OpKind::Get("aten::sub_")); + + auto loc = mlir::UnknownLoc::get(&context); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto arg2 = symbolTable[operand(2)]; + auto retTy = arg0.getType(); + + return builder->create(loc, retTy, arg0, arg1, + arg2); +} + +mlir::Operation *ThresholdBackwardNode::genMLIR( + std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::threshold_backward")); + + auto arg0 = symbolTable[operand(0)]; + auto arg1 = symbolTable[operand(1)]; + auto arg2 = symbolTable[operand(2)]; + + return builder->create( + loc, arg0.getType(), arg0, arg1, arg2); +} + +mlir::Operation * +TransposeNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::t")); + + mlir::Value self = symbolTable[operand(0)]; + mlir::Type elemTy = self.getType().cast().getElementType(); + mlir::Type mlirTy = mlir::RankedTensorType::get(sizes(), elemTy); + + return builder->create(loc, mlirTy, self); +} + +mlir::Operation * +UnsqueezeNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::unsqueeze")); + + mlir::Value self = symbolTable[operand(0)]; + mlir::Type elemTy = self.getType().cast().getElementType(); + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + return builder->create( + loc, retTy, self, symbolTable[operand(1)]); +} + +std::vector UnsqueezeNode::sizes() const { + std::vector input_size = operand(0).sizes(); + std::vector output_size; + + int input_dim = input_size.size(); + int arg_dim = dim; + assert(arg_dim <= input_dim + 1); + assert(arg_dim >= -input_dim - 1); + + if (arg_dim < 0) + arg_dim = arg_dim + input_dim + 1; + + int i = 1; + for (int64_t n : input_size) { + if (i++ == dim) + output_size.push_back(1); + output_size.push_back(n); + } + return output_size; +} + +mlir::Operation * +ViewNode::genMLIR(std::unique_ptr &builder, + mlir::MLIRContext &context, + std::map &symbolTable) { + auto loc = mlir::UnknownLoc::get(&context); + + assert(op() == ir::OpKind::Get("aten::view")); + + mlir::Value self = symbolTable[operand(0)]; + mlir::Type elemTy = self.getType().cast().getElementType(); + mlir::Type retTy = mlir::RankedTensorType::get(sizes(), elemTy); + + return builder->create(loc, retTy, self, + symbolTable[operand(1)]); +} + +std::vector ViewNode::sizes() const { + + auto input_size = operand(0).sizes(); + +#if 0 + std::cout << "view input size: "; + for (int64_t n : input_size) + std::cout << n << " "; + std::cout << std::endl; + + std::cout << "view size: "; + for (int64_t n : view_size) + std::cout << n << " "; + std::cout << std::endl; +#endif + + std::vector output_size; + output_size.resize(view_size.size()); + + int64_t numel = 1; + for (int64_t n : input_size) + numel *= n; + + int64_t numel_view = 1; + for (int i = view_size.size() - 1; i >= 0; i--) { + int64_t n = view_size[i]; + if (n == -1) + n = numel / numel_view; + else if (n <= 0) + assert(n && "unhandled size in ViewNode::sizes()"); + output_size[i] = n; + numel_view *= n; + } + + assert(numel == numel_view && "bad size in ViewNode::sizes()"); + return output_size; +} + +} // namespace ir +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/ir.h b/frontends/pytorch/csrc/ir.h new file mode 100644 index 000000000..7c1cac876 --- /dev/null +++ b/frontends/pytorch/csrc/ir.h @@ -0,0 +1,920 @@ +//===- ir.h -----------------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +// This file defines an intermediate IR generated from a pytorch model. +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +class OpBuilder; +class Value; +class Operation; +class MLIRContext; +} // namespace mlir + +#include +#include + +#include +#include +#include +#include + +namespace torch_mlir { +namespace ir { + +class Node; + +void RegisterAtenIR(); + +using NodePtr = std::shared_ptr; + +struct Value { + Value() = default; + Value(NodePtr node, size_t index = 0) : node(std::move(node)), index(index) {} + + operator bool() const { return node != nullptr; } + + bool operator==(const Value &rhs) const { + return node == rhs.node && index == rhs.index; + } + + bool operator<(const Value &rhs) const { + if (node == rhs.node) + return index < rhs.index; + return node < rhs.node; + } + + std::vector sizes() const; + std::vector strides() const; + + NodePtr node; + size_t index = 0; +}; + +struct OpKind { + OpKind() = default; + explicit OpKind(c10::Symbol op) : op(std::move(op)) {} + + bool operator==(const OpKind &rhs) const { return op == rhs.op; } + bool operator!=(const OpKind &rhs) const { return !operator==(rhs); } + bool operator<(const OpKind &rhs) const { + return c10::unique_t(op) < c10::unique_t(rhs.op); + } + + // size_t hash() const; + + std::string ToString() const { return op.toQualString(); } + + static OpKind Get(const std::string &name) { + return OpKind(c10::Symbol::fromQualString(name)); + } + + c10::Symbol op; +}; + +inline std::ostream &operator<<(std::ostream &stream, const OpKind &op) { + stream << op.ToString(); + return stream; +} + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, + const OpKind &op) { + stream << op.ToString(); + return stream; +} + +using OpList = std::vector; + +class Node { + +public: + Node(OpKind op); + Node(OpKind op, OpList operands, std::vector sizes); + Node(OpKind op, OpList operands, at::IntArrayRef sizes); + + const OpKind &op() const { return op_; } + + virtual std::vector sizes() const { return sizes_[0]; } + virtual std::vector sizes(size_t i) const { return sizes_[0]; } + + virtual std::vector strides() const { return strides(sizes()); } + virtual std::vector strides(size_t i) const { + return strides(sizes(i)); + } + + OpList &operands() { return operands_; } + Value operand(size_t i) const { return operands_.at(i); } + + virtual mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable); + +private: + std::vector strides(std::vector sz) const; + + OpKind op_; + OpList operands_; + std::array, 3> sizes_; + // std::array, 3> strides_; +}; + +class ConstantNode : public Node { +public: + ConstantNode(at::Scalar scalar) + : Node(OpKind::Get("aten::constant")), scalar(scalar) {} + + ConstantNode(at::IntArrayRef array) + : Node(OpKind::Get("aten::constant")), array(array.begin(), array.end()) { + } + + ConstantNode(bool bool_) + : Node(OpKind::Get("aten::constant")), bool_(bool_) {} + + ConstantNode(int int_) : Node(OpKind::Get("aten::constant")), int_(int_) {} + + ConstantNode(int64_t int_) + : Node(OpKind::Get("aten::constant")), int_(int_) {} + + ConstantNode(float float_) + : Node(OpKind::Get("aten::constant")), float_(float_) {} + + ConstantNode(double double_) + : Node(OpKind::Get("aten::constant")), double_(double_) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override { return {1}; } + std::vector sizes(size_t i) const override { return sizes(); } + +private: + c10::optional scalar; + std::vector array; + c10::optional bool_; + c10::optional int_; + c10::optional float_; + c10::optional double_; +}; + +class AdaptiveAvgPool2dNode : public Node { +public: + AdaptiveAvgPool2dNode(Value input, at::IntArrayRef kernel_size) + : Node(OpKind::Get("aten::_adaptive_avg_pool2d"), + OpList{input, + ir::Value(std::make_shared(kernel_size))}, + std::vector{input.sizes()[0], input.sizes()[1], + kernel_size[0], kernel_size[1]}) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class AdaptiveAvgPool2dBackwardNode : public Node { +public: + AdaptiveAvgPool2dBackwardNode(Value grad_output, Value self) + : Node(OpKind::Get("aten::_adaptive_avg_pool2d_backward"), + OpList{grad_output, self}, self.sizes()) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class AddNode : public Node { +public: + AddNode(Value rhs, Value lhs, Value alpha) + : Node(OpKind::Get("aten::add"), OpList{rhs, lhs, alpha}, rhs.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class AddInPlaceNode : public Node { +public: + AddInPlaceNode(Value self, Value other, Value alpha) + : Node(OpKind::Get("aten::add_"), OpList{self, other, alpha}, + self.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class AddmmNode : public Node { +public: + AddmmNode(Value input, Value mat1, Value mat2, Value beta, Value alpha) + : Node(OpKind::Get("aten::addmm"), OpList{input, mat1, mat2, beta, alpha}, + std::vector{mat1.sizes()[0], mat2.sizes()[1]}){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class AsStridedNode : public Node { +public: + AsStridedNode(Value input, at::IntArrayRef size, at::IntArrayRef stride, + c10::optional storage_offset) + : Node(OpKind::Get("aten::as_strided"), + OpList{input, ir::Value(std::make_shared(size)), + ir::Value(std::make_shared(stride))}, + input.sizes()), + size(size.begin(), size.end()), stride(stride.begin(), stride.end()), + storage_offset(storage_offset) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override; + std::vector sizes(size_t i) const override { return sizes(); } + + std::vector strides() const override { return stride; } + std::vector strides(size_t i) const override { return strides(); } + + std::vector size; + std::vector stride; + c10::optional storage_offset; +}; + +class BatchNormNode : public Node { +public: + BatchNormNode(Value input, Value weight, Value bias, Value running_mean, + Value running_var, bool training, double momentum, double eps) + : Node(OpKind::Get("aten::native_batch_norm"), + OpList{ + input, weight, bias, running_mean, running_var, + ir::Value(std::make_shared(training)), + ir::Value(std::make_shared((float)momentum)), + ir::Value(std::make_shared((float)eps))}, + input.sizes()), + training(training), momentum(momentum), eps(eps) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + +private: + bool training; + double momentum; + double eps; +}; + +class BatchNormBackwardNode : public Node { +public: + BatchNormBackwardNode(Value grad_out, Value input, Value weight, + Value running_mean, Value running_var, Value save_mean, + Value save_invstd, bool train, double eps, + std::array output_mask) + : Node(OpKind::Get("aten::native_batch_norm_backward"), + OpList{grad_out, input, weight, running_mean, running_var, + save_mean, save_invstd, + ir::Value(std::make_shared(train)), + ir::Value(std::make_shared((float)eps))}, + input.sizes()), + train(train), eps(eps), output_mask(output_mask) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override { + assert(0 && "Cannot call sizes() for multiple outputs"); + } + std::vector sizes(size_t i) const override; + +private: + bool train; + double eps; + std::array output_mask; +}; + +class Conv2dNode : public Node { +public: + Conv2dNode(Value input, Value weight, Value bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, + at::IntArrayRef output_padding, int64_t groups) + : Node(OpKind::Get("aten::_convolution"), + OpList{ + input, weight, bias, + ir::Value(std::make_shared(stride)), + ir::Value(std::make_shared(padding)), + ir::Value(std::make_shared(dilation)), + ir::Value(std::make_shared(transposed)), + ir::Value(std::make_shared(output_padding)), + ir::Value(std::make_shared(groups))}, + input.sizes()), + stride(stride.begin(), stride.end()), + padding(padding.begin(), padding.end()), + dilation(dilation.begin(), dilation.end()), transposed(transposed), + output_padding(output_padding.begin(), output_padding.end()), + groups(groups), has_bias(true) {} + + Conv2dNode(Value input, Value weight, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, + at::IntArrayRef output_padding, int64_t groups) + : Node(OpKind::Get("aten::_convolution"), + OpList{ + input, weight, + ir::Value(std::make_shared(stride)), + ir::Value(std::make_shared(padding)), + ir::Value(std::make_shared(dilation)), + ir::Value(std::make_shared(transposed)), + ir::Value(std::make_shared(output_padding)), + ir::Value(std::make_shared(groups))}, + input.sizes()), + stride(stride.begin(), stride.end()), + padding(padding.begin(), padding.end()), + dilation(dilation.begin(), dilation.end()), transposed(transposed), + output_padding(output_padding.begin(), output_padding.end()), + groups(groups), has_bias(false) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override; + std::vector sizes(size_t i) const override { return sizes(); } + +private: + std::vector stride; + std::vector padding; + std::vector dilation; + bool transposed; + std::vector output_padding; + int64_t groups; + bool has_bias; +}; + +class Conv2dBackwardNode : public Node { +public: + Conv2dBackwardNode(Value grad_output, Value input, Value weight, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, + at::IntArrayRef output_padding, int64_t groups) + : Node(OpKind::Get("aten::_convolution_backward"), + OpList{ + grad_output, input, weight, + ir::Value(std::make_shared(stride)), + ir::Value(std::make_shared(padding)), + ir::Value(std::make_shared(dilation)), + ir::Value(std::make_shared(transposed)), + ir::Value(std::make_shared(output_padding)), + ir::Value(std::make_shared(groups))}, + input.sizes()), + stride(stride.begin(), stride.end()), + padding(padding.begin(), padding.end()), + dilation(dilation.begin(), dilation.end()), transposed(transposed), + output_padding(output_padding.begin(), output_padding.end()), + groups(groups) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override { + assert(0 && "Cannot call sizes() for multiple outputs"); + } + std::vector sizes(size_t i) const override; + +private: + std::vector stride; + std::vector padding; + std::vector dilation; + bool transposed; + std::vector output_padding; + int64_t groups; +}; + +class DivNode : public Node { +public: + DivNode(Value rhs, Value lhs) + : Node(OpKind::Get("aten::div"), OpList{rhs, lhs}, rhs.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class DivInPlaceNode : public Node { +public: + DivInPlaceNode(Value self, Value other) + : Node(OpKind::Get("aten::div_"), OpList{self, other}, self.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class ExpandNode : public Node { +public: + ExpandNode(Value input, at::IntArrayRef size, bool implicit) + : Node(OpKind::Get("aten::expand"), + OpList{input, ir::Value(std::make_shared(size)), + ir::Value(std::make_shared(implicit))}, + input.sizes()), + output_size(size.begin(), size.end()), implicit(implicit) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override { return output_size; } + std::vector sizes(size_t i) const override { return sizes(); } + +private: + std::vector output_size; + bool implicit; +}; + +class GatherNode : public Node { +public: + GatherNode(Value input, int64_t dim, Value index, bool sparse_grad) + : Node(OpKind::Get("aten::gather"), + OpList{input, ir::Value(std::make_shared(dim)), + index, + ir::Value(std::make_shared(sparse_grad))}, + input.sizes()) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class HardtanhNode : public Node { +public: + HardtanhNode(Value self, Value min_val, Value max_val) + : Node(OpKind::Get("aten::hardtanh"), OpList{self, min_val, max_val}, + self.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class HardtanhInPlaceNode : public Node { +public: + HardtanhInPlaceNode(Value self, Value min_val, Value max_val) + : Node(OpKind::Get("aten::hardtanh_"), OpList{self, min_val, max_val}, + self.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class HardtanhBackwardNode : public Node { +public: + HardtanhBackwardNode(Value grad_output, Value self, Value min_val, + Value max_val) + : Node(OpKind::Get("aten::hardtanh_backward"), + OpList{grad_output, self, min_val, max_val}, self.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class LogSoftmaxNode : public Node { +public: + LogSoftmaxNode(Value input, int64_t dim, bool half_to_float) + : Node(OpKind::Get("aten::_log_softmax"), + OpList{ + input, ir::Value(std::make_shared(dim)), + ir::Value(std::make_shared(half_to_float))}, + input.sizes()), + dim(dim), half_to_float(half_to_float) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + +private: + int64_t dim; + bool half_to_float; +}; + +class LogSoftmaxBackwardNode : public Node { +public: + LogSoftmaxBackwardNode(Value grad_output, Value output, int64_t dim, + Value input) + : Node(OpKind::Get("aten::_log_softmax_backward_data"), + OpList{grad_output, output, + ir::Value(std::make_shared(dim)), input}, + input.sizes()), + dim(dim) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + +private: + int64_t dim; +}; + +class MaxPool2dWithIndicesNode : public Node { +public: + MaxPool2dWithIndicesNode(Value input, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode) + : Node(OpKind::Get("aten::max_pool2d_with_indices"), + OpList{input, + ir::Value(std::make_shared(kernel_size)), + ir::Value(std::make_shared(stride)), + ir::Value(std::make_shared(padding)), + ir::Value(std::make_shared(dilation)), + ir::Value(std::make_shared(ceil_mode))}, + input.sizes()), + kernel_size(kernel_size.begin(), kernel_size.end()), + stride(stride.begin(), stride.end()), + padding(padding.begin(), padding.end()), + dilation(dilation.begin(), dilation.end()), ceil_mode(ceil_mode){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override { + assert(0 && "Cannot call sizes() for multiple outputs"); + } + std::vector sizes(size_t i) const override; + +private: + std::vector kernel_size; + std::vector stride; + std::vector padding; + std::vector dilation; + bool ceil_mode; +}; + +class MaxPool2dWithIndicesBackwardNode : public Node { +public: + MaxPool2dWithIndicesBackwardNode(Value grad_output, Value input, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode, + Value indices) + : Node(OpKind::Get("aten::max_pool2d_with_indices_backward"), + OpList{grad_output, input, + ir::Value(std::make_shared(kernel_size)), + ir::Value(std::make_shared(stride)), + ir::Value(std::make_shared(padding)), + ir::Value(std::make_shared(dilation)), + ir::Value(std::make_shared(ceil_mode)), + indices}, + input.sizes()), + kernel_size(kernel_size.begin(), kernel_size.end()), + stride(stride.begin(), stride.end()), + padding(padding.begin(), padding.end()), + dilation(dilation.begin(), dilation.end()), ceil_mode(ceil_mode){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + +private: + std::vector kernel_size; + std::vector stride; + std::vector padding; + std::vector dilation; + bool ceil_mode; +}; + +class MeanNode : public Node { +public: + MeanNode(Value input, at::IntArrayRef dim, bool keepdim, + c10::optional dtype) + : Node(OpKind::Get("aten::mean"), + OpList{input, ir::Value(std::make_shared(dim)), + ir::Value(std::make_shared(keepdim))}, + input.sizes()), + dim(dim.begin(), dim.end()), keepdim(keepdim), dtype(dtype) {} + + MeanNode(Value input, c10::optional dtype) + : Node(OpKind::Get("aten::mean"), OpList{input}, input.sizes()), + dtype(dtype) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override; + std::vector sizes(size_t i) const override { return sizes(); } + +private: + std::vector dim; + bool keepdim; + c10::optional dtype; +}; + +class MMNode : public Node { +public: + MMNode(Value input, Value mat2) + : Node(OpKind::Get("aten::mm"), OpList{input, mat2}, + std::vector{input.sizes()[0], mat2.sizes()[1]}){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class MulNode : public Node { +public: + MulNode(Value rhs, Value lhs) + : Node(OpKind::Get("aten::mul"), OpList{rhs, lhs}, rhs.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class MulInPlaceNode : public Node { +public: + MulInPlaceNode(Value self, Value other) + : Node(OpKind::Get("aten::mul_"), OpList{self, other}, self.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class NegNode : public Node { +public: + NegNode(Value input) + : Node(OpKind::Get("aten::neg"), OpList{input}, input.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class NllLoss2dForwardNode : public Node { +public: + NllLoss2dForwardNode(Value self, Value target, Value weight, + int64_t reduction, int64_t ignore_index) + : Node( + OpKind::Get("aten::nll_loss2d_forward"), + OpList{self, target, weight, + ir::Value(std::make_shared(reduction)), + ir::Value(std::make_shared(ignore_index))}, + 1 /*target.sizes()*/), + reduction(reduction), ignore_index(ignore_index) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + +private: + int64_t reduction; + int64_t ignore_index; +}; + +class NllLoss2dBackwardNode : public Node { +public: + NllLoss2dBackwardNode(Value grad_output, Value self, Value target, + Value weight, int64_t reduction, int64_t ignore_index, + Value total_weight) + : Node(OpKind::Get("aten::nll_loss2d_backward"), + OpList{grad_output, self, target, weight, + ir::Value(std::make_shared(reduction)), + ir::Value(std::make_shared(ignore_index)), + total_weight}, + self.sizes()), + reduction(reduction), ignore_index(ignore_index) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + +private: + int64_t reduction; + int64_t ignore_index; +}; + +class NllLossForwardNode : public Node { +public: + NllLossForwardNode(Value self, Value target, Value weight, int64_t reduction, + int64_t ignore_index) + : Node( + OpKind::Get("aten::nll_loss_forward"), + OpList{self, target, weight, + ir::Value(std::make_shared(reduction)), + ir::Value(std::make_shared(ignore_index))}, + 1 /*target.sizes()*/), + reduction(reduction), ignore_index(ignore_index) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + +private: + int64_t reduction; + int64_t ignore_index; +}; + +class NllLossBackwardNode : public Node { +public: + NllLossBackwardNode(Value grad_output, Value self, Value target, Value weight, + int64_t reduction, int64_t ignore_index, + Value total_weight) + : Node(OpKind::Get("aten::nll_loss_backward"), + OpList{grad_output, self, target, weight, + ir::Value(std::make_shared(reduction)), + ir::Value(std::make_shared(ignore_index)), + total_weight}, + self.sizes()), + reduction(reduction), ignore_index(ignore_index) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + +private: + int64_t reduction; + int64_t ignore_index; +}; + +class SumNode : public Node { +public: + SumNode(Value input, at::IntArrayRef dim, bool keepdim, + c10::optional dtype) + : Node(OpKind::Get("aten::sum"), + OpList{input, ir::Value(std::make_shared(dim)), + ir::Value(std::make_shared(keepdim))}, + input.sizes()), + dim(dim.begin(), dim.end()), keepdim(keepdim), dtype(dtype) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override; + std::vector sizes(size_t i) const override { return sizes(); } + +private: + std::vector dim; + bool keepdim; + c10::optional dtype; +}; + +class ReLUNode : public Node { +public: + ReLUNode(Value input) + : Node(OpKind::Get("aten::relu"), OpList{input}, input.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class ReLUInPlaceNode : public Node { +public: + ReLUInPlaceNode(Value input) + : Node(OpKind::Get("aten::relu_"), OpList{input}, input.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class ThresholdBackwardNode : public Node { +public: + ThresholdBackwardNode(Value grad_output, Value input, Value threshold) + : Node(OpKind::Get("aten::threshold_backward"), + OpList{grad_output, input, threshold}, input.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class TransposeNode : public Node { +public: + TransposeNode(Value input) + : Node(OpKind::Get("aten::t"), OpList{input}, + std::vector{input.sizes()[1], input.sizes()[0]}){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class SizeNode : public Node { +public: + SizeNode(Value input, int64_t dim) + : Node(OpKind::Get("aten::size"), + OpList{input, ir::Value(std::make_shared(dim))}, + 1), + dim(dim) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + +private: + int64_t dim; +}; + +class SqueezeNode : public Node { +public: + SqueezeNode(Value input, int64_t dim) + : Node(OpKind::Get("aten::squeeze"), + OpList{input, ir::Value(std::make_shared(dim))}, + input.sizes()), + dim(dim) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override; + std::vector sizes(size_t i) const override { return sizes(); } + +private: + int64_t dim; +}; + +class SubNode : public Node { +public: + SubNode(Value rhs, Value lhs, Value alpha) + : Node(OpKind::Get("aten::sub"), OpList{rhs, lhs, alpha}, rhs.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class SubInPlaceNode : public Node { +public: + SubInPlaceNode(Value self, Value other, Value alpha) + : Node(OpKind::Get("aten::sub_"), OpList{self, other, alpha}, + self.sizes()){}; + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; +}; + +class UnsqueezeNode : public Node { +public: + UnsqueezeNode(Value input, int64_t dim) + : Node(OpKind::Get("aten::unsqueeze"), + OpList{input, ir::Value(std::make_shared(dim))}, + input.sizes()), + dim(dim) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override; + std::vector sizes(size_t i) const override { return sizes(); } + +private: + int64_t dim; +}; + +class ViewNode : public Node { +public: + ViewNode(Value input, at::IntArrayRef size) + : Node(OpKind::Get("aten::view"), + OpList{input, ir::Value(std::make_shared(size))}, + input.sizes()), + view_size(size.begin(), size.end()) {} + + mlir::Operation * + genMLIR(std::unique_ptr &builder, mlir::MLIRContext &context, + std::map &symbolTable) override; + + std::vector sizes() const override; + std::vector sizes(size_t i) const override { return sizes(); } + +private: + std::vector view_size; +}; + +class TorchDataNode : public Node { + +public: + TorchDataNode(at::Tensor tensor) + : Node(ir::OpKind::Get("aten::torch_data"), {}, tensor.sizes()), + tensor_(std::move(tensor)) {} + + at::Tensor tensor() { return tensor_; } + +private: + at::Tensor tensor_; +}; + +} // namespace ir +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/jit.cpp b/frontends/pytorch/csrc/jit.cpp new file mode 100644 index 000000000..801cafbb9 --- /dev/null +++ b/frontends/pytorch/csrc/jit.cpp @@ -0,0 +1,333 @@ +//===- jit.cpp --------------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +// This file drives the generation and lowering of MLIR, followed by JIT +// compiling the resulting LLVM dialect. + +#include "npcomp/Dialect/ATen/ATenDialect.h" +#include "npcomp/Dialect/ATen/ATenPasses.h" + +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/JitRunner.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/Support/Debug.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" + +#include + +#include "ATen/ArrayRef.h" +namespace at { +template using ArrayRef = c10::ArrayRef; +} +#include "ATen/Tensor.h" +#include + +#include "jit.h" +#include "mlir_gen.h" +#include "tensor.h" +#include "torch_util.h" + +#define DEBUG_TYPE "torch_mlir" + +using namespace mlir; + +namespace torch_mlir { + +namespace { + +int LowerATenDialect(mlir::ModuleOp module) { + PassManager pm0(module.getContext()); + pm0.addPass(mlir::createCSEPass()); + + // Lower to function calls. + pm0.addPass(mlir::NPCOMP::aten::createATenLoweringPass()); + pm0.addPass(mlir::NPCOMP::aten::createReturnEliminationPass()); + + if (failed(pm0.run(module))) { + llvm::errs() << "aten to loops conversion failed "; + return 1; + } + + PassManager pm1(module.getContext()); + pm1.addPass(mlir::createLowerAffinePass()); + pm1.addPass(mlir::createLowerToCFGPass()); + pm1.addPass(mlir::createCSEPass()); + + if (failed(pm1.run(module))) { + llvm::errs() << "loops to std conversion failed "; + return 1; + } + + return 0; +} + +int LowerStdDialect(mlir::ModuleOp module) { + PassManager pm(module.getContext()); + + struct LowerToLLVMOptions options; + options.emitCWrappers = true; + LLVM_DEBUG(module.print(llvm::outs())); + + pm.addPass(mlir::createLowerToLLVMPass(options)); + pm.addPass(mlir::createCSEPass()); + + LLVM_DEBUG(module.print(llvm::outs())); + + if (failed(pm.run(module))) { + llvm::errs() << "std to llvm conversion failed "; + return 1; + } + + if (!module) + return 1; + return 0; +} + +template struct llvm_tensor_t { + T *d; + T *aligned; + size_t offset; + size_t shape[N]; + size_t stride[N]; +}; + +template void *setupArg(at::Tensor &t) { + llvm_tensor_t *arg = new llvm_tensor_t; + llvm_tensor_t **arg_storage = new llvm_tensor_t *; + *arg_storage = arg; + arg->d = arg->aligned = (T *)t.data_ptr(); + arg->offset = 0; + assert(t.dim() == N); + for (int j = 0; j < N; j++) { + arg->shape[j] = t.sizes()[j]; + arg->stride[j] = t.stride(j); + } + return (void *)arg_storage; +} + +at::Tensor LowerAndRun(mlir::ModuleOp module, + std::vector &arguments, const ir::Value &v, + mlir::MLIRContext &context) { + + LowerATenDialect(module); + LowerStdDialect(module); + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + Optional jitCodeGenOptLevel = + llvm::CodeGenOpt::Level::Aggressive; + std::string libpath; + if (const char *path = std::getenv("TEST_BUILD_PATH")) { + libpath = path; + } + + std::vector sharedLibs{libpath + + "/frontends/pytorch/lib/libaten_ops.so"}; + llvm::errs() << "Loading " << sharedLibs[0] << "\n"; + + llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr); + + llvm::SmallVector libs(sharedLibs.begin(), + sharedLibs.end()); + auto expectedEngine = mlir::ExecutionEngine::create( + module, {}, jitCodeGenOptLevel, libs, false, false, false); + assert(expectedEngine && "no engine, cannot fly"); + + llvm::StringRef entryPoint("_mlir_ciface_graph"); + auto engine = std::move(*expectedEngine); + auto expectedFPtr = engine->lookup(entryPoint); + assert(expectedFPtr && "entryPoint missing"); + + void (*fptr)(void **) = *expectedFPtr; + + // this array holds pointers to the function arguments + void **args = (void **)malloc((arguments.size() + 1) * sizeof(void *)); + + // allocate and setup the function arguments + for (int i = 0, e = arguments.size(); i < e; i++) { + at::Tensor &t = arguments[i]; + auto dtype = t.dtype(); + int dim = t.dim(); + if (dim == 4) { + if (dtype == at::kFloat) + args[i] = setupArg(t); + else if (dtype == at::kLong) + args[i] = setupArg(t); + else + assert(0); + } else if (dim == 3) { + if (dtype == at::kFloat) + args[i] = setupArg(t); + else if (dtype == at::kLong) + args[i] = setupArg(t); + else + assert(0); + } else if (dim == 2) { + if (dtype == at::kFloat) + args[i] = setupArg(t); + else if (dtype == at::kLong) + args[i] = setupArg(t); + else + assert(0); + } else if (dim == 1) { + if (dtype == at::kFloat) + args[i] = setupArg(t); + else if (dtype == at::kLong) + args[i] = setupArg(t); + else + assert(0); + } else { + assert(0 && "unhandled dim"); + } + } + + // allocate the result tensors + // TODO: num results > 1 + at::Tensor result = util::Zeros(v.sizes(), at::kFloat); + if (result.dim() == 4) { + args[arguments.size()] = setupArg(result); + } else if (result.dim() == 3) { + args[arguments.size()] = setupArg(result); + } else if (result.dim() == 2) { + args[arguments.size()] = setupArg(result); + } else if (result.dim() == 1) { + args[arguments.size()] = setupArg(result); + } else { + assert(0 && "unhandled dim"); + } + + // call the JITed function + fptr(args); + + // free pointers to the results + // TODO: num results > 1 + if (result.dim() == 4) { + auto arg_storage = + static_cast **>(args[arguments.size()]); + auto arg = *arg_storage; + delete arg; + delete arg_storage; + } else if (result.dim() == 3) { + auto arg_storage = + static_cast **>(args[arguments.size()]); + auto arg = *arg_storage; + delete arg; + delete arg_storage; + } else if (result.dim() == 2) { + auto arg_storage = + static_cast **>(args[arguments.size()]); + auto arg = *arg_storage; + delete arg; + delete arg_storage; + } else if (result.dim() == 1) { + auto arg_storage = + static_cast **>(args[arguments.size()]); + auto arg = *arg_storage; + delete arg; + delete arg_storage; + } else { + assert(0 && "unhandled dim"); + } + + // free pointers to the arguments + for (int i = 0, e = arguments.size(); i < e; i++) { + at::Tensor &t = arguments[i]; + int dim = t.dim(); + if (dim == 4) { + auto arg_storage = static_cast **>(args[i]); + auto arg = *arg_storage; + delete arg; + delete arg_storage; + } else if (dim == 3) { + auto arg_storage = static_cast **>(args[i]); + auto arg = *arg_storage; + delete arg; + delete arg_storage; + } else if (dim == 2) { + auto arg_storage = static_cast **>(args[i]); + auto arg = *arg_storage; + delete arg; + delete arg_storage; + } else if (dim == 1) { + auto arg_storage = static_cast **>(args[i]); + auto arg = *arg_storage; + delete arg; + delete arg_storage; + } else { + assert(0 && "unhandled dim"); + } + } + + // free the array of void* ptrs + free(args); + + return result; +} + +at::Tensor JitAndRun(const ir::Value &v, mlir::MLIRContext &context) { + + // generate the MLIR + std::vector vs{v}; + auto mlir_gen = MLIRGen(context).genModule(vs); + mlir::OwningModuleRef module = std::move(std::get<0>(mlir_gen)); + std::vector arguments = std::move(std::get<1>(mlir_gen)); + + return LowerAndRun(module.get(), arguments, v, context); +} + +at::Tensor JitAndRun(const ir::Value &v) { + mlir::MLIRContext context; + return JitAndRun(v, context); +} + +at::Tensor Interpret(const ir::Value &v) { assert(0 && "unsupported"); } +} // anonymous namespace + +// FIXME: Why is this code here and not in tensor.cpp? +std::string MLIRTensor::GetMLIR() const { + + // generate the MLIR + mlir::MLIRContext context; + ir::Value ir_value = CurrentIrValue(); + if (!ir_value) + return ""; + + std::vector vs{ir_value}; + auto mlir_gen = MLIRGen(context).genModule(vs); + mlir::OwningModuleRef module = std::move(std::get<0>(mlir_gen)); + + std::string aten; + llvm::raw_string_ostream ss(aten); + module->print(ss); + return ss.str(); +} + +at::Tensor MLIRTensor::CompileAndRun() const { + return JitAndRun(CurrentIrValue()); +} + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/jit.h b/frontends/pytorch/csrc/jit.h new file mode 100644 index 000000000..7da99ac26 --- /dev/null +++ b/frontends/pytorch/csrc/jit.h @@ -0,0 +1,16 @@ +//===- jit.h ----------------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace torch_mlir { +// namespace jit { + +// at::Tensor CompileAndRun(const MLIRTensor &tensor); +// at::Tensor JitAndRun(const ir::Value &v); +//} +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/mlir_gen.cpp b/frontends/pytorch/csrc/mlir_gen.cpp new file mode 100644 index 000000000..0e0df9871 --- /dev/null +++ b/frontends/pytorch/csrc/mlir_gen.cpp @@ -0,0 +1,207 @@ +//===- mlir_gen.cpp ---------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Verifier.h" + +#include "llvm/Support/Debug.h" + +#include "ATen/ArrayRef.h" +namespace at { +template using ArrayRef = c10::ArrayRef; +} +#include "ATen/Tensor.h" + +#include "ir.h" +#include "mlir_gen.h" + +#include +#include + +#define DEBUG_TYPE "torch_mlir" + +namespace torch_mlir { + +std::tuple> +MLIRGen::genModule(std::vector &v) { + // the module + module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + + auto fn = genFunction(v); + if (fn) { + module->push_back(fn); + if (failed(mlir::verify(*module))) { + emitError(mlir::UnknownLoc::get(&context), "module verification error"); + } + } + return std::make_tuple(std::move(module), arguments); +} + +mlir::Value MLIRGen::genValue(const ir::Value &v) { + + if (symbolTable.count(v)) + return symbolTable[v]; + + LLVM_DEBUG(llvm::dbgs() << "genValue node: " << v.node->op() << "\n"); + + ir::NodePtr node = v.node; + auto loc = mlir::UnknownLoc::get(&context); + + for (auto &operand : node->operands()) + genValue(operand); + + mlir::Value mlirValue = nullptr; + if (opTable.count(v.node)) { + mlirValue = opTable[v.node]->getResult(v.index); + } else { + mlir::Operation *mlirOp = node->genMLIR(builder, context, symbolTable); + opTable.insert({v.node, mlirOp}); + assert(mlirOp && "failed to generate mlir op"); + mlirValue = mlirOp->getResult(v.index); + } + + declareSymbol(v, mlirValue); + + return mlirValue; +} + +// generate function parameters for the IR rooted at v +void MLIRGen::genParameters(const ir::Value &v, std::set &visited) { + ir::NodePtr node = v.node; + if (visited.count(v)) + return; + visited.insert(v); + for (const ir::Value &operand : node->operands()) { + // if the operand is a leaf + if (operand.node->op() == ir::OpKind::Get("aten::torch_data")) { + parameters.push_back(operand); + } else { + genParameters(operand, visited); + } + } +} + +mlir::FuncOp MLIRGen::genFunction(std::vector &vs) { + + auto loc = mlir::UnknownLoc::get(&context); + + auto gen_tensor_ty = [&](const ir::Value &v) { + auto shape = v.sizes(); + auto tdn = dynamic_cast(v.node.get()); + mlir::Type elemTy; + if (tdn) { + auto dtype = tdn->tensor().dtype(); + if (dtype == at::kFloat) + elemTy = mlir::FloatType::getF32(&context); + else if (dtype == at::kDouble) + elemTy = mlir::FloatType::getF64(&context); + else if (dtype == at::kLong) + elemTy = mlir::IntegerType::get(64, &context); + else if (dtype == at::kInt) + elemTy = mlir::IntegerType::get(32, &context); + else if (dtype == at::kShort) + elemTy = mlir::IntegerType::get(16, &context); + else if (dtype == at::kChar || dtype == at::kByte) + elemTy = mlir::IntegerType::get(8, &context); + else { + std::cout << tdn->tensor().dtype() << "\n"; + assert(0 && "bad type"); + } + } else { + elemTy = mlir::FloatType::getF32(&context); + } + return mlir::RankedTensorType::get(shape, elemTy); + }; + + std::set visited; + for (auto &v : vs) + genParameters(v, visited); + + std::map parameter_map; + std::vector unique_parameters; + + for (const ir::Value &p : parameters) { + bool found = false; + for (const ir::Value &q : unique_parameters) { + if (p.node->op() == ir::OpKind::Get("aten::torch_data") && + q.node->op() == ir::OpKind::Get("aten::torch_data")) { + auto &ptd = *dynamic_cast(p.node.get()); + auto &qtd = *dynamic_cast(q.node.get()); + if (ptd.tensor().is_same(qtd.tensor())) { + found = true; + parameter_map.insert({p, q}); + break; + } + } + } + if (!found) { + unique_parameters.push_back(p); + } + } + + // collect the argument types and tensors + std::vector arg_types; + for (const ir::Value &p : unique_parameters) { + // tensor type for the function signature + arg_types.push_back(gen_tensor_ty(p)); + + // tensor itself for actually calling the graph + auto tdn = dynamic_cast(p.node.get()); + arguments.push_back(tdn->tensor()); + } + + // construct return type + std::vector ret_types; + for (auto &v : vs) + ret_types.push_back(gen_tensor_ty(v)); + + // create the function type and the function itself + auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); + auto function = + mlir::FuncOp::create(loc, "graph", func_type, /* attrs = */ {}); + + // entry + auto &entryBlock = *function.addEntryBlock(); + + // Declare all the function arguments in the symbol table. + for (const auto &i : + llvm::zip(unique_parameters, entryBlock.getArguments())) { + declareSymbol(std::get<0>(i), std::get<1>(i)); + } + // Declare all the duplicates from the original + // parameter list in the symbol table + for (auto &k_v : parameter_map) { + assert(symbolTable.count(k_v.second)); + declareSymbol(k_v.first, symbolTable[k_v.second]); + } + + builder = std::make_unique(function.getBody()); + + std::vector rets; + for (auto &v : vs) + rets.push_back(genValue(v)); + + builder->create(loc, rets); + return function; +} + +bool MLIRGen::declareSymbol(const ir::Value &irValue, mlir::Value mlirValue) { + if (symbolTable.count(irValue)) { + return false; + } + symbolTable.insert({irValue, mlirValue}); + return true; +} + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/mlir_gen.h b/frontends/pytorch/csrc/mlir_gen.h new file mode 100644 index 000000000..584604593 --- /dev/null +++ b/frontends/pytorch/csrc/mlir_gen.h @@ -0,0 +1,45 @@ +//===- mlir_gen.h -----------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/IR/MLIRContext.h" + +#include "ir.h" + +namespace torch_mlir { + +/// This class generates MLIR from a pytorch graph +class MLIRGen { + +public: + MLIRGen(mlir::MLIRContext &context) : context(context){}; + + // Generate an MLIR model that computes the given outputs. + std::tuple> + genModule(std::vector &v); + +private: + mlir::Value genValue(const ir::Value &v); + + void genParameters(const ir::Value &v, std::set &visited); + + mlir::FuncOp genFunction(std::vector &v); + + bool declareSymbol(const ir::Value &irValue, mlir::Value mlirValue); + +private: + mlir::MLIRContext &context; + mlir::OwningModuleRef module; + std::unique_ptr builder; + std::map symbolTable; + std::map opTable; + std::vector parameters; + std::vector arguments; +}; + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/tensor.cpp b/frontends/pytorch/csrc/tensor.cpp new file mode 100644 index 000000000..df9451c94 --- /dev/null +++ b/frontends/pytorch/csrc/tensor.cpp @@ -0,0 +1,613 @@ +//===- tensor.cpp -----------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/Debug.h" + +#include "ATen/ArrayRef.h" +namespace at { +template using ArrayRef = c10::ArrayRef; +} +#include "ATen/Tensor.h" + +#include "jit.h" +#include "tensor.h" + +#include + +#define DEBUG_TYPE "torch_mlir" + +namespace torch_mlir { + +MLIRTensor MLIRTensor::Create(const at::Tensor &tensor, const Device &device) { + assert(tensor.device().type() == at::kCPU); + MLIRTensor device_tensor(tensor, device); + return device_tensor; +} + +MLIRTensor +MLIRTensor::Create(ir::Value ir_value, const Device &device, + c10::optional logical_element_type) { + MLIRTensor device_tensor(std::move(ir_value), device, logical_element_type); + return device_tensor; +} + +MLIRTensor::MLIRTensor(const at::Tensor &tensor, const Device &device) + : data_(std::make_shared(tensor, device)) {} + +MLIRTensor::MLIRTensor(ir::Value ir_value, const Device &device, + c10::optional logical_element_type) + : data_(std::make_shared(std::move(ir_value), device, + logical_element_type)) {} + +MLIRTensor::Data *MLIRTensor::data() const { + assert(data_ != nullptr && "Trying to access null data"); + return data_.get(); +} + +at::ScalarType MLIRTensor::dtype() const { + return data()->logical_element_type ? *data()->logical_element_type + : at::ScalarType::Float; +} + +const Device &MLIRTensor::GetDevice() const { return data()->device; } + +uint64_t MLIRTensor::GetNextTensorId() { + static std::atomic *id_generator = new std::atomic(1); + return id_generator->fetch_add(1); +} + +void MLIRTensor::SetTensorData(at::Tensor tensor_data) { + data()->tensor_data = std::move(tensor_data); +} + +ir::Value MLIRTensor::GetIrValue() const { + ir::Value ir_value = CurrentIrValue(); + if (ir_value) { + return ir_value; + } + c10::optional tensor_data = CurrentTensorData(); + if (tensor_data) { + at::Tensor tensor = *tensor_data; + if (!tensor.dim()) { + auto dtype = tensor.dtype(); + if (dtype == at::kFloat) { + auto d = tensor.data_ptr(); + return ir::Value(std::make_shared(d[0])); + } else if (dtype == at::kDouble) { + auto d = tensor.data_ptr(); + return ir::Value(std::make_shared(d[0])); + } else if (dtype == at::kLong) { + auto d = tensor.data_ptr(); + return ir::Value(std::make_shared(d[0])); + } else if (dtype == at::kInt) { + auto d = tensor.data_ptr(); + return ir::Value(std::make_shared(d[0])); + } else if (dtype == at::kShort) { + auto d = tensor.data_ptr(); + return ir::Value(std::make_shared(d[0])); + } else if (dtype == at::kChar || dtype == at::kByte) { + auto d = tensor.data_ptr(); + return ir::Value(std::make_shared(d[0])); + } + // fall through to TorchDataNode below + } + return ir::Value(std::make_shared(*tensor_data)); + } + assert(0 && "Could not create ir value from leaf tensor"); + return ir::Value(); +} + +ir::Value MLIRTensor::CurrentIrValue() const { return data()->ir_value; } + +void MLIRTensor::SetIrValue(ir::Value ir_value) { + data()->generation += 1; + data()->ir_value = std::move(ir_value); +} + +c10::optional MLIRTensor::CurrentTensorData() const { + return data()->tensor_data; +} + +void MLIRTensor::SetTensor(at::Tensor tensor) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + SetTensorData(tensor); + data()->generation += 1; +} + +at::Tensor MLIRTensor::ToTensor() const { + c10::optional tensor_data = CurrentTensorData(); + if (!tensor_data) + tensor_data = CompileAndRun(); + assert(tensor_data); + return *tensor_data; +} + +void MLIRTensor::ShallowCopyTo(MLIRTensor *dest) const { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + + auto data = CurrentTensorData(); + if (data) + dest->SetTensor(*data); + else + dest->SetIrValue(CurrentIrValue()); + + dest->SetScalarType(dtype()); + assert(GetDevice() == dest->GetDevice()); +} + +void MLIRTensor::SetScalarType( + c10::optional logical_element_type) { + data()->logical_element_type = logical_element_type; +} + +std::vector MLIRTensor::sizes() const { + if (data()->ir_value) { + return data()->ir_value.sizes(); + } + assert(data()->tensor_data && "tensor has no shape information"); + if (data()->tensor_data) { + auto s = data()->tensor_data->sizes(); + return {s.begin(), s.end()}; + } + return {}; +} + +std::vector MLIRTensor::strides() const { + if (data()->ir_value) { + return data()->ir_value.strides(); + } + assert(data()->tensor_data && "tensor has no shape information"); + if (data()->tensor_data) { + auto s = data()->tensor_data->strides(); + return {s.begin(), s.end()}; + } + return {}; +} + +MLIRTensor MLIRTensor::CreateFrom(ir::Value ir_value) const { + return Create(std::move(ir_value), GetDevice(), dtype()); +} + +//////////////////////////////////////////// +// aten tensor methods +//////////////////////////////////////////// + +MLIRTensor MLIRTensor::_adaptive_avg_pool2d(const MLIRTensor &self, + at::IntArrayRef output_size) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), output_size); + return self.CreateFrom(node); +} + +MLIRTensor +MLIRTensor::_adaptive_avg_pool2d_backward(const MLIRTensor &grad_output, + const MLIRTensor &self) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared( + grad_output.GetIrValue(), self.GetIrValue()); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::add(const MLIRTensor &self, const MLIRTensor &other, + at::Scalar alpha) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), other.GetIrValue(), + ir::Value(std::make_shared(alpha))); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::add_(MLIRTensor &self, const MLIRTensor &other, + at::Scalar alpha) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), other.GetIrValue(), + ir::Value(std::make_shared(alpha))); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::addmm(const MLIRTensor &input, const MLIRTensor &mat1, + const MLIRTensor &mat2, at::Scalar beta, + at::Scalar alpha) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + input.GetIrValue(), mat1.GetIrValue(), mat2.GetIrValue(), + ir::Value(std::make_shared(beta)), + ir::Value(std::make_shared(alpha))); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::as_strided(const MLIRTensor &input, at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + input.GetIrValue(), size, stride, storage_offset); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::clone(const MLIRTensor &input) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + return MLIRTensor::Create(std::move(input.ToTensor()), input.GetDevice()); +} + +MLIRTensor MLIRTensor::convolution( + const MLIRTensor &input, const MLIRTensor &weight, const MLIRTensor &bias, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + input.GetIrValue(), weight.GetIrValue(), bias.GetIrValue(), stride, + padding, dilation, transposed, output_padding, groups); + return input.CreateFrom(node); +} + +std::tuple MLIRTensor::convolution_backward( + const MLIRTensor &grad_output, const MLIRTensor &input, + const MLIRTensor &weight, at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, + int64_t groups, std::array output_mask) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + grad_output.GetIrValue(), input.GetIrValue(), weight.GetIrValue(), stride, + padding, dilation, transposed, output_padding, groups /*, output_mask*/); + auto result0 = input.CreateFrom(ir::Value(node, 0)); + auto result1 = input.CreateFrom(ir::Value(node, 1)); + auto result2 = input.CreateFrom(ir::Value(node, 2)); + return std::make_tuple(result0, result1, result2); +} + +void MLIRTensor::copy_(MLIRTensor &self, MLIRTensor &src) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + src.ShallowCopyTo(&self); +} + +MLIRTensor MLIRTensor::div(const MLIRTensor &self, at::Scalar other) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), ir::Value(std::make_shared(other))); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::div(const MLIRTensor &self, const MLIRTensor &other) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(self.GetIrValue(), other.GetIrValue()); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::div_(MLIRTensor &self, const MLIRTensor &other) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), other.GetIrValue()); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::expand(const MLIRTensor &self, at::IntArrayRef size, + bool implicit) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(self.GetIrValue(), size, implicit); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::gather(const MLIRTensor &self, int64_t dim, + const MLIRTensor &index, bool sparse_grad) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), dim, index.GetIrValue(), sparse_grad); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::hardtanh(const MLIRTensor &self, at::Scalar min_val, + at::Scalar max_val) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), ir::Value(std::make_shared(min_val)), + ir::Value(std::make_shared(max_val))); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::hardtanh_(MLIRTensor &self, at::Scalar min_val, + at::Scalar max_val) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), ir::Value(std::make_shared(min_val)), + ir::Value(std::make_shared(max_val))); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::hardtanh_backward(const MLIRTensor &grad_output, + const MLIRTensor &self, + at::Scalar min_val, + at::Scalar max_val) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + grad_output.GetIrValue(), self.GetIrValue(), + ir::Value(std::make_shared(min_val)), + ir::Value(std::make_shared(max_val))); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::_log_softmax(const MLIRTensor &input, int64_t dim, + bool half_to_float) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + input.GetIrValue(), dim, half_to_float); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::_log_softmax_backward_data(const MLIRTensor &grad_output, + const MLIRTensor &output, + int64_t dim, + const MLIRTensor &input) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + grad_output.GetIrValue(), output.GetIrValue(), dim, input.GetIrValue()); + return input.CreateFrom(node); +} + +std::tuple MLIRTensor::max_pool2d_with_indices( + const MLIRTensor &input, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool ceil_mode) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared( + input.GetIrValue(), kernel_size, stride, padding, dilation, + ceil_mode); + auto result0 = input.CreateFrom(ir::Value(node, 0)); + auto result1 = input.CreateFrom(ir::Value(node, 1)); + return std::make_tuple(result0, result1); +} + +MLIRTensor MLIRTensor::max_pool2d_with_indices_backward( + const MLIRTensor &grad_output, const MLIRTensor &input, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const MLIRTensor &indices) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared( + grad_output.GetIrValue(), input.GetIrValue(), kernel_size, stride, + padding, dilation, ceil_mode, indices.GetIrValue()); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::mean(const MLIRTensor &input, + c10::optional dtype) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue(), dtype); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::mean(const MLIRTensor &input, at::IntArrayRef dim, + bool keepdim, c10::optional dtype) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue(), dim, keepdim, dtype); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::mm(const MLIRTensor &input, const MLIRTensor &mat1) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue(), mat1.GetIrValue()); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::mul(const MLIRTensor &self, const MLIRTensor &other) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(self.GetIrValue(), other.GetIrValue()); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::mul_(MLIRTensor &self, const MLIRTensor &other) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), other.GetIrValue()); + return self.CreateFrom(node); +} + +std::tuple MLIRTensor::native_batch_norm( + const MLIRTensor &self, const MLIRTensor &weight, const MLIRTensor &bias, + const MLIRTensor &running_mean, const MLIRTensor &running_var, + bool training, double momentum, double eps) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), weight.GetIrValue(), bias.GetIrValue(), + running_mean.GetIrValue(), running_var.GetIrValue(), training, momentum, + eps); + auto result0 = self.CreateFrom(ir::Value(node, 0)); + auto result1 = self.CreateFrom(ir::Value(node, 1)); + auto result2 = self.CreateFrom(ir::Value(node, 2)); + return std::make_tuple(result0, result1, result2); +} + +std::tuple +MLIRTensor::native_batch_norm_backward( + const MLIRTensor &grad_out, const MLIRTensor &input, + const MLIRTensor &weight, const MLIRTensor &running_mean, + const MLIRTensor &running_var, const MLIRTensor &save_mean, + const MLIRTensor &save_invstd, bool train, double eps, + std::array output_mask) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + grad_out.GetIrValue(), input.GetIrValue(), weight.GetIrValue(), + running_mean.GetIrValue(), running_var.GetIrValue(), + save_mean.GetIrValue(), save_invstd.GetIrValue(), train, eps, + output_mask); + auto result0 = input.CreateFrom(ir::Value(node, 0)); + auto result1 = input.CreateFrom(ir::Value(node, 1)); + auto result2 = input.CreateFrom(ir::Value(node, 2)); + return std::make_tuple(result0, result1, result2); +} + +MLIRTensor MLIRTensor::neg(const MLIRTensor &input) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue()); + return input.CreateFrom(node); +} + +std::tuple +MLIRTensor::nll_loss2d_forward(const MLIRTensor &self, const MLIRTensor &target, + const MLIRTensor &weight, int64_t reduction, + int64_t ignore_index) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), target.GetIrValue(), weight.GetIrValue(), reduction, + ignore_index); + auto result0 = self.CreateFrom(ir::Value(node, 0)); + auto result1 = self.CreateFrom(ir::Value(node, 1)); + return std::make_tuple(result0, result1); +} + +MLIRTensor MLIRTensor::nll_loss2d_backward( + const MLIRTensor &grad_output, const MLIRTensor &self, + const MLIRTensor &target, const MLIRTensor &weight, int64_t reduction, + int64_t ignore_index, const MLIRTensor &total_weight) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + grad_output.GetIrValue(), self.GetIrValue(), target.GetIrValue(), + weight.GetIrValue(), reduction, ignore_index, total_weight.GetIrValue()); + return self.CreateFrom(node); +} + +std::tuple +MLIRTensor::nll_loss_forward(const MLIRTensor &self, const MLIRTensor &target, + const MLIRTensor &weight, int64_t reduction, + int64_t ignore_index) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), target.GetIrValue(), weight.GetIrValue(), reduction, + ignore_index); + auto result0 = self.CreateFrom(ir::Value(node, 0)); + auto result1 = self.CreateFrom(ir::Value(node, 1)); + return std::make_tuple(result0, result1); +} + +MLIRTensor MLIRTensor::nll_loss_backward( + const MLIRTensor &grad_output, const MLIRTensor &self, + const MLIRTensor &target, const MLIRTensor &weight, int64_t reduction, + int64_t ignore_index, const MLIRTensor &total_weight) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + grad_output.GetIrValue(), self.GetIrValue(), target.GetIrValue(), + weight.GetIrValue(), reduction, ignore_index, total_weight.GetIrValue()); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::sum(const MLIRTensor &input, at::IntArrayRef dim, + bool keepdim, c10::optional dtype) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue(), dim, keepdim, dtype); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::relu(const MLIRTensor &input) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue()); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::relu_(MLIRTensor &input) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue()); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::size(const MLIRTensor &input, int64_t dim) { + assert(0); + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue(), dim); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::squeeze(const MLIRTensor &input, int64_t dim) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue(), dim); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::sub(const MLIRTensor &self, const MLIRTensor &other, + at::Scalar alpha) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), other.GetIrValue(), + ir::Value(std::make_shared(alpha))); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::sub_(MLIRTensor &self, const MLIRTensor &other, + at::Scalar alpha) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + self.GetIrValue(), other.GetIrValue(), + ir::Value(std::make_shared(alpha))); + return self.CreateFrom(node); +} + +MLIRTensor MLIRTensor::t(const MLIRTensor &input) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue()); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::threshold_backward(const MLIRTensor &grad_output, + const MLIRTensor &input, + at::Scalar threshold) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = std::make_shared( + grad_output.GetIrValue(), input.GetIrValue(), + ir::Value(std::make_shared(threshold))); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::to(MLIRTensor &input, c10::optional device, + c10::optional scalar_type) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + if (!device) { + device = input.GetDevice(); + } + if (!scalar_type) { + scalar_type = input.dtype(); + } + + MLIRTensor new_tensor = Create(input.ToTensor(), *device); + + if (input.dtype() != *scalar_type) { + new_tensor.SetScalarType(*scalar_type); + } + return new_tensor; +} + +MLIRTensor MLIRTensor::unsqueeze(const MLIRTensor &input, int64_t dim) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue(), dim); + return input.CreateFrom(node); +} + +MLIRTensor MLIRTensor::view(const MLIRTensor &input, at::IntArrayRef size) { + LLVM_DEBUG(llvm::dbgs() << "MLIRTensor::" << __func__ << "\n"); + std::shared_ptr node = + std::make_shared(input.GetIrValue(), size); + return input.CreateFrom(node); +} + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/tensor.h b/frontends/pytorch/csrc/tensor.h new file mode 100644 index 000000000..244c524e8 --- /dev/null +++ b/frontends/pytorch/csrc/tensor.h @@ -0,0 +1,275 @@ +//===- tensor.h -------------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "device.h" +#include "ir.h" + +#include + +#include +#include + +namespace torch_mlir { + +class MLIRTensor { + struct Data; + +public: + static MLIRTensor Create(const at::Tensor &tensor, const Device &device); + static MLIRTensor Create(ir::Value ir_value, const Device &device, + c10::optional logical_element_type); + + MLIRTensor() = default; + + bool is_null() const { return data_ptr() == nullptr; } + + void ShallowCopyTo(MLIRTensor *dest) const; + + void SetTensor(at::Tensor tensor); + void SetIrValue(ir::Value ir_value); + + at::ScalarType dtype() const; + + // Set logical_element_type which is visible to upstream PyTorch. + void SetScalarType(c10::optional logical_element_type); + + std::vector sizes() const; + std::vector strides() const; + + at::Tensor ToTensor() const; + + const Device &GetDevice() const; + + size_t generation() const { return data()->generation; } + + std::string GetMLIR() const; + + // Retrieves the IR Node representing this MLIRTensor. One will be created if + // missing. Note that although this is a const API, it actually changes the + // internal state of the object. + ir::Value GetIrValue() const; + + at::Tensor CompileAndRun() const; + + uint64_t id() const { return data()->unique_id; } + +private: + struct Data { + Data(at::Tensor tensor_data, const Device &device) + : logical_element_type(tensor_data.scalar_type()), + tensor_data(std::move(tensor_data)), device(device), + unique_id(GetNextTensorId()) {} + + Data(ir::Value ir_value, const Device &device, + c10::optional logical_element_type) + : logical_element_type(logical_element_type), + ir_value(std::move(ir_value)), device(device), + unique_id(GetNextTensorId()) {} + + ~Data(){}; + + c10::optional logical_element_type; + c10::optional tensor_data; + ir::Value ir_value; + + const Device device; + const uint64_t unique_id = 0; + size_t generation = 1; + }; + + MLIRTensor(const at::Tensor &tensor, const Device &device); + + MLIRTensor(ir::Value ir_value, const Device &device, + c10::optional logical_element_type = c10::nullopt); + + void SetTensorData(at::Tensor tensor_data); + + c10::optional CurrentTensorData() const; + + // Retrieves the current IR Node, or nullptr in case no active IR Node is + // available. + ir::Value CurrentIrValue() const; + + Data *data() const; + + std::shared_ptr data_ptr() const { return data_; } + + MLIRTensor CreateFrom(ir::Value ir_value) const; + + static uint64_t GetNextTensorId(); + + std::shared_ptr data_; + + ////////////////////////////////////////////////////////////////////////////// + // ATEN operators follows here, listed in alphabetical order. + ////////////////////////////////////////////////////////////////////////////// +public: + static MLIRTensor _adaptive_avg_pool2d(const MLIRTensor &self, + at::IntArrayRef output_size); + + static MLIRTensor _adaptive_avg_pool2d_backward(const MLIRTensor &grad_output, + const MLIRTensor &self); + + static MLIRTensor add(const MLIRTensor &input, const MLIRTensor &other, + at::Scalar alpha); + + static MLIRTensor add_(MLIRTensor &input, const MLIRTensor &other, + at::Scalar alpha); + + static MLIRTensor addmm(const MLIRTensor &input, const MLIRTensor &mat1, + const MLIRTensor &mat2, at::Scalar beta, + at::Scalar alpha); + + static MLIRTensor as_strided(const MLIRTensor &self, at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset); + + static MLIRTensor clone(const MLIRTensor &self); + + static MLIRTensor convolution(const MLIRTensor &input, + const MLIRTensor &weight, + const MLIRTensor &bias, at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, + at::IntArrayRef output_padding, int64_t groups); + + static std::tuple + convolution_backward(const MLIRTensor &grad_output, const MLIRTensor &input, + const MLIRTensor &weight, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, + int64_t groups, std::array output_mask); + + static void copy_(MLIRTensor &input, MLIRTensor &src); + + static MLIRTensor div(const MLIRTensor &self, at::Scalar other); + + static MLIRTensor div(const MLIRTensor &self, const MLIRTensor &other); + + static MLIRTensor div_(MLIRTensor &self, const MLIRTensor &other); + + static MLIRTensor expand(const MLIRTensor &self, at::IntArrayRef size, + bool implicit); + + static MLIRTensor gather(const MLIRTensor &self, int64_t dim, + const MLIRTensor &index, bool sparse_grad); + + static MLIRTensor hardtanh(const MLIRTensor &self, at::Scalar min_val, + at::Scalar max_val); + + static MLIRTensor hardtanh_(MLIRTensor &self, at::Scalar min_val, + at::Scalar max_val); + + static MLIRTensor hardtanh_backward(const MLIRTensor &grad_output, + const MLIRTensor &self, + at::Scalar min_val, at::Scalar max_val); + + static MLIRTensor _log_softmax(const MLIRTensor &input, int64_t dim, + bool half_to_float); + + static MLIRTensor _log_softmax_backward_data(const MLIRTensor &grad_output, + const MLIRTensor &output, + int64_t dim, + const MLIRTensor &self); + + static std::tuple + max_pool2d_with_indices(const MLIRTensor &input, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool ceil_mode); + + static MLIRTensor max_pool2d_with_indices_backward( + const MLIRTensor &grad_output, const MLIRTensor &self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const MLIRTensor &indices); + + static MLIRTensor mean(const MLIRTensor &input, + c10::optional dtype); + + static MLIRTensor mean(const MLIRTensor &input, at::IntArrayRef dim, + bool keepdim, c10::optional dtype); + + static MLIRTensor mm(const MLIRTensor &input, const MLIRTensor &mat1); + + static MLIRTensor mul(const MLIRTensor &self, const MLIRTensor &other); + + static MLIRTensor mul_(MLIRTensor &self, const MLIRTensor &other); + + static std::tuple + native_batch_norm(const MLIRTensor &input, const MLIRTensor &weight, + const MLIRTensor &bias, const MLIRTensor &running_mean, + const MLIRTensor &running_var, bool training, + double momentum, double eps); + + static std::tuple + native_batch_norm_backward(const MLIRTensor &grad_out, + const MLIRTensor &input, const MLIRTensor &weight, + const MLIRTensor &running_mean, + const MLIRTensor &running_var, + const MLIRTensor &save_mean, + const MLIRTensor &save_invstd, bool train, + double eps, std::array output_mask); + + static MLIRTensor neg(const MLIRTensor &input); + + static std::tuple + nll_loss2d_forward(const MLIRTensor &self, const MLIRTensor &target, + const MLIRTensor &weight, int64_t reduction, + int64_t ignore_index); + + static MLIRTensor nll_loss2d_backward(const MLIRTensor &grad_output, + const MLIRTensor &self, + const MLIRTensor &target, + const MLIRTensor &weight, + int64_t reduction, int64_t ignore_index, + const MLIRTensor &total_weight); + + static std::tuple + nll_loss_forward(const MLIRTensor &self, const MLIRTensor &target, + const MLIRTensor &weight, int64_t reduction, + int64_t ignore_index); + + static MLIRTensor nll_loss_backward(const MLIRTensor &grad_output, + const MLIRTensor &self, + const MLIRTensor &target, + const MLIRTensor &weight, + int64_t reduction, int64_t ignore_index, + const MLIRTensor &total_weight); + + static MLIRTensor size(const MLIRTensor &self, int64_t dim); + + static MLIRTensor squeeze(const MLIRTensor &self, int64_t dim); + + static MLIRTensor sub(const MLIRTensor &input, const MLIRTensor &other, + at::Scalar alpha); + + static MLIRTensor sub_(MLIRTensor &input, const MLIRTensor &other, + at::Scalar alpha); + + static MLIRTensor sum(const MLIRTensor &self, at::IntArrayRef dim, + bool keepdim, c10::optional dtype); + + static MLIRTensor relu(const MLIRTensor &input); + + static MLIRTensor relu_(MLIRTensor &input); + + static MLIRTensor t(const MLIRTensor &input); + + static MLIRTensor threshold_backward(const MLIRTensor &grad_output, + const MLIRTensor &self, + at::Scalar threshold); + + static MLIRTensor to(MLIRTensor &input, c10::optional device, + c10::optional scalar_type); + + static MLIRTensor unsqueeze(const MLIRTensor &self, int64_t dim); + + static MLIRTensor view(const MLIRTensor &input, at::IntArrayRef size); +}; +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/tensor_impl.cpp b/frontends/pytorch/csrc/tensor_impl.cpp new file mode 100644 index 000000000..aca4d2343 --- /dev/null +++ b/frontends/pytorch/csrc/tensor_impl.cpp @@ -0,0 +1,156 @@ +//===- tensor_impl.cpp ------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#include "tensor_impl.h" +#include "aten_mlir_bridge.h" + +#include +#include + +namespace torch_mlir { +namespace { + +thread_local c10::Device g_current_device(at::DeviceType::XLA, 0); + +struct MLIRGuardImpl : public c10::impl::DeviceGuardImplInterface { + at::DeviceType type() const override { return at::DeviceType::XLA; } + + c10::Device exchangeDevice(c10::Device device) const override { + std::swap(g_current_device, device); + return device; + } + + c10::Device getDevice() const override { return g_current_device; } + + void setDevice(c10::Device device) const override { + g_current_device = device; + } + + void uncheckedSetDevice(c10::Device device) const noexcept override { + g_current_device = device; + } + + c10::Stream getStream(c10::Device device) const noexcept override { + return c10::Stream(c10::Stream::DEFAULT, device); + } + + c10::Stream exchangeStream(c10::Stream s) const noexcept override { + return c10::Stream(c10::Stream::DEFAULT, g_current_device); + } + + c10::DeviceIndex deviceCount() const noexcept override { return 0; } +}; + +C10_REGISTER_GUARD_IMPL(XLA, MLIRGuardImpl); + +} // namespace + +MLIRTensorImpl::MLIRTensorImpl(MLIRTensor tensor) + : c10::TensorImpl(c10::XLATensorId(), GetTypeMeta(tensor), + bridge::MLIRDeviceToAtenDevice(tensor.GetDevice())), + tensor_(std::move(tensor)) {} + +c10::intrusive_ptr MLIRTensorImpl::shallow_copy_and_detach( + const c10::VariableVersion &version_counter, + bool allow_tensor_metadata_change) const { + // std::cout << "MLIRTensorImpl::" << __func__ << std::endl; + auto impl = c10::make_intrusive(tensor_); + copy_tensor_metadata( + /*src_impl=*/this, + /*dest_impl=*/impl.get(), + /*version_counter=*/version_counter, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + return impl; +} + +void MLIRTensorImpl::shallow_copy_from( + const c10::intrusive_ptr &impl) { + // std::cout << "MLIRTensorImpl::" << __func__ << std::endl; + MLIRTensorImpl *tensor_impl = dynamic_cast(impl.get()); + copy_tensor_metadata( + /*src_impl=*/tensor_impl, + /*dest_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + tensor_impl->tensor_.ShallowCopyTo(&tensor_); + generation_ = 0; +} + +at::IntArrayRef MLIRTensorImpl::sizes() const { + const_cast(this)->SetupSizeProperties(); + return c10::TensorImpl::sizes(); +} + +at::IntArrayRef MLIRTensorImpl::strides() const { + const_cast(this)->SetupSizeProperties(); + return c10::TensorImpl::strides(); +} + +int64_t MLIRTensorImpl::dim() const { + const_cast(this)->SetupSizeProperties(); + return c10::TensorImpl::dim(); +} + +int64_t MLIRTensorImpl::numel() const { + const_cast(this)->SetupSizeProperties(); + return c10::TensorImpl::numel(); +} + +bool MLIRTensorImpl::is_contiguous(at::MemoryFormat memory_format) const { + // Only check that the storage is already contiguous. + assert(is_contiguous_ && "Non-contiguous storage for MLIR tensor"); + return true; +} + +int64_t MLIRTensorImpl::size(int64_t d) const { + const_cast(this)->SetupSizeProperties(); + return c10::TensorImpl::size(d); +} + +void MLIRTensorImpl::SetupSizeProperties() { + size_t generation = tensor_.generation(); + if (generation != generation_) { + // Fill up the basic dimension data members which the base class + // implementation uses in its APIs. + auto sizes = tensor_.sizes(); + auto strides = tensor_.strides(); + + strides_.clear(); + sizes_.clear(); + numel_ = 1; + + for (auto t : llvm::zip(sizes, strides)) { + auto size = std::get<0>(t); + sizes_.push_back(size); + strides_.push_back(std::get<1>(t)); + numel_ *= size; + } + + generation_ = generation; + } +} + +caffe2::TypeMeta MLIRTensorImpl::GetTypeMeta(const MLIRTensor &tensor) { + return c10::scalarTypeToTypeMeta(tensor.dtype()); +} + +c10::Device MLIRTensorImpl::GetCurrentAtenDevice() { return g_current_device; } + +c10::Device MLIRTensorImpl::SetCurrentAtenDevice(c10::Device device) { + std::swap(g_current_device, device); + return device; +} + +void MLIRTensorImpl::AtenInitialize() {} + +const at::Storage &MLIRTensorImpl::storage() const { + assert(0 && "MLIR tensors do not have storage"); +} + +bool MLIRTensorImpl::has_storage() const { return false; } + +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/tensor_impl.h b/frontends/pytorch/csrc/tensor_impl.h new file mode 100644 index 000000000..bed3985f6 --- /dev/null +++ b/frontends/pytorch/csrc/tensor_impl.h @@ -0,0 +1,60 @@ +//===- tensor_impl.h --------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "tensor.h" + +#include +#include +#include + +namespace torch_mlir { + +class MLIRTensorImpl : public c10::TensorImpl { +public: + explicit MLIRTensorImpl(MLIRTensor tensor); + + MLIRTensor &tensor() { return tensor_; } + + c10::intrusive_ptr + shallow_copy_and_detach(const c10::VariableVersion &version_counter, + bool allow_tensor_metadata_change) const override; + + void shallow_copy_from(const c10::intrusive_ptr &impl) override; + + at::IntArrayRef sizes() const override; + + at::IntArrayRef strides() const override; + + int64_t dim() const override; + + int64_t numel() const override; + + bool is_contiguous(at::MemoryFormat memory_format) const override; + + int64_t size(int64_t d) const override; + + static c10::Device GetCurrentAtenDevice(); + + static c10::Device SetCurrentAtenDevice(c10::Device device); + + static void AtenInitialize(); + + const at::Storage &storage() const override; + + bool has_storage() const override; + +private: + static caffe2::TypeMeta GetTypeMeta(const MLIRTensor &tensor); + + void SetupSizeProperties(); + + MLIRTensor tensor_; + size_t generation_ = 0; +}; +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/torch_util.cpp b/frontends/pytorch/csrc/torch_util.cpp new file mode 100644 index 000000000..a11b607ef --- /dev/null +++ b/frontends/pytorch/csrc/torch_util.cpp @@ -0,0 +1,44 @@ +//===- torch_util.cpp -------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#include "torch_util.h" + +#include +#include + +namespace torch_mlir { +namespace util { + +at::Tensor Zeros(at::IntArrayRef sizes, at::ScalarType type) { + return at::zeros(sizes, type); +} + +at::Tensor CopyTensor(const at::Tensor &ref) { + return ref.to(ref.options(), /*non_blocking=*/false, /*copy=*/true); +} + +// Same as above, with an additional cast. +at::Tensor CopyTensor(const at::Tensor &ref, at::ScalarType dest_type) { + return ref.to(ref.options().dtype(dest_type), /*non_blocking=*/false, + /*copy=*/true); +} + +at::ScalarType GetScalarType(at::Scalar scalar) { + if (scalar.isFloatingPoint()) { + return at::kDouble; + } else if (scalar.isIntegral(/*includeBool=*/false)) { + return at::kLong; + } else if (scalar.isBoolean()) { + return at::kBool; + } else if (scalar.isComplex()) { + return at::kComplexDouble; + } + assert(0 && "Unknown type for scalar"); +} + +} // namespace util +} // namespace torch_mlir diff --git a/frontends/pytorch/csrc/torch_util.h b/frontends/pytorch/csrc/torch_util.h new file mode 100644 index 000000000..7696603e3 --- /dev/null +++ b/frontends/pytorch/csrc/torch_util.h @@ -0,0 +1,34 @@ +//===- torch_util.h ---------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +namespace torch_mlir { +namespace util { + +at::Tensor Zeros(at::IntArrayRef sizes, at::ScalarType type); + +// Makes a deep copy of an ATEN tensor. +at::Tensor CopyTensor(const at::Tensor &ref); + +// Same as above, with an additional cast. +at::Tensor CopyTensor(const at::Tensor &ref, at::ScalarType dest_type); + +// Return at::ScalarType from at::Scalar +at::ScalarType GetScalarType(at::Scalar scalar); + +template +T OptionalOr(const c10::optional &value, T defval) { + return value ? static_cast(*value) : defval; +} + +} // namespace util +} // namespace torch_mlir diff --git a/frontends/pytorch/lib/CMakeLists.txt b/frontends/pytorch/lib/CMakeLists.txt new file mode 100644 index 000000000..a105ffac1 --- /dev/null +++ b/frontends/pytorch/lib/CMakeLists.txt @@ -0,0 +1,10 @@ +include_directories( + ${TORCH_INCLUDE_DIRS} + ) +add_library(aten_ops SHARED + aten_ops.cpp + ) + +target_link_libraries(aten_ops + ${TORCH_LIBRARIES} + ) diff --git a/frontends/pytorch/lib/aten_ops.cpp b/frontends/pytorch/lib/aten_ops.cpp new file mode 100644 index 000000000..571e64159 --- /dev/null +++ b/frontends/pytorch/lib/aten_ops.cpp @@ -0,0 +1,772 @@ +//===- aten_ops.cpp ---------------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +// This file implements C libraries that are targetted by MLIR code generation +// from the ATen dialect. This library is intended to support a functional +// proof of concept rather than optimized for high performance. Most of the +// functions are implemented by calling back into the torch libraries. + +#include +#include +#include +#include +#include + +#include +#include + +#include "nnpack.h" +#include + +namespace { + +template struct tensor_t { + T *d; + T *aligned; + size_t offset; + size_t shape[N]; + size_t stride[N]; + + size_t index(size_t n, size_t channel, size_t row, size_t col) const { + size_t channels = shape[1]; + size_t height = shape[2]; + size_t width = shape[3]; + return n * height * width * channels + channel * height * width + + row * width + col; + } + + tensor_t() { + d = aligned = nullptr; + offset = 0; + for (int i = 0; i < N; i++) + shape[i] = stride[i] = 0; + } +}; + +template +std::vector translate_shape(tensor_t *t) { + std::vector shape; + for (int i = 0; i < N; i++) { + shape.push_back(t->shape[i]); + // std::cout << i << " shape " << t->shape[i] << std::endl; + } + return shape; +} + +template +std::vector translate_stride(tensor_t *t) { + std::vector stride; + for (int i = 0; i < N; i++) { + stride.push_back(t->stride[i]); + // std::cout << i << " stride " << t->stride[i] << std::endl; + } + return stride; +} + +template void dumpTensor(std::ostream &o, tensor_t *t) { + o << "Shape:"; + for (int i = 0; i < N; i++) + o << t->shape[i] << " "; + o << "Stride:"; + for (int i = 0; i < N; i++) + o << t->stride[i] << " "; + o << "\n"; +} + +template +at::Tensor to_torch(tensor_t *t, + const at::TensorOptions &options = at::TensorOptions()) { + // std::cout << "to_torch\n"; + return torch::from_blob((void *)t->d, translate_shape(t), translate_stride(t), + options); +} + +template +void mm_out(tensor_t *a, tensor_t *b, tensor_t *r); + +template +void add_out(tensor_t *a, tensor_t *b, T alpha, tensor_t *r) { + at::Tensor torch_a = to_torch(a); + at::Tensor torch_b = to_torch(b); + at::Tensor result = at::native::add(torch_a, torch_b, alpha).clone(); + + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +template +void addmm_out(tensor_t *a, tensor_t *b, tensor_t *c, + int32_t alpha, int32_t beta, tensor_t *r) { + at::Tensor torch_a = to_torch(a); + at::Tensor torch_b = to_torch(b); + at::Tensor torch_c = to_torch(c); + at::Tensor result = + at::native::addmm(torch_a, torch_b, torch_c, alpha, beta).clone(); + + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +template +void as_strided_out(tensor_t *a, + /*size*/ int32_t sz0, int32_t sz1, int32_t sz2, int32_t sz3, + /*stride*/ int32_t sd0, int32_t sd1, int32_t sd2, + int32_t sd3, int32_t offset, tensor_t *r) { + at::Tensor input = to_torch(a); + + std::vector size; + std::vector stride; + c10::optional storage_offset; + + if (offset != 0) + storage_offset = offset; + if (N > 0) { + size.push_back(sz0); + stride.push_back(sd0); + } + if (N > 1) { + size.push_back(sz1); + stride.push_back(sd1); + } + if (N > 2) { + size.push_back(sz2); + stride.push_back(sd2); + } + if (N > 3) { + size.push_back(sz3); + stride.push_back(sd3); + } + + std::vector sizeRef{size}; + std::vector strideRef{stride}; + + // for (int i = 0; id, result.data_ptr(), result.numel() * sizeof(T)); +} + +// FIXME: stride, padding, dilaection, output_padding should be IntArrayRef +template +void conv2d_out(tensor_t *t, tensor_t *weight, tensor_t *bias, + int32_t stride, int32_t pad, int32_t dilation, + tensor_t *r) { + at::Tensor torch_t = to_torch(t); + at::Tensor torch_w = to_torch(weight); + at::Tensor torch_b = to_torch(bias); + int64_t groups = 1; + + at::Tensor result = at::native::conv2d(torch_t, torch_w, torch_b, stride, pad, + dilation, groups) + .clone(); + + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +template +void conv2d_backward_out(tensor_t *grad_output, tensor_t *input, + tensor_t *weight, int32_t stride, int32_t pad, + int32_t dilation, tensor_t *r0, + tensor_t *r1, tensor_t *r2) { + const at::Tensor &arg_grad = to_torch(grad_output); + const at::Tensor &arg_input = to_torch(input); + const at::Tensor &arg_weight = to_torch(weight); + + std::vector p{pad, pad}; + std::vector s{stride, stride}; + std::vector d{dilation, dilation}; + + std::array output_mask{true, true, true}; + + std::tuple grads = + at::native::mkldnn_convolution_backward(arg_input, arg_grad, arg_weight, + p, s, d, 1, output_mask); + + auto result0 = std::get<0>(grads); + auto result1 = std::get<1>(grads); + auto result2 = std::get<2>(grads); + + memcpy(r0->d, result0.data_ptr(), result0.numel() * sizeof(T)); + memcpy(r1->d, result1.data_ptr(), result1.numel() * sizeof(T)); + memcpy(r2->d, result2.data_ptr(), result2.numel() * sizeof(T)); +} + +template +void log_softmax_out(tensor_t *t, int32_t dim, bool half_to_float, + tensor_t *r) { + at::Tensor input = to_torch(t); + at::Tensor result = at::native::log_softmax_cpu(input, dim, half_to_float); + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +template +void log_softmax_backward_data_out(tensor_t *a, tensor_t *b, + int32_t c, tensor_t *d, + tensor_t *r) { + at::Tensor inputA = to_torch(a); + at::Tensor inputB = to_torch(b); + at::Tensor inputD = to_torch(d); + + at::Tensor result = + at::native::log_softmax_backward_cpu(inputA, inputB, c, inputD); + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +template +void max_pool2d_with_indices_out(tensor_t *t, int32_t c, int32_t d, + int32_t e, int32_t f, bool ceil_mode, + tensor_t *r0, tensor_t *r1) { + at::Tensor input = to_torch(t); + + std::vector kernel{c, c}; + std::vector stride{d, d}; + std::vector padding{e, e}; + std::vector dilation{f, f}; + + auto result = at::native::max_pool2d_with_indices_cpu( + input, kernel, stride, padding, dilation, ceil_mode); + at::Tensor outTensor = std::get<0>(result); + at::Tensor idxTensor = std::get<1>(result); + memcpy(r0->d, outTensor.data_ptr(), outTensor.numel() * sizeof(T)); + memcpy(r1->d, idxTensor.data_ptr(), idxTensor.numel() * sizeof(T)); +} + +template +void max_pool2d_with_indices_backward_out(tensor_t *a, tensor_t *b, + int32_t c, int32_t d, int32_t e, + int32_t f, bool g, + tensor_t *h, + tensor_t *r) { + const at::Tensor &inputA = to_torch(a); + const at::Tensor &inputB = to_torch(b); + at::TensorOptions options(at::ScalarType::Long); + const at::Tensor &inputH = to_torch(h, options); + + std::vector kernel{c, c}; + std::vector stride{d, d}; + std::vector padding{e, e}; + std::vector dilation{f, f}; + + at::Tensor result = at::native::max_pool2d_with_indices_backward_cpu( + inputA, inputB, kernel, stride, padding, dilation, g, inputH); + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +template +void mm_out(tensor_t *a, tensor_t *b, tensor_t *r) { + at::Tensor inputA = to_torch(a); + at::Tensor inputB = to_torch(b); + + at::Tensor result = inputA.matmul(inputB); + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +template +void mul_out(tensor_t *a, tensor_t *b, tensor_t *r) { + at::Tensor inputA = to_torch(a); + at::Tensor inputB = to_torch(b); + + at::Tensor result = at::native::mul(inputA, inputB); + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +template +void relu_out(tensor_t *a, tensor_t *r) { + at::Tensor inputA = to_torch(a); + + at::Tensor result = at::native::relu(inputA); + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +template void t_out(tensor_t *a, tensor_t *r) { + size_t h = a->shape[0]; + size_t w = a->shape[1]; + + for (size_t i = 0; i < h; i++) + for (size_t j = 0; j < w; j++) + r->d[j * h + i] = a->d[i * w + j]; +} + +template +void threshold_backward_out(tensor_t *a, tensor_t *b, int32_t c, + tensor_t *r) { + at::Tensor inputA = to_torch(a); + at::Tensor inputB = to_torch(b); + + at::Tensor result = at::native::threshold_backward(inputA, inputB, c); + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +template +void view_out(tensor_t *a, int32_t b, int32_t c, int32_t d, int32_t e, + tensor_t *r) { + tensor_t result; + size_t numel = 1; + for (size_t d = 0; d < M; d++) + numel *= a->shape[d]; + + if (N == 1) + c = d = e = 1; + if (N == 2) + d = e = 1; + if (N == 3) + e = 1; + + int inferred = 0; + if (b == -1) + inferred++; + if (c == -1) + inferred++; + if (d == -1) + inferred++; + if (e == -1) + inferred++; + assert(inferred <= 1 && + "aten.view Error: only one dimension can be inferred"); + + if (b == -1) + b = numel / (c * d * e); + if (c == -1) + c = numel / (b * d * e); + if (d == -1) + d = numel / (b * c * e); + if (e == -1) + e = numel / (b * c * d); + + if (N > 0) + r->shape[0] = b; + if (N > 1) + r->shape[1] = c; + if (N > 2) + r->shape[2] = d; + if (N > 3) + r->shape[3] = e; + + memcpy(r->d, a->d, numel * sizeof(T)); +} + +} // namespace + +extern "C" { + +// add_out + +void _mlir_ciface_add_1F32_1F32_1F32_out(tensor_t *a, + tensor_t *b, int32_t i, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + add_out(a, b, i, r); +} + +void _mlir_ciface_add_2F32_2F32_2F32_out(tensor_t *a, + tensor_t *b, int32_t i, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + add_out(a, b, i, r); +} + +void _mlir_ciface_add_3F32_3F32_3F32_out(tensor_t *a, + tensor_t *b, int32_t i, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + add_out(a, b, i, r); +} + +void _mlir_ciface_add_4F32_4F32_4F32_out(tensor_t *a, + tensor_t *b, int32_t i, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + add_out(a, b, i, r); +} + +// addmm_out + +void _mlir_ciface_addmm_2F32_1F32_2F32_2F32_out(tensor_t *a, + tensor_t *b, + tensor_t *c, + int32_t alpha, int32_t beta, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + addmm_out(a, b, c, alpha, beta, r); +} + +// as_strided_out + +void _mlir_ciface_as_strided_1F32_1F32_out(tensor_t *a, + /*size*/ int32_t sz0, int32_t sz1, + int32_t sz2, int32_t sz3, + /*stride*/ int32_t sd0, int32_t sd1, + int32_t sd2, int32_t sd3, + int32_t offset, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + as_strided_out(a, sz0, sz1, sz2, sz3, sd0, sd1, sd2, sd3, offset, + r); +} + +void _mlir_ciface_as_strided_4F32_2F32_out(tensor_t *a, + /*size*/ int32_t sz0, int32_t sz1, + int32_t sz2, int32_t sz3, + /*stride*/ int32_t sd0, int32_t sd1, + int32_t sd2, int32_t sd3, + int32_t offset, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + // std::cout << sz0 << " " + // << sz1 << " " + // << sz2 << " " + // << sz3 << "\n"; + // std::cout << sd0 << " " + // << sd1 << " " + // << sd2 << " " + // << sd3 << "\n"; + as_strided_out(a, sz0, sz1, sz2, sz3, sd0, sd1, sd2, sd3, offset, + r); +} + +// conv2d_out + +void _mlir_ciface_conv2d_4F32_4F32_4F32_1F32_out( + tensor_t *t, tensor_t *weight, tensor_t *bias, + int32_t stride, int32_t padding, int32_t dilation, tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + conv2d_out(t, weight, bias, stride, padding, dilation, r); +} + +void _mlir_ciface_conv2d_relu_4F32_4F32_4F32_1F32_out( + tensor_t *t, tensor_t *weight, tensor_t *bias, + int32_t stride, int32_t padding, int32_t dilation, tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + conv2d_out(t, weight, bias, stride, padding, dilation, r); + relu_out(r, r); +} + +// conv2d_backward_out + +void _mlir_ciface_conv2d_backward_4F32_4F32_1F32_4F32_4F32_4F32_out( + tensor_t *grad_output, tensor_t *t, + tensor_t *weight, int32_t stride, int32_t padding, + int32_t dilation, tensor_t *r0, tensor_t *r1, + tensor_t *r2) { + // std::cout << "aten_ops " << __func__ << "\n"; + conv2d_backward_out(grad_output, t, weight, stride, padding, dilation, + r0, r1, r2); +} + +// div +float *div_0F32_0F32_0F32(float *a, float *b) { + // std::cout << "aten_ops " << __func__ << "\n"; + float *ret = (float *)malloc(sizeof(float)); + *ret = *a / *b; + return ret; +} + +// log_softmax_out + +void _mlir_ciface_log_softmax_1F32_1F32_out(tensor_t *t, int32_t dim, + bool half_to_float, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + log_softmax_out(t, dim, half_to_float, r); +} +void _mlir_ciface_log_softmax_2F32_2F32_out(tensor_t *t, int32_t dim, + bool half_to_float, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + log_softmax_out(t, dim, half_to_float, r); +} +void _mlir_ciface_log_softmax_3F32_3F32_out(tensor_t *t, int32_t dim, + bool half_to_float, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + log_softmax_out(t, dim, half_to_float, r); +} +void _mlir_ciface_log_softmax_4F32_4F32_out(tensor_t *t, int32_t dim, + bool half_to_float, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + log_softmax_out(t, dim, half_to_float, r); +} + +// log_softmax_backward_data_out + +void _mlir_ciface_log_softmax_backward_data_2F32_2F32_2F32_2F32_out( + tensor_t *a, tensor_t *b, int32_t c, + tensor_t *d, tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + log_softmax_backward_data_out(a, b, c, d, r); +} + +void _mlir_ciface_log_softmax_backward_data_4F32_4F32_4F32_4F32_out( + tensor_t *a, tensor_t *b, int32_t c, + tensor_t *d, tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + log_softmax_backward_data_out(a, b, c, d, r); +} + +// max_pool2d_out + +void _mlir_ciface_max_pool2d_with_indices_4F32_4I64_4F32_out( + tensor_t *t, int32_t kernel, int32_t pad, int32_t stride, + int32_t dilation, bool ceil_mode, tensor_t *r0, + tensor_t *r1) { + // std::cout << "aten_ops " << __func__ << "\n"; + max_pool2d_with_indices_out(t, kernel, pad, stride, dilation, + ceil_mode, r0, r1); +} + +// max_pool2d backward_out + +void _mlir_ciface_max_pool2d_with_indices_backward_4F32_4F32_4F32_4I64_out( + tensor_t *a, tensor_t *b, int32_t c, int32_t d, + int32_t e, int32_t f, bool g, tensor_t *h, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + max_pool2d_with_indices_backward_out(a, b, c, d, e, f, g, h, r); +} + +// mm_out + +void _mlir_ciface_mm_2F32_2F32_2F32_out(tensor_t *a, + tensor_t *b, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + mm_out(a, b, r); +} + +// mul_out + +void _mlir_ciface_mul_1F32_1F32_1F32_out(tensor_t *a, + tensor_t *b, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + mul_out(a, b, r); +} + +void _mlir_ciface_mul_2F32_2F32_2F32_out(tensor_t *a, + tensor_t *b, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + mul_out(a, b, r); +} + +void _mlir_ciface_mul_3F32_3F32_3F32_out(tensor_t *a, + tensor_t *b, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + mul_out(a, b, r); +} + +void _mlir_ciface_mul_4F32_4F32_4F32_out(tensor_t *a, + tensor_t *b, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + mul_out(a, b, r); +} + +// nll_loss2d_forward_out + +void _mlir_ciface_nll_loss2d_forward_1F32_1F32_4F32_3I64_1F32_out( + tensor_t *a, tensor_t *b, tensor_t *c, + int64_t d, int64_t e, tensor_t *r0, tensor_t *r1) { + // std::cout << "aten_ops " << __func__ << "\n"; + using T = float; + at::Tensor inputA = to_torch(a); + at::TensorOptions options(at::ScalarType::Long); + at::Tensor inputB = to_torch(b, options); + at::Tensor inputC = to_torch(c); + + std::tuple result = + at::CPUType::nll_loss2d_forward(inputA, inputB, inputC, d, e); + + at::Tensor result0 = std::get<0>(result); + at::Tensor result1 = std::get<1>(result); + memcpy(r0->d, result0.data_ptr(), result0.numel() * sizeof(T)); + memcpy(r1->d, result1.data_ptr(), result1.numel() * sizeof(T)); +} + +// nll_loss2d_backward_out + +void _mlir_ciface_nll_loss2d_backward_4F32_1F32_4F32_3I64_1F32_1F32_out( + tensor_t *a, tensor_t *b, tensor_t *c, + tensor_t *d, int32_t e, int32_t f, tensor_t *g, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + using T = float; + at::Tensor inputA = to_torch(a); + at::Tensor inputB = to_torch(b); + at::TensorOptions options(at::ScalarType::Long); + at::Tensor inputC = to_torch(c, options); + at::Tensor inputD = to_torch(d); + at::Tensor inputG = to_torch(g); + + at::Tensor result = at::CPUType::nll_loss2d_backward(inputA, inputB, inputC, + inputD, e, f, inputG); + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +void _mlir_ciface_nll_loss_backward_2F32_1F32_2F32_1I64_1F32_1F32_out( + tensor_t *a, tensor_t *b, tensor_t *c, + tensor_t *d, int32_t e, int32_t f, tensor_t *g, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + using T = float; + at::Tensor inputA = to_torch(a); + at::Tensor inputB = to_torch(b); + at::TensorOptions options(at::ScalarType::Long); + at::Tensor inputC = to_torch(c, options); + at::Tensor inputD = to_torch(d); + at::Tensor inputG = to_torch(g); + + at::Tensor result = at::CPUType::nll_loss_backward(inputA, inputB, inputC, + inputD, e, f, inputG); + + memcpy(r->d, result.data_ptr(), result.numel() * sizeof(T)); +} + +// nll_loss_forward_out + +void _mlir_ciface_nll_loss_forward_1F32_1F32_2F32_1I64_1F32_out( + tensor_t *a, tensor_t *b, tensor_t *c, + int64_t d, int64_t e, tensor_t *r0, tensor_t *r1) { + // std::cout << "aten_ops " << __func__ << "\n"; + using T = float; + at::Tensor inputA = to_torch(a); + at::TensorOptions options(at::ScalarType::Long); + at::Tensor inputB = to_torch(b, options); + at::Tensor inputC = to_torch(c); + + std::tuple result = + at::CPUType::nll_loss_forward(inputA, inputB, inputC, d, e); + + at::Tensor result0 = std::get<0>(result); + at::Tensor result1 = std::get<1>(result); + + memcpy(r0->d, result0.data_ptr(), result0.numel() * sizeof(T)); + memcpy(r1->d, result1.data_ptr(), result1.numel() * sizeof(T)); +} + +// relu_out + +void _mlir_ciface_relu_1F32_1F32_out(tensor_t *a, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + relu_out(a, r); +} + +void _mlir_ciface_relu_2F32_2F32_out(tensor_t *a, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + relu_out(a, r); +} + +void _mlir_ciface_relu_3F32_3F32_out(tensor_t *a, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + relu_out(a, r); +} + +void _mlir_ciface_relu_4F32_4F32_out(tensor_t *a, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + relu_out(a, r); +} + +// t_out + +void _mlir_ciface_t_2F32_2F32_out(tensor_t *a, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + t_out(a, r); +} + +// threshold_backward_out + +void _mlir_ciface_threshold_backward_1F32_1F32_1F32_out(tensor_t *a, + tensor_t *b, + int32_t c, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + threshold_backward_out(a, b, c, r); +} + +void _mlir_ciface_threshold_backward_2F32_2F32_2F32_out(tensor_t *a, + tensor_t *b, + int32_t c, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + threshold_backward_out(a, b, c, r); +} + +void _mlir_ciface_threshold_backward_3F32_3F32_3F32_out(tensor_t *a, + tensor_t *b, + int32_t c, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + threshold_backward_out(a, b, c, r); +} + +void _mlir_ciface_threshold_backward_4F32_4F32_4F32_out(tensor_t *a, + tensor_t *b, + int32_t c, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + threshold_backward_out(a, b, c, r); +} + +// view_out + +void _mlir_ciface_view_1F32_4F32_out(tensor_t *a, int32_t b, + int32_t c, int32_t d, int32_t e, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + view_out(a, b, c, d, e, r); +} + +void _mlir_ciface_view_1F32_3F32_out(tensor_t *a, int32_t b, + int32_t c, int32_t d, int32_t e, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + view_out(a, b, c, d, e, r); +} + +void _mlir_ciface_view_1F32_2F32_out(tensor_t *a, int32_t b, + int32_t c, int32_t d, int32_t e, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + view_out(a, b, c, d, e, r); +} + +void _mlir_ciface_view_2F32_4F32_out(tensor_t *a, int32_t b, + int32_t c, int32_t d, int32_t e, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + view_out(a, b, c, d, e, r); +} + +void _mlir_ciface_view_4F32_1F32_out(tensor_t *a, int32_t b, + int32_t c, int32_t d, int32_t e, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + view_out(a, b, c, d, e, r); +} + +void _mlir_ciface_view_4F32_2F32_out(tensor_t *a, int32_t b, + int32_t c, int32_t d, int32_t e, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + view_out(a, b, c, d, e, r); +} + +void _mlir_ciface_view_4F32_3F32_out(tensor_t *a, int32_t b, + int32_t c, int32_t d, int32_t e, + tensor_t *r) { + // std::cout << "aten_ops " << __func__ << "\n"; + view_out(a, b, c, d, e, r); +} +} diff --git a/frontends/pytorch/test/CMakeLists.txt b/frontends/pytorch/test/CMakeLists.txt new file mode 100644 index 000000000..62b21f96f --- /dev/null +++ b/frontends/pytorch/test/CMakeLists.txt @@ -0,0 +1,21 @@ +configure_lit_site_cfg( + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py + MAIN_CONFIG + ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py +) + +set(TEST_DEPENDS + FileCheck count not + _torch_mlir + aten_ops + ) + +add_lit_testsuite(check-frontends-pytorch "Running the frontends-pytorch regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${TEST_DEPENDS} + ) +set_target_properties(check-frontends-pytorch PROPERTIES FOLDER "Tests") + +add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TEST_DEPENDS}) +add_dependencies(check-all check-frontends-pytorch) diff --git a/frontends/pytorch/test/lit.cfg.py b/frontends/pytorch/test/lit.cfg.py new file mode 100644 index 000000000..c6316c10f --- /dev/null +++ b/frontends/pytorch/test/lit.cfg.py @@ -0,0 +1,71 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import os +import platform +import re +import subprocess +import tempfile + +import lit.formats +import lit.util + +from lit.llvm import llvm_config +from lit.llvm.subst import ToolSubst +from lit.llvm.subst import FindTool + +# Configuration file for the 'lit' test runner. + +# name: The name of this test suite. +config.name = 'FRONTENDS_PYTORCH' + +config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) +config.environment['PYTHONPATH'] = "{}:{}".format( + os.path.join(config.npcomp_obj_root, "python"), + # path to our python hooks + os.path.join(config.npcomp_obj_root, "frontends", "pytorch", "csrc")) + +if 'TEST_SRC_PATH' in os.environ: + config.environment['TEST_SRC_PATH'] = os.environ['TEST_SRC_PATH'] + +# path to our python operation library +config.environment['TEST_BUILD_PATH'] = os.path.join(config.npcomp_obj_root) + +# suffixes: A list of file extensions to treat as test files. +config.suffixes = ['.py'] + +# test_source_root: The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.npcomp_obj_root, 'test') + +config.substitutions.append(('%PATH%', config.environment['PATH'])) +config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) + +llvm_config.with_system_environment( + ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) + +llvm_config.use_default_substitutions() + +# excludes: A list of directories to exclude from the testsuite. The 'Inputs' +# subdirectories contain auxiliary inputs for various tests in their parent +# directories. +config.excludes = ['lit.cfg.py', 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt'] + +# test_source_root: The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.npcomp_obj_root, 'test') +config.npcomp_tools_dir = os.path.join(config.npcomp_obj_root, 'bin') + +# Tweak the PATH to include the tools dir. +llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) + +tool_dirs = [config.npcomp_tools_dir, config.llvm_tools_dir] +tools = [ +] + +llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/frontends/pytorch/test/lit.site.cfg.py.in b/frontends/pytorch/test/lit.site.cfg.py.in new file mode 100644 index 000000000..646c17208 --- /dev/null +++ b/frontends/pytorch/test/lit.site.cfg.py.in @@ -0,0 +1,53 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +@LIT_SITE_CFG_IN_HEADER@ + +import sys + +config.host_triple = "@LLVM_HOST_TRIPLE@" +config.target_triple = "@TARGET_TRIPLE@" +config.llvm_src_root = "@LLVM_SOURCE_DIR@" +config.llvm_obj_root = "@LLVM_BINARY_DIR@" +config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" +config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@" +config.llvm_shlib_dir = "@SHLIBDIR@" +config.llvm_shlib_ext = "@SHLIBEXT@" +config.llvm_exe_ext = "@EXEEXT@" +config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" +config.python_executable = "@PYTHON_EXECUTABLE@" +config.gold_executable = "@GOLD_EXECUTABLE@" +config.ld64_executable = "@LD64_EXECUTABLE@" +config.enable_shared = @ENABLE_SHARED@ +config.enable_assertions = @ENABLE_ASSERTIONS@ +config.targets_to_build = "@TARGETS_TO_BUILD@" +config.native_target = "@LLVM_NATIVE_ARCH@" +config.llvm_bindings = "@LLVM_BINDINGS@".split(' ') +config.host_os = "@HOST_OS@" +config.host_cc = "@HOST_CC@" +config.host_cxx = "@HOST_CXX@" +# Note: ldflags can contain double-quoted paths, so must use single quotes here. +config.host_ldflags = '@HOST_LDFLAGS@' +config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" +config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' +config.host_arch = "@HOST_ARCH@" +config.npcomp_src_root = "@CMAKE_SOURCE_DIR@" +config.npcomp_obj_root = "@CMAKE_BINARY_DIR@" + +# Support substitution of the tools_dir with user parameters. This is +# used when we can't determine the tool dir at configuration time. +try: + config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params + config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params +except KeyError: + e = sys.exc_info()[1] + key, = e.args + lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key)) + + +import lit.llvm +lit.llvm.initialize(lit_config, config) + +# Let the main config do the real work. +lit_config.load_config(config, "@CMAKE_SOURCE_DIR@/frontends/pytorch/test/lit.cfg.py") diff --git a/frontends/pytorch/test/test_export_ResA.py b/frontends/pytorch/test/test_export_ResA.py new file mode 100644 index 000000000..1f24f1565 --- /dev/null +++ b/frontends/pytorch/test/test_export_ResA.py @@ -0,0 +1,78 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import unittest +from unittest import TestCase + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import npcomp.frontends.pytorch as torch_mlir + +import inspect + +# RUN: python %s | FileCheck %s + +class ResA(nn.Module): + def __init__(self, channels): + C = int(channels) + C2 = int(channels/2) + super(ResA, self).__init__() + self.model = nn.Sequential(# A1 + nn.BatchNorm2d(C), + nn.ReLU(), + nn.Conv2d(C,C2,1,stride=1,padding=0,dilation=1,groups=1,bias=True), + # B1 + nn.BatchNorm2d(C2), + nn.ReLU(), + nn.Conv2d(C2,C2,3,stride=1,padding=1,dilation=1,groups=1,bias=True), + # C1 + nn.BatchNorm2d(C2), + nn.ReLU(), + nn.Conv2d(C2,C,1,stride=1,padding=0,dilation=1,groups=1,bias=True)) + def forward(self, x): + res = self.model.forward(x) + return x + res + +# Prints `str` prefixed by the current test function name so we can use it in +# Filecheck label directives. +# This is achieved by inspecting the stack and getting the parent name. +def printWithCurrentFunctionName(s): + # stack[1] is the caller, i.e. "_test_model" + # stack[2] is the caller's caller, e.g. "test_conv_1" + print(inspect.stack()[2][3], s) + +class TestMLIRExport(unittest.TestCase): + def setUp(self): + pass + + def _test_model(self, model, model_args): + result = model(model_args) + + mlir = torch_mlir.get_mlir(result) + printWithCurrentFunctionName (mlir) + return True + + def test_ResA_16(self): + dev = torch_mlir.mlir_device() + model = ResA(16).to(dev) + passed = self._test_model(model, torch.ones((1,16,128,128), device=dev)) + # CHECK-LABEL: test_ResA_16 + # CHECK: [[V0:%[a-zA-Z0-9]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"({{.*}}) {layer_name = "L0-native_batch_norm-0"} + # CHECK: [[V1:%[a-zA-Z0-9]+]] = "aten.relu"([[V0]]) {layer_name = "L1-relu-0"} + # CHECK: [[V2:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V1]], {{.*}}) {layer_name = "L2-convolution_overrideable-0"} + # CHECK: [[V3:%[a-zA-Z0-9_]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"([[V2]]{{.*}}) {layer_name = "L3-native_batch_norm-1"} + # CHECK: [[V4:%[a-zA-Z0-9]+]] = "aten.relu"([[V3]]) {layer_name = "L4-relu-1"} + # CHECK: [[V5:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V4]],{{.*}}) {layer_name = "L5-convolution_overrideable-1"} + # CHECK: [[V6:%[a-zA-Z0-9_]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"([[V5]],{{.*}}) {layer_name = "L6-native_batch_norm-2"} + # CHECK: [[V7:%[a-zA-Z0-9]+]] = "aten.relu"([[V6]]) {layer_name = "L7-relu-2"} + # CHECK: [[V8:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V7]],{{.*}}) {layer_name = "L8-convolution_overrideable-2"} + # CHECK: {{.*}} = "aten.add"(%arg0, [[V8]], {{.*}}) {layer_name = "L9-add-0"} + self.assertTrue(passed) + +verbose = False +if __name__ == '__main__': + verbose = True + unittest.main() diff --git a/frontends/pytorch/test/test_export_add3.py b/frontends/pytorch/test/test_export_add3.py new file mode 100644 index 000000000..2bff4c038 --- /dev/null +++ b/frontends/pytorch/test/test_export_add3.py @@ -0,0 +1,26 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() +t0 = torch.randn((1,2,3,4), device=dev) +t1 = torch.randn((1,2,3,4), device=dev) +t2 = torch.randn((1,2,3,4), device=dev) + +t3 = t0 + t1 + t2 + +# +# Generate and check the MLIR for the result tensor +# +t3_mlir = torch_mlir.get_mlir( t3 ) + +# CHECK-LABEL: test_export_add3 +# CHECK: %1 = "aten.add"(%arg0, %arg1, %0) {layer_name = "L0-add-0"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, i32) -> tensor<1x2x3x4xf32> +# CHECK: %2 = "aten.add"(%1, %arg2, %0) {layer_name = "L1-add-1"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, i32) -> tensor<1x2x3x4xf32> +print("test_export_add3") +print(t3_mlir) diff --git a/frontends/pytorch/test/test_export_batchnorm.py b/frontends/pytorch/test/test_export_batchnorm.py new file mode 100644 index 000000000..97f02b8aa --- /dev/null +++ b/frontends/pytorch/test/test_export_batchnorm.py @@ -0,0 +1,19 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +model = torch.nn.BatchNorm2d(123).to(dev) +result = model(torch.ones(42,123,4,5).to(dev)) + +# CHECK-LABEL: test_export_batchnorm +# CHECK: aten.native_batch_norm +mlir = torch_mlir.get_mlir( result ) +print("test_export_batchnorm") +print(mlir) diff --git a/frontends/pytorch/test/test_export_conv2d_back.py b/frontends/pytorch/test/test_export_conv2d_back.py new file mode 100644 index 000000000..74eed1394 --- /dev/null +++ b/frontends/pytorch/test/test_export_conv2d_back.py @@ -0,0 +1,49 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +N = 3 +Cin = 16 +Cout = 4 +w = 10 +h = 10 + +model = torch.nn.Conv2d(Cin, Cout, (3,3)) +ref_model = torch.nn.Conv2d(Cin, Cout, (3,3)) + +ref_model.weight.data = model.weight.clone() +ref_model.bias.data = model.bias.clone() + +model = model.to(dev) + +softmax = torch.nn.LogSoftmax(dim=1) +loss = torch.nn.NLLLoss() + +tensor = torch.randn(N, Cin, h, w, device=dev) +result = model(tensor) + +# CHECK-LABEL: test_export_conv2d +# CHECK: aten.convolution_overrideable +print("test_export_conv2d") +print(torch_mlir.get_mlir( result )) + +target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout) +ref_target = target.clone() +target = target.to(dev) + +test_loss = loss( softmax(result), target ) +test_loss.backward() + +# CHECK-LABEL: test_export_conv2d_back +# CHECK: aten.convolution_overrideable +# CHECK: aten._log_softmax +# CHECK: aten.nll_loss2d_forward +print("test_export_conv2d_back") +print(torch_mlir.get_mlir( test_loss )) diff --git a/frontends/pytorch/test/test_export_multi_out.py b/frontends/pytorch/test/test_export_multi_out.py new file mode 100644 index 000000000..864c0a8f9 --- /dev/null +++ b/frontends/pytorch/test/test_export_multi_out.py @@ -0,0 +1,24 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +t0 = torch.randn(4, device=dev) +t1 = torch.randn(4, device=dev) +t2 = torch.randn(4, device=dev) + +t4 = t0 + t1 + t2 +t5 = t4 + t1 +t6 = t5 + t4 + +# CHECK-LABEL: test_multi_out +# CHECK: return %2, %3, %4 : tensor<4xf32>, tensor<4xf32>, tensor<4xf32> +mlir = torch_mlir.get_mlir([t4, t5, t6]) +print ("test_multi_out") +print (mlir) diff --git a/frontends/pytorch/test/test_export_resnet18.py b/frontends/pytorch/test/test_export_resnet18.py new file mode 100644 index 000000000..9264f6267 --- /dev/null +++ b/frontends/pytorch/test/test_export_resnet18.py @@ -0,0 +1,25 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import torchvision.models as models + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +model = models.resnet18().to(dev) +model.training = False + +tensor = torch.randn(32,3,32,32).to(dev) +result = model(tensor) + +mlir = torch_mlir.get_mlir( result ) + +# for now we just check the output shape +# CHECK-LABEL: test_export_resnet18 +# CHECK: return %{{.*}} : tensor<32x1000xf32> +print("test_export_resnet18") +print(mlir) diff --git a/frontends/pytorch/test/test_export_vgg11.py b/frontends/pytorch/test/test_export_vgg11.py new file mode 100644 index 000000000..2544c7a2f --- /dev/null +++ b/frontends/pytorch/test/test_export_vgg11.py @@ -0,0 +1,24 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import torchvision.models as models + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +model = models.vgg11_bn().to(dev) +model.training = False + +result = model(torch.ones(32,3,32,32).to(dev)) + +mlir = torch_mlir.get_mlir( result ) + +# for now we just check the output shape +# CHECK-LABEL: test_export_vgg11 +# CHECK: return %{{.*}} : tensor<32x1000xf32> +print("test_export_vgg11") +print(mlir) \ No newline at end of file diff --git a/frontends/pytorch/test/test_jit_add2.py b/frontends/pytorch/test/test_jit_add2.py new file mode 100644 index 000000000..93f1b2e9a --- /dev/null +++ b/frontends/pytorch/test/test_jit_add2.py @@ -0,0 +1,27 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() +t0 = torch.randn((4,4), device=dev) +t1 = torch.randn((4,4), device=dev) + +t2 = t0 + t1 + +# +# Check the result tensor against the CPU +# +t0_cpu = t0.to('cpu') +t1_cpu = t1.to('cpu') +t2_cpu = t2.to('cpu') + +print (t0_cpu, " +\n", t1_cpu, " =\n", t2_cpu) + +# CHECK: PASS! add2 check +test.compare(t2, t0_cpu + t1_cpu, "add2") diff --git a/frontends/pytorch/test/test_jit_add3.py b/frontends/pytorch/test/test_jit_add3.py new file mode 100644 index 000000000..b0c286011 --- /dev/null +++ b/frontends/pytorch/test/test_jit_add3.py @@ -0,0 +1,29 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() +t0 = torch.randn((1,2,3,4), device=dev) +t1 = torch.randn((1,2,3,4), device=dev) +t2 = torch.randn((1,2,3,4), device=dev) + +t3 = t0 + t1 + t2 + +# +# Check the result tensor against the CPU +# +t0_cpu = t0.to('cpu') +t1_cpu = t1.to('cpu') +t2_cpu = t2.to('cpu') +t3_cpu = t3.to('cpu') + +print (t0_cpu, " +\n", t1_cpu, " +\n", t2_cpu, " =\n", t3_cpu) + +# CHECK: PASS! +test.compare(t3, t0_cpu + t1_cpu + t2_cpu, "add3") diff --git a/frontends/pytorch/test/test_jit_add_views.py b/frontends/pytorch/test/test_jit_add_views.py new file mode 100644 index 000000000..6022fdcde --- /dev/null +++ b/frontends/pytorch/test/test_jit_add_views.py @@ -0,0 +1,42 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() +t0 = torch.randn((4,16,4), device=dev) +t1 = torch.randn((4,16,4), device=dev) + +t3 = torch.randn((4,64), device=dev) +t4 = torch.randn((4,64), device=dev) + +t2 = t0 + t1 +t5 = t3 + t4 + +t6 = t5.view((4,4,4,4)) +t7 = t2.view((4,4,4,4)) + +t8 = t6 + t7 + +t0_cpu = t0.to('cpu') +t1_cpu = t1.to('cpu') + +# CHECK: PASS! add_views_0 check +test.compare(t2, t0_cpu + t1_cpu, "add_views_0") + +t3_cpu = t3.to('cpu') +t4_cpu = t4.to('cpu') + +# CHECK: PASS! add_views_1 check +test.compare(t5, t3_cpu + t4_cpu, "add_views_1") + +t6_cpu = t6.to('cpu') +t7_cpu = t7.to('cpu') + +# CHECK: PASS! add_views_2 check +test.compare(t8, t6_cpu + t7_cpu, "add_views_2") diff --git a/frontends/pytorch/test/test_jit_as_stride.py b/frontends/pytorch/test/test_jit_as_stride.py new file mode 100644 index 000000000..979608c22 --- /dev/null +++ b/frontends/pytorch/test/test_jit_as_stride.py @@ -0,0 +1,43 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +x = torch.rand((3,64,8,8), device=dev) +y = x*x +print (y.stride()) + +dim = [64,24,24] +dim = [4,4,4] +N = 2; +count = dim[0]*dim[1]*dim[2] +sizes = (N,dim[0],dim[1],dim[2]) +strides = (1,dim[1]*dim[2],dim[2],1) +print(count) +t0 = torch.randn((N,count), device=dev) +t0_like = torch.randn((N,count)) + + +t1 = t0.as_strided(sizes, strides) +t1_ref = t0.to('cpu').as_strided(sizes, strides) +t1_like = t0_like.as_strided(sizes, strides) + +t1_ref = t1_ref.clone() + +# check that the IR has recorded the +# stride properly before invoking JIT +# CHECK: PASS! stride check +test.compare_eq(t1.stride(), t1_like.stride(), "stride") + +# CHECK: PASS! as_stride check +test.compare(t1_ref, t1, "as_stride") + +# CHECK: PASS! as_stride stride check +test.compare_eq(t1_ref.stride(), t1.to("cpu").stride(), "as_stride stride") diff --git a/frontends/pytorch/test/test_jit_conv2d.py b/frontends/pytorch/test/test_jit_conv2d.py new file mode 100644 index 000000000..cd8b70d56 --- /dev/null +++ b/frontends/pytorch/test/test_jit_conv2d.py @@ -0,0 +1,17 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +model = torch.nn.Conv2d(2,16,7,stride=[2,2], padding=[3,3], + dilation=1, groups=1, bias=True) + +tensor = torch.randn((1,2,128,128)) + +# CHECK: PASS! fwd check +test.check_ref(model, tensor) diff --git a/frontends/pytorch/test/test_jit_conv2d_back.py b/frontends/pytorch/test/test_jit_conv2d_back.py new file mode 100644 index 000000000..4848be813 --- /dev/null +++ b/frontends/pytorch/test/test_jit_conv2d_back.py @@ -0,0 +1,46 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +N = 3 +Cin = 16 +Cout = 4 +w = 10 +h = 10 + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(Cin, Cout, (3,3)) + + def forward(self, x): + x = self.conv1(x) + output = F.log_softmax(x, dim=1) + return output + +model = Net() +tensor = torch.randn(N, Cin, h, w) + +# CHECK: PASS! fwd check +fwd_path = test.check_ref(model, tensor) + +loss = torch.nn.NLLLoss() +target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout) + +# CHECK: PASS! back check +test.check_back(fwd_path, target, loss) + +# CHECK: PASS! weight_grad check +test.compare(model.conv1.weight.grad, fwd_path[0].conv1.weight.grad, "weight_grad") +# CHECK: PASS! bias_grad check +test.compare(model.conv1.bias.grad, fwd_path[0].conv1.bias.grad, "bias_grad") diff --git a/frontends/pytorch/test/test_jit_lenet_back.py b/frontends/pytorch/test/test_jit_lenet_back.py new file mode 100644 index 000000000..ee0a4ffc5 --- /dev/null +++ b/frontends/pytorch/test/test_jit_lenet_back.py @@ -0,0 +1,60 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=True) + self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=True) + #self.maxpool2d = nn.MaxPool2d(2,2) + self.fc1 = nn.Linear(9216*4, 128, bias=True) + self.fc2 = nn.Linear(128, 10, bias=True) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + #x = self.maxpool2d(x) + x = x.view((64,9216*4)) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + +def main(): + model = Net() + tensor = torch.randn((64, 1, 28, 28), requires_grad=True) + + # CHECK: PASS! fwd check + fwd_path = test.check_fwd(model, tensor) + + target = torch.ones((64), dtype=torch.long) + loss = F.nll_loss + + # CHECK: PASS! back check + test.check_back(fwd_path, target, loss) + + # CHECK: PASS! weight_grad check + test.compare(model.conv2.weight.grad, + fwd_path[0].conv2.weight.grad, "weight_grad") + # CHECK: PASS! bias_grad check + test.compare(model.conv2.bias.grad, + fwd_path[0].conv2.bias.grad, "bias_grad") + # CHECK: PASS! fc1_weight_grad check + test.compare(model.fc1.weight.grad, + fwd_path[0].fc1.weight.grad, "fc1_weight_grad") + +if __name__ == '__main__': + main() diff --git a/frontends/pytorch/test/test_jit_lenet_fwd.py b/frontends/pytorch/test/test_jit_lenet_fwd.py new file mode 100644 index 000000000..87f853678 --- /dev/null +++ b/frontends/pytorch/test/test_jit_lenet_fwd.py @@ -0,0 +1,53 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.maxpool2d = nn.MaxPool2d(2,2) + #self.dropout1 = nn.Dropout2d(0.25) + #self.dropout2 = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.maxpool2d(x) + #x = self.dropout1(x) + x = x.view((4,9216)) + x = self.fc1(x) + x = F.relu(x) + #x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def main(): + model = Net() + tensor = torch.randn((4, 1, 28, 28)) + + # CHECK: PASS! fwd check + fwd_path = test.check_fwd(model, tensor) + +if __name__ == '__main__': + main() diff --git a/frontends/pytorch/test/test_jit_linear.py b/frontends/pytorch/test/test_jit_linear.py new file mode 100644 index 000000000..fafbf072f --- /dev/null +++ b/frontends/pytorch/test/test_jit_linear.py @@ -0,0 +1,17 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +model = torch.nn.Linear(1024,16).to(dev) +tensor = torch.randn(4,1024).to(dev) + +# CHECK: PASS! fwd check +fwd_path = test.check_fwd(model, tensor) diff --git a/frontends/pytorch/test/test_jit_logsoftmax.py b/frontends/pytorch/test/test_jit_logsoftmax.py new file mode 100644 index 000000000..5d902d731 --- /dev/null +++ b/frontends/pytorch/test/test_jit_logsoftmax.py @@ -0,0 +1,15 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +model = torch.nn.LogSoftmax(dim=0) +tensor = torch.ones(1,2,3,4) + +# CHECK: PASS! fwd check +fwd_path = test.check_fwd(model, tensor) diff --git a/frontends/pytorch/test/test_jit_maxpool.py b/frontends/pytorch/test/test_jit_maxpool.py new file mode 100644 index 000000000..0426ef198 --- /dev/null +++ b/frontends/pytorch/test/test_jit_maxpool.py @@ -0,0 +1,18 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +model = torch.nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1), + dilation=1, return_indices=False, ceil_mode=False) + +tensor = torch.randn(1,32,16,16) + +# CHECK: PASS! fwd check +fwd_path = test.check_fwd(model, tensor) + diff --git a/frontends/pytorch/test/test_jit_mlp_back.py b/frontends/pytorch/test/test_jit_mlp_back.py new file mode 100644 index 000000000..9f25b9983 --- /dev/null +++ b/frontends/pytorch/test/test_jit_mlp_back.py @@ -0,0 +1,49 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(28*28, 50) + self.fc2 = nn.Linear(50, 50) + self.fc3 = nn.Linear(50, 10) + + def forward(self, x): + x = x.view(-1, 28*28) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return F.log_softmax(self.fc3(x), dim=1) + +def main(): + device = torch_mlir.mlir_device() + model = Net() + tensor = torch.randn((64, 1, 28, 28),requires_grad=True) + # CHECK: PASS! fwd check + fwd_path = test.check_ref(model, tensor) + + target = torch.ones((64), dtype=torch.long) + loss = F.nll_loss + + # CHECK: PASS! back check + test.check_back(fwd_path, target, loss) + + # CHECK: PASS! fc1_weight_grad check + test.compare(model.fc1.weight.grad, fwd_path[0].fc1.weight.grad, "fc1_weight_grad") + +if __name__ == '__main__': + main() diff --git a/frontends/pytorch/test/test_jit_mm.py b/frontends/pytorch/test/test_jit_mm.py new file mode 100644 index 000000000..5f2d9ddd2 --- /dev/null +++ b/frontends/pytorch/test/test_jit_mm.py @@ -0,0 +1,32 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +t0 = torch.randn((3,13), device=dev) +t1 = torch.randn((13,5), device=dev) +print(t0.to('cpu'), t1.to('cpu')) +print(torch.mm(t0.to('cpu'), t1.to('cpu'))) + +t2 = torch.mm(t0, t1) + +# +# Check the result tensor against the CPU +# +t0_cpu = t0.to('cpu') +t1_cpu = t1.to('cpu') +t2_cpu = t2.to('cpu') + +print (t0_cpu, " *\n", t1_cpu, " =\n", t2_cpu) + +ref_tensor = torch.mm(t0_cpu, t1_cpu) +# CHECK: PASS! mm check +test.compare(t2, ref_tensor, "mm") + diff --git a/frontends/pytorch/test/test_jit_mul2.py b/frontends/pytorch/test/test_jit_mul2.py new file mode 100644 index 000000000..c2e4e9125 --- /dev/null +++ b/frontends/pytorch/test/test_jit_mul2.py @@ -0,0 +1,26 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() +t0 = torch.randn((4,4), device=dev) +t1 = torch.randn((4,4), device=dev) + +t2 = t0 * t1 +# +# Check the result tensor against the CPU +# +t0_cpu = t0.to('cpu') +t1_cpu = t1.to('cpu') +t2_cpu = t2.to('cpu') + +print (t0_cpu, " *\n", t1_cpu, " =\n", t2_cpu) + +# CHECK: PASS! mul2 check +test.compare(t2, t0_cpu * t1_cpu, "mul2") diff --git a/frontends/pytorch/test/test_jit_nllloss.py b/frontends/pytorch/test/test_jit_nllloss.py new file mode 100644 index 000000000..c375a7577 --- /dev/null +++ b/frontends/pytorch/test/test_jit_nllloss.py @@ -0,0 +1,21 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +model = torch.nn.LogSoftmax(dim=1) +tensor = torch.randn(3,5,requires_grad=True) + +# CHECK: PASS! fwd check +fwd_path = test.check_fwd(model, tensor) + +target = torch.tensor([1, 0, 4]) +loss = torch.nn.NLLLoss() + +# CHECK: PASS! back check +test.check_back(fwd_path, target, loss) diff --git a/frontends/pytorch/test/test_jit_relu.py b/frontends/pytorch/test/test_jit_relu.py new file mode 100644 index 000000000..eb3f4b784 --- /dev/null +++ b/frontends/pytorch/test/test_jit_relu.py @@ -0,0 +1,15 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +model = torch.nn.ReLU() +tensor = torch.randn(10) + +# CHECK: PASS! fwd check +fwd_path = test.check_ref(model, tensor) diff --git a/frontends/pytorch/test/test_jit_t.py b/frontends/pytorch/test/test_jit_t.py new file mode 100644 index 000000000..ab38cedbe --- /dev/null +++ b/frontends/pytorch/test/test_jit_t.py @@ -0,0 +1,18 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir +import npcomp.frontends.pytorch.test as test + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +tensor = torch.randn(2,3).to(dev) +result = tensor.t() + +ref_result = tensor.to('cpu').t() +# CHECK: PASS! transpose check +test.compare(ref_result, result, "transpose") diff --git a/frontends/pytorch/test/test_op_report_conv2d.py b/frontends/pytorch/test/test_op_report_conv2d.py new file mode 100644 index 000000000..701e480a2 --- /dev/null +++ b/frontends/pytorch/test/test_op_report_conv2d.py @@ -0,0 +1,31 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import npcomp.frontends.pytorch as torch_mlir + +# RUN: python %s | FileCheck %s + +dev = torch_mlir.mlir_device() + +model = torch.nn.Conv2d(2,16,7,stride=[2,2], padding=[3,3], dilation=1, groups=1, bias=True).to(dev) + +tensor = torch.randn((1,2,128,128), device=dev) +result = model(tensor) + +mlir = torch_mlir.get_mlir( result ) +report = torch_mlir.op_report(mlir) + +# CHECK-LABEL: "L0-convolution_overrideable-0" +# CHECK-NEXT: "activation_in": 32768 +# CHECK-NEXT: "activation_out": 65536 +# CHECK-NEXT: "ops:+": 65536 +# CHECK-NEXT: "ops:MAC": 6422528 +# CHECK-NEXT: "parameters_in": 1584 +# CHECK-NEXT: "reads": 34352 +# CHECK-NEXT: "writes": 65536 +for k,v in report.items(): + print("\"{}\"".format(k)) + for k,v in v.items(): + print("\"{}\": {}".format(k,v)) diff --git a/frontends/pytorch/test/test_op_report_vgg_style_lenet.py b/frontends/pytorch/test/test_op_report_vgg_style_lenet.py new file mode 100644 index 000000000..32d64f8f4 --- /dev/null +++ b/frontends/pytorch/test/test_op_report_vgg_style_lenet.py @@ -0,0 +1,108 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + +import npcomp.frontends.pytorch as torch_mlir +import json + +# RUN: python %s | FileCheck %s + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 8, 3, padding=1) + self.conv2 = nn.Conv2d(8, 16, 3, padding=0) + self.maxpool1 = nn.MaxPool2d(2,2) + self.maxpool2 = nn.MaxPool2d(2,2) + self.fc1 = nn.Linear(576, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, 8) + + def forward(self, x): + x = self.conv1(x) + print(x.shape) + x = F.relu(x) + print(x.shape) + x = self.maxpool1(x) + print(x.shape) + + x = self.conv2(x) + print(x.shape) + x = F.relu(x) + print(x.shape) + x = self.maxpool2(x) + print(x.shape) + x = x.view(8, 6*6*16) + + x = self.fc1(x) + x = F.relu(x) + + x = self.fc2(x) + x = F.relu(x) + + x = self.fc3(x) + output = F.log_softmax(x, dim=1) + + return output + +def main(): + + test_status = "PASS!" + + # CHECK-LABEL: test_op_report_vgg_style_lenet + # CHECK: PASS! + print("test_op_report_vgg_style_lenet") + + device = torch_mlir.mlir_device() + + model = Net().to(device) + ref_tensor = torch.randn((8, 1, 30, 30)) + tensor = ref_tensor.clone().to(device) + + result = model(tensor) + target = torch.ones((8), dtype=torch.long).to(device) + loss = F.nll_loss(result, target) + loss.backward() + + mlir0 = torch_mlir.get_mlir(model.conv1.weight.grad) + print(mlir0) + report = torch_mlir.op_report(mlir0) + print(report) + + report_dict = report + expected = 32 + if (len(report_dict) != expected): + print("### ERROR: Expecting",expected,"items in the report, but got ",len(report_dict)) + test_status = "FAIL!" + + # Every item should have a read and a write + for key, value in report_dict.items(): + if not 'reads' in value: + print(f"### ERROR: {key} does not contain the required reads field") + test_status = "FAIL!" + if not 'writes' in value: + print(f"### ERROR: {key} does not contain the required writes field") + test_status = "FAIL!" + if "convolution" in key: + if not 'ops:MAC' in value: + print(f"### ERROR: convolution {key} does not contain the required MAC field") + test_status = "FAIL!" + if "mm" in key: + if not 'ops:MAC' in value: + print(f"### ERROR: mm {key} does not contain the required MAC field") + test_status = "FAIL!" + + + print(test_status) + +if __name__ == '__main__': + main() diff --git a/frontends/pytorch/utils/gen_aten_dialect.py b/frontends/pytorch/utils/gen_aten_dialect.py new file mode 100644 index 000000000..2dc158445 --- /dev/null +++ b/frontends/pytorch/utils/gen_aten_dialect.py @@ -0,0 +1,1244 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +# Structured similarly to code from git@github.com:pytorch/xla.git + +from __future__ import print_function + +import argparse +import collections +import lark +import os +import re +import string +import sys + +#### +# This file parses the C++ signatures exported by pytorch and generates +# appropriate MLIR operations in a tablegen file. It also generates some of +# the more boilerplate parts of the pytorch integration. This may need to be +# run if pytorch versions change. Primarily this reads information from +# pytorch through RegistrationDeclarations.h and Functions.h. It also allows +# some local overrides (specified in aten_mlir_type.h). +# It generates: aten_mlir_type_defaults.{.cpp,.h} and ATenOps.td, which will need +# to be moved to their appropriate places. + +# To run: +# python3 gen_aten_dialect.py --output_folder=. \ +# ../csrc/aten_mlir_type.h \ +# ${TORCH_INSTALL_PREFIX}/include/ATen/RegistrationDeclarations.h \ +# ${TORCH_INSTALL_PREFIX}/include/ATen/Functions.h + + +def namedtuple_with_defaults(typename, field_names, default_values=()): + ntuple = collections.namedtuple(typename, field_names) + ntuple.__new__.__defaults__ = (None,) * len(ntuple._fields) + if isinstance(default_values, collections.Mapping): + prototype = ntuple(**default_values) + else: + prototype = ntuple(*default_values) + ntuple.__new__.__defaults__ = tuple(prototype) + return ntuple + + +class ArgTemplate(string.Template): + idpattern = r'[a-z0-9_]+' + + +FuncDef = namedtuple_with_defaults('FuncDef', 'cpp_sig, aten_sig') + +FuncGen = namedtuple_with_defaults( + 'FuncGen', + 'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig, aten_sig' +) + +FuncOpts = namedtuple_with_defaults( + 'FuncOpts', + 'ref_param, device_param, wparams, outfn_template, outfn_name, shape_check_indices' +) + +_GRAMMAR = r""" + start: type fnname "(" params ")" + rtype: "(" rparams ")" + | TNAME + rparams: rparam + | rparam "," rparams + rparam: type param_name + type: CONST? core_type refspec? + fnname: CNAME + refspec: REF + | PTR + core_type: template + | TNAME + template: TNAME "<" typelist ">" + typelist: type + | type "," typelist + REF: "&" + PTR: "*" + CONST: "const" + TNAME: /[a-zA-Z0-9_:]+/ + HEXNUMBER: /0x[0-9a-fA-F]+/ + params: param + | param "," params + param: type param_name param_defval? + param_name: CNAME + + param_defval: "=" init_value + init_value: "true" + | "false" + | "{}" + | NUMBER + | SIGNED_NUMBER + | HEXNUMBER + | ESCAPED_STRING + + %import common.CNAME -> CNAME + %import common.NUMBER -> NUMBER + %import common.SIGNED_NUMBER -> SIGNED_NUMBER + %import common.ESCAPED_STRING -> ESCAPED_STRING + %import common.WS + %ignore WS + """ + +_PARSER = lark.Lark(_GRAMMAR, parser='lalr', propagate_positions=True) + +_XPARSER = lark.Lark(_GRAMMAR, + parser='lalr', + propagate_positions=True, + keep_all_tokens=True) + +_TD_BLACKLIST = set([ + 'clone', + 'to', + 'copy_', + 'copy', + 'copy_from', + '_copy_from', + '_unsafe_view', +]) + +_TD_NO_OPSTATS_LIST = set([ + '_log_softmax', + '_log_softmax_backward_data', +]) + +_FN_BLACKLIST = set([ + 'numel', + 'ones', + 'ones_like', + 'result_type', + # 'zero_', + 'zeros', + 'zeros_like', +]) + +_FN_NO_DEBUG_ENTRY_LIST = set([ + 'empty', + 'fill_', + 'zero_', +]) + +_FN_BLACKLIST_REGEX = [ + # ATEN functions + r'[^(]*cudnn', + # XLA/TPU functions +] + +_FN_OUT = { + 'add_out': + FuncOpts(), + 'arange_out(Tensor, Scalar, Scalar, Scalar) -> Tensor': + FuncOpts(outfn_template=ArgTemplate( + 'ATenMLIRType::arange($1, $2, $3, $0.options())')), + 'bitwise_not_out': + FuncOpts(), + 'clamp_out': + FuncOpts(), + 'div_out': + FuncOpts(), + 'gather_out': + FuncOpts(), + 'kthvalue_out': + FuncOpts(), + 'index_select_out': + FuncOpts(), + 'log_out': + FuncOpts(), + 'topk_out': + FuncOpts(), +} +_FN_OUT = {} + +# List of tuples with the regex match first, and the corresponding FuncOpts() +# second. +_FN_OUT_REGEX = [] + +_FN_REMAP = { + '_th_eq(Tensor, Scalar) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::eq'), + '_th_eq(Tensor, Tensor) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::eq'), + '_th_ge(Tensor, Scalar) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::ge'), + '_th_ge(Tensor, Tensor) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::ge'), + '_th_gt(Tensor, Scalar) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::gt'), + '_th_gt(Tensor, Tensor) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::gt'), + '_th_le(Tensor, Scalar) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::le'), + '_th_le(Tensor, Tensor) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::le'), + '_th_lt(Tensor, Scalar) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::lt'), + '_th_lt(Tensor, Tensor) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::lt'), + '_th_ne(Tensor, Scalar) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::ne'), + '_th_ne(Tensor, Tensor) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::ne'), + 's__th_and(Tensor, Tensor) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::__and__', + shape_check_indices=((0, 1),)), + 's__th_or(Tensor, Tensor) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::__or__', + shape_check_indices=((0, 1),)), + 's__th_xor(Tensor, Tensor) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::__xor__', + shape_check_indices=((0, 1),)), + # '_s_where(Tensor, Tensor, Tensor) -> Tensor': + # FuncOpts( + # outfn_name='ATenMLIRType::where', + # shape_check_indices=( + # (0, 1), + # (0, 2), + # )), + 's__th_eq(Tensor, Tensor) -> Tensor': + FuncOpts(outfn_name='ATenMLIRType::eq', shape_check_indices=((0, 1),)), +} + +_TYPE_NSMAP = { + 'Tensor': 'at::Tensor', + 'TensorList': 'at::TensorList', + 'Scalar': 'at::Scalar', + 'Storage': 'at::Storage', + 'IntList': 'at::IntList', + 'IntArrayRef': 'at::IntArrayRef', + 'Generator': 'at::Generator', + 'ScalarType': 'at::ScalarType', + 'TensorOptions': 'at::TensorOptions', + 'SparseTensorRef': 'at::SparseTensorRef', + 'Device': 'c10::Device', + 'optional': 'c10::optional', + 'MemoryFormat': 'at::MemoryFormat', + 'QScheme': 'at::QScheme', + 'ConstQuantizerPtr': 'at::ConstQuantizerPtr', + 'Dimname': 'at::Dimname', # namedtensor-only + 'DimnameList': 'at::DimnameList', # namedtensor-only +} + +_H_HEADER = """// Autogenerated file by {gen}. Do not edit directly! + +#include + +namespace torch_mlir {{ + +class ATenMLIRTypeDefault {{ + public: +{hfuncs} +}}; + +void RegisterAtenTypeFunctions(); + +}} // namespace torch_mlir +""" + +_CPP_HEADER = """// Autogenerated file by {gen}. Do not edit directly! +#include "aten_mlir_type_default.h" + +#include +#include +#include +#include + +#include "aten_mlir_bridge.h" +#include "aten_mlir_type.h" + +namespace torch_mlir {{ + +{funcs} + +{regs} +}} // namespace torch_mlir +""" + +_torch_mlir_FUNCTIONS = {} + +_CTOR_FUNCTIONS = { + 'empty': '.device(at::DeviceType::CPU)', + 'linspace': '.device(at::DeviceType::CPU)', + 'logspace': '.device(at::DeviceType::CPU)', + 'rand': '.device(at::DeviceType::CPU)', + 'rand_like': '.device(at::DeviceType::CPU)', + 'randn': '.device(at::DeviceType::CPU)', + 'randn_like': '.device(at::DeviceType::CPU)', + 'randint': '.device(at::DeviceType::CPU)', + 'randint_like': '.device(at::DeviceType::CPU)', + 'randperm': '.device(at::DeviceType::CPU)', + 'scalar_tensor': '.device(at::DeviceType::CPU)', +} + +_FUNCTION_OPTIONS = { + 'slice(Tensor, int64_t, int64_t, int64_t, int64_t) -> Tensor': + FuncOpts(wparams=['self']), +} + +_RESULT_NAME = 'x_result' + + +class Context(object): + + def __init__(self, functions): + with open(functions, 'r') as ff: + self.functions_data = ff.read() + + def get_function(self, name): + if self.functions_data.find(' {}('.format(name)) >= 0: + return 'at::{}'.format(name) + + +class StringEmit(object): + + def __init__(self, sref): + self.sref = sref + self.sval = '' + self.pos = -1 + + def __repr__(self): + return self.sval + + def advance(self, t): + start = t.column - 1 + end = t.end_column - 1 + pos = self.pos if self.pos >= 0 else start + if start > pos: + self.sval += self.sref[pos:start] + self.sval += t.value + self.pos = end + + def skip(self, t): + self.pos = last_match(t) if self.pos >= 0 else -1 + + def append(self, s): + self.sval += s + self.pos = -1 + + +class TensorFetcher(object): + + def __init__(self, var_name): + self.var_name = var_name + self.tvar_name = '{}_tensors'.format(self.var_name) + self.tensors = [] + self.writeable = [] + + def add(self, name, writeable): + if writeable: + self.writeable.append(len(self.tensors)) + self.tensors.append(name) + return '{}[{}]'.format(self.var_name, len(self.tensors) - 1) + + def generate_fetches(self): + code = '' + code += ' std::vector {} = {{{}}};\n'.format( + self.tvar_name, ', '.join(self.tensors)) + code += (' auto {} = bridge::MLIRCreateTensorList({});\n').format( + self.var_name, self.tvar_name) + return code + + def generate_updates(self): + assert (0) + code = '' + if self.writeable: + ivar_name = '{}_update_indices'.format(self.var_name) + code += ' std::vector {} = {{{}}};\n'.format( + ivar_name, ', '.join(str(x) for x in self.writeable)) + code += ' bridge::XlaUpdateTensors({}, {}, {});\n'.format( + self.tvar_name, self.var_name, ivar_name) + return code + + +def list_get(l, n): + return l[n] if n < len(l) else None + + +def is_blacklisted_fn(fname, mapsig): + if fname in _FN_BLACKLIST or mapsig in _FN_BLACKLIST: + return True + for frx in _FN_BLACKLIST_REGEX: + if re.match(frx, fname) or re.match(frx, mapsig): + return True + return False + + +def get_outfn_options(fname, mapsig): + for name in [fname, mapsig]: + fnopts = _FN_OUT.get(name, None) + if fnopts is not None: + return fnopts + for frx, fnopts in _FN_OUT_REGEX: + if re.match(frx, fname) or re.match(frx, mapsig): + return fnopts + + +def get_remapfn_options(fname, mapsig): + for name in [fname, mapsig]: + fnopts = _FN_REMAP.get(name, None) + if fnopts is not None: + return fnopts + + +def is_write_param(fnopts, pname, defval): + if fnopts and fnopts.wparams: + if pname in fnopts.wparams: + return True + return defval + + +def first_match(t): + if isinstance(t, lark.lexer.Token): + return t.column - 1 + assert isinstance(t, lark.tree.Tree) + return first_match(t.children[0]) + + +def last_match(t): + if isinstance(t, lark.lexer.Token): + return t.end_column - 1 + assert isinstance(t, lark.tree.Tree) + return last_match(t.children[-1]) + + +def for_every_token(t, fn): + if isinstance(t, lark.lexer.Token): + fn(t) + else: + assert isinstance(t, lark.tree.Tree) + for c in t.children: + for_every_token(c, fn) + + +def emit_string(t, emit, emit_fn): + status = emit_fn(t) + if status > 0: + + def do_emit(tok): + emit.advance(tok) + + for_every_token(t, do_emit) + elif status == 0: + if isinstance(t, lark.lexer.Token): + emit.advance(t) + else: + assert isinstance(t, lark.tree.Tree) + for c in t.children: + emit_string(c, emit, emit_fn) + else: + emit.skip(t) + + +def typed_child(t, n, ttype): + assert isinstance(t, lark.tree.Tree) + assert n < len(t.children) + c = t.children[n] + assert isinstance(c, lark.tree.Tree) + assert c.data == ttype, t.pretty() + return c + + +def rewrite_sig(tree, orig_sig, emit_fn=lambda x: 0): + emit = StringEmit(orig_sig) + emit_string(tree, emit, emit_fn) + return str(emit) + + +def rewrite_signature(sig, tmap): + + def rewrite(t): + if t.type == 'TNAME': + new_type = tmap.get(t.value, None) + if new_type is not None: + t.value = new_type + + def emit_fn(t): + if isinstance(t, lark.lexer.Token): + return 0 + return -1 if t.data == 'param_defval' else 0 + + xtree = _XPARSER.parse(sig) + for_every_token(xtree, rewrite) + return rewrite_sig(xtree, sig, emit_fn=emit_fn) + + +def create_stdfunc_sig(tree, orig_sig): + + def emit_fn(t): + if isinstance(t, lark.lexer.Token): + return 0 + return -1 if t.data == 'param_name' else 0 + + emit = StringEmit(orig_sig) + # Emit full function return type. + emit_string(typed_child(tree, 0, 'type'), emit, emit_fn) + emit.append('(') + # Emit parameter list w/out parameter names. + emit_string(typed_child(tree, 3, 'params'), emit, emit_fn) + emit.append(')') + return str(emit) + + +def create_map_sig(tree, orig_sig): + + def emit_fn(t): + if isinstance(t, lark.lexer.Token): + return -1 if t.type in ['CONST', 'REF', 'PTR'] else 0 + return -1 if t.data in ['param_name', 'param_defval'] else 0 + + emit = StringEmit(orig_sig) + # Emit full function return type. + emit_string(typed_child(tree, 1, 'fnname'), emit, emit_fn) + emit.append('(') + # Emit parameter list w/out parameter names. + emit_string(typed_child(tree, 3, 'params'), emit, emit_fn) + emit.append(') -> ') + emit_string(typed_child(tree, 0, 'type'), emit, emit_fn) + return str(emit) + + +def type_core(t): + assert isinstance(t, lark.tree.Tree) + for c in t.children: + if isinstance(c, lark.tree.Tree) and c.data == 'core_type': + c = c.children[0] + if isinstance(c, lark.lexer.Token): + return c.value + assert isinstance(c, lark.tree.Tree) and c.data == 'template' + return c.children[0].value + raise RuntimeError('Not a type tree: {}'.format(t)) + + +def type_is_const(t): + assert isinstance(t, lark.tree.Tree) + c = t.children[0] + return isinstance(c, lark.lexer.Token) and c.value == 'const' + + +def type_is_refptr(t, kind): + assert isinstance(t, lark.tree.Tree) + c = t.children[-1] + if not isinstance(c, lark.tree.Tree) or c.data != 'refspec': + return False + c = c.children[0] + return isinstance(c, lark.lexer.Token) and c.value == kind + + +def extract_list(t, l): + assert isinstance(t, lark.tree.Tree) + l.append(t.children[0]) + if len(t.children) == 2: + c = t.children[1] + if isinstance(c, lark.tree.Tree) and c.data == t.data: + extract_list(c, l) + return l + + +def tuple_type_list(t): + assert isinstance(t, lark.tree.Tree) + c = t.children[0] + assert isinstance(c, lark.tree.Tree) and c.data == 'core_type' + c = c.children[0] + assert isinstance(c, lark.tree.Tree) and c.data == 'template' + types = [] + return extract_list(c.children[1], types) + + +def get_function_name(t): + assert isinstance(t, lark.tree.Tree) + fname = t.children[1] + assert isinstance(fname, lark.tree.Tree) + assert fname.data == 'fnname' + return fname.children[0].value + + +def get_function_signature(t, orig_sig, namefn): + emit = StringEmit(orig_sig) + # Emit full function return type. + emit_string(typed_child(t, 0, 'type'), emit, lambda t: 0) + fnname = typed_child(t, 1, 'fnname').children[0] + xfname = namefn(fnname.value) + emit.append(' {}('.format(xfname)) + # Emit parameter list w/out parameter names. + emit_string(typed_child(t, 3, 'params'), emit, lambda t: 0) + emit.append(')') + return str(emit), fnname.value, xfname + + +def get_parameters(t): + assert isinstance(t, lark.tree.Tree) + c = t.children[2] + assert isinstance(c, lark.tree.Tree) + assert c.data == 'params' + params = [] + extract_list(c, params) + return params + + +def get_rparameters(t): + assert isinstance(t, lark.tree.Tree) + params = [] + print(len(t.children)) + # c = t.children[3] + # assert isinstance(c, lark.tree.Tree) + # assert c.data == 'rparams' + + # extract_list(c, params) + return params + + +def param_name(t): + assert isinstance(t, lark.tree.Tree) + c = t.children[1] + assert isinstance(c, lark.tree.Tree) + assert c.data == 'param_name' + token = c.children[0] + assert isinstance(token, lark.lexer.Token) + return token.value + + +def param_type(t): + assert isinstance(t, lark.tree.Tree) + c = t.children[0] + assert isinstance(c, lark.tree.Tree) + return c + + +def get_optional(fnopts, name, defval=None): + if fnopts is None or not hasattr(fnopts, name): + return defval + return getattr(fnopts, name, defval) or defval + + +def get_return_value(rtype, rname, param, var, ref_param, fnopts): + crtype = type_core(rtype) + if type_is_const(rtype) or type_is_refptr(rtype, '&'): + # If the return type is a const or a reference, return the matching + # parameter. In these cases we operated on XLA tensors data (the ATEN one), + # but the returned references are the input parameters. + assert param + return param_name(param) + elif crtype != 'Tensor': + return rname + else: + # If instead the return type is a value Tensor, we create a new one by + # wrapping the proper local variable which has been created by calling + # into the CPU tensor implementation. + return 'bridge::CreateMLIRTensor({}, bridge::GetMLIRDevice({}))'.format( + rname, get_optional(fnopts, 'device_param', param_name(ref_param))) + + +def get_reference_param(params, fnopts=None): + # The reference parameter is the Tensor object which we use to extract the + # result Tensor device, if any. + ref_param = None + other = None + for p in params: + ptype = param_type(p) + cptype = type_core(ptype) + pname = param_name(p) + if get_optional(fnopts, 'ref_param') == pname: + return p + if not other and (cptype == 'TensorOptions' or cptype == 'TensorList'): + other = p + if cptype != 'Tensor': + continue + if not ref_param and (pname == 'self' or type_is_const(ptype)): + ref_param = p + other = p + return ref_param or other + + +def get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param, + fnopts): + types = tuple_type_list(rtype) + retstr = '{}('.format(rtype_str) + for i, ttype in enumerate(types): + if i > 0: + retstr += ', ' + tuple_var = 'std::get<{}>({})'.format(i, rname) + retstr += get_return_value(ttype, tuple_var, list_get(params, i), + list_get(param_vars, i), ref_param, fnopts) + return retstr + ')' + + +def get_return_type_str(t, orig_sig): + assert isinstance(t, lark.tree.Tree) + fname = t.children[1] + assert isinstance(fname, lark.tree.Tree) + assert fname.data == 'fnname' + token = fname.children[0] + assert isinstance(token, lark.lexer.Token) + return orig_sig[0:token.column - 2] + + +def generate_entry_debug_code(t, fname, params, fname_ns='aten'): + code = '' + if fname in _FN_NO_DEBUG_ENTRY_LIST: + return code + code += ' std::cout << "{}::{}" << std::endl;\n'.format(fname_ns, fname) + # Emits debug code for a given intercepted ATEN type function. For now we use + # a counter which will show up in the metrics reports. + # VLOG info. Use the following to see debug output: + # export TF_CPP_VMODULE=aten_mlir_type_default=3 + #code += ' TF_VLOG(3) << "XLA {} :"'.format(fname) + #for p in params: + # ptype = param_type(p) + # cptype = type_core(ptype) + # pname = param_name(p) + # if cptype == 'Tensor': + # code += ' << " {}=" << {}.toString()'.format(pname, pname) + #code += ';\n' + return code + + +def generate_exit_debug_code(t, fname, rname, params, param_vars): + code = '' + return code + + +def generate_return_stmt(t, rtype_str, fname, rname, params, param_vars, + ref_param, fnopts): + assert isinstance(t, lark.tree.Tree) + rtype = t.children[0] + ctype = type_core(rtype) + if ctype == 'std::tuple': + retstr = get_tuple_return(rtype, rtype_str, rname, params, param_vars, + ref_param, fnopts) + elif ctype == 'std::vector': + #retstr = 'bridge::CreateXlaTensors({}, bridge::GetXlaDevice({}))'.format( + # rname, get_optional(fnopts, 'device_param', param_name(ref_param))) + retstr = rname + elif ctype == 'Tensor': + retstr = get_return_value(rtype, rname, params[0], param_vars[0], ref_param, + fnopts) + elif ctype == 'void' and not type_is_refptr(rtype, '*'): + return '' + else: + retstr = rname + return ' return {};\n'.format(retstr) + + +def generate_result_assignment(t, rname): + assert isinstance(t, lark.tree.Tree) + rtype = t.children[0] + ctype = type_core(rtype) + if ctype == 'void' and not type_is_refptr(rtype, '*'): + return '' + return 'auto&& {} = '.format(rname) + + +def get_handling_function(ctx, fname, the_ref_param, param_vars): + function = _torch_mlir_FUNCTIONS.get(fname, None) or ctx.get_function(fname) + if function: + code = '{}({})'.format(function, ', '.join(param_vars)) + else: + other_params = list(param_vars) + other_params.remove(the_ref_param) + code = '{}.{}({})'.format(the_ref_param, fname, ', '.join(other_params)) + return code + + +def rewrite_tensor_options(fname, pname): + rw = _CTOR_FUNCTIONS.get(fname, None) + if rw is None: + return '', pname + xname = 'o_{}'.format(pname) + code = ' at::TensorOptions {} = {}{};\n'.format(xname, pname, rw) + return code, xname + + +def get_param_names(params): + param_vars = [] + for p in params: + pname = param_name(p) + param_vars.append(pname) + return param_vars + + +def expand_fn_template(tmpl, param_vars): + mdict = {} + for i, pname in enumerate(param_vars): + mdict[str(i)] = pname + return tmpl.substitute(mdict) + + +def create_call(fname, param_vars): + return '{}({})'.format(fname, ', '.join(param_vars)) + + +def generate_shape_checks(param_vars, shape_check_indices, fname): + code = '' + #for i, j in shape_check_indices: + # code += (' XLA_CHECK({}.sizes() == {}.sizes()) << "Operand shapes must be ' + # 'identical for {}, mismatch for arguments {} and {}";\n').format( + # param_vars[i], param_vars[j], fname, i + 1, j + 1) + return code + + +def generate_aten_remap(ctx, fname, sig, params, fnopts): + code = '{} {{\n'.format(sig) + + param_vars = get_param_names(params) + if fnopts.outfn_template is not None: + fcall = expand_fn_template(fnopts.outfn_template, param_vars) + else: + assert fnopts.outfn_name + fcall = create_call(fnopts.outfn_name, param_vars) + + if fnopts.shape_check_indices is not None: + code += generate_shape_checks(param_vars, fnopts.shape_check_indices, fname) + code += ' return {};\n'.format(fcall) + code += '}' + return code + + +def generate_outfn_result_copy(dest, src): + return ' {}.unsafeGetTensorImpl()->shallow_copy_from({}.getIntrusivePtr());\n'.format( + dest, src) + + +def generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts): + rtype = tree.children[0] + num_outputs = None + if type_core(rtype) == 'std::tuple': + num_outputs = len(tuple_type_list(rtype)) + + code = '{} {{\n'.format(sig) + code += generate_entry_debug_code(tree, fname, params) + + param_vars = get_param_names(params) + if fnopts.outfn_template is not None: + fcall = expand_fn_template(fnopts.outfn_template, param_vars) + else: + m = re.match(r'(.*)_out$', fname) + assert m is not None, fname + out_count = num_outputs if num_outputs is not None else 1 + fcall = create_call('ATenMLIRType::{}'.format(m.group(1)), + param_vars[out_count:]) + + tmp_result = '{}_tmp'.format(fname) + code += ' auto {} = {};\n'.format(tmp_result, fcall) + if num_outputs is None: + code += generate_outfn_result_copy(param_vars[0], tmp_result) + code += generate_exit_debug_code(tree, fname, param_vars[0], params, + param_vars) + code += ' return {};\n'.format(param_vars[0]) + else: + for i in range(0, num_outputs): + code += generate_outfn_result_copy( + param_vars[i], 'std::get<{}>({})'.format(i, tmp_result)) + code += generate_exit_debug_code(tree, fname, param_vars[0:num_outputs], + params, param_vars) + code += ' return {}('.format(get_return_type_str(rwxtree, rwsig)) + for i in range(0, num_outputs): + if i > 0: + code += ', ' + code += param_vars[i] + code += ');\n' + code += '}' + return code + + +def generate_aten_to_mlir(ctx, tree, rwxtree, fname, sig, rwsig, params, + fnopts): + ref_param = get_reference_param(params, fnopts=fnopts) + + code = '{} {{\n'.format(sig) + code += generate_entry_debug_code(tree, fname, params) + the_ref_param = param_name(ref_param) if ref_param else None + tfetcher = TensorFetcher('mlirtens') + param_vars = [] + for p in params: + ptype = param_type(p) + cptype = type_core(ptype) + pname = param_name(p) + if cptype == 'TensorList': + #xname = 'l_{}'.format(pname) + #code += (' auto {} = bridge::XlaCreateTensorList({});\n').format( + # xname, pname) + xname = pname + param_vars.append(xname) + elif cptype == 'TensorOptions': + gcode, xname = rewrite_tensor_options(fname, pname) + code += gcode + param_vars.append(xname) + elif cptype != 'Tensor': + param_vars.append(pname) + elif type_is_const(ptype): + xname = tfetcher.add(pname, is_write_param(fnopts, pname, False)) + param_vars.append(xname) + else: + xname = tfetcher.add(pname, is_write_param(fnopts, pname, True)) + param_vars.append(xname) + if p == ref_param and not get_optional(fnopts, 'ref_param'): + the_ref_param = param_vars[-1] + code += tfetcher.generate_fetches() + result_assign = generate_result_assignment(tree, _RESULT_NAME) + code += ' {}{};\n'.format( + result_assign, get_handling_function(ctx, fname, the_ref_param, + param_vars)) + #code += tfetcher.generate_updates() + if result_assign: + code += (' static_cast({}); // Avoid warnings in case not ' + 'used\n'.format(_RESULT_NAME)) + code += generate_exit_debug_code(tree, fname, + _RESULT_NAME if result_assign else None, + params, param_vars) + code += generate_return_stmt(tree, get_return_type_str(rwxtree, rwsig), fname, + _RESULT_NAME if result_assign else None, params, + param_vars, ref_param, fnopts) + code += '}' + return code + + +def get_mlir_wrapper(fndef, ctx): + tree = _PARSER.parse(fndef.cpp_sig) + xtree = _XPARSER.parse(fndef.cpp_sig) + mapsig = create_map_sig(xtree, fndef.cpp_sig) + rwsig = rewrite_signature(fndef.cpp_sig, _TYPE_NSMAP) + rwxtree = _XPARSER.parse(rwsig) + params = get_parameters(tree) + fnopts = _FUNCTION_OPTIONS.get(mapsig, None) + + def gen_fnname(x): + return 'ATenMLIRTypeDefault::{}'.format(x) + + sig, fname, xfname = get_function_signature(rwxtree, rwsig, gen_fnname) + if not is_blacklisted_fn(fname, mapsig): + ofnopts = get_outfn_options(fname, mapsig) + rfnopts = get_remapfn_options(fname, mapsig) + if ofnopts is not None: + #print ("gen_aten_out:", fname) + code = generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params, + ofnopts) + elif rfnopts is not None: + #print ("gen_aten_remap", fname) + code = generate_aten_remap(ctx, fname, sig, params, rfnopts) + else: + code = generate_aten_to_mlir(ctx, tree, rwxtree, fname, sig, rwsig, + params, fnopts) + else: + code = None + return FuncGen(tree=tree, + xtree=xtree, + rwxtree=rwxtree, + func=fname, + xfunc=xfname, + code=code, + sig=fndef.cpp_sig, + rwsig=rwsig, + cppsig=sig, + mapsig=mapsig, + funsig=create_stdfunc_sig(rwxtree, rwsig), + aten_sig=fndef.aten_sig) + + +def is_tensor_api(fndef): + fndef = fndef.replace('at::', '') + fndef = fndef.replace('c10::Device', 'Device') + m = re.search(r'\bTensor\b', fndef) + return m is not None, fndef + + +def extract_functions(path): + functions = [] + errors = [] + for line in open(path, 'r'): + m = re.match(r'\s*([^\s].*); //\s+(.*)', line) + if not m: + continue + fndef = m.group(1) + try: + _XPARSER.parse(fndef) + functions.append(FuncDef(cpp_sig=fndef, aten_sig=m.group(2))) + except Exception as e: + if is_tensor_api(fndef)[0]: + errors.append((fndef, str(e))) + print('Error parsing "{}": {}'.format(fndef, e), file=sys.stderr) + return functions, errors + + +def get_mapsig_key(mapsig): + # PyTorch generates std::tuple<> without space among the tuple types, + # which would require special understanding in the string rewriter. + # Since we are using this as simple key, we can just string the spaces. + return mapsig.replace(' ', '') + + +def parse_local_overrides(path): + functions = [] + fndef = None + for line in open(path, 'r'): + line = line.strip() + if not fndef: + m = re.match(r'static\s+(.*);', line) + if m: + functions.append(m.group(1)) + continue + m = re.match(r'static\s+(.*)', line) + if m: + fndef = m.group(1) + else: + fndef = '{} {}'.format(fndef, line) + if fndef.endswith(';'): + functions.append(fndef[:-1]) + fndef = None + assert fndef is None + + overrides = {} + for fndef in functions: + # Discard static XLA type functions which are not ATEN. + is_tensor, fndef = is_tensor_api(fndef) + if is_tensor: + xtree = _XPARSER.parse(fndef) + mapsig_key = get_mapsig_key(create_map_sig(xtree, fndef)) + overrides[mapsig_key] = fndef + return overrides + + +def get_dialect_name(func): + name = '' + upper = True + cs = list(func) + for c in cs: + if c == '_': + upper = True + elif upper: + name += str(c).upper() + upper = False + else: + name += c + if cs[-1] == "_": + name += "Under" + return name + + +def generate_td_functions(fgens, overrides): + code = '' + overridden = set() + + code += "#ifdef ATEN_OP_DEFS\n" + code += "#else\n" + code += "#define ATEN_OP_DEFS\n\n" + + for fgen in fgens: + mapsig_key = get_mapsig_key(fgen.mapsig) + if mapsig_key in overrides: + overridden.add(mapsig_key) + if fgen.func in _TD_BLACKLIST: + continue + + rtype = fgen.tree.children[0] + num_outputs = 1 + if type_core(rtype) == 'std::tuple': + num_outputs = len(tuple_type_list(rtype)) + #print(num_outputs, rtype) + + dialect_name = get_dialect_name(fgen.func) + #print ('"{}"'.format(dialect_name)) + code += 'def aten_{}Op: aten_Op<"{}"'.format(dialect_name, fgen.func) + code += ', [NoSideEffect' + if not fgen.func in _TD_NO_OPSTATS_LIST: + code += ', StatisticsOpInterface' + code += ']>,\n' + code += ' Results<(outs' + # foreach output + # rparams = get_rparameters(fgen.tree) + # for p in rparams: + # pname = param_name(p) + # ptype = param_type(p) + # cptype = type_core(ptype) + # print(pname) + code += ' AnyTensor' + for i in range(num_outputs - 1): + code += ', AnyTensor' + code += ')> {\n' + code += ' let arguments = (\n' + params = get_parameters(fgen.tree) + for p in params: + pname = param_name(p) + ptype = param_type(p) + cptype = type_core(ptype) + if (cptype == 'Tensor'): + td_type = "AnyTensor" + elif (cptype == 'Scalar' or cptype == 'int64_t' or cptype == 'double' or + cptype == 'bool'): + td_type = "AnyScalar" + elif (cptype == 'c10::optional' or cptype == 'std::array'): + continue + elif (cptype == 'IntArrayRef'): + td_type = "AnyType" + else: + print('unhandled type', cptype) + td_type = "AnyType" + if p == params[0]: + code += ' ins {}:${}'.format(td_type, pname) + else: + code += ',\n {}:${}'.format(td_type, pname) + code += '\n );\n' + code += ' let summary = "aten {} operator";\n'.format(fgen.func) + code += ' let description = [{\n' + code += ' {}Op\n'.format(dialect_name) + code += ' aten {} operator\n'.format(fgen.func) + code += ' }];\n' + if not fgen.func in _TD_NO_OPSTATS_LIST: + code += ' let extraClassDeclaration = [{\n' + code += ' std::map getStatistics();\n' + code += ' }];\n' + code += '}\n\n' + + code += "#endif\n" + return code, overridden + + +def generate_registrations(fgens, overrides): + code = 'void RegisterAtenTypeFunctions() {\n' + code += ' static auto dispatch = torch::RegisterOperators()\n' + overridden = set() + for fgen in fgens: + mapsig_key = get_mapsig_key(fgen.mapsig) + if mapsig_key in overrides: + override_fn = 'ATenMLIRType::{}'.format(fgen.func) + overridden.add(mapsig_key) + else: + override_fn = fgen.xfunc if fgen.code else None + if override_fn: + code += ( + ' .op(torch::RegisterOperators::options().schema("{}")\n ' + '.impl_unboxedOnlyKernel<{}, &{}>(at::TensorTypeId::XLATensorId)\n' + ' .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))\n'.format( + fgen.aten_sig, fgen.funsig, override_fn, override_fn, + fgen.aten_sig)) + return code + ';\n}\n', overridden + + +def generate_functions(fgens): + code = '' + for fgen in fgens: + if fgen.code: + code += '{}\n\n'.format(fgen.code) + return code + + +def generate_class_functions(fgens): + code = '' + for fgen in fgens: + if fgen.code: + code += ' static {};\n'.format(fgen.rwsig) + return code + + +def gen_output_file(args, name): + if not args.output_folder: + return sys.stdout + return open(os.path.join(args.output_folder, name), 'w') + + +def gen_h_output_file(args): + return gen_output_file(args, 'aten_mlir_type_default.h') + + +def gen_cpp_output_file(args): + return gen_output_file(args, 'aten_mlir_type_default.cpp') + + +def gen_td_output_file(args): + return gen_output_file(args, 'ATenOps.td') + + +def check_overrides(overrides, overridden): + misses = 0 + for mapsig, cpp_sig in overrides.items(): + mapsig_key = get_mapsig_key(mapsig) + if not mapsig_key in overridden: + misses += 1 + print('ATenMLIRType function missed override: {}; // {}'.format( + cpp_sig, mapsig), + file=sys.stderr) + return misses == 0 + + +def generate(args): + fndefs, errors = extract_functions(args.typedef) + print('Extracted {} functions ({} errors) from {}'.format( + len(fndefs), len(errors), args.typedef), + file=sys.stderr) + assert len(errors) == 0 + + overrides = parse_local_overrides(args.overridetype) + print('{} function overrides in {}'.format(len(overrides), args.overridetype), + file=sys.stderr) + + fgens = [] + ctx = Context(args.functions) + for ts in fndefs: + try: + fgen = get_mlir_wrapper(ts, ctx) + if fgen: + fgens.append(fgen) + except Exception as e: + print('Failed to generate wrapper for {}: {}'.format(ts, e), + file=sys.stderr) + print('Generated {} wrappers for {}'.format(len(fgens), args.typedef), + file=sys.stderr) + + functions = generate_functions(fgens) + hfunctions = generate_class_functions(fgens) + + tdfunctions, overridden = generate_td_functions(fgens, overrides) + assert check_overrides(overrides, overridden) + #print(tdfunctions) + + regs, overridden = generate_registrations(fgens, overrides) + #print (len(overrides), len(overridden)) + assert check_overrides(overrides, overridden) + # Create output files ... + print(_H_HEADER.format(gen=os.path.basename(sys.argv[0]), hfuncs=hfunctions), + file=gen_h_output_file(args)) + print(_CPP_HEADER.format(gen=os.path.basename(sys.argv[0]), + funcs=functions, + regs=regs), + file=gen_cpp_output_file(args)) + + with gen_td_output_file(args) as f: + f.write(tdfunctions) + + +if __name__ == '__main__': + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('--output_folder', type=str) + arg_parser.add_argument('overridetype', + type=str, + metavar='OVERRIDE_TYPE_FILE', + help='The path to the overrides file') + arg_parser.add_argument('typedef', + type=str, + metavar='TYPE_DEFAULT_FILE', + help='The path to the TypeDefault.h file') + arg_parser.add_argument('functions', + type=str, + metavar='FUNCTIONS_FILE', + help='The path to the Functions.h file') + args, files = arg_parser.parse_known_args() + generate(args) diff --git a/python/npcomp/frontends/__init__.py b/python/npcomp/frontends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/npcomp/frontends/pytorch/__init__.py b/python/npcomp/frontends/pytorch/__init__.py new file mode 100644 index 000000000..f60a7a72b --- /dev/null +++ b/python/npcomp/frontends/pytorch/__init__.py @@ -0,0 +1,45 @@ +# -*- Python -*- +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +import _torch_mlir +from _torch_mlir import _get_mlir +from _torch_mlir import _op_report +from _torch_mlir import _liveness_report +from _torch_mlir import set_debug +from _torch_mlir import lower_to_std + +import json + +_torch_mlir._initialize_aten_bindings() +_torch_mlir.set_debug(False, "") + + +def get_mlir(t): + if not isinstance(t, list): + t = [t] + return _get_mlir(t) + + +def op_report(mlir): + return json.loads(_op_report(mlir)) + + +def liveness_report(mlir): + return json.loads(_liveness_report(mlir)) + + +def get_mlir_supported_devices(devkind=None): + # TODO: define our own device and stop hijacking the xla device. + return ["xla:0"] + + +def mlir_device(devkind=None): + devices = get_mlir_supported_devices(devkind=devkind) + device = devices[0] + return torch.device(device) + + +__all__ = ['get_mlir', 'mlir_device', 'op_report', 'liveness_report'] diff --git a/python/npcomp/frontends/pytorch/core/__init__.py b/python/npcomp/frontends/pytorch/core/__init__.py new file mode 100644 index 000000000..a03d4c390 --- /dev/null +++ b/python/npcomp/frontends/pytorch/core/__init__.py @@ -0,0 +1,4 @@ +# -*- Python -*- +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/python/npcomp/frontends/pytorch/core/aten_mlir_model.py b/python/npcomp/frontends/pytorch/core/aten_mlir_model.py new file mode 100644 index 000000000..fad270c11 --- /dev/null +++ b/python/npcomp/frontends/pytorch/core/aten_mlir_model.py @@ -0,0 +1,7 @@ +# -*- Python -*- +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +from npcomp.frontends.pytorch import * diff --git a/python/npcomp/frontends/pytorch/test/__init__.py b/python/npcomp/frontends/pytorch/test/__init__.py new file mode 100644 index 000000000..4e3a2c2a9 --- /dev/null +++ b/python/npcomp/frontends/pytorch/test/__init__.py @@ -0,0 +1,6 @@ +# -*- Python -*- +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .test_infrastructure import * diff --git a/python/npcomp/frontends/pytorch/test/test_infrastructure.py b/python/npcomp/frontends/pytorch/test/test_infrastructure.py new file mode 100644 index 000000000..472aa7413 --- /dev/null +++ b/python/npcomp/frontends/pytorch/test/test_infrastructure.py @@ -0,0 +1,52 @@ +# -*- Python -*- +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import npcomp.frontends.pytorch as torch_mlir +import copy + + +def compare(a, b, test): + print("Computing:" + test) + err = (a.to('cpu') - b.to('cpu')).abs().max() + if (err <= 1e-5): + print("PASS! " + test + " check") + else: + print("FAILED " + test + " check") + + +def compare_eq(a, b, test): + print("Computing:" + test) + if (a == b): + print("PASS! " + test + " check") + else: + print("FAILED " + test + " check") + + +def check_fwd(model, tensor): + device = torch_mlir.mlir_device() + result = model(tensor) + device_model = copy.deepcopy(model).to(device) + device_tensor = tensor.clone().to(device) + device_result = device_model(device_tensor) + + compare(result, device_result, "fwd") + return (device_model, device_result, result) + + +def check_ref(model, tensor): + return check_fwd(model, tensor) + + +def check_back(fwd_path, target, lossmodel): + device = torch_mlir.mlir_device() + (device_model, device_result, result) = fwd_path + device_target = target.clone().to(device) + ref_loss = lossmodel(result, target) + ref_loss.backward() + device_loss = lossmodel(device_result, device_target) + device_loss.backward() + + compare(ref_loss, device_loss, "back") + return (device_model, device_result)