torch-mlir/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp

204 lines
6.5 KiB
C++
Raw Normal View History

2022-03-24 22:15:43 +08:00
//===- backend_impl.cpp ---------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
2022-03-24 22:15:43 +08:00
//===----------------------------------------------------------------------===//
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>
#include "backend_impl.h"
#include "ir_builder.h"
2022-03-24 22:15:43 +08:00
#include "mlir_lowering_context.h"
E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-06-08 02:38:50 +08:00
#include "ops/device_data.h"
#include "utils/debug.h"
#include "utils/exception.h"
2022-03-24 22:15:43 +08:00
namespace torch {
namespace lazy {
TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape)
: BackendData(device, shape),
info_(std::make_shared<TorchMlirBackendData::Info>()) {
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
BackendDevice device, Shape shape, std::shared_ptr<BackendData::Info> info)
: BackendData(device, shape), info_(info) {
2022-03-24 22:15:43 +08:00
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
const at::Scalar& scalar, BackendDevice device)
: BackendData(device, Shape(scalar.type(), {})),
info_(std::make_shared<TorchMlirBackendData::Info>(scalar)) {
2022-03-24 22:15:43 +08:00
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
2022-03-24 22:15:43 +08:00
const at::Tensor& tensor, BackendDevice device, Shape shape)
: BackendData(device, shape),
info_(std::make_shared<TorchMlirBackendData::Info>(tensor)) {
2022-03-24 22:15:43 +08:00
PRINT_FUNCTION();
}
BackendData::Handle TorchMlirBackendData::GetHandle() {
2022-03-24 22:15:43 +08:00
return reinterpret_cast<int64_t>(this);
}
void TorchMlirBackendData::Assign(const BackendData& data) {
const TorchMlirBackendData* torch_mlir_data =
dynamic_cast<const TorchMlirBackendData*>(&data);
TORCH_CHECK(
torch_mlir_data,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
info_ = torch_mlir_data->info_;
2022-03-24 22:15:43 +08:00
}
bool TorchMlirBackendData::HasValue() const { return bool(info_); }
BackendData::Info* TorchMlirBackendData::mlir_info() const {
return info_.get();
}
2022-03-24 22:15:43 +08:00
/**
* Initialization/Teardown
* */
void TorchMlirBackendImpl::PrepareToExit() const {}
2022-03-24 22:15:43 +08:00
/**
* IR Tracing
* */
const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const {
static const IrBuilder* builder = new TorchMlirIrBuilder();
return builder;
}
2022-03-24 22:15:43 +08:00
/**
* Data Transfer
* */
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor(
2022-03-24 22:15:43 +08:00
const at::Tensor& tensor, const Shape& shape,
const BackendDevice& device) const {
PRINT_FUNCTION();
return std::make_shared<TorchMlirBackendData>(tensor, device, shape);
2022-03-24 22:15:43 +08:00
}
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar(
2022-03-24 22:15:43 +08:00
const at::Scalar& scalar, const BackendDevice& device) const {
PRINT_FUNCTION();
return std::make_shared<TorchMlirBackendData>(scalar, device);
2022-03-24 22:15:43 +08:00
}
BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
2022-03-24 22:15:43 +08:00
const BackendDevice& device, const Shape& shape) const {
PRINT_FUNCTION();
return std::make_shared<TorchMlirBackendData>(device, shape);
2022-03-24 22:15:43 +08:00
}
BackendDataPtr
TorchMlirBackendImpl::GetComputationDataFromNode(const Node* node) const {
PRINT_FUNCTION();
const auto* device_data_node = dynamic_cast<const DeviceData*>(node);
if (!device_data_node) {
return nullptr;
}
E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-06-08 02:38:50 +08:00
return device_data_node->data();
}
at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
2022-03-24 22:15:43 +08:00
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const {
PRINT_FUNCTION();
TorchMlirBackendData* torch_mlir_data =
dynamic_cast<TorchMlirBackendData*>(data.get());
TORCH_CHECK(
torch_mlir_data,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
TorchMlirBackendData::Info* info =
dynamic_cast<TorchMlirBackendData::Info*>(torch_mlir_data->mlir_info());
2022-03-24 22:15:43 +08:00
TORCH_CHECK(
info,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
2022-03-24 22:15:43 +08:00
return info->tensor;
}
/**
* Lowering, Compilation, Execution
* */
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
2022-03-24 22:15:43 +08:00
const std::string& name, BackendDevice device,
c10::ArrayRef<const Node*> post_order, Util::EmissionMap emit_status) const {
2022-03-24 22:15:43 +08:00
PRINT_FUNCTION();
return std::make_unique<TorchMlirLoweringContext>(
2022-03-24 22:15:43 +08:00
name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<const Node*>>(post_order),
2022-03-24 22:15:43 +08:00
std::forward<Util::EmissionMap>(emit_status));
}
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
2022-03-24 22:15:43 +08:00
const std::string& name, BackendDevice device) const {
PRINT_FUNCTION();
return std::make_unique<TorchMlirLoweringContext>(
2022-03-24 22:15:43 +08:00
name, std::forward<BackendDevice>(device));
}
/**
* Device Configuration
* */
// Set or get the default device type.
// For backends used with virtual c10:: Devices, this configures what real
// device type the backend should use, and matters if the backend supports
// more than one type of real device.
// Specify which aten device should be used for eager fallback
// may change depending on current 'Default' DeviceType
at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const {
2022-03-24 22:15:43 +08:00
PRINT_FUNCTION();
return at::DeviceType::CPU;
}
// Query all available backend devices
std::vector<BackendDevice> TorchMlirBackendImpl::GetBackendDevices() const {
2022-03-24 22:15:43 +08:00
PRINT_FUNCTION();
return {
GetBackendDevice(c10::Device(c10::kLazy, 0)),
GetBackendDevice(c10::Device(c10::kCPU, 0))};
}
// Map a particular c10:: device to a concrete backend device
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
// through a mode, in which case these APIs should still work, but should be
// identity mappings.
BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const {
2022-03-24 22:15:43 +08:00
PRINT_FUNCTION();
return BackendDevice(GetDefaultDeviceType(), device.index());
}
int64_t TorchMlirBackendImpl::GetDefaultDeviceOrdinal() const {
return default_device_ordinal;
}
void TorchMlirBackendImpl::SetDefaultDeviceOrdinal(int64_t ordinal) {
default_device_ordinal = ordinal;
}
2022-03-24 22:15:43 +08:00
} // namespace lazy
} // namespace torch