mirror of https://github.com/llvm/torch-mlir
Add support for scalar type propagation
The main changes are: - Added `ValueKnowledge.scalarType` to track scalar type information. - Added `ValueKnowledge.kind` to indicate the value kind. - Modified the meet and join helper functions. The ValueKnowledge has slightly more complicated state now so the meet and join function need to look at the `kind` field in addition to just the type field.pull/822/head snapshot-20220504.431
parent
0fb7a03ac9
commit
9f7264a7a4
|
@ -7398,6 +7398,29 @@ def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [
|
|||
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands)) `->` qualified(type($results))";
|
||||
}
|
||||
|
||||
def Torch_PrimAbsScalarOp : Torch_Op<"prim.abs.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::abs.Scalar : (Scalar) -> (Scalar)`";
|
||||
let arguments = (ins
|
||||
AnyTorchScalarType:$a
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchScalarType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult PrimAbsScalarOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void PrimAbsScalarOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
||||
HasValueSemantics,
|
||||
AllowsTypeRefinement,
|
||||
|
|
|
@ -29,6 +29,55 @@ namespace mlir {
|
|||
namespace torch {
|
||||
namespace torch_upstream {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeKind related enum related code are copied from
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type_base.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
#define C10_FORALL_TYPES(_) \
|
||||
_(AnyType) \
|
||||
_(EnumType) \
|
||||
_(AnyEnumType) \
|
||||
_(TensorType) \
|
||||
_(StorageType) \
|
||||
_(TupleType) \
|
||||
_(ListType) \
|
||||
_(DictType) \
|
||||
_(NumberType) \
|
||||
_(FloatType) \
|
||||
_(ComplexType) \
|
||||
_(FutureType) \
|
||||
_(RRefType) \
|
||||
_(IntType) \
|
||||
_(NoneType) \
|
||||
_(StringType) \
|
||||
_(GeneratorType) \
|
||||
_(QuantizerType) \
|
||||
_(BoolType) \
|
||||
_(OptionalType) \
|
||||
_(VarType) \
|
||||
_(DeviceObjType) \
|
||||
_(StreamObjType) \
|
||||
_(FunctionType) \
|
||||
_(ClassType) \
|
||||
_(PyObjectType) \
|
||||
_(CapsuleType) \
|
||||
_(InterfaceType) \
|
||||
_(QSchemeType) \
|
||||
_(LayoutType) \
|
||||
_(ScalarTypeType) \
|
||||
_(AnyListType) \
|
||||
_(AnyTupleType) \
|
||||
_(AnyClassType) \
|
||||
_(SymIntType) \
|
||||
_(UnionType) \
|
||||
_(DynamicType)
|
||||
|
||||
enum class TypeKind {
|
||||
#define DEFINE_TYPE(T) T,
|
||||
C10_FORALL_TYPES(DEFINE_TYPE)
|
||||
#undef DEFINE_TYPE
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScalarType enum related code are copied from c10/core/ScalarType.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -92,14 +92,27 @@ static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
|
|||
return Type();
|
||||
}
|
||||
|
||||
static Type joinElementTypes(Type lhs, Type rhs) {
|
||||
if (lhs == rhs)
|
||||
return lhs;
|
||||
return Type();
|
||||
// Get the kind enum for `ValueKnowledge.kind`.
|
||||
static torch_upstream::TypeKind getTypeKind(Type type) {
|
||||
if (type.isa<NumberType>())
|
||||
return torch_upstream::TypeKind::NumberType;
|
||||
if (type.isa<IntType>())
|
||||
return torch_upstream::TypeKind::IntType;
|
||||
if (type.isa<Torch::FloatType>())
|
||||
return torch_upstream::TypeKind::FloatType;
|
||||
if (type.isa<BaseTensorType>())
|
||||
return torch_upstream::TypeKind::TensorType;
|
||||
if (type.isa<Torch::NoneType>())
|
||||
return torch_upstream::TypeKind::NoneType;
|
||||
// Skip the Torch::OptionalType on purpose because optional knowledge is
|
||||
// tracked separately. See comments for `ValueKnowledge.kind` field.
|
||||
return torch_upstream::TypeKind::AnyType;
|
||||
}
|
||||
|
||||
/// Returns the dtype that assumes information from both `lhs` and `rhs`.
|
||||
/// Returns `None` if the types are contradictory.
|
||||
/// Returns `None` if the types are contradictory. Note this can only be used
|
||||
/// on the `dtype` from tensors and can't be used on other types like scalar
|
||||
/// types.
|
||||
static Optional<Type> meetElementTypes(Type lhs, Type rhs) {
|
||||
if (!lhs)
|
||||
return rhs;
|
||||
|
@ -148,27 +161,44 @@ namespace {
|
|||
// This class could also be called "dataflow facts", "lattice value", etc.
|
||||
struct ValueKnowledge {
|
||||
ValueKnowledge() = delete;
|
||||
ValueKnowledge(Type dtype, OptionalKnowledge optionalKnowledge)
|
||||
: dtype(dtype), optional(optionalKnowledge) {}
|
||||
ValueKnowledge(Type dtype, Type scalarType,
|
||||
OptionalKnowledge optionalKnowledge,
|
||||
torch_upstream::TypeKind kind)
|
||||
: dtype(dtype), scalarType(scalarType), kind(kind),
|
||||
optional(optionalKnowledge) {}
|
||||
|
||||
// Get the static knowledge intrinsic to `type`.
|
||||
static ValueKnowledge getKnowledgeFromType(Type type) {
|
||||
ValueKnowledge result = getPessimisticValueState(type.getContext());
|
||||
if (auto tensorType = type.dyn_cast<BaseTensorType>()) {
|
||||
result.dtype = tensorType.getOptionalDtype();
|
||||
result.kind = getTypeKind(type);
|
||||
switch (result.kind) {
|
||||
case torch_upstream::TypeKind::TensorType:
|
||||
result.dtype = type.cast<BaseTensorType>().getOptionalDtype();
|
||||
result.optional = OptionalKnowledge::notNone;
|
||||
} else if (auto optionalType = type.dyn_cast<Torch::NoneType>()) {
|
||||
return result;
|
||||
case torch_upstream::TypeKind::NumberType:
|
||||
case torch_upstream::TypeKind::IntType:
|
||||
case torch_upstream::TypeKind::FloatType:
|
||||
result.scalarType = type;
|
||||
result.optional = OptionalKnowledge::notNone;
|
||||
return result;
|
||||
case torch_upstream::TypeKind::NoneType:
|
||||
result.optional = OptionalKnowledge::isNone;
|
||||
} else if (!type.isa<OptionalType>()) {
|
||||
return result;
|
||||
default:
|
||||
if (type.isa<OptionalType>())
|
||||
return result;
|
||||
// All other types that are not optional type.
|
||||
result.optional = OptionalKnowledge::notNone;
|
||||
return result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Return a pessimistic/conservative value state without assuming any knowlege
|
||||
// about the IR.
|
||||
static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
|
||||
return ValueKnowledge(Type(), OptionalKnowledge::unKnown);
|
||||
return ValueKnowledge(Type(), Type(), OptionalKnowledge::unKnown,
|
||||
torch_upstream::TypeKind::AnyType);
|
||||
}
|
||||
// Return a pessimistic/conservative value state only using knowlege already
|
||||
// recorded in the IR.
|
||||
|
@ -177,7 +207,19 @@ struct ValueKnowledge {
|
|||
}
|
||||
|
||||
static ValueKnowledge getNotNonePessimisticValueState(MLIRContext *context) {
|
||||
return ValueKnowledge(Type(), OptionalKnowledge::notNone);
|
||||
return ValueKnowledge(Type(), Type(), OptionalKnowledge::notNone,
|
||||
torch_upstream::TypeKind::AnyType);
|
||||
}
|
||||
|
||||
static ValueKnowledge getTensorPessimisticValueState(MLIRContext *context) {
|
||||
return ValueKnowledge(Type(), Type(), OptionalKnowledge::notNone,
|
||||
torch_upstream::TypeKind::TensorType);
|
||||
}
|
||||
|
||||
static ValueKnowledge getScalarPessimisticValueState(MLIRContext *context) {
|
||||
return ValueKnowledge(Type(), NumberType::get(context),
|
||||
OptionalKnowledge::notNone,
|
||||
torch_upstream::TypeKind::NumberType);
|
||||
}
|
||||
|
||||
bool operator==(const ValueKnowledge &rhs) const {
|
||||
|
@ -185,6 +227,25 @@ struct ValueKnowledge {
|
|||
std::make_tuple(rhs.dtype, rhs.optional);
|
||||
}
|
||||
|
||||
// Return true if the `refinedType` has more concrete type info than `type`.
|
||||
static bool hasStrictlyMoreRefinedTypeInfo(const ValueKnowledge &refinedType,
|
||||
const ValueKnowledge &type) {
|
||||
if (type.kind == torch_upstream::TypeKind::AnyType &&
|
||||
refinedType.kind != torch_upstream::TypeKind::AnyType)
|
||||
return true;
|
||||
|
||||
// If both are tensors but `type` doesn't have concrete dtype info.
|
||||
if (refinedType.kind == torch_upstream::TypeKind::TensorType &&
|
||||
type.kind == torch_upstream::TypeKind::TensorType) {
|
||||
return refinedType.dtype && !type.dtype;
|
||||
}
|
||||
|
||||
if (refinedType.scalarType && type.scalarType)
|
||||
return isValidSubtype(refinedType.scalarType, type.scalarType);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Given two pieces of static knowledge, intersect the facts that are known in
|
||||
// both knowledges. This always produces knowledge that has less (or equal)
|
||||
// facts than both the lhs and rhs.
|
||||
|
@ -200,40 +261,79 @@ struct ValueKnowledge {
|
|||
// Mental model: All conditions are checking how to change from the safe "no
|
||||
// knowledge" default-initialized state to a state with more knowledge
|
||||
// consistent with lhs and rhs.
|
||||
ValueKnowledge result = getPessimisticValueState(nullptr);
|
||||
|
||||
ValueKnowledge result = joinTypes(lhs, rhs);
|
||||
result.optional = joinOptionalKnowledge(lhs.optional, rhs.optional);
|
||||
result.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static ValueKnowledge joinTypes(const ValueKnowledge &lhs,
|
||||
const ValueKnowledge &rhs) {
|
||||
if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs))
|
||||
return rhs;
|
||||
if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs))
|
||||
return lhs;
|
||||
if (lhs == rhs)
|
||||
return lhs;
|
||||
return getPessimisticValueState(nullptr);
|
||||
}
|
||||
|
||||
// Given two pieces of static knowledge, calculate new knowledge that assumes
|
||||
// the facts from both.
|
||||
// If the two pieces of knowledge are contradictory, None is returned.
|
||||
static Optional<ValueKnowledge> meet(const ValueKnowledge &lhs,
|
||||
const ValueKnowledge &rhs) {
|
||||
ValueKnowledge result = getPessimisticValueState(nullptr);
|
||||
Optional<ValueKnowledge> knowledge = meetTypes(lhs, rhs);
|
||||
|
||||
if (!knowledge.hasValue())
|
||||
return None;
|
||||
ValueKnowledge result = knowledge.getValue();
|
||||
|
||||
Optional<OptionalKnowledge> optional =
|
||||
meetOptionalKnowledge(lhs.optional, rhs.optional);
|
||||
if (!optional.hasValue())
|
||||
return None;
|
||||
result.optional = optional.getValue();
|
||||
|
||||
Optional<Type> dtype = meetElementTypes(lhs.dtype, rhs.dtype);
|
||||
if (!dtype.hasValue())
|
||||
return None;
|
||||
result.dtype = dtype.getValue();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static Optional<ValueKnowledge> meetTypes(const ValueKnowledge &lhs,
|
||||
const ValueKnowledge &rhs) {
|
||||
if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs))
|
||||
return lhs;
|
||||
if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs))
|
||||
return rhs;
|
||||
if (lhs == rhs)
|
||||
return lhs;
|
||||
return None;
|
||||
}
|
||||
|
||||
// The dtype of a tensor.
|
||||
// This is equal to nullptr if we don't know that it is a specific concrete
|
||||
// type.
|
||||
// This is equal to nullptr for the follow cases:
|
||||
// 1. it is unknown whether the value is a tensor or not, ie the `kind` field
|
||||
// is torch_upstream::TypeKind::AnyType.
|
||||
// 2. the value is a tensor type but the dtype is unknown.
|
||||
// 3. the value is not a tensor type.
|
||||
Type dtype;
|
||||
|
||||
// The type of a scalar.
|
||||
// This is equal to nullptr for the follow cases:
|
||||
// 1. it is unknown whether the value is a scalar or not, ie the `kind` field
|
||||
// is torch_upstream::TypeKind::AnyType.
|
||||
// 2. the value is not a scalar type.
|
||||
Type scalarType;
|
||||
|
||||
// The type kind. If it's torch_upstream::TypeKind::AnyType,
|
||||
// all the type fields are nullptr. Note that the `kind` never equals to
|
||||
// torch_upstream::TypeKind::OptionalType because optional knowledge is
|
||||
// tracked separately through the `optional` field.
|
||||
torch_upstream::TypeKind kind;
|
||||
|
||||
// What is known about an optional value.
|
||||
// When equal to OptionalKnowledge::notNone, the type info is kept in type
|
||||
// fields like `dtype`, `scalarType`.
|
||||
// When equal to OptionalKnowledge::isNone or OptionalKnowledge::unKnown, the
|
||||
// other type fields are currently nullptr. It might worth considering
|
||||
// tracking wrapped type info when OptionalKnowledge::unKnown in the future.
|
||||
OptionalKnowledge optional;
|
||||
};
|
||||
|
||||
|
@ -500,11 +600,9 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
|
||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp,
|
||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = operands[0]->getValue().dtype;
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||
PrimAbsScalarOp>(op)) {
|
||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
}
|
||||
|
||||
// Dtype is always float32, except for bfloat16, float64 and nullptr.
|
||||
|
@ -512,7 +610,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp,
|
||||
AtenErfOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
Type dtype = operands[0]->getValue().dtype;
|
||||
if (dtype) {
|
||||
knowledge.dtype = Float32Type::get(op->getContext());
|
||||
|
@ -526,7 +624,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
|
||||
auto self = operands[1]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = self.dtype;
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
@ -536,7 +634,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
|
||||
AtenGtTensorOp, AtenLtTensorOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = IntegerType::get(op->getContext(), 1);
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
@ -544,7 +642,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
// Dtype is always si64.
|
||||
if (isa<AtenBincountOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
|
@ -554,7 +652,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||
AtenConvolutionOverrideableOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
|
@ -565,7 +663,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
Aten__And__TensorOp, AtenMinimumOp, AtenMaximumOp,
|
||||
AtenBitwiseAndTensorOp, AtenThresholdBackwardOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultType(
|
||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
|
||||
getRankIsNonZeroArray(op->getOperands()));
|
||||
|
@ -575,7 +673,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
// Promote three dtypes.
|
||||
if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue(),
|
||||
&operands[2]->getValue()});
|
||||
|
@ -593,7 +691,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
auto lhs = operands[0]->getValue();
|
||||
Value scalar = op->getOperand(1);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
@ -601,7 +699,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
// Promote 2nd and 3rd operands.
|
||||
if (isa<AtenWhereSelfOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||
knowledge.dtype = getPromotedResultType(
|
||||
getContext(), {&operands[1]->getValue(), &operands[2]->getValue()},
|
||||
getRankIsNonZeroArray(op->getOperands().slice(1, 2)));
|
||||
|
@ -613,7 +711,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
Value lhsScalar = op->getOperand(1);
|
||||
Value rhsScalar = op->getOperand(2);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||
knowledge.dtype =
|
||||
getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()});
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
|
@ -624,7 +722,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
auto lhs = operands[1]->getValue();
|
||||
Value scalar = op->getOperand(2);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
@ -634,7 +732,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
auto rhs = operands[2]->getValue();
|
||||
Value scalar = op->getOperand(1);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||
knowledge.dtype = getPromotedResultType(&rhs, scalar.getType());
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
@ -643,10 +741,10 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
if (isa<AtenNllLossForwardOp>(op)) {
|
||||
auto self = operands[0]->getValue();
|
||||
auto result0Knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
result0Knowledge.dtype = self.dtype;
|
||||
auto result1Knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
result1Knowledge.dtype = self.dtype;
|
||||
auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge);
|
||||
changed |= incorporateKnowledge(op->getResult(1), result1Knowledge);
|
||||
|
@ -657,13 +755,13 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp>(op)) {
|
||||
auto self = operands[0]->getValue();
|
||||
auto result0Knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
result0Knowledge.dtype = self.dtype;
|
||||
auto result1Knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
result1Knowledge.dtype = self.dtype;
|
||||
auto result2Knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
result2Knowledge.dtype = self.dtype;
|
||||
auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge);
|
||||
changed |= incorporateKnowledge(op->getResult(1), result1Knowledge);
|
||||
|
@ -674,10 +772,10 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
if (isa<AtenMaxPool2dWithIndicesOp>(op)) {
|
||||
auto self = operands[0]->getValue();
|
||||
auto result0Knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
result0Knowledge.dtype = self.dtype;
|
||||
auto result1Knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
result1Knowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
;
|
||||
|
@ -700,7 +798,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
Type defaultDtype = operands[0]->getValue().dtype;
|
||||
Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = dtype;
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
@ -816,7 +914,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
|
||||
if (auto shapeAsTensor = dyn_cast<Aten_ShapeAsTensorOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
return incorporateKnowledge(shapeAsTensor.getResult(), knowledge);
|
||||
|
@ -824,7 +922,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
|
||||
if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = Float32Type::get(op->getContext());
|
||||
return incorporateKnowledge(embedding.getResult(), knowledge);
|
||||
}
|
||||
|
@ -864,7 +962,7 @@ TypeAnalyzer::incorporateKnowledge(Value v, const ValueKnowledge &knowledge) {
|
|||
ChangeResult TypeAnalyzer::visitAtenLinearOp(
|
||||
AtenLinearOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
auto input = operands[0]->getValue();
|
||||
auto weight = operands[1]->getValue();
|
||||
auto bias = operands[2]->getValue();
|
||||
|
@ -889,7 +987,7 @@ ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
|
|||
Operation *op, llvm::Optional<Value> start, Value end,
|
||||
llvm::Optional<Value> step, Value dtype) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
int64_t dtypeInt;
|
||||
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) {
|
||||
knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt);
|
||||
|
@ -930,7 +1028,8 @@ ChangeResult TypeAnalyzer::visitAtenArangeOp(AtenArangeOp op) {
|
|||
ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp(
|
||||
Operation *op, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = dtype;
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
@ -941,7 +1040,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
|
|||
Operation *op, Value dim, Value keepdim, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = dtype;
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
@ -951,7 +1050,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
|
|||
ArrayRef<LatticeElement<ValueKnowledge> *> operands, int resNum) {
|
||||
assert(dim.getType().isa<Torch::IntType>() && "dim must be int type");
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = dtype;
|
||||
return incorporateKnowledge(op->getResult(resNum), knowledge);
|
||||
}
|
||||
|
@ -959,7 +1058,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
|
|||
template <typename OpTy>
|
||||
ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op.getContext());
|
||||
Value t = op.t();
|
||||
Value dtype = op.dtype();
|
||||
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType());
|
||||
|
@ -969,15 +1068,17 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
|
|||
ChangeResult TypeAnalyzer::visitBinaryScalarOp(
|
||||
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultType(
|
||||
ValueKnowledge::getScalarPessimisticValueState(op->getContext());
|
||||
Type resultType = getPromotedResultType(
|
||||
{op->getOperand(0).getType(), op->getOperand(1).getType()});
|
||||
knowledge.scalarType = resultType;
|
||||
knowledge.kind = getTypeKind(resultType);
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op.getContext());
|
||||
Value data = op.data();
|
||||
Value dtype = op.dtype();
|
||||
Type type = data.getType();
|
||||
|
@ -993,7 +1094,7 @@ ChangeResult
|
|||
TypeAnalyzer::visitConstantTensorAllocOp(OpTy op,
|
||||
llvm::Optional<Type> dataType) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
if (!dataType)
|
||||
dataType = Torch::FloatType::get(op->getContext());
|
||||
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue());
|
||||
|
@ -1005,7 +1106,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocLikeOp(
|
|||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
|
||||
return incorporateKnowledge(op.getResult(), knowledge);
|
||||
}
|
||||
|
@ -1015,7 +1116,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorNewLikeOp(
|
|||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
|
||||
return incorporateKnowledge(op.getResult(), knowledge);
|
||||
}
|
||||
|
@ -1025,7 +1126,7 @@ template <typename OpTy>
|
|||
ChangeResult TypeAnalyzer::visitAtenToDtypeLikeOp(
|
||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
Value dtype = op.dtype();
|
||||
int64_t dtypeInt;
|
||||
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
|
@ -1038,7 +1139,7 @@ template <typename OpTy>
|
|||
ChangeResult TypeAnalyzer::visitTypeConversionOp(
|
||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
Value other = op.other();
|
||||
BaseTensorType type = other.getType().cast<BaseTensorType>();
|
||||
if (type.hasDtype())
|
||||
|
@ -1053,7 +1154,7 @@ ChangeResult TypeAnalyzer::visitAtenCatOp(
|
|||
AtenCatOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto tensorList = op.tensors();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
|
||||
if (!listConstruct)
|
||||
return incorporateKnowledge(op.getResult(), knowledge);
|
||||
|
@ -1073,7 +1174,7 @@ ChangeResult TypeAnalyzer::visitAtenCatOp(
|
|||
|
||||
ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
// The resulting type from converting a Scalar into a Tensor is different
|
||||
// if the scalar is part of a tensor operation (such as AtenMulScalar) or
|
||||
// not. In the former case, the type promotion rules are captured by the
|
||||
|
@ -1098,7 +1199,7 @@ ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
|
|||
auto input = operands[0]->getValue();
|
||||
auto dtype = op.dtype();
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
|
||||
return incorporateKnowledge(op.getResult(), knowledge);
|
||||
}
|
||||
|
@ -1109,7 +1210,7 @@ ChangeResult TypeAnalyzer::visitAten_SoftmaxLikeOp(
|
|||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
bool halfToFloat;
|
||||
if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) {
|
||||
knowledge.dtype =
|
||||
|
@ -1154,6 +1255,16 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
|
|||
else
|
||||
return containedType;
|
||||
}
|
||||
} else if (auto scalarType = v.getType().dyn_cast<NumberType>()) {
|
||||
LatticeElement<ValueKnowledge> *latticeElement =
|
||||
analyzer.lookupLatticeElement(v);
|
||||
if (!latticeElement)
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
if (knowledge.kind == torch_upstream::TypeKind::IntType)
|
||||
return Torch::IntType::get(v.getContext());
|
||||
if (knowledge.kind == torch_upstream::TypeKind::FloatType)
|
||||
return Torch::FloatType::get(v.getContext());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1217,7 +1328,7 @@ void optimize(func::FuncOp func, TypeAnalyzer &analyzer) {
|
|||
return b.create<TensorStaticInfoCastOp>(loc, newType, v);
|
||||
};
|
||||
createStaticInfoUpCast = createStaticInfoDownCast;
|
||||
} else if (originalType.isa<OptionalType>()) {
|
||||
} else if (originalType.isa<OptionalType, NumberType>()) {
|
||||
createStaticInfoDownCast = [&](Location loc, Type newType,
|
||||
Value v) -> Value {
|
||||
return b.create<PrimUncheckedCastOp>(loc, newType, v);
|
||||
|
|
|
@ -543,6 +543,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
|
||||
emit("prim::Print : (...) -> ()")
|
||||
emit("prim::tolist : (...) -> (...)")
|
||||
emit("prim::abs.Scalar : (Scalar) -> (Scalar)")
|
||||
|
||||
# ==========================================================================
|
||||
# `quantized::` namespace.
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
// Code for testing transfer functions for new ops (which is most changes)
|
||||
// should go in refine-types-ops.mlir.
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
|
@ -16,6 +17,7 @@ func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
|||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @keep_existing_shape_information(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
||||
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
|
||||
|
@ -25,6 +27,7 @@ func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vt
|
|||
return %1 : !torch.vtensor<[2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @propagate_through_multiple_ops(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
|
@ -39,6 +42,7 @@ func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vte
|
|||
return %3 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// Check rewriting logic in case of mixes of users that do/don't allow type
|
||||
// refinement.
|
||||
// CHECK-LABEL: func @mixed_allowing_not_allowing_type_refinement(
|
||||
|
@ -53,6 +57,7 @@ func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>)
|
|||
return %1, %1 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @type_promotion$same_category_different_width(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
|
||||
|
@ -66,6 +71,7 @@ func @type_promotion$same_category_different_width(%arg0: !torch.vtensor<[?],si3
|
|||
return %0 : !torch.vtensor<[?],unk>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @type_promotion$different_category(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
|
||||
|
@ -79,6 +85,7 @@ func @type_promotion$different_category(%arg0: !torch.vtensor<[?],si64>, %arg1:
|
|||
return %0 : !torch.vtensor<[?],unk>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @type_promotion$same_category_zero_rank_wider(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
|
||||
|
@ -92,6 +99,7 @@ func @type_promotion$same_category_zero_rank_wider(%arg0: !torch.vtensor<[?],f32
|
|||
return %0 : !torch.vtensor<[?],unk>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @type_promotion$zero_rank_higher_category(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
|
||||
|
@ -105,6 +113,7 @@ func @type_promotion$zero_rank_higher_category(%arg0: !torch.vtensor<[?],si64>,
|
|||
return %0 : !torch.vtensor<[?],unk>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @type_promotion$alpha_wider(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
|
||||
|
@ -118,6 +127,7 @@ func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.v
|
|||
return %0 : !torch.vtensor<[?],unk>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
||||
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
|
||||
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
||||
|
@ -134,6 +144,7 @@ func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.
|
|||
return %result : !torch.vtensor<[2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
|
@ -151,6 +162,7 @@ func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.
|
|||
return %result : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @bf16_result_type(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
|
||||
// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16>
|
||||
|
@ -159,3 +171,16 @@ func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16
|
|||
%1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16>
|
||||
return %1 : !torch.vtensor<[2],bf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @propagate_scalar_type(
|
||||
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
|
||||
// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number
|
||||
// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int
|
||||
// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number
|
||||
// CHECK: return %[[RET]] : !torch.number
|
||||
func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number {
|
||||
%num = torch.derefine %arg0 : !torch.int to !torch.number
|
||||
%1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number
|
||||
return %1 : !torch.number
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue