torch-mlir/projects/ltc/csrc/base_lazy_backend/mlir_node.h

109 lines
4.0 KiB
C
Raw Normal View History

//===- mlir_node.h --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node.h
//===----------------------------------------------------------------------===//
#pragma once
#include <ATen/core/interned_strings.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/ir.h>
2022-03-24 22:15:43 +08:00
#include <torch/csrc/lazy/core/shape.h>
#include "mlir_lowering_context.h"
#include "utils/debug.h"
#include "utils/exception.h"
namespace torch {
namespace lazy {
class TORCH_API TorchMlirNode : public torch::lazy::Node {
2022-03-24 22:15:43 +08:00
public:
TorchMlirNode(OpKind op, OpList operands, std::vector<Shape> &&shapes,
size_t num_outputs, hash_t hash_seed = kHashSeed);
TorchMlirNode(OpKind op, OpList operands,
const std::function<Shape()> &shape_fn, size_t num_outputs,
hash_t hash_seed = kHashSeed);
TorchMlirNode(OpKind op, OpList operands, size_t num_outputs,
hash_t hash_seed = kHashSeed);
TorchMlirNode(OpKind op, Shape shape, size_t num_outputs,
hash_t hash_seed = kHashSeed);
// Adds a static hook that is run after every single TorchMlirNode is
// constructed
static void addConstructorHook(std::function<void(TorchMlirNode *)>);
E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-06-08 02:38:50 +08:00
~TorchMlirNode() override = default;
hash_t hash() const override;
hash_t shapeHash() const override;
TorchMlirNode *mlir_node(int index) const;
virtual TorchMlirOpVector Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const;
private:
// The hash of the dag WITH size info. Used for shape caching
hash_t shape_hash_;
// The hash of the dag used to look up the compiled graph by a hash
// in this case, we will use the dag hash WITHOUT size info if dynamic shape
// is enabled and use the dag hash WITH size info otherwise.
hash_t dag_hash_;
};
// TensorList represents an at::TensorList which is a vector[Tensor] but is also
// a first-class IValue and can be fed as a single input to a TS program. It is
// much easier to handle TensorLists in Lazy Tensor code if they are represented
// as a single Node so there can be more than one TensorList and more than one
// Tensor side-by-side as operands to an op.
//
// Note: shape is undefined for TensorList. We assert in some places that
// #shapes matches #outputs and this stems from
// the fact that currently all IR nodes represent tensors (there is no
// type system for this IR). Because of this, TensorList is a bit of a
// hack.
//
// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and
// then implement it as NotImplemented for TensorList, also fixing the assertion
// that would fail.
E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-06-08 02:38:50 +08:00
struct TORCH_API TorchMlirTensorList : public TorchMlirNode {
static OpKind ClassOpKind();
TorchMlirTensorList() = delete;
TorchMlirTensorList(OpList values);
torch::lazy::TorchMlirOpVector
Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const override;
};
// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also
// represent optional tensors, so the output type for this op is
// !torch.list<optional<vtensor>>.
struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode {
static OpKind ClassOpKind();
TorchMlirOptionalTensorList() = delete;
TorchMlirOptionalTensorList(OpList values);
torch::lazy::TorchMlirOpVector
Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const override;
};
2022-03-24 22:15:43 +08:00
} // namespace lazy
} // namespace torch