torch-mlir/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp

75 lines
2.6 KiB
C++
Raw Normal View History

//===- dynamic_ir.cpp -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/dynamic_ir.cpp
//===----------------------------------------------------------------------===//
#include "dynamic_ir.h"
namespace torch {
namespace lazy {
DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed)
: TorchMlirNode(
op, operands, /*num_outputs=*/1,
/* hash_seed */ HashCombine(op.hash(), hash_seed)) {}
std::string DimensionNode::ToString() const { return "DimensionNode"; }
SizeNode::SizeNode(Value input, size_t dim)
: DimensionNode(
OpKind{c10::Symbol::fromQualString("aten::size")}, {input},
MHash(dim)),
dim_(dim){};
int64_t SizeNode::getStaticValue() const {
return dynamic_cast<const TorchMlirNode*>(operand(0).node)
->shape(0)
.size(dim_);
}
std::string SizeNode::ToString() const { return "SizeNode"; }
SizeAdd::SizeAdd(Value a, Value b)
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){};
int64_t SizeAdd::getStaticValue() const {
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() +
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
}
std::string SizeAdd::ToString() const { return "SizeAdd"; }
SizeMul::SizeMul(Value a, Value b)
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){};
int64_t SizeMul::getStaticValue() const {
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() *
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
}
std::string SizeMul::ToString() const { return "SizeMul"; }
SizeDiv::SizeDiv(Value a, Value b)
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}){};
int64_t SizeDiv::getStaticValue() const {
TORCH_CHECK(
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue() !=
0,
"Can't divide a dimension by zero");
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() /
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
}
std::string SizeDiv::ToString() const { return "SizeDiv"; }
} // namespace lazy
} // namespace torch