mirror of https://github.com/llvm/torch-mlir
MT model compilation minor changes
This contains the following changes: - Fix optional knowledge propagation. The initial knowledge should always be NotNone for the operations we implemented. - Add Folder for `prim.dtype`pull/303/head
parent
5f3eb637c4
commit
73d553e168
|
@ -398,7 +398,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("prim::layout : (Tensor) -> (int)")
|
||||
emit("prim::TupleIndex : (Any, int) -> (Any)")
|
||||
emit("prim::device : (Tensor) -> (Device)")
|
||||
emit("prim::dtype : (Tensor) -> (int)")
|
||||
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
|
||||
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
|
||||
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
|
||||
emit("prim::min.self_int : (int[]) -> (int)")
|
||||
|
|
|
@ -791,6 +791,40 @@ def Torch_AtenTriu_Op : Torch_Op<"aten.triu_", [
|
|||
let assemblyFormat = "$self `,` $diagonal attr-dict `:` type($self) `,` type($diagonal) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices,
|
||||
AnyTorchTensorType:$values,
|
||||
Torch_BoolType:$accumulate
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate attr-dict `:` type($self) `,` type($indices) `,` type($values) `,` type($accumulate) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices,
|
||||
AnyTorchTensorType:$values,
|
||||
Torch_BoolType:$accumulate
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate attr-dict `:` type($self) `,` type($indices) `,` type($values) `,` type($accumulate) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenLinearOp : Torch_Op<"aten.linear", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -1404,22 +1438,6 @@ def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
|
|||
let assemblyFormat = "$self `,` $indices attr-dict `:` type($self) `,` type($indices) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices,
|
||||
AnyTorchTensorType:$values,
|
||||
Torch_BoolType:$accumulate
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate attr-dict `:` type($self) `,` type($indices) `,` type($values) `,` type($accumulate) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -70,6 +70,7 @@ def Torch_PrimDtypeOp : Torch_Op<"prim.dtype", [
|
|||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimTupleUnpackOp : Torch_Op<"prim.TupleUnpack", [
|
||||
|
|
|
@ -18,6 +18,20 @@ using namespace mlir;
|
|||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28
|
||||
static int64_t getDtypeIntegerFromMlirType(Type dtype) {
|
||||
if (dtype.isa<Float32Type>())
|
||||
return 6;
|
||||
|
||||
if (auto integerType = dtype.dyn_cast<IntegerType>()) {
|
||||
if (integerType.isSignedInteger(64))
|
||||
return 4;
|
||||
if (integerType.isSignlessInteger(1))
|
||||
return 11;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -129,6 +143,10 @@ bool isValidSubtype(Type subtype, Type type) {
|
|||
type ==
|
||||
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
|
||||
if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
|
||||
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -972,5 +990,18 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimDtypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
||||
BaseTensorType tensorType = a().getType().cast<BaseTensorType>();
|
||||
if (tensorType.hasDtype()) {
|
||||
int64_t dtypeInt = getDtypeIntegerFromMlirType(tensorType.getDtype());
|
||||
if (dtypeInt != -1)
|
||||
return getI64IntegerAttr(getContext(), dtypeInt);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||
|
|
|
@ -114,6 +114,10 @@ struct ValueKnowledge {
|
|||
return getKnowledgeFromType(value.getType());
|
||||
}
|
||||
|
||||
static ValueKnowledge getNotNonePessimisticValueState(MLIRContext *context) {
|
||||
return ValueKnowledge(false, {}, Type(), OptionalKnowledge::notNone);
|
||||
}
|
||||
|
||||
bool operator==(const ValueKnowledge &rhs) const {
|
||||
return std::make_tuple(hasSizes, sizes, dtype, optional) ==
|
||||
std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype, rhs.optional);
|
||||
|
@ -199,7 +203,7 @@ public:
|
|||
if (isa<AtenAnyOp, AtenAllOp>(op)) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(1, 1);
|
||||
knowledge.dtype = IntegerType::get(op->getContext(), 1);
|
||||
|
@ -211,7 +215,7 @@ public:
|
|||
if (auto maskedSelect = dyn_cast<AtenMaskedSelectOp>(op)) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(1, kUnknownSize);
|
||||
knowledge.dtype = input.dtype;
|
||||
|
@ -225,7 +229,7 @@ public:
|
|||
if (auto indexTensor = dyn_cast<AtenIndexTensorOp>(op)) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
if (input.hasSizes) {
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(input.sizes.size(), kUnknownSize);
|
||||
|
@ -482,7 +486,8 @@ ChangeResult TypeAnalyzer::visitAtenMmOp(
|
|||
AtenMmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto &lhs = operands[0]->getValue();
|
||||
auto &rhs = operands[1]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
// WARNING: We could be more precise here by calculating the output
|
||||
// shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky
|
||||
|
@ -532,7 +537,8 @@ ChangeResult TypeAnalyzer::visitAtenLinearOp(
|
|||
|
||||
ChangeResult TypeAnalyzer::visitAtenConv2dOp(
|
||||
AtenConv2dOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
// Running some experiments in PyTorch, the bias doesn't seem to
|
||||
|
@ -544,7 +550,8 @@ ChangeResult TypeAnalyzer::visitAtenConv2dOp(
|
|||
|
||||
ChangeResult TypeAnalyzer::visitAtenMaxPool2dOp(
|
||||
AtenMaxPool2dOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
knowledge.dtype = operands[0]->getValue().dtype;
|
||||
|
@ -555,7 +562,8 @@ ChangeResult TypeAnalyzer::visitAtenAdaptiveAvgPool2dOp(
|
|||
AtenAdaptiveAvgPool2dOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
if (input.hasSizes) {
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(input.sizes.size(), kUnknownSize);
|
||||
|
@ -573,7 +581,8 @@ ChangeResult TypeAnalyzer::visitBinaryBroadcastingOp(
|
|||
// tricky, so we defer that until we need it.
|
||||
auto lhs = operands[0]->getValue();
|
||||
auto rhs = operands[1]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
if (lhs.hasSizes && rhs.hasSizes) {
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()),
|
||||
|
@ -593,7 +602,8 @@ ChangeResult TypeAnalyzer::visitAtenLerpTensorOp(
|
|||
auto a = operands[0]->getValue();
|
||||
auto b = operands[1]->getValue();
|
||||
auto c = operands[1]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
if (a.hasSizes && b.hasSizes && c.hasSizes) {
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(
|
||||
|
@ -611,7 +621,8 @@ ChangeResult TypeAnalyzer::visitAtenFlattenUsingIntsOp(
|
|||
int64_t startDim;
|
||||
int64_t endDim;
|
||||
auto operand = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
knowledge.dtype = operand.dtype;
|
||||
if (operand.hasSizes && operand.sizes.size() == 0) {
|
||||
// Rank 0 is special and flattens to rank 1 with size 1.
|
||||
|
@ -642,7 +653,8 @@ ChangeResult TypeAnalyzer::visitAtenFlattenUsingIntsOp(
|
|||
ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
|
||||
AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto operand = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
knowledge.dtype = operand.dtype;
|
||||
int64_t dim;
|
||||
if (operand.hasSizes && matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
||||
|
@ -666,7 +678,8 @@ ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
|
|||
// Arange like ops returns a 1-D tensor of size ceil(end - start).
|
||||
ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
|
||||
Operation *op, llvm::Optional<Value> start, Value end, Value dtype) {
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.sizes.resize(1, kUnknownSize);
|
||||
knowledge.hasSizes = true;
|
||||
int64_t dtypeInt;
|
||||
|
@ -716,7 +729,8 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
|
|||
Operation *op, Value dim, Value keepdim,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
llvm::SmallVector<int64_t> dimList;
|
||||
bool keepdimBool;
|
||||
|
@ -749,7 +763,8 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
|
|||
ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
|
||||
AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
int64_t dim;
|
||||
bool keepdimBool;
|
||||
|
@ -779,7 +794,8 @@ template <typename OpTy>
|
|||
ChangeResult TypeAnalyzer::visitReshapeLikeOp(
|
||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
|
||||
fillInSizesGivenSizesList(knowledge, op.size());
|
||||
|
@ -790,7 +806,8 @@ ChangeResult TypeAnalyzer::visitAtenTransposeIntOp(
|
|||
AtenTransposeIntOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
knowledge.hasSizes = input.hasSizes;
|
||||
auto dim0 = op.dim0();
|
||||
|
@ -815,7 +832,8 @@ ChangeResult TypeAnalyzer::visitAtenTransposeIntOp(
|
|||
|
||||
template <typename OpTy>
|
||||
ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
Value t = op.t();
|
||||
Value dtype = op.dtype();
|
||||
knowledge.hasSizes = true;
|
||||
|
@ -828,18 +846,18 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
|
|||
// `torch.aten.tensor` get a tensor from a list. Each layer of the list
|
||||
// corresponds to one dim of the tensor.
|
||||
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
Value data = op.data();
|
||||
Value dtype = op.dtype();
|
||||
Type type = data.getType();
|
||||
int64_t rank = 0;
|
||||
bool rankIsUnknown = false;
|
||||
while (auto listType = type.dyn_cast<ListType>()) {
|
||||
type = listType.getContainedType();
|
||||
rank++;
|
||||
}
|
||||
|
||||
if (!rankIsUnknown) {
|
||||
if (rank != 0) {
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(rank, kUnknownSize);
|
||||
}
|
||||
|
@ -849,7 +867,8 @@ ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
|
|||
|
||||
template <typename OpTy>
|
||||
ChangeResult TypeAnalyzer::visitConstantTensorAllocOp(OpTy op) {
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
fillInSizesGivenSizesList(knowledge, op.size());
|
||||
fillInDTypeGivenDTypeAndDataType(op->getContext(), knowledge, op.dtype(),
|
||||
Torch::FloatType::get(op->getContext()));
|
||||
|
@ -861,7 +880,8 @@ template <typename OpTy>
|
|||
ChangeResult TypeAnalyzer::visitTypeConversionOp(
|
||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = input.hasSizes;
|
||||
knowledge.sizes = input.sizes;
|
||||
Value other = op.other();
|
||||
|
@ -879,7 +899,8 @@ ChangeResult TypeAnalyzer::visitSliceLikeOp(
|
|||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands,
|
||||
SetDimSizeFn setDim) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
if (!input.hasSizes)
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
|
@ -912,7 +933,8 @@ ChangeResult TypeAnalyzer::visitAtenGatherOp(
|
|||
AtenGatherOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto index = operands[2]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
knowledge.hasSizes = index.hasSizes;
|
||||
knowledge.sizes = index.sizes;
|
||||
|
@ -927,7 +949,8 @@ ChangeResult TypeAnalyzer::visitExpandLikeOp(
|
|||
ArrayRef<LatticeElement<ValueKnowledge> *> operands,
|
||||
SetDimSizePerListItemFn setDim) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
if (!input.hasSizes)
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
|
@ -957,7 +980,8 @@ ChangeResult TypeAnalyzer::visitExpandLikeOp(
|
|||
ChangeResult TypeAnalyzer::visitAtenCatOp(
|
||||
AtenCatOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto tensorList = op.tensors();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
|
||||
if (!listConstruct)
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
|
@ -995,7 +1019,8 @@ ChangeResult TypeAnalyzer::visitAtenShapeAsTensorOp(
|
|||
Aten_ShapeAsTensorOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
if (input.hasSizes)
|
||||
knowledge.sizes.resize(1, input.sizes.size());
|
||||
else
|
||||
|
@ -1007,7 +1032,8 @@ ChangeResult TypeAnalyzer::visitAtenShapeAsTensorOp(
|
|||
|
||||
ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
|
||||
AtenEmbeddingOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
auto weight = operands[0]->getValue();
|
||||
auto indices = operands[1]->getValue();
|
||||
if (indices.hasSizes) {
|
||||
|
@ -1026,13 +1052,15 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
|
|||
|
||||
ChangeResult TypeAnalyzer::visitAtenBmmOp(
|
||||
AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
auto self = operands[0]->getValue();
|
||||
auto mat2 = operands[1]->getValue();
|
||||
knowledge.sizes.resize(3, kUnknownSize);
|
||||
knowledge.dtype = joinElementTypes(self.dtype, mat2.dtype);
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Transforms.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -406,7 +406,6 @@ func @torch.aten.__getitem__.Dict_str(%k0 : !torch.str, %v0: !torch.tensor, %k1:
|
|||
// CHECK-LABEL: func @torch.aten.add.int() -> !torch.int {
|
||||
// CHECK: %[[CST9:.*]] = torch.constant.int 9
|
||||
// CHECK: return %[[CST9]] : !torch.int
|
||||
// CHECK: }
|
||||
func @torch.aten.add.int() -> !torch.int {
|
||||
%cst4 = torch.constant.int 4
|
||||
%cst5 = torch.constant.int 5
|
||||
|
@ -417,7 +416,6 @@ func @torch.aten.add.int() -> !torch.int {
|
|||
// CHECK-LABEL: func @torch.aten.sub.int() -> !torch.int {
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: return %[[CST1]] : !torch.int
|
||||
// CHECK: }
|
||||
func @torch.aten.sub.int() -> !torch.int {
|
||||
%cst6 = torch.constant.int 6
|
||||
%cst5 = torch.constant.int 5
|
||||
|
@ -428,7 +426,6 @@ func @torch.aten.sub.int() -> !torch.int {
|
|||
// CHECK-LABEL: func @torch.aten.mul.int() -> !torch.int {
|
||||
// CHECK: %[[CST30:.*]] = torch.constant.int 30
|
||||
// CHECK: return %[[CST30]] : !torch.int
|
||||
// CHECK: }
|
||||
func @torch.aten.mul.int() -> !torch.int {
|
||||
%cst6 = torch.constant.int 6
|
||||
%cst5 = torch.constant.int 5
|
||||
|
@ -439,7 +436,6 @@ func @torch.aten.mul.int() -> !torch.int {
|
|||
// CHECK-LABEL: func @torch.aten.mul.int$with_zero() -> !torch.int {
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: return %[[CST0]] : !torch.int
|
||||
// CHECK: }
|
||||
func @torch.aten.mul.int$with_zero() -> !torch.int {
|
||||
%cst6 = torch.constant.int 6
|
||||
%cst0 = torch.constant.int 0
|
||||
|
@ -450,7 +446,6 @@ func @torch.aten.mul.int$with_zero() -> !torch.int {
|
|||
// CHECK-LABEL: func @torch.aten.floordiv.int() -> !torch.int {
|
||||
// CHECK: %[[CST3:.*]] = torch.constant.int 3
|
||||
// CHECK: return %[[CST3]] : !torch.int
|
||||
// CHECK: }
|
||||
func @torch.aten.floordiv.int() -> !torch.int {
|
||||
%cst18 = torch.constant.int 18
|
||||
%cst5 = torch.constant.int 5
|
||||
|
@ -461,10 +456,36 @@ func @torch.aten.floordiv.int() -> !torch.int {
|
|||
// CHECK-LABEL: func @torch.aten.remainder.int() -> !torch.int {
|
||||
// CHECK: %[[CST3:.*]] = torch.constant.int 3
|
||||
// CHECK: return %[[CST3]] : !torch.int
|
||||
// CHECK: }
|
||||
func @torch.aten.remainder.int() -> !torch.int {
|
||||
%cst18 = torch.constant.int 18
|
||||
%cst5 = torch.constant.int 5
|
||||
%ret = torch.aten.remainder.int %cst18, %cst5: !torch.int, !torch.int -> !torch.int
|
||||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.dtype$float(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,f32>) -> !torch.int {
|
||||
// CHECK: %[[CST:.*]] = torch.constant.int 6
|
||||
// CHECK: return %[[CST]] : !torch.int
|
||||
func @torch.prim.dtype$float(%t : !torch.tensor<*,f32>) -> !torch.int {
|
||||
%ret = torch.prim.dtype %t: !torch.tensor<*,f32> -> !torch.int
|
||||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.dtype$bool(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,i1>) -> !torch.int {
|
||||
// CHECK: %[[CST:.*]] = torch.constant.int 11
|
||||
// CHECK: return %[[CST]] : !torch.int
|
||||
func @torch.prim.dtype$bool(%t : !torch.tensor<*,i1>) -> !torch.int {
|
||||
%ret = torch.prim.dtype %t: !torch.tensor<*,i1> -> !torch.int
|
||||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.dtype$int64(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,si64>) -> !torch.int {
|
||||
// CHECK: %[[CST:.*]] = torch.constant.int 4
|
||||
// CHECK: return %[[CST]] : !torch.int
|
||||
func @torch.prim.dtype$int64(%t : !torch.tensor<*,si64>) -> !torch.int {
|
||||
%ret = torch.prim.dtype %t: !torch.tensor<*,si64> -> !torch.int
|
||||
return %ret : !torch.int
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue