MT model compilation minor changes

This contains the following changes:
 - Fix optional knowledge propagation. The initial knowledge should
 always be NotNone for the operations we implemented.
 - Add Folder for `prim.dtype`
pull/303/head
Yi Zhang 2021-09-01 15:53:52 -04:00
parent 5f3eb637c4
commit 73d553e168
6 changed files with 151 additions and 52 deletions

View File

@ -398,7 +398,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
emit("prim::layout : (Tensor) -> (int)")
emit("prim::TupleIndex : (Any, int) -> (Any)")
emit("prim::device : (Tensor) -> (Device)")
emit("prim::dtype : (Tensor) -> (int)")
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
emit("prim::min.self_int : (int[]) -> (int)")

View File

@ -791,6 +791,40 @@ def Torch_AtenTriu_Op : Torch_Op<"aten.triu_", [
let assemblyFormat = "$self `,` $diagonal attr-dict `:` type($self) `,` type($diagonal) `->` type($result)";
}
def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorListType:$indices,
AnyTorchTensorType:$values,
Torch_BoolType:$accumulate
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate attr-dict `:` type($self) `,` type($indices) `,` type($values) `,` type($accumulate) `->` type($result)";
}
def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorListType:$indices,
AnyTorchTensorType:$values,
Torch_BoolType:$accumulate
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate attr-dict `:` type($self) `,` type($indices) `,` type($values) `,` type($accumulate) `->` type($result)";
}
def Torch_AtenLinearOp : Torch_Op<"aten.linear", [
AllowsTypeRefinement,
HasValueSemantics
@ -1404,22 +1438,6 @@ def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
let assemblyFormat = "$self `,` $indices attr-dict `:` type($self) `,` type($indices) `->` type($result)";
}
def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorListType:$indices,
AnyTorchTensorType:$values,
Torch_BoolType:$accumulate
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate attr-dict `:` type($self) `,` type($indices) `,` type($values) `,` type($accumulate) `->` type($result)";
}
def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -70,6 +70,7 @@ def Torch_PrimDtypeOp : Torch_Op<"prim.dtype", [
Torch_IntType:$result
);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
let hasFolder = 1;
}
def Torch_PrimTupleUnpackOp : Torch_Op<"prim.TupleUnpack", [

View File

@ -18,6 +18,20 @@ using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28
static int64_t getDtypeIntegerFromMlirType(Type dtype) {
if (dtype.isa<Float32Type>())
return 6;
if (auto integerType = dtype.dyn_cast<IntegerType>()) {
if (integerType.isSignedInteger(64))
return 4;
if (integerType.isSignlessInteger(1))
return 11;
}
return -1;
}
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
@ -129,6 +143,10 @@ bool isValidSubtype(Type subtype, Type type) {
type ==
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
return false;
}
@ -972,5 +990,18 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// PrimDtypeOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
BaseTensorType tensorType = a().getType().cast<BaseTensorType>();
if (tensorType.hasDtype()) {
int64_t dtypeInt = getDtypeIntegerFromMlirType(tensorType.getDtype());
if (dtypeInt != -1)
return getI64IntegerAttr(getContext(), dtypeInt);
}
return nullptr;
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"

View File

@ -114,6 +114,10 @@ struct ValueKnowledge {
return getKnowledgeFromType(value.getType());
}
static ValueKnowledge getNotNonePessimisticValueState(MLIRContext *context) {
return ValueKnowledge(false, {}, Type(), OptionalKnowledge::notNone);
}
bool operator==(const ValueKnowledge &rhs) const {
return std::make_tuple(hasSizes, sizes, dtype, optional) ==
std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype, rhs.optional);
@ -199,7 +203,7 @@ public:
if (isa<AtenAnyOp, AtenAllOp>(op)) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getPessimisticValueState(op->getContext());
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
knowledge.sizes.resize(1, 1);
knowledge.dtype = IntegerType::get(op->getContext(), 1);
@ -211,7 +215,7 @@ public:
if (auto maskedSelect = dyn_cast<AtenMaskedSelectOp>(op)) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getPessimisticValueState(op->getContext());
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
knowledge.sizes.resize(1, kUnknownSize);
knowledge.dtype = input.dtype;
@ -225,7 +229,7 @@ public:
if (auto indexTensor = dyn_cast<AtenIndexTensorOp>(op)) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getPessimisticValueState(op->getContext());
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
if (input.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes.resize(input.sizes.size(), kUnknownSize);
@ -482,7 +486,8 @@ ChangeResult TypeAnalyzer::visitAtenMmOp(
AtenMmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto &lhs = operands[0]->getValue();
auto &rhs = operands[1]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
// WARNING: We could be more precise here by calculating the output
// shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky
@ -532,7 +537,8 @@ ChangeResult TypeAnalyzer::visitAtenLinearOp(
ChangeResult TypeAnalyzer::visitAtenConv2dOp(
AtenConv2dOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
knowledge.sizes.resize(4, kUnknownSize);
// Running some experiments in PyTorch, the bias doesn't seem to
@ -544,7 +550,8 @@ ChangeResult TypeAnalyzer::visitAtenConv2dOp(
ChangeResult TypeAnalyzer::visitAtenMaxPool2dOp(
AtenMaxPool2dOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
knowledge.sizes.resize(4, kUnknownSize);
knowledge.dtype = operands[0]->getValue().dtype;
@ -555,7 +562,8 @@ ChangeResult TypeAnalyzer::visitAtenAdaptiveAvgPool2dOp(
AtenAdaptiveAvgPool2dOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
if (input.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes.resize(input.sizes.size(), kUnknownSize);
@ -573,7 +581,8 @@ ChangeResult TypeAnalyzer::visitBinaryBroadcastingOp(
// tricky, so we defer that until we need it.
auto lhs = operands[0]->getValue();
auto rhs = operands[1]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
if (lhs.hasSizes && rhs.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()),
@ -593,7 +602,8 @@ ChangeResult TypeAnalyzer::visitAtenLerpTensorOp(
auto a = operands[0]->getValue();
auto b = operands[1]->getValue();
auto c = operands[1]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
if (a.hasSizes && b.hasSizes && c.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes.resize(
@ -611,7 +621,8 @@ ChangeResult TypeAnalyzer::visitAtenFlattenUsingIntsOp(
int64_t startDim;
int64_t endDim;
auto operand = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
knowledge.dtype = operand.dtype;
if (operand.hasSizes && operand.sizes.size() == 0) {
// Rank 0 is special and flattens to rank 1 with size 1.
@ -642,7 +653,8 @@ ChangeResult TypeAnalyzer::visitAtenFlattenUsingIntsOp(
ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto operand = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
knowledge.dtype = operand.dtype;
int64_t dim;
if (operand.hasSizes && matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
@ -666,7 +678,8 @@ ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
// Arange like ops returns a 1-D tensor of size ceil(end - start).
ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
Operation *op, llvm::Optional<Value> start, Value end, Value dtype) {
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.sizes.resize(1, kUnknownSize);
knowledge.hasSizes = true;
int64_t dtypeInt;
@ -716,7 +729,8 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
Operation *op, Value dim, Value keepdim,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
llvm::SmallVector<int64_t> dimList;
bool keepdimBool;
@ -749,7 +763,8 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
int64_t dim;
bool keepdimBool;
@ -779,7 +794,8 @@ template <typename OpTy>
ChangeResult TypeAnalyzer::visitReshapeLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
knowledge.dtype = input.dtype;
fillInSizesGivenSizesList(knowledge, op.size());
@ -790,7 +806,8 @@ ChangeResult TypeAnalyzer::visitAtenTransposeIntOp(
AtenTransposeIntOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
knowledge.dtype = input.dtype;
knowledge.hasSizes = input.hasSizes;
auto dim0 = op.dim0();
@ -815,7 +832,8 @@ ChangeResult TypeAnalyzer::visitAtenTransposeIntOp(
template <typename OpTy>
ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
Value t = op.t();
Value dtype = op.dtype();
knowledge.hasSizes = true;
@ -828,18 +846,18 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
// `torch.aten.tensor` get a tensor from a list. Each layer of the list
// corresponds to one dim of the tensor.
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
Value data = op.data();
Value dtype = op.dtype();
Type type = data.getType();
int64_t rank = 0;
bool rankIsUnknown = false;
while (auto listType = type.dyn_cast<ListType>()) {
type = listType.getContainedType();
rank++;
}
if (!rankIsUnknown) {
if (rank != 0) {
knowledge.hasSizes = true;
knowledge.sizes.resize(rank, kUnknownSize);
}
@ -849,7 +867,8 @@ ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
template <typename OpTy>
ChangeResult TypeAnalyzer::visitConstantTensorAllocOp(OpTy op) {
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
fillInSizesGivenSizesList(knowledge, op.size());
fillInDTypeGivenDTypeAndDataType(op->getContext(), knowledge, op.dtype(),
Torch::FloatType::get(op->getContext()));
@ -861,7 +880,8 @@ template <typename OpTy>
ChangeResult TypeAnalyzer::visitTypeConversionOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = input.hasSizes;
knowledge.sizes = input.sizes;
Value other = op.other();
@ -879,7 +899,8 @@ ChangeResult TypeAnalyzer::visitSliceLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands,
SetDimSizeFn setDim) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
if (!input.hasSizes)
return getLatticeElement(op.getResult()).join(knowledge);
@ -912,7 +933,8 @@ ChangeResult TypeAnalyzer::visitAtenGatherOp(
AtenGatherOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto index = operands[2]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
knowledge.hasSizes = index.hasSizes;
knowledge.sizes = index.sizes;
@ -927,7 +949,8 @@ ChangeResult TypeAnalyzer::visitExpandLikeOp(
ArrayRef<LatticeElement<ValueKnowledge> *> operands,
SetDimSizePerListItemFn setDim) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
if (!input.hasSizes)
return getLatticeElement(op->getResult(0)).join(knowledge);
@ -957,7 +980,8 @@ ChangeResult TypeAnalyzer::visitExpandLikeOp(
ChangeResult TypeAnalyzer::visitAtenCatOp(
AtenCatOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto tensorList = op.tensors();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
if (!listConstruct)
return getLatticeElement(op.getResult()).join(knowledge);
@ -995,7 +1019,8 @@ ChangeResult TypeAnalyzer::visitAtenShapeAsTensorOp(
Aten_ShapeAsTensorOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
if (input.hasSizes)
knowledge.sizes.resize(1, input.sizes.size());
else
@ -1007,7 +1032,8 @@ ChangeResult TypeAnalyzer::visitAtenShapeAsTensorOp(
ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
AtenEmbeddingOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto weight = operands[0]->getValue();
auto indices = operands[1]->getValue();
if (indices.hasSizes) {
@ -1026,13 +1052,15 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
ChangeResult TypeAnalyzer::visitAtenBmmOp(
AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto self = operands[0]->getValue();
auto mat2 = operands[1]->getValue();
knowledge.sizes.resize(3, kUnknownSize);
knowledge.dtype = joinElementTypes(self.dtype, mat2.dtype);
return getLatticeElement(op->getResult(0)).join(knowledge);
}
// -----------------------------------------------------------------------------
// Transforms.
// -----------------------------------------------------------------------------

View File

@ -406,7 +406,6 @@ func @torch.aten.__getitem__.Dict_str(%k0 : !torch.str, %v0: !torch.tensor, %k1:
// CHECK-LABEL: func @torch.aten.add.int() -> !torch.int {
// CHECK: %[[CST9:.*]] = torch.constant.int 9
// CHECK: return %[[CST9]] : !torch.int
// CHECK: }
func @torch.aten.add.int() -> !torch.int {
%cst4 = torch.constant.int 4
%cst5 = torch.constant.int 5
@ -417,7 +416,6 @@ func @torch.aten.add.int() -> !torch.int {
// CHECK-LABEL: func @torch.aten.sub.int() -> !torch.int {
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: return %[[CST1]] : !torch.int
// CHECK: }
func @torch.aten.sub.int() -> !torch.int {
%cst6 = torch.constant.int 6
%cst5 = torch.constant.int 5
@ -428,7 +426,6 @@ func @torch.aten.sub.int() -> !torch.int {
// CHECK-LABEL: func @torch.aten.mul.int() -> !torch.int {
// CHECK: %[[CST30:.*]] = torch.constant.int 30
// CHECK: return %[[CST30]] : !torch.int
// CHECK: }
func @torch.aten.mul.int() -> !torch.int {
%cst6 = torch.constant.int 6
%cst5 = torch.constant.int 5
@ -439,7 +436,6 @@ func @torch.aten.mul.int() -> !torch.int {
// CHECK-LABEL: func @torch.aten.mul.int$with_zero() -> !torch.int {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: return %[[CST0]] : !torch.int
// CHECK: }
func @torch.aten.mul.int$with_zero() -> !torch.int {
%cst6 = torch.constant.int 6
%cst0 = torch.constant.int 0
@ -450,7 +446,6 @@ func @torch.aten.mul.int$with_zero() -> !torch.int {
// CHECK-LABEL: func @torch.aten.floordiv.int() -> !torch.int {
// CHECK: %[[CST3:.*]] = torch.constant.int 3
// CHECK: return %[[CST3]] : !torch.int
// CHECK: }
func @torch.aten.floordiv.int() -> !torch.int {
%cst18 = torch.constant.int 18
%cst5 = torch.constant.int 5
@ -461,10 +456,36 @@ func @torch.aten.floordiv.int() -> !torch.int {
// CHECK-LABEL: func @torch.aten.remainder.int() -> !torch.int {
// CHECK: %[[CST3:.*]] = torch.constant.int 3
// CHECK: return %[[CST3]] : !torch.int
// CHECK: }
func @torch.aten.remainder.int() -> !torch.int {
%cst18 = torch.constant.int 18
%cst5 = torch.constant.int 5
%ret = torch.aten.remainder.int %cst18, %cst5: !torch.int, !torch.int -> !torch.int
return %ret : !torch.int
}
// CHECK-LABEL: func @torch.prim.dtype$float(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,f32>) -> !torch.int {
// CHECK: %[[CST:.*]] = torch.constant.int 6
// CHECK: return %[[CST]] : !torch.int
func @torch.prim.dtype$float(%t : !torch.tensor<*,f32>) -> !torch.int {
%ret = torch.prim.dtype %t: !torch.tensor<*,f32> -> !torch.int
return %ret : !torch.int
}
// CHECK-LABEL: func @torch.prim.dtype$bool(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,i1>) -> !torch.int {
// CHECK: %[[CST:.*]] = torch.constant.int 11
// CHECK: return %[[CST]] : !torch.int
func @torch.prim.dtype$bool(%t : !torch.tensor<*,i1>) -> !torch.int {
%ret = torch.prim.dtype %t: !torch.tensor<*,i1> -> !torch.int
return %ret : !torch.int
}
// CHECK-LABEL: func @torch.prim.dtype$int64(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,si64>) -> !torch.int {
// CHECK: %[[CST:.*]] = torch.constant.int 4
// CHECK: return %[[CST]] : !torch.int
func @torch.prim.dtype$int64(%t : !torch.tensor<*,si64>) -> !torch.int {
%ret = torch.prim.dtype %t: !torch.tensor<*,si64> -> !torch.int
return %ret : !torch.int
}