mirror of https://github.com/llvm/torch-mlir
Add support for scalar type propagation
The main changes are: - Added `ValueKnowledge.scalarType` to track scalar type information. - Added `ValueKnowledge.kind` to indicate the value kind. - Modified the meet and join helper functions. The ValueKnowledge has slightly more complicated state now so the meet and join function need to look at the `kind` field in addition to just the type field.pull/822/head snapshot-20220504.431
parent
0fb7a03ac9
commit
9f7264a7a4
|
@ -7398,6 +7398,29 @@ def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [
|
||||||
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands)) `->` qualified(type($results))";
|
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands)) `->` qualified(type($results))";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_PrimAbsScalarOp : Torch_Op<"prim.abs.Scalar", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `prim::abs.Scalar : (Scalar) -> (Scalar)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchScalarType:$a
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchScalarType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult PrimAbsScalarOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void PrimAbsScalarOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
|
|
|
@ -29,6 +29,55 @@ namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace torch_upstream {
|
namespace torch_upstream {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TypeKind related enum related code are copied from
|
||||||
|
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type_base.h
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
#define C10_FORALL_TYPES(_) \
|
||||||
|
_(AnyType) \
|
||||||
|
_(EnumType) \
|
||||||
|
_(AnyEnumType) \
|
||||||
|
_(TensorType) \
|
||||||
|
_(StorageType) \
|
||||||
|
_(TupleType) \
|
||||||
|
_(ListType) \
|
||||||
|
_(DictType) \
|
||||||
|
_(NumberType) \
|
||||||
|
_(FloatType) \
|
||||||
|
_(ComplexType) \
|
||||||
|
_(FutureType) \
|
||||||
|
_(RRefType) \
|
||||||
|
_(IntType) \
|
||||||
|
_(NoneType) \
|
||||||
|
_(StringType) \
|
||||||
|
_(GeneratorType) \
|
||||||
|
_(QuantizerType) \
|
||||||
|
_(BoolType) \
|
||||||
|
_(OptionalType) \
|
||||||
|
_(VarType) \
|
||||||
|
_(DeviceObjType) \
|
||||||
|
_(StreamObjType) \
|
||||||
|
_(FunctionType) \
|
||||||
|
_(ClassType) \
|
||||||
|
_(PyObjectType) \
|
||||||
|
_(CapsuleType) \
|
||||||
|
_(InterfaceType) \
|
||||||
|
_(QSchemeType) \
|
||||||
|
_(LayoutType) \
|
||||||
|
_(ScalarTypeType) \
|
||||||
|
_(AnyListType) \
|
||||||
|
_(AnyTupleType) \
|
||||||
|
_(AnyClassType) \
|
||||||
|
_(SymIntType) \
|
||||||
|
_(UnionType) \
|
||||||
|
_(DynamicType)
|
||||||
|
|
||||||
|
enum class TypeKind {
|
||||||
|
#define DEFINE_TYPE(T) T,
|
||||||
|
C10_FORALL_TYPES(DEFINE_TYPE)
|
||||||
|
#undef DEFINE_TYPE
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ScalarType enum related code are copied from c10/core/ScalarType.h
|
// ScalarType enum related code are copied from c10/core/ScalarType.h
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -92,14 +92,27 @@ static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
|
||||||
return Type();
|
return Type();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type joinElementTypes(Type lhs, Type rhs) {
|
// Get the kind enum for `ValueKnowledge.kind`.
|
||||||
if (lhs == rhs)
|
static torch_upstream::TypeKind getTypeKind(Type type) {
|
||||||
return lhs;
|
if (type.isa<NumberType>())
|
||||||
return Type();
|
return torch_upstream::TypeKind::NumberType;
|
||||||
|
if (type.isa<IntType>())
|
||||||
|
return torch_upstream::TypeKind::IntType;
|
||||||
|
if (type.isa<Torch::FloatType>())
|
||||||
|
return torch_upstream::TypeKind::FloatType;
|
||||||
|
if (type.isa<BaseTensorType>())
|
||||||
|
return torch_upstream::TypeKind::TensorType;
|
||||||
|
if (type.isa<Torch::NoneType>())
|
||||||
|
return torch_upstream::TypeKind::NoneType;
|
||||||
|
// Skip the Torch::OptionalType on purpose because optional knowledge is
|
||||||
|
// tracked separately. See comments for `ValueKnowledge.kind` field.
|
||||||
|
return torch_upstream::TypeKind::AnyType;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the dtype that assumes information from both `lhs` and `rhs`.
|
/// Returns the dtype that assumes information from both `lhs` and `rhs`.
|
||||||
/// Returns `None` if the types are contradictory.
|
/// Returns `None` if the types are contradictory. Note this can only be used
|
||||||
|
/// on the `dtype` from tensors and can't be used on other types like scalar
|
||||||
|
/// types.
|
||||||
static Optional<Type> meetElementTypes(Type lhs, Type rhs) {
|
static Optional<Type> meetElementTypes(Type lhs, Type rhs) {
|
||||||
if (!lhs)
|
if (!lhs)
|
||||||
return rhs;
|
return rhs;
|
||||||
|
@ -148,27 +161,44 @@ namespace {
|
||||||
// This class could also be called "dataflow facts", "lattice value", etc.
|
// This class could also be called "dataflow facts", "lattice value", etc.
|
||||||
struct ValueKnowledge {
|
struct ValueKnowledge {
|
||||||
ValueKnowledge() = delete;
|
ValueKnowledge() = delete;
|
||||||
ValueKnowledge(Type dtype, OptionalKnowledge optionalKnowledge)
|
ValueKnowledge(Type dtype, Type scalarType,
|
||||||
: dtype(dtype), optional(optionalKnowledge) {}
|
OptionalKnowledge optionalKnowledge,
|
||||||
|
torch_upstream::TypeKind kind)
|
||||||
|
: dtype(dtype), scalarType(scalarType), kind(kind),
|
||||||
|
optional(optionalKnowledge) {}
|
||||||
|
|
||||||
// Get the static knowledge intrinsic to `type`.
|
// Get the static knowledge intrinsic to `type`.
|
||||||
static ValueKnowledge getKnowledgeFromType(Type type) {
|
static ValueKnowledge getKnowledgeFromType(Type type) {
|
||||||
ValueKnowledge result = getPessimisticValueState(type.getContext());
|
ValueKnowledge result = getPessimisticValueState(type.getContext());
|
||||||
if (auto tensorType = type.dyn_cast<BaseTensorType>()) {
|
result.kind = getTypeKind(type);
|
||||||
result.dtype = tensorType.getOptionalDtype();
|
switch (result.kind) {
|
||||||
|
case torch_upstream::TypeKind::TensorType:
|
||||||
|
result.dtype = type.cast<BaseTensorType>().getOptionalDtype();
|
||||||
result.optional = OptionalKnowledge::notNone;
|
result.optional = OptionalKnowledge::notNone;
|
||||||
} else if (auto optionalType = type.dyn_cast<Torch::NoneType>()) {
|
|
||||||
result.optional = OptionalKnowledge::isNone;
|
|
||||||
} else if (!type.isa<OptionalType>()) {
|
|
||||||
result.optional = OptionalKnowledge::notNone;
|
|
||||||
}
|
|
||||||
return result;
|
return result;
|
||||||
|
case torch_upstream::TypeKind::NumberType:
|
||||||
|
case torch_upstream::TypeKind::IntType:
|
||||||
|
case torch_upstream::TypeKind::FloatType:
|
||||||
|
result.scalarType = type;
|
||||||
|
result.optional = OptionalKnowledge::notNone;
|
||||||
|
return result;
|
||||||
|
case torch_upstream::TypeKind::NoneType:
|
||||||
|
result.optional = OptionalKnowledge::isNone;
|
||||||
|
return result;
|
||||||
|
default:
|
||||||
|
if (type.isa<OptionalType>())
|
||||||
|
return result;
|
||||||
|
// All other types that are not optional type.
|
||||||
|
result.optional = OptionalKnowledge::notNone;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a pessimistic/conservative value state without assuming any knowlege
|
// Return a pessimistic/conservative value state without assuming any knowlege
|
||||||
// about the IR.
|
// about the IR.
|
||||||
static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
|
static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
|
||||||
return ValueKnowledge(Type(), OptionalKnowledge::unKnown);
|
return ValueKnowledge(Type(), Type(), OptionalKnowledge::unKnown,
|
||||||
|
torch_upstream::TypeKind::AnyType);
|
||||||
}
|
}
|
||||||
// Return a pessimistic/conservative value state only using knowlege already
|
// Return a pessimistic/conservative value state only using knowlege already
|
||||||
// recorded in the IR.
|
// recorded in the IR.
|
||||||
|
@ -177,7 +207,19 @@ struct ValueKnowledge {
|
||||||
}
|
}
|
||||||
|
|
||||||
static ValueKnowledge getNotNonePessimisticValueState(MLIRContext *context) {
|
static ValueKnowledge getNotNonePessimisticValueState(MLIRContext *context) {
|
||||||
return ValueKnowledge(Type(), OptionalKnowledge::notNone);
|
return ValueKnowledge(Type(), Type(), OptionalKnowledge::notNone,
|
||||||
|
torch_upstream::TypeKind::AnyType);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ValueKnowledge getTensorPessimisticValueState(MLIRContext *context) {
|
||||||
|
return ValueKnowledge(Type(), Type(), OptionalKnowledge::notNone,
|
||||||
|
torch_upstream::TypeKind::TensorType);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ValueKnowledge getScalarPessimisticValueState(MLIRContext *context) {
|
||||||
|
return ValueKnowledge(Type(), NumberType::get(context),
|
||||||
|
OptionalKnowledge::notNone,
|
||||||
|
torch_upstream::TypeKind::NumberType);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(const ValueKnowledge &rhs) const {
|
bool operator==(const ValueKnowledge &rhs) const {
|
||||||
|
@ -185,6 +227,25 @@ struct ValueKnowledge {
|
||||||
std::make_tuple(rhs.dtype, rhs.optional);
|
std::make_tuple(rhs.dtype, rhs.optional);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return true if the `refinedType` has more concrete type info than `type`.
|
||||||
|
static bool hasStrictlyMoreRefinedTypeInfo(const ValueKnowledge &refinedType,
|
||||||
|
const ValueKnowledge &type) {
|
||||||
|
if (type.kind == torch_upstream::TypeKind::AnyType &&
|
||||||
|
refinedType.kind != torch_upstream::TypeKind::AnyType)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// If both are tensors but `type` doesn't have concrete dtype info.
|
||||||
|
if (refinedType.kind == torch_upstream::TypeKind::TensorType &&
|
||||||
|
type.kind == torch_upstream::TypeKind::TensorType) {
|
||||||
|
return refinedType.dtype && !type.dtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (refinedType.scalarType && type.scalarType)
|
||||||
|
return isValidSubtype(refinedType.scalarType, type.scalarType);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Given two pieces of static knowledge, intersect the facts that are known in
|
// Given two pieces of static knowledge, intersect the facts that are known in
|
||||||
// both knowledges. This always produces knowledge that has less (or equal)
|
// both knowledges. This always produces knowledge that has less (or equal)
|
||||||
// facts than both the lhs and rhs.
|
// facts than both the lhs and rhs.
|
||||||
|
@ -200,40 +261,79 @@ struct ValueKnowledge {
|
||||||
// Mental model: All conditions are checking how to change from the safe "no
|
// Mental model: All conditions are checking how to change from the safe "no
|
||||||
// knowledge" default-initialized state to a state with more knowledge
|
// knowledge" default-initialized state to a state with more knowledge
|
||||||
// consistent with lhs and rhs.
|
// consistent with lhs and rhs.
|
||||||
ValueKnowledge result = getPessimisticValueState(nullptr);
|
ValueKnowledge result = joinTypes(lhs, rhs);
|
||||||
|
|
||||||
result.optional = joinOptionalKnowledge(lhs.optional, rhs.optional);
|
result.optional = joinOptionalKnowledge(lhs.optional, rhs.optional);
|
||||||
result.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ValueKnowledge joinTypes(const ValueKnowledge &lhs,
|
||||||
|
const ValueKnowledge &rhs) {
|
||||||
|
if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs))
|
||||||
|
return rhs;
|
||||||
|
if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs))
|
||||||
|
return lhs;
|
||||||
|
if (lhs == rhs)
|
||||||
|
return lhs;
|
||||||
|
return getPessimisticValueState(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
// Given two pieces of static knowledge, calculate new knowledge that assumes
|
// Given two pieces of static knowledge, calculate new knowledge that assumes
|
||||||
// the facts from both.
|
// the facts from both.
|
||||||
// If the two pieces of knowledge are contradictory, None is returned.
|
// If the two pieces of knowledge are contradictory, None is returned.
|
||||||
static Optional<ValueKnowledge> meet(const ValueKnowledge &lhs,
|
static Optional<ValueKnowledge> meet(const ValueKnowledge &lhs,
|
||||||
const ValueKnowledge &rhs) {
|
const ValueKnowledge &rhs) {
|
||||||
ValueKnowledge result = getPessimisticValueState(nullptr);
|
Optional<ValueKnowledge> knowledge = meetTypes(lhs, rhs);
|
||||||
|
|
||||||
|
if (!knowledge.hasValue())
|
||||||
|
return None;
|
||||||
|
ValueKnowledge result = knowledge.getValue();
|
||||||
|
|
||||||
Optional<OptionalKnowledge> optional =
|
Optional<OptionalKnowledge> optional =
|
||||||
meetOptionalKnowledge(lhs.optional, rhs.optional);
|
meetOptionalKnowledge(lhs.optional, rhs.optional);
|
||||||
if (!optional.hasValue())
|
if (!optional.hasValue())
|
||||||
return None;
|
return None;
|
||||||
result.optional = optional.getValue();
|
result.optional = optional.getValue();
|
||||||
|
|
||||||
Optional<Type> dtype = meetElementTypes(lhs.dtype, rhs.dtype);
|
|
||||||
if (!dtype.hasValue())
|
|
||||||
return None;
|
|
||||||
result.dtype = dtype.getValue();
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Optional<ValueKnowledge> meetTypes(const ValueKnowledge &lhs,
|
||||||
|
const ValueKnowledge &rhs) {
|
||||||
|
if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs))
|
||||||
|
return lhs;
|
||||||
|
if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs))
|
||||||
|
return rhs;
|
||||||
|
if (lhs == rhs)
|
||||||
|
return lhs;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
// The dtype of a tensor.
|
// The dtype of a tensor.
|
||||||
// This is equal to nullptr if we don't know that it is a specific concrete
|
// This is equal to nullptr for the follow cases:
|
||||||
// type.
|
// 1. it is unknown whether the value is a tensor or not, ie the `kind` field
|
||||||
|
// is torch_upstream::TypeKind::AnyType.
|
||||||
|
// 2. the value is a tensor type but the dtype is unknown.
|
||||||
|
// 3. the value is not a tensor type.
|
||||||
Type dtype;
|
Type dtype;
|
||||||
|
|
||||||
|
// The type of a scalar.
|
||||||
|
// This is equal to nullptr for the follow cases:
|
||||||
|
// 1. it is unknown whether the value is a scalar or not, ie the `kind` field
|
||||||
|
// is torch_upstream::TypeKind::AnyType.
|
||||||
|
// 2. the value is not a scalar type.
|
||||||
|
Type scalarType;
|
||||||
|
|
||||||
|
// The type kind. If it's torch_upstream::TypeKind::AnyType,
|
||||||
|
// all the type fields are nullptr. Note that the `kind` never equals to
|
||||||
|
// torch_upstream::TypeKind::OptionalType because optional knowledge is
|
||||||
|
// tracked separately through the `optional` field.
|
||||||
|
torch_upstream::TypeKind kind;
|
||||||
|
|
||||||
// What is known about an optional value.
|
// What is known about an optional value.
|
||||||
|
// When equal to OptionalKnowledge::notNone, the type info is kept in type
|
||||||
|
// fields like `dtype`, `scalarType`.
|
||||||
|
// When equal to OptionalKnowledge::isNone or OptionalKnowledge::unKnown, the
|
||||||
|
// other type fields are currently nullptr. It might worth considering
|
||||||
|
// tracking wrapped type info when OptionalKnowledge::unKnown in the future.
|
||||||
OptionalKnowledge optional;
|
OptionalKnowledge optional;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -500,11 +600,9 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
|
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
|
||||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||||
ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp,
|
ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp,
|
||||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp>(op)) {
|
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||||
ValueKnowledge knowledge =
|
PrimAbsScalarOp>(op)) {
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
knowledge.dtype = operands[0]->getValue().dtype;
|
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dtype is always float32, except for bfloat16, float64 and nullptr.
|
// Dtype is always float32, except for bfloat16, float64 and nullptr.
|
||||||
|
@ -512,7 +610,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp,
|
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp,
|
||||||
AtenErfOp>(op)) {
|
AtenErfOp>(op)) {
|
||||||
ValueKnowledge knowledge =
|
ValueKnowledge knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
Type dtype = operands[0]->getValue().dtype;
|
Type dtype = operands[0]->getValue().dtype;
|
||||||
if (dtype) {
|
if (dtype) {
|
||||||
knowledge.dtype = Float32Type::get(op->getContext());
|
knowledge.dtype = Float32Type::get(op->getContext());
|
||||||
|
@ -526,7 +624,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
|
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
|
||||||
auto self = operands[1]->getValue();
|
auto self = operands[1]->getValue();
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = self.dtype;
|
knowledge.dtype = self.dtype;
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -536,7 +634,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
|
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
|
||||||
AtenGtTensorOp, AtenLtTensorOp>(op)) {
|
AtenGtTensorOp, AtenLtTensorOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = IntegerType::get(op->getContext(), 1);
|
knowledge.dtype = IntegerType::get(op->getContext(), 1);
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -544,7 +642,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
// Dtype is always si64.
|
// Dtype is always si64.
|
||||||
if (isa<AtenBincountOp>(op)) {
|
if (isa<AtenBincountOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype =
|
knowledge.dtype =
|
||||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
|
@ -554,7 +652,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||||
AtenConvolutionOverrideableOp>(op)) {
|
AtenConvolutionOverrideableOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
|
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
|
@ -565,7 +663,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
Aten__And__TensorOp, AtenMinimumOp, AtenMaximumOp,
|
Aten__And__TensorOp, AtenMinimumOp, AtenMaximumOp,
|
||||||
AtenBitwiseAndTensorOp, AtenThresholdBackwardOp>(op)) {
|
AtenBitwiseAndTensorOp, AtenThresholdBackwardOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = getPromotedResultType(
|
knowledge.dtype = getPromotedResultType(
|
||||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
|
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
|
||||||
getRankIsNonZeroArray(op->getOperands()));
|
getRankIsNonZeroArray(op->getOperands()));
|
||||||
|
@ -575,7 +673,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
// Promote three dtypes.
|
// Promote three dtypes.
|
||||||
if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) {
|
if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue(),
|
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue(),
|
||||||
&operands[2]->getValue()});
|
&operands[2]->getValue()});
|
||||||
|
@ -593,7 +691,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
auto lhs = operands[0]->getValue();
|
auto lhs = operands[0]->getValue();
|
||||||
Value scalar = op->getOperand(1);
|
Value scalar = op->getOperand(1);
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||||
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
|
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -601,7 +699,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
// Promote 2nd and 3rd operands.
|
// Promote 2nd and 3rd operands.
|
||||||
if (isa<AtenWhereSelfOp>(op)) {
|
if (isa<AtenWhereSelfOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||||
knowledge.dtype = getPromotedResultType(
|
knowledge.dtype = getPromotedResultType(
|
||||||
getContext(), {&operands[1]->getValue(), &operands[2]->getValue()},
|
getContext(), {&operands[1]->getValue(), &operands[2]->getValue()},
|
||||||
getRankIsNonZeroArray(op->getOperands().slice(1, 2)));
|
getRankIsNonZeroArray(op->getOperands().slice(1, 2)));
|
||||||
|
@ -613,7 +711,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
Value lhsScalar = op->getOperand(1);
|
Value lhsScalar = op->getOperand(1);
|
||||||
Value rhsScalar = op->getOperand(2);
|
Value rhsScalar = op->getOperand(2);
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||||
knowledge.dtype =
|
knowledge.dtype =
|
||||||
getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()});
|
getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()});
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
|
@ -624,7 +722,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
auto lhs = operands[1]->getValue();
|
auto lhs = operands[1]->getValue();
|
||||||
Value scalar = op->getOperand(2);
|
Value scalar = op->getOperand(2);
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||||
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
|
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -634,7 +732,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
auto rhs = operands[2]->getValue();
|
auto rhs = operands[2]->getValue();
|
||||||
Value scalar = op->getOperand(1);
|
Value scalar = op->getOperand(1);
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||||
knowledge.dtype = getPromotedResultType(&rhs, scalar.getType());
|
knowledge.dtype = getPromotedResultType(&rhs, scalar.getType());
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -643,10 +741,10 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
if (isa<AtenNllLossForwardOp>(op)) {
|
if (isa<AtenNllLossForwardOp>(op)) {
|
||||||
auto self = operands[0]->getValue();
|
auto self = operands[0]->getValue();
|
||||||
auto result0Knowledge =
|
auto result0Knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
result0Knowledge.dtype = self.dtype;
|
result0Knowledge.dtype = self.dtype;
|
||||||
auto result1Knowledge =
|
auto result1Knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
result1Knowledge.dtype = self.dtype;
|
result1Knowledge.dtype = self.dtype;
|
||||||
auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge);
|
auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge);
|
||||||
changed |= incorporateKnowledge(op->getResult(1), result1Knowledge);
|
changed |= incorporateKnowledge(op->getResult(1), result1Knowledge);
|
||||||
|
@ -657,13 +755,13 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp>(op)) {
|
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp>(op)) {
|
||||||
auto self = operands[0]->getValue();
|
auto self = operands[0]->getValue();
|
||||||
auto result0Knowledge =
|
auto result0Knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
result0Knowledge.dtype = self.dtype;
|
result0Knowledge.dtype = self.dtype;
|
||||||
auto result1Knowledge =
|
auto result1Knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
result1Knowledge.dtype = self.dtype;
|
result1Knowledge.dtype = self.dtype;
|
||||||
auto result2Knowledge =
|
auto result2Knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
result2Knowledge.dtype = self.dtype;
|
result2Knowledge.dtype = self.dtype;
|
||||||
auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge);
|
auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge);
|
||||||
changed |= incorporateKnowledge(op->getResult(1), result1Knowledge);
|
changed |= incorporateKnowledge(op->getResult(1), result1Knowledge);
|
||||||
|
@ -674,10 +772,10 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
if (isa<AtenMaxPool2dWithIndicesOp>(op)) {
|
if (isa<AtenMaxPool2dWithIndicesOp>(op)) {
|
||||||
auto self = operands[0]->getValue();
|
auto self = operands[0]->getValue();
|
||||||
auto result0Knowledge =
|
auto result0Knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
result0Knowledge.dtype = self.dtype;
|
result0Knowledge.dtype = self.dtype;
|
||||||
auto result1Knowledge =
|
auto result1Knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
result1Knowledge.dtype =
|
result1Knowledge.dtype =
|
||||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||||
;
|
;
|
||||||
|
@ -700,7 +798,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
Type defaultDtype = operands[0]->getValue().dtype;
|
Type defaultDtype = operands[0]->getValue().dtype;
|
||||||
Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype);
|
Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype);
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = dtype;
|
knowledge.dtype = dtype;
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -816,7 +914,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
|
|
||||||
if (auto shapeAsTensor = dyn_cast<Aten_ShapeAsTensorOp>(op)) {
|
if (auto shapeAsTensor = dyn_cast<Aten_ShapeAsTensorOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype =
|
knowledge.dtype =
|
||||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||||
return incorporateKnowledge(shapeAsTensor.getResult(), knowledge);
|
return incorporateKnowledge(shapeAsTensor.getResult(), knowledge);
|
||||||
|
@ -824,7 +922,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
|
|
||||||
if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) {
|
if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = Float32Type::get(op->getContext());
|
knowledge.dtype = Float32Type::get(op->getContext());
|
||||||
return incorporateKnowledge(embedding.getResult(), knowledge);
|
return incorporateKnowledge(embedding.getResult(), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -864,7 +962,7 @@ TypeAnalyzer::incorporateKnowledge(Value v, const ValueKnowledge &knowledge) {
|
||||||
ChangeResult TypeAnalyzer::visitAtenLinearOp(
|
ChangeResult TypeAnalyzer::visitAtenLinearOp(
|
||||||
AtenLinearOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
AtenLinearOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
auto input = operands[0]->getValue();
|
auto input = operands[0]->getValue();
|
||||||
auto weight = operands[1]->getValue();
|
auto weight = operands[1]->getValue();
|
||||||
auto bias = operands[2]->getValue();
|
auto bias = operands[2]->getValue();
|
||||||
|
@ -889,7 +987,7 @@ ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
|
||||||
Operation *op, llvm::Optional<Value> start, Value end,
|
Operation *op, llvm::Optional<Value> start, Value end,
|
||||||
llvm::Optional<Value> step, Value dtype) {
|
llvm::Optional<Value> step, Value dtype) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
int64_t dtypeInt;
|
int64_t dtypeInt;
|
||||||
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) {
|
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) {
|
||||||
knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt);
|
knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt);
|
||||||
|
@ -930,7 +1028,8 @@ ChangeResult TypeAnalyzer::visitAtenArangeOp(AtenArangeOp op) {
|
||||||
ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp(
|
ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp(
|
||||||
Operation *op, Type dtype,
|
Operation *op, Type dtype,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
auto knowledge =
|
||||||
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = dtype;
|
knowledge.dtype = dtype;
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -941,7 +1040,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
|
||||||
Operation *op, Value dim, Value keepdim, Type dtype,
|
Operation *op, Value dim, Value keepdim, Type dtype,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = dtype;
|
knowledge.dtype = dtype;
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -951,7 +1050,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands, int resNum) {
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands, int resNum) {
|
||||||
assert(dim.getType().isa<Torch::IntType>() && "dim must be int type");
|
assert(dim.getType().isa<Torch::IntType>() && "dim must be int type");
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = dtype;
|
knowledge.dtype = dtype;
|
||||||
return incorporateKnowledge(op->getResult(resNum), knowledge);
|
return incorporateKnowledge(op->getResult(resNum), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -959,7 +1058,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
|
ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op.getContext());
|
||||||
Value t = op.t();
|
Value t = op.t();
|
||||||
Value dtype = op.dtype();
|
Value dtype = op.dtype();
|
||||||
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType());
|
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType());
|
||||||
|
@ -969,15 +1068,17 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
|
||||||
ChangeResult TypeAnalyzer::visitBinaryScalarOp(
|
ChangeResult TypeAnalyzer::visitBinaryScalarOp(
|
||||||
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getScalarPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = getPromotedResultType(
|
Type resultType = getPromotedResultType(
|
||||||
{op->getOperand(0).getType(), op->getOperand(1).getType()});
|
{op->getOperand(0).getType(), op->getOperand(1).getType()});
|
||||||
|
knowledge.scalarType = resultType;
|
||||||
|
knowledge.kind = getTypeKind(resultType);
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
|
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op.getContext());
|
||||||
Value data = op.data();
|
Value data = op.data();
|
||||||
Value dtype = op.dtype();
|
Value dtype = op.dtype();
|
||||||
Type type = data.getType();
|
Type type = data.getType();
|
||||||
|
@ -993,7 +1094,7 @@ ChangeResult
|
||||||
TypeAnalyzer::visitConstantTensorAllocOp(OpTy op,
|
TypeAnalyzer::visitConstantTensorAllocOp(OpTy op,
|
||||||
llvm::Optional<Type> dataType) {
|
llvm::Optional<Type> dataType) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
if (!dataType)
|
if (!dataType)
|
||||||
dataType = Torch::FloatType::get(op->getContext());
|
dataType = Torch::FloatType::get(op->getContext());
|
||||||
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue());
|
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue());
|
||||||
|
@ -1005,7 +1106,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocLikeOp(
|
||||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto input = operands[0]->getValue();
|
auto input = operands[0]->getValue();
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
|
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
|
||||||
return incorporateKnowledge(op.getResult(), knowledge);
|
return incorporateKnowledge(op.getResult(), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -1015,7 +1116,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorNewLikeOp(
|
||||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto input = operands[0]->getValue();
|
auto input = operands[0]->getValue();
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
|
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
|
||||||
return incorporateKnowledge(op.getResult(), knowledge);
|
return incorporateKnowledge(op.getResult(), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -1025,7 +1126,7 @@ template <typename OpTy>
|
||||||
ChangeResult TypeAnalyzer::visitAtenToDtypeLikeOp(
|
ChangeResult TypeAnalyzer::visitAtenToDtypeLikeOp(
|
||||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
Value dtype = op.dtype();
|
Value dtype = op.dtype();
|
||||||
int64_t dtypeInt;
|
int64_t dtypeInt;
|
||||||
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||||
|
@ -1038,7 +1139,7 @@ template <typename OpTy>
|
||||||
ChangeResult TypeAnalyzer::visitTypeConversionOp(
|
ChangeResult TypeAnalyzer::visitTypeConversionOp(
|
||||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
Value other = op.other();
|
Value other = op.other();
|
||||||
BaseTensorType type = other.getType().cast<BaseTensorType>();
|
BaseTensorType type = other.getType().cast<BaseTensorType>();
|
||||||
if (type.hasDtype())
|
if (type.hasDtype())
|
||||||
|
@ -1053,7 +1154,7 @@ ChangeResult TypeAnalyzer::visitAtenCatOp(
|
||||||
AtenCatOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
AtenCatOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto tensorList = op.tensors();
|
auto tensorList = op.tensors();
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
|
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
|
||||||
if (!listConstruct)
|
if (!listConstruct)
|
||||||
return incorporateKnowledge(op.getResult(), knowledge);
|
return incorporateKnowledge(op.getResult(), knowledge);
|
||||||
|
@ -1073,7 +1174,7 @@ ChangeResult TypeAnalyzer::visitAtenCatOp(
|
||||||
|
|
||||||
ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
// The resulting type from converting a Scalar into a Tensor is different
|
// The resulting type from converting a Scalar into a Tensor is different
|
||||||
// if the scalar is part of a tensor operation (such as AtenMulScalar) or
|
// if the scalar is part of a tensor operation (such as AtenMulScalar) or
|
||||||
// not. In the former case, the type promotion rules are captured by the
|
// not. In the former case, the type promotion rules are captured by the
|
||||||
|
@ -1098,7 +1199,7 @@ ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
|
||||||
auto input = operands[0]->getValue();
|
auto input = operands[0]->getValue();
|
||||||
auto dtype = op.dtype();
|
auto dtype = op.dtype();
|
||||||
ValueKnowledge knowledge =
|
ValueKnowledge knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
|
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
|
||||||
return incorporateKnowledge(op.getResult(), knowledge);
|
return incorporateKnowledge(op.getResult(), knowledge);
|
||||||
}
|
}
|
||||||
|
@ -1109,7 +1210,7 @@ ChangeResult TypeAnalyzer::visitAten_SoftmaxLikeOp(
|
||||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto input = operands[0]->getValue();
|
auto input = operands[0]->getValue();
|
||||||
ValueKnowledge knowledge =
|
ValueKnowledge knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
bool halfToFloat;
|
bool halfToFloat;
|
||||||
if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) {
|
if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) {
|
||||||
knowledge.dtype =
|
knowledge.dtype =
|
||||||
|
@ -1154,6 +1255,16 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
|
||||||
else
|
else
|
||||||
return containedType;
|
return containedType;
|
||||||
}
|
}
|
||||||
|
} else if (auto scalarType = v.getType().dyn_cast<NumberType>()) {
|
||||||
|
LatticeElement<ValueKnowledge> *latticeElement =
|
||||||
|
analyzer.lookupLatticeElement(v);
|
||||||
|
if (!latticeElement)
|
||||||
|
return nullptr;
|
||||||
|
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||||
|
if (knowledge.kind == torch_upstream::TypeKind::IntType)
|
||||||
|
return Torch::IntType::get(v.getContext());
|
||||||
|
if (knowledge.kind == torch_upstream::TypeKind::FloatType)
|
||||||
|
return Torch::FloatType::get(v.getContext());
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1217,7 +1328,7 @@ void optimize(func::FuncOp func, TypeAnalyzer &analyzer) {
|
||||||
return b.create<TensorStaticInfoCastOp>(loc, newType, v);
|
return b.create<TensorStaticInfoCastOp>(loc, newType, v);
|
||||||
};
|
};
|
||||||
createStaticInfoUpCast = createStaticInfoDownCast;
|
createStaticInfoUpCast = createStaticInfoDownCast;
|
||||||
} else if (originalType.isa<OptionalType>()) {
|
} else if (originalType.isa<OptionalType, NumberType>()) {
|
||||||
createStaticInfoDownCast = [&](Location loc, Type newType,
|
createStaticInfoDownCast = [&](Location loc, Type newType,
|
||||||
Value v) -> Value {
|
Value v) -> Value {
|
||||||
return b.create<PrimUncheckedCastOp>(loc, newType, v);
|
return b.create<PrimUncheckedCastOp>(loc, newType, v);
|
||||||
|
|
|
@ -543,6 +543,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
|
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
|
||||||
emit("prim::Print : (...) -> ()")
|
emit("prim::Print : (...) -> ()")
|
||||||
emit("prim::tolist : (...) -> (...)")
|
emit("prim::tolist : (...) -> (...)")
|
||||||
|
emit("prim::abs.Scalar : (Scalar) -> (Scalar)")
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
# `quantized::` namespace.
|
# `quantized::` namespace.
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
// Code for testing transfer functions for new ops (which is most changes)
|
// Code for testing transfer functions for new ops (which is most changes)
|
||||||
// should go in refine-types-ops.mlir.
|
// should go in refine-types-ops.mlir.
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @basic(
|
// CHECK-LABEL: func @basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
|
@ -16,6 +17,7 @@ func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
return %1 : !torch.vtensor
|
return %1 : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @keep_existing_shape_information(
|
// CHECK-LABEL: func @keep_existing_shape_information(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
||||||
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
|
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
|
||||||
|
@ -25,6 +27,7 @@ func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vt
|
||||||
return %1 : !torch.vtensor<[2],f32>
|
return %1 : !torch.vtensor<[2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @propagate_through_multiple_ops(
|
// CHECK-LABEL: func @propagate_through_multiple_ops(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
|
@ -39,6 +42,7 @@ func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vte
|
||||||
return %3 : !torch.vtensor
|
return %3 : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// Check rewriting logic in case of mixes of users that do/don't allow type
|
// Check rewriting logic in case of mixes of users that do/don't allow type
|
||||||
// refinement.
|
// refinement.
|
||||||
// CHECK-LABEL: func @mixed_allowing_not_allowing_type_refinement(
|
// CHECK-LABEL: func @mixed_allowing_not_allowing_type_refinement(
|
||||||
|
@ -53,6 +57,7 @@ func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>)
|
||||||
return %1, %1 : !torch.vtensor, !torch.vtensor
|
return %1, %1 : !torch.vtensor, !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @type_promotion$same_category_different_width(
|
// CHECK-LABEL: func @type_promotion$same_category_different_width(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
|
||||||
|
@ -66,6 +71,7 @@ func @type_promotion$same_category_different_width(%arg0: !torch.vtensor<[?],si3
|
||||||
return %0 : !torch.vtensor<[?],unk>
|
return %0 : !torch.vtensor<[?],unk>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @type_promotion$different_category(
|
// CHECK-LABEL: func @type_promotion$different_category(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
|
||||||
|
@ -79,6 +85,7 @@ func @type_promotion$different_category(%arg0: !torch.vtensor<[?],si64>, %arg1:
|
||||||
return %0 : !torch.vtensor<[?],unk>
|
return %0 : !torch.vtensor<[?],unk>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @type_promotion$same_category_zero_rank_wider(
|
// CHECK-LABEL: func @type_promotion$same_category_zero_rank_wider(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
|
||||||
|
@ -92,6 +99,7 @@ func @type_promotion$same_category_zero_rank_wider(%arg0: !torch.vtensor<[?],f32
|
||||||
return %0 : !torch.vtensor<[?],unk>
|
return %0 : !torch.vtensor<[?],unk>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @type_promotion$zero_rank_higher_category(
|
// CHECK-LABEL: func @type_promotion$zero_rank_higher_category(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
|
||||||
|
@ -105,6 +113,7 @@ func @type_promotion$zero_rank_higher_category(%arg0: !torch.vtensor<[?],si64>,
|
||||||
return %0 : !torch.vtensor<[?],unk>
|
return %0 : !torch.vtensor<[?],unk>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @type_promotion$alpha_wider(
|
// CHECK-LABEL: func @type_promotion$alpha_wider(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
|
||||||
|
@ -118,6 +127,7 @@ func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.v
|
||||||
return %0 : !torch.vtensor<[?],unk>
|
return %0 : !torch.vtensor<[?],unk>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
||||||
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
|
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
|
||||||
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
||||||
|
@ -134,6 +144,7 @@ func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.
|
||||||
return %result : !torch.vtensor<[2],f32>
|
return %result : !torch.vtensor<[2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
|
// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
|
@ -151,6 +162,7 @@ func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.
|
||||||
return %result : !torch.vtensor<[?],f32>
|
return %result : !torch.vtensor<[?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @bf16_result_type(
|
// CHECK-LABEL: func @bf16_result_type(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
|
||||||
// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16>
|
// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16>
|
||||||
|
@ -159,3 +171,16 @@ func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16
|
||||||
%1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16>
|
%1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16>
|
||||||
return %1 : !torch.vtensor<[2],bf16>
|
return %1 : !torch.vtensor<[2],bf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @propagate_scalar_type(
|
||||||
|
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
|
||||||
|
// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number
|
||||||
|
// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number
|
||||||
|
// CHECK: return %[[RET]] : !torch.number
|
||||||
|
func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number {
|
||||||
|
%num = torch.derefine %arg0 : !torch.int to !torch.number
|
||||||
|
%1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number
|
||||||
|
return %1 : !torch.number
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue