mirror of https://github.com/llvm/torch-mlir
204 lines
6.7 KiB
C++
204 lines
6.7 KiB
C++
//===- backend_impl.cpp ---------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
// This file is adapted from pytorch/pytorch
|
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <torch/csrc/lazy/backend/backend_data.h>
|
|
#include <torch/csrc/lazy/backend/backend_device.h>
|
|
#include <torch/csrc/lazy/backend/lowering_context.h>
|
|
#include <torch/csrc/lazy/core/shape.h>
|
|
|
|
#include "backend_impl.h"
|
|
#include "ir_builder.h"
|
|
#include "mlir_lowering_context.h"
|
|
#include "ops/device_data.h"
|
|
#include "utils/debug.h"
|
|
#include "utils/exception.h"
|
|
|
|
namespace torch {
|
|
namespace lazy {
|
|
|
|
TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape)
|
|
: BackendData(device, shape),
|
|
info_(std::make_shared<TorchMlirBackendData::Info>()) {
|
|
PRINT_FUNCTION();
|
|
}
|
|
TorchMlirBackendData::TorchMlirBackendData(
|
|
BackendDevice device, Shape shape, std::shared_ptr<BackendData::Info> info)
|
|
: BackendData(device, shape), info_(info) {
|
|
PRINT_FUNCTION();
|
|
}
|
|
TorchMlirBackendData::TorchMlirBackendData(const at::Scalar &scalar,
|
|
BackendDevice device)
|
|
: BackendData(device, Shape(scalar.type(), {})),
|
|
info_(std::make_shared<TorchMlirBackendData::Info>(scalar)) {
|
|
PRINT_FUNCTION();
|
|
}
|
|
TorchMlirBackendData::TorchMlirBackendData(const at::Tensor &tensor,
|
|
BackendDevice device, Shape shape)
|
|
: BackendData(device, shape),
|
|
info_(std::make_shared<TorchMlirBackendData::Info>(tensor)) {
|
|
PRINT_FUNCTION();
|
|
}
|
|
|
|
BackendData::Handle TorchMlirBackendData::GetHandle() {
|
|
return reinterpret_cast<int64_t>(this);
|
|
}
|
|
|
|
void TorchMlirBackendData::Assign(const BackendData &data) {
|
|
const TorchMlirBackendData *torch_mlir_data =
|
|
dynamic_cast<const TorchMlirBackendData *>(&data);
|
|
TORCH_CHECK(torch_mlir_data,
|
|
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
|
|
|
|
info_ = torch_mlir_data->info_;
|
|
}
|
|
|
|
bool TorchMlirBackendData::HasValue() const { return bool(info_); }
|
|
|
|
BackendData::Info *TorchMlirBackendData::mlir_info() const {
|
|
return info_.get();
|
|
}
|
|
|
|
/**
|
|
* Initialization/Teardown
|
|
* */
|
|
void TorchMlirBackendImpl::PrepareToExit() const {}
|
|
|
|
/**
|
|
* IR Tracing
|
|
* */
|
|
|
|
const IrBuilder *TorchMlirBackendImpl::GetIrBuilder() const {
|
|
static const IrBuilder *builder = new TorchMlirIrBuilder();
|
|
return builder;
|
|
}
|
|
|
|
/**
|
|
* Data Transfer
|
|
* */
|
|
|
|
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor(
|
|
const at::Tensor &tensor, const Shape &shape,
|
|
const BackendDevice &device) const {
|
|
PRINT_FUNCTION();
|
|
return std::make_shared<TorchMlirBackendData>(tensor, device, shape);
|
|
}
|
|
|
|
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar(
|
|
const at::Scalar &scalar, const BackendDevice &device) const {
|
|
PRINT_FUNCTION();
|
|
return std::make_shared<TorchMlirBackendData>(scalar, device);
|
|
}
|
|
|
|
BackendDataPtr
|
|
TorchMlirBackendImpl::CreateDataPlaceholder(const BackendDevice &device,
|
|
const Shape &shape) const {
|
|
PRINT_FUNCTION();
|
|
return std::make_shared<TorchMlirBackendData>(device, shape);
|
|
}
|
|
|
|
BackendDataPtr
|
|
TorchMlirBackendImpl::GetComputationDataFromNode(const Node *node) const {
|
|
PRINT_FUNCTION();
|
|
const auto *device_data_node = dynamic_cast<const DeviceData *>(node);
|
|
if (!device_data_node) {
|
|
return nullptr;
|
|
}
|
|
return device_data_node->data();
|
|
}
|
|
|
|
at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
|
|
const BackendDataPtr data,
|
|
std::optional<at::ScalarType> logical_scalar_type) const {
|
|
PRINT_FUNCTION();
|
|
|
|
TorchMlirBackendData *torch_mlir_data =
|
|
dynamic_cast<TorchMlirBackendData *>(data.get());
|
|
TORCH_CHECK(torch_mlir_data,
|
|
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
|
|
|
|
TorchMlirBackendData::Info *info =
|
|
dynamic_cast<TorchMlirBackendData::Info *>(torch_mlir_data->mlir_info());
|
|
TORCH_CHECK(
|
|
info,
|
|
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
|
|
|
|
return info->tensor;
|
|
}
|
|
|
|
/**
|
|
* Lowering, Compilation, Execution
|
|
* */
|
|
|
|
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
|
|
const std::string &name, BackendDevice device,
|
|
c10::ArrayRef<const Node *> post_order,
|
|
Util::EmissionMap emit_status) const {
|
|
PRINT_FUNCTION();
|
|
return std::make_unique<TorchMlirLoweringContext>(
|
|
name, std::forward<BackendDevice>(device),
|
|
std::forward<c10::ArrayRef<const Node *>>(post_order),
|
|
std::forward<Util::EmissionMap>(emit_status));
|
|
}
|
|
|
|
std::unique_ptr<LoweringContext>
|
|
TorchMlirBackendImpl::CreateLoweringContext(const std::string &name,
|
|
BackendDevice device) const {
|
|
PRINT_FUNCTION();
|
|
return std::make_unique<TorchMlirLoweringContext>(
|
|
name, std::forward<BackendDevice>(device));
|
|
}
|
|
|
|
/**
|
|
* Device Configuration
|
|
* */
|
|
|
|
// Set or get the default device type.
|
|
// For backends used with virtual c10:: Devices, this configures what real
|
|
// device type the backend should use, and matters if the backend supports
|
|
// more than one type of real device.
|
|
|
|
// Specify which aten device should be used for eager fallback
|
|
// may change depending on current 'Default' DeviceType
|
|
at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const {
|
|
PRINT_FUNCTION();
|
|
return at::DeviceType::CPU;
|
|
}
|
|
|
|
// Query all available backend devices
|
|
std::vector<BackendDevice> TorchMlirBackendImpl::GetBackendDevices() const {
|
|
PRINT_FUNCTION();
|
|
return {GetBackendDevice(c10::Device(c10::kLazy, 0)),
|
|
GetBackendDevice(c10::Device(c10::kCPU, 0))};
|
|
}
|
|
|
|
// Map a particular c10:: device to a concrete backend device
|
|
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
|
|
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
|
|
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
|
|
// through a mode, in which case these APIs should still work, but should be
|
|
// identity mappings.
|
|
BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const {
|
|
PRINT_FUNCTION();
|
|
return BackendDevice(GetDefaultDeviceType(), device.index());
|
|
}
|
|
|
|
int64_t TorchMlirBackendImpl::GetDefaultDeviceOrdinal() const {
|
|
return default_device_ordinal;
|
|
}
|
|
|
|
void TorchMlirBackendImpl::SetDefaultDeviceOrdinal(int64_t ordinal) {
|
|
default_device_ordinal = ordinal;
|
|
}
|
|
|
|
} // namespace lazy
|
|
} // namespace torch
|