Use JIT GraphExecutor for execution in example backend (#830)

* Update LazyShapeInference header

* Use JIT GraphExecutor for execution in example backend
pull/1125/head
Henry Tu 2022-05-05 08:57:03 -04:00 committed by Henry Tu
parent 1bde00c73d
commit 406d1e7538
4 changed files with 42 additions and 23 deletions

View File

@ -76,25 +76,43 @@ public:
// `arguments` maps 1:1 with the parameters in the generated MLIR. In this
// function, we will generate a list of BackendData that corresponds to the
// return values in the MLIR.
std::vector<torch::lazy::BackendDataPtr> results;
// "Borrow" some tensor data from arguments to reuse in return. This ensures
// that the tensor device is correctly configured.
TORCH_CHECK(arguments.size() > 0,
"Need at least one argument for example execution.");
const TorchMlirBackendData *torch_mlir_data =
dynamic_cast<const TorchMlirBackendData *>(arguments[0].get());
TORCH_CHECK(torch_mlir_data,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
// For this demo we aren't performing a legitimate execution, so we generate
// some dummy data to return based on the expected number of return values.
auto mlir_computation = static_cast<TorchMlirComputation *>(&computation);
for (unsigned i = 0; i < mlir_computation->num_results(); i++) {
results.push_back(std::make_shared<TorchMlirBackendData>(
torch_mlir_data->mlir_info()->tensor, device,
torch_mlir_data->shape()));
// Vendor backend specific execution can be inserted here.
//
// We don't have a way to execute a computation based on the generated MLIR,
// so we'll fallback to the implementation used by the TS LTC backend.
//
// JIT Execution adopted from:
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
torch::jit::GraphExecutor graph_executor(mlir_computation->graph(), "");
std::vector<torch::jit::IValue> stack;
for (const auto &argument : arguments) {
const auto mlir_data =
std::static_pointer_cast<TorchMlirBackendData>(argument);
if (mlir_data->mlir_info()->scalar.has_value()) {
stack.emplace_back(mlir_data->mlir_info()->scalar.value());
} else {
at::Tensor tensor = mlir_data->mlir_info()->tensor;
stack.emplace_back(tensor);
}
}
graph_executor.run(stack);
std::vector<torch::lazy::BackendDataPtr> results;
for (torch::jit::IValue component : stack) {
at::Tensor result = component.toTensor();
at::IntArrayRef result_sizes = result.sizes();
torch::lazy::Shape shape(
result.scalar_type(),
std::vector<int64_t>(result_sizes.begin(), result_sizes.end()));
results.push_back(
std::make_shared<TorchMlirBackendData>(result, device, shape));
}
std::cout << "Received " << arguments.size() << " arguments, and returned "
<< results.size() << " results during ExecuteCompile!"
<< std::endl;
return results;
}

View File

@ -41,6 +41,8 @@ TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & se
TORCH_API std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_dropout(const at::Tensor & input, double p, bool train);
@ -65,6 +67,7 @@ TORCH_API std::vector<Shape> compute_shape_max_pool2d(const at::Tensor & self, a
TORCH_API std::vector<Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_mul_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
TORCH_API std::vector<Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
TORCH_API std::vector<Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_new_ones(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
@ -88,8 +91,6 @@ TORCH_API std::vector<Shape> compute_shape_type_as(const at::Tensor & self, cons
TORCH_API std::vector<Shape> compute_shape_var(const at::Tensor & self, bool unbiased);
TORCH_API std::vector<Shape> compute_shape_zero_(at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
// clang-format on
} // namespace lazy

View File

@ -270,8 +270,7 @@ TorchMlirComputation::TorchMlirComputation(
const std::shared_ptr<torch::jit::Graph>& graph,
InputOutputAliases input_output_aliases)
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
graph_(graph), input_output_aliases_(input_output_aliases),
num_results_(graph_->outputs().size()) {
graph_(graph), input_output_aliases_(input_output_aliases) {
for (torch::jit::Value* input : graph_->inputs()) {
parameter_names_.push_back(input->debugName());
}
@ -298,7 +297,9 @@ const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
return result_shape_;
}
unsigned TorchMlirComputation::num_results() const { return num_results_; }
std::shared_ptr<torch::jit::Graph> TorchMlirComputation::graph() const {
return graph_;
}
MlirOperation TorchMlirComputation::func_op() const { return func_op_; }

View File

@ -145,7 +145,7 @@ public:
const torch::lazy::Shape& result_shape() const override;
unsigned num_results() const;
std::shared_ptr<torch::jit::Graph> graph() const;
MlirOperation func_op() const;
@ -160,7 +160,6 @@ private:
MlirContext mlir_context_;
std::shared_ptr<torch::jit::Graph> graph_;
InputOutputAliases input_output_aliases_;
unsigned num_results_;
};
} // namespace lazy