mirror of https://github.com/llvm/torch-mlir
76 lines
2.7 KiB
C++
76 lines
2.7 KiB
C++
//===- dynamic_ir.cpp -----------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
// This file is adapted from pytorch/pytorch
|
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/dynamic_ir.cpp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "dynamic_ir.h"
|
|
|
|
namespace torch {
|
|
namespace lazy {
|
|
|
|
DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed)
|
|
: TorchMlirNode(op, operands, /*num_outputs=*/1,
|
|
/* hash_seed */ HashCombine(op.hash(), hash_seed)) {}
|
|
|
|
std::string DimensionNode::ToString() const { return "DimensionNode"; }
|
|
|
|
SizeNode::SizeNode(Value input, size_t dim)
|
|
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::size")}, {input},
|
|
MHash(dim)),
|
|
dim_(dim) {}
|
|
|
|
int64_t SizeNode::getStaticValue() const {
|
|
return dynamic_cast<const TorchMlirNode *>(operand(0).node)
|
|
->shape(0)
|
|
.size(dim_);
|
|
}
|
|
|
|
std::string SizeNode::ToString() const { return "SizeNode"; }
|
|
|
|
SizeAdd::SizeAdd(Value a, Value b)
|
|
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}) {}
|
|
|
|
int64_t SizeAdd::getStaticValue() const {
|
|
return dynamic_cast<const DimensionNode *>(operand(0).node)
|
|
->getStaticValue() +
|
|
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue();
|
|
}
|
|
|
|
std::string SizeAdd::ToString() const { return "SizeAdd"; }
|
|
|
|
SizeMul::SizeMul(Value a, Value b)
|
|
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}) {}
|
|
|
|
int64_t SizeMul::getStaticValue() const {
|
|
return dynamic_cast<const DimensionNode *>(operand(0).node)
|
|
->getStaticValue() *
|
|
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue();
|
|
}
|
|
|
|
std::string SizeMul::ToString() const { return "SizeMul"; }
|
|
|
|
SizeDiv::SizeDiv(Value a, Value b)
|
|
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}) {}
|
|
|
|
int64_t SizeDiv::getStaticValue() const {
|
|
TORCH_CHECK(
|
|
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue() !=
|
|
0,
|
|
"Can't divide a dimension by zero");
|
|
return dynamic_cast<const DimensionNode *>(operand(0).node)
|
|
->getStaticValue() /
|
|
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue();
|
|
}
|
|
|
|
std::string SizeDiv::ToString() const { return "SizeDiv"; }
|
|
|
|
} // namespace lazy
|
|
} // namespace torch
|