Add support for scalar type propagation

The main changes are:
- Added `ValueKnowledge.scalarType` to track scalar type information.
- Added `ValueKnowledge.kind` to indicate the value kind.
- Modified the meet and join helper functions. The ValueKnowledge has
slightly more complicated state now so the meet and join function need
to look at the `kind` field in addition to just the type field.
pull/822/head snapshot-20220504.431
Yi Zhang 2022-05-03 11:59:49 -04:00
parent 0fb7a03ac9
commit 9f7264a7a4
5 changed files with 282 additions and 73 deletions

View File

@ -7398,6 +7398,29 @@ def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands)) `->` qualified(type($results))"; let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands)) `->` qualified(type($results))";
} }
def Torch_PrimAbsScalarOp : Torch_Op<"prim.abs.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::abs.Scalar : (Scalar) -> (Scalar)`";
let arguments = (ins
AnyTorchScalarType:$a
);
let results = (outs
AnyTorchScalarType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult PrimAbsScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void PrimAbsScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
HasValueSemantics, HasValueSemantics,
AllowsTypeRefinement, AllowsTypeRefinement,

View File

@ -29,6 +29,55 @@ namespace mlir {
namespace torch { namespace torch {
namespace torch_upstream { namespace torch_upstream {
//===----------------------------------------------------------------------===//
// TypeKind related enum related code are copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type_base.h
//===----------------------------------------------------------------------===//
#define C10_FORALL_TYPES(_) \
_(AnyType) \
_(EnumType) \
_(AnyEnumType) \
_(TensorType) \
_(StorageType) \
_(TupleType) \
_(ListType) \
_(DictType) \
_(NumberType) \
_(FloatType) \
_(ComplexType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
_(NoneType) \
_(StringType) \
_(GeneratorType) \
_(QuantizerType) \
_(BoolType) \
_(OptionalType) \
_(VarType) \
_(DeviceObjType) \
_(StreamObjType) \
_(FunctionType) \
_(ClassType) \
_(PyObjectType) \
_(CapsuleType) \
_(InterfaceType) \
_(QSchemeType) \
_(LayoutType) \
_(ScalarTypeType) \
_(AnyListType) \
_(AnyTupleType) \
_(AnyClassType) \
_(SymIntType) \
_(UnionType) \
_(DynamicType)
enum class TypeKind {
#define DEFINE_TYPE(T) T,
C10_FORALL_TYPES(DEFINE_TYPE)
#undef DEFINE_TYPE
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ScalarType enum related code are copied from c10/core/ScalarType.h // ScalarType enum related code are copied from c10/core/ScalarType.h
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -92,14 +92,27 @@ static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
return Type(); return Type();
} }
static Type joinElementTypes(Type lhs, Type rhs) { // Get the kind enum for `ValueKnowledge.kind`.
if (lhs == rhs) static torch_upstream::TypeKind getTypeKind(Type type) {
return lhs; if (type.isa<NumberType>())
return Type(); return torch_upstream::TypeKind::NumberType;
if (type.isa<IntType>())
return torch_upstream::TypeKind::IntType;
if (type.isa<Torch::FloatType>())
return torch_upstream::TypeKind::FloatType;
if (type.isa<BaseTensorType>())
return torch_upstream::TypeKind::TensorType;
if (type.isa<Torch::NoneType>())
return torch_upstream::TypeKind::NoneType;
// Skip the Torch::OptionalType on purpose because optional knowledge is
// tracked separately. See comments for `ValueKnowledge.kind` field.
return torch_upstream::TypeKind::AnyType;
} }
/// Returns the dtype that assumes information from both `lhs` and `rhs`. /// Returns the dtype that assumes information from both `lhs` and `rhs`.
/// Returns `None` if the types are contradictory. /// Returns `None` if the types are contradictory. Note this can only be used
/// on the `dtype` from tensors and can't be used on other types like scalar
/// types.
static Optional<Type> meetElementTypes(Type lhs, Type rhs) { static Optional<Type> meetElementTypes(Type lhs, Type rhs) {
if (!lhs) if (!lhs)
return rhs; return rhs;
@ -148,27 +161,44 @@ namespace {
// This class could also be called "dataflow facts", "lattice value", etc. // This class could also be called "dataflow facts", "lattice value", etc.
struct ValueKnowledge { struct ValueKnowledge {
ValueKnowledge() = delete; ValueKnowledge() = delete;
ValueKnowledge(Type dtype, OptionalKnowledge optionalKnowledge) ValueKnowledge(Type dtype, Type scalarType,
: dtype(dtype), optional(optionalKnowledge) {} OptionalKnowledge optionalKnowledge,
torch_upstream::TypeKind kind)
: dtype(dtype), scalarType(scalarType), kind(kind),
optional(optionalKnowledge) {}
// Get the static knowledge intrinsic to `type`. // Get the static knowledge intrinsic to `type`.
static ValueKnowledge getKnowledgeFromType(Type type) { static ValueKnowledge getKnowledgeFromType(Type type) {
ValueKnowledge result = getPessimisticValueState(type.getContext()); ValueKnowledge result = getPessimisticValueState(type.getContext());
if (auto tensorType = type.dyn_cast<BaseTensorType>()) { result.kind = getTypeKind(type);
result.dtype = tensorType.getOptionalDtype(); switch (result.kind) {
case torch_upstream::TypeKind::TensorType:
result.dtype = type.cast<BaseTensorType>().getOptionalDtype();
result.optional = OptionalKnowledge::notNone; result.optional = OptionalKnowledge::notNone;
} else if (auto optionalType = type.dyn_cast<Torch::NoneType>()) { return result;
case torch_upstream::TypeKind::NumberType:
case torch_upstream::TypeKind::IntType:
case torch_upstream::TypeKind::FloatType:
result.scalarType = type;
result.optional = OptionalKnowledge::notNone;
return result;
case torch_upstream::TypeKind::NoneType:
result.optional = OptionalKnowledge::isNone; result.optional = OptionalKnowledge::isNone;
} else if (!type.isa<OptionalType>()) { return result;
default:
if (type.isa<OptionalType>())
return result;
// All other types that are not optional type.
result.optional = OptionalKnowledge::notNone; result.optional = OptionalKnowledge::notNone;
return result;
} }
return result;
} }
// Return a pessimistic/conservative value state without assuming any knowlege // Return a pessimistic/conservative value state without assuming any knowlege
// about the IR. // about the IR.
static ValueKnowledge getPessimisticValueState(MLIRContext *context) { static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
return ValueKnowledge(Type(), OptionalKnowledge::unKnown); return ValueKnowledge(Type(), Type(), OptionalKnowledge::unKnown,
torch_upstream::TypeKind::AnyType);
} }
// Return a pessimistic/conservative value state only using knowlege already // Return a pessimistic/conservative value state only using knowlege already
// recorded in the IR. // recorded in the IR.
@ -177,7 +207,19 @@ struct ValueKnowledge {
} }
static ValueKnowledge getNotNonePessimisticValueState(MLIRContext *context) { static ValueKnowledge getNotNonePessimisticValueState(MLIRContext *context) {
return ValueKnowledge(Type(), OptionalKnowledge::notNone); return ValueKnowledge(Type(), Type(), OptionalKnowledge::notNone,
torch_upstream::TypeKind::AnyType);
}
static ValueKnowledge getTensorPessimisticValueState(MLIRContext *context) {
return ValueKnowledge(Type(), Type(), OptionalKnowledge::notNone,
torch_upstream::TypeKind::TensorType);
}
static ValueKnowledge getScalarPessimisticValueState(MLIRContext *context) {
return ValueKnowledge(Type(), NumberType::get(context),
OptionalKnowledge::notNone,
torch_upstream::TypeKind::NumberType);
} }
bool operator==(const ValueKnowledge &rhs) const { bool operator==(const ValueKnowledge &rhs) const {
@ -185,6 +227,25 @@ struct ValueKnowledge {
std::make_tuple(rhs.dtype, rhs.optional); std::make_tuple(rhs.dtype, rhs.optional);
} }
// Return true if the `refinedType` has more concrete type info than `type`.
static bool hasStrictlyMoreRefinedTypeInfo(const ValueKnowledge &refinedType,
const ValueKnowledge &type) {
if (type.kind == torch_upstream::TypeKind::AnyType &&
refinedType.kind != torch_upstream::TypeKind::AnyType)
return true;
// If both are tensors but `type` doesn't have concrete dtype info.
if (refinedType.kind == torch_upstream::TypeKind::TensorType &&
type.kind == torch_upstream::TypeKind::TensorType) {
return refinedType.dtype && !type.dtype;
}
if (refinedType.scalarType && type.scalarType)
return isValidSubtype(refinedType.scalarType, type.scalarType);
return false;
}
// Given two pieces of static knowledge, intersect the facts that are known in // Given two pieces of static knowledge, intersect the facts that are known in
// both knowledges. This always produces knowledge that has less (or equal) // both knowledges. This always produces knowledge that has less (or equal)
// facts than both the lhs and rhs. // facts than both the lhs and rhs.
@ -200,40 +261,79 @@ struct ValueKnowledge {
// Mental model: All conditions are checking how to change from the safe "no // Mental model: All conditions are checking how to change from the safe "no
// knowledge" default-initialized state to a state with more knowledge // knowledge" default-initialized state to a state with more knowledge
// consistent with lhs and rhs. // consistent with lhs and rhs.
ValueKnowledge result = getPessimisticValueState(nullptr); ValueKnowledge result = joinTypes(lhs, rhs);
result.optional = joinOptionalKnowledge(lhs.optional, rhs.optional); result.optional = joinOptionalKnowledge(lhs.optional, rhs.optional);
result.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
return result; return result;
} }
static ValueKnowledge joinTypes(const ValueKnowledge &lhs,
const ValueKnowledge &rhs) {
if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs))
return rhs;
if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs))
return lhs;
if (lhs == rhs)
return lhs;
return getPessimisticValueState(nullptr);
}
// Given two pieces of static knowledge, calculate new knowledge that assumes // Given two pieces of static knowledge, calculate new knowledge that assumes
// the facts from both. // the facts from both.
// If the two pieces of knowledge are contradictory, None is returned. // If the two pieces of knowledge are contradictory, None is returned.
static Optional<ValueKnowledge> meet(const ValueKnowledge &lhs, static Optional<ValueKnowledge> meet(const ValueKnowledge &lhs,
const ValueKnowledge &rhs) { const ValueKnowledge &rhs) {
ValueKnowledge result = getPessimisticValueState(nullptr); Optional<ValueKnowledge> knowledge = meetTypes(lhs, rhs);
if (!knowledge.hasValue())
return None;
ValueKnowledge result = knowledge.getValue();
Optional<OptionalKnowledge> optional = Optional<OptionalKnowledge> optional =
meetOptionalKnowledge(lhs.optional, rhs.optional); meetOptionalKnowledge(lhs.optional, rhs.optional);
if (!optional.hasValue()) if (!optional.hasValue())
return None; return None;
result.optional = optional.getValue(); result.optional = optional.getValue();
Optional<Type> dtype = meetElementTypes(lhs.dtype, rhs.dtype);
if (!dtype.hasValue())
return None;
result.dtype = dtype.getValue();
return result; return result;
} }
static Optional<ValueKnowledge> meetTypes(const ValueKnowledge &lhs,
const ValueKnowledge &rhs) {
if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs))
return lhs;
if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs))
return rhs;
if (lhs == rhs)
return lhs;
return None;
}
// The dtype of a tensor. // The dtype of a tensor.
// This is equal to nullptr if we don't know that it is a specific concrete // This is equal to nullptr for the follow cases:
// type. // 1. it is unknown whether the value is a tensor or not, ie the `kind` field
// is torch_upstream::TypeKind::AnyType.
// 2. the value is a tensor type but the dtype is unknown.
// 3. the value is not a tensor type.
Type dtype; Type dtype;
// The type of a scalar.
// This is equal to nullptr for the follow cases:
// 1. it is unknown whether the value is a scalar or not, ie the `kind` field
// is torch_upstream::TypeKind::AnyType.
// 2. the value is not a scalar type.
Type scalarType;
// The type kind. If it's torch_upstream::TypeKind::AnyType,
// all the type fields are nullptr. Note that the `kind` never equals to
// torch_upstream::TypeKind::OptionalType because optional knowledge is
// tracked separately through the `optional` field.
torch_upstream::TypeKind kind;
// What is known about an optional value. // What is known about an optional value.
// When equal to OptionalKnowledge::notNone, the type info is kept in type
// fields like `dtype`, `scalarType`.
// When equal to OptionalKnowledge::isNone or OptionalKnowledge::unKnown, the
// other type fields are currently nullptr. It might worth considering
// tracking wrapped type info when OptionalKnowledge::unKnown in the future.
OptionalKnowledge optional; OptionalKnowledge optional;
}; };
@ -500,11 +600,9 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp, ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp>(op)) { AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
ValueKnowledge knowledge = PrimAbsScalarOp>(op)) {
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
knowledge.dtype = operands[0]->getValue().dtype;
return incorporateKnowledge(op->getResult(0), knowledge);
} }
// Dtype is always float32, except for bfloat16, float64 and nullptr. // Dtype is always float32, except for bfloat16, float64 and nullptr.
@ -512,7 +610,7 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp,
AtenErfOp>(op)) { AtenErfOp>(op)) {
ValueKnowledge knowledge = ValueKnowledge knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type dtype = operands[0]->getValue().dtype; Type dtype = operands[0]->getValue().dtype;
if (dtype) { if (dtype) {
knowledge.dtype = Float32Type::get(op->getContext()); knowledge.dtype = Float32Type::get(op->getContext());
@ -526,7 +624,7 @@ ChangeResult TypeAnalyzer::visitOperation(
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) { if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
auto self = operands[1]->getValue(); auto self = operands[1]->getValue();
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = self.dtype; knowledge.dtype = self.dtype;
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
} }
@ -536,7 +634,7 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp, AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
AtenGtTensorOp, AtenLtTensorOp>(op)) { AtenGtTensorOp, AtenLtTensorOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op->getContext(), 1); knowledge.dtype = IntegerType::get(op->getContext(), 1);
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
} }
@ -544,7 +642,7 @@ ChangeResult TypeAnalyzer::visitOperation(
// Dtype is always si64. // Dtype is always si64.
if (isa<AtenBincountOp>(op)) { if (isa<AtenBincountOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = knowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed); IntegerType::get(op->getContext(), 64, IntegerType::Signed);
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
@ -554,7 +652,7 @@ ChangeResult TypeAnalyzer::visitOperation(
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp, if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
AtenConvolutionOverrideableOp>(op)) { AtenConvolutionOverrideableOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}); op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
@ -565,7 +663,7 @@ ChangeResult TypeAnalyzer::visitOperation(
Aten__And__TensorOp, AtenMinimumOp, AtenMaximumOp, Aten__And__TensorOp, AtenMinimumOp, AtenMaximumOp,
AtenBitwiseAndTensorOp, AtenThresholdBackwardOp>(op)) { AtenBitwiseAndTensorOp, AtenThresholdBackwardOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultType( knowledge.dtype = getPromotedResultType(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}, op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
getRankIsNonZeroArray(op->getOperands())); getRankIsNonZeroArray(op->getOperands()));
@ -575,7 +673,7 @@ ChangeResult TypeAnalyzer::visitOperation(
// Promote three dtypes. // Promote three dtypes.
if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) { if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue(), op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue(),
&operands[2]->getValue()}); &operands[2]->getValue()});
@ -593,7 +691,7 @@ ChangeResult TypeAnalyzer::visitOperation(
auto lhs = operands[0]->getValue(); auto lhs = operands[0]->getValue();
Value scalar = op->getOperand(1); Value scalar = op->getOperand(1);
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext()); ValueKnowledge::getTensorPessimisticValueState(getContext());
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType()); knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
} }
@ -601,7 +699,7 @@ ChangeResult TypeAnalyzer::visitOperation(
// Promote 2nd and 3rd operands. // Promote 2nd and 3rd operands.
if (isa<AtenWhereSelfOp>(op)) { if (isa<AtenWhereSelfOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext()); ValueKnowledge::getTensorPessimisticValueState(getContext());
knowledge.dtype = getPromotedResultType( knowledge.dtype = getPromotedResultType(
getContext(), {&operands[1]->getValue(), &operands[2]->getValue()}, getContext(), {&operands[1]->getValue(), &operands[2]->getValue()},
getRankIsNonZeroArray(op->getOperands().slice(1, 2))); getRankIsNonZeroArray(op->getOperands().slice(1, 2)));
@ -613,7 +711,7 @@ ChangeResult TypeAnalyzer::visitOperation(
Value lhsScalar = op->getOperand(1); Value lhsScalar = op->getOperand(1);
Value rhsScalar = op->getOperand(2); Value rhsScalar = op->getOperand(2);
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext()); ValueKnowledge::getTensorPessimisticValueState(getContext());
knowledge.dtype = knowledge.dtype =
getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()}); getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()});
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
@ -624,7 +722,7 @@ ChangeResult TypeAnalyzer::visitOperation(
auto lhs = operands[1]->getValue(); auto lhs = operands[1]->getValue();
Value scalar = op->getOperand(2); Value scalar = op->getOperand(2);
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext()); ValueKnowledge::getTensorPessimisticValueState(getContext());
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType()); knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
} }
@ -634,7 +732,7 @@ ChangeResult TypeAnalyzer::visitOperation(
auto rhs = operands[2]->getValue(); auto rhs = operands[2]->getValue();
Value scalar = op->getOperand(1); Value scalar = op->getOperand(1);
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext()); ValueKnowledge::getTensorPessimisticValueState(getContext());
knowledge.dtype = getPromotedResultType(&rhs, scalar.getType()); knowledge.dtype = getPromotedResultType(&rhs, scalar.getType());
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
} }
@ -643,10 +741,10 @@ ChangeResult TypeAnalyzer::visitOperation(
if (isa<AtenNllLossForwardOp>(op)) { if (isa<AtenNllLossForwardOp>(op)) {
auto self = operands[0]->getValue(); auto self = operands[0]->getValue();
auto result0Knowledge = auto result0Knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result0Knowledge.dtype = self.dtype; result0Knowledge.dtype = self.dtype;
auto result1Knowledge = auto result1Knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result1Knowledge.dtype = self.dtype; result1Knowledge.dtype = self.dtype;
auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge); auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge);
changed |= incorporateKnowledge(op->getResult(1), result1Knowledge); changed |= incorporateKnowledge(op->getResult(1), result1Knowledge);
@ -657,13 +755,13 @@ ChangeResult TypeAnalyzer::visitOperation(
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp>(op)) { if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp>(op)) {
auto self = operands[0]->getValue(); auto self = operands[0]->getValue();
auto result0Knowledge = auto result0Knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result0Knowledge.dtype = self.dtype; result0Knowledge.dtype = self.dtype;
auto result1Knowledge = auto result1Knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result1Knowledge.dtype = self.dtype; result1Knowledge.dtype = self.dtype;
auto result2Knowledge = auto result2Knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result2Knowledge.dtype = self.dtype; result2Knowledge.dtype = self.dtype;
auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge); auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge);
changed |= incorporateKnowledge(op->getResult(1), result1Knowledge); changed |= incorporateKnowledge(op->getResult(1), result1Knowledge);
@ -674,10 +772,10 @@ ChangeResult TypeAnalyzer::visitOperation(
if (isa<AtenMaxPool2dWithIndicesOp>(op)) { if (isa<AtenMaxPool2dWithIndicesOp>(op)) {
auto self = operands[0]->getValue(); auto self = operands[0]->getValue();
auto result0Knowledge = auto result0Knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result0Knowledge.dtype = self.dtype; result0Knowledge.dtype = self.dtype;
auto result1Knowledge = auto result1Knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result1Knowledge.dtype = result1Knowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed); IntegerType::get(op->getContext(), 64, IntegerType::Signed);
; ;
@ -700,7 +798,7 @@ ChangeResult TypeAnalyzer::visitOperation(
Type defaultDtype = operands[0]->getValue().dtype; Type defaultDtype = operands[0]->getValue().dtype;
Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype); Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype);
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = dtype; knowledge.dtype = dtype;
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
} }
@ -816,7 +914,7 @@ ChangeResult TypeAnalyzer::visitOperation(
if (auto shapeAsTensor = dyn_cast<Aten_ShapeAsTensorOp>(op)) { if (auto shapeAsTensor = dyn_cast<Aten_ShapeAsTensorOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = knowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed); IntegerType::get(op->getContext(), 64, IntegerType::Signed);
return incorporateKnowledge(shapeAsTensor.getResult(), knowledge); return incorporateKnowledge(shapeAsTensor.getResult(), knowledge);
@ -824,7 +922,7 @@ ChangeResult TypeAnalyzer::visitOperation(
if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) { if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = Float32Type::get(op->getContext()); knowledge.dtype = Float32Type::get(op->getContext());
return incorporateKnowledge(embedding.getResult(), knowledge); return incorporateKnowledge(embedding.getResult(), knowledge);
} }
@ -864,7 +962,7 @@ TypeAnalyzer::incorporateKnowledge(Value v, const ValueKnowledge &knowledge) {
ChangeResult TypeAnalyzer::visitAtenLinearOp( ChangeResult TypeAnalyzer::visitAtenLinearOp(
AtenLinearOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { AtenLinearOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
auto input = operands[0]->getValue(); auto input = operands[0]->getValue();
auto weight = operands[1]->getValue(); auto weight = operands[1]->getValue();
auto bias = operands[2]->getValue(); auto bias = operands[2]->getValue();
@ -889,7 +987,7 @@ ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper(
Operation *op, llvm::Optional<Value> start, Value end, Operation *op, llvm::Optional<Value> start, Value end,
llvm::Optional<Value> step, Value dtype) { llvm::Optional<Value> step, Value dtype) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
int64_t dtypeInt; int64_t dtypeInt;
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) { if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) {
knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt); knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt);
@ -930,7 +1028,8 @@ ChangeResult TypeAnalyzer::visitAtenArangeOp(AtenArangeOp op) {
ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp( ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp(
Operation *op, Type dtype, Operation *op, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) { ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = dtype; knowledge.dtype = dtype;
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
} }
@ -941,7 +1040,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
Operation *op, Value dim, Value keepdim, Type dtype, Operation *op, Value dim, Value keepdim, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) { ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = dtype; knowledge.dtype = dtype;
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
} }
@ -951,7 +1050,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
ArrayRef<LatticeElement<ValueKnowledge> *> operands, int resNum) { ArrayRef<LatticeElement<ValueKnowledge> *> operands, int resNum) {
assert(dim.getType().isa<Torch::IntType>() && "dim must be int type"); assert(dim.getType().isa<Torch::IntType>() && "dim must be int type");
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = dtype; knowledge.dtype = dtype;
return incorporateKnowledge(op->getResult(resNum), knowledge); return incorporateKnowledge(op->getResult(resNum), knowledge);
} }
@ -959,7 +1058,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
template <typename OpTy> template <typename OpTy>
ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) { ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); ValueKnowledge::getTensorPessimisticValueState(op.getContext());
Value t = op.t(); Value t = op.t();
Value dtype = op.dtype(); Value dtype = op.dtype();
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType()); fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType());
@ -969,15 +1068,17 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
ChangeResult TypeAnalyzer::visitBinaryScalarOp( ChangeResult TypeAnalyzer::visitBinaryScalarOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getScalarPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultType( Type resultType = getPromotedResultType(
{op->getOperand(0).getType(), op->getOperand(1).getType()}); {op->getOperand(0).getType(), op->getOperand(1).getType()});
knowledge.scalarType = resultType;
knowledge.kind = getTypeKind(resultType);
return incorporateKnowledge(op->getResult(0), knowledge); return incorporateKnowledge(op->getResult(0), knowledge);
} }
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) { ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); ValueKnowledge::getTensorPessimisticValueState(op.getContext());
Value data = op.data(); Value data = op.data();
Value dtype = op.dtype(); Value dtype = op.dtype();
Type type = data.getType(); Type type = data.getType();
@ -993,7 +1094,7 @@ ChangeResult
TypeAnalyzer::visitConstantTensorAllocOp(OpTy op, TypeAnalyzer::visitConstantTensorAllocOp(OpTy op,
llvm::Optional<Type> dataType) { llvm::Optional<Type> dataType) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
if (!dataType) if (!dataType)
dataType = Torch::FloatType::get(op->getContext()); dataType = Torch::FloatType::get(op->getContext());
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue()); fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue());
@ -1005,7 +1106,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue(); auto input = operands[0]->getValue();
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype); fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
return incorporateKnowledge(op.getResult(), knowledge); return incorporateKnowledge(op.getResult(), knowledge);
} }
@ -1015,7 +1116,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorNewLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue(); auto input = operands[0]->getValue();
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype); fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
return incorporateKnowledge(op.getResult(), knowledge); return incorporateKnowledge(op.getResult(), knowledge);
} }
@ -1025,7 +1126,7 @@ template <typename OpTy>
ChangeResult TypeAnalyzer::visitAtenToDtypeLikeOp( ChangeResult TypeAnalyzer::visitAtenToDtypeLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Value dtype = op.dtype(); Value dtype = op.dtype();
int64_t dtypeInt; int64_t dtypeInt;
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
@ -1038,7 +1139,7 @@ template <typename OpTy>
ChangeResult TypeAnalyzer::visitTypeConversionOp( ChangeResult TypeAnalyzer::visitTypeConversionOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Value other = op.other(); Value other = op.other();
BaseTensorType type = other.getType().cast<BaseTensorType>(); BaseTensorType type = other.getType().cast<BaseTensorType>();
if (type.hasDtype()) if (type.hasDtype())
@ -1053,7 +1154,7 @@ ChangeResult TypeAnalyzer::visitAtenCatOp(
AtenCatOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { AtenCatOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto tensorList = op.tensors(); auto tensorList = op.tensors();
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>(); auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
if (!listConstruct) if (!listConstruct)
return incorporateKnowledge(op.getResult(), knowledge); return incorporateKnowledge(op.getResult(), knowledge);
@ -1073,7 +1174,7 @@ ChangeResult TypeAnalyzer::visitAtenCatOp(
ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) { ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
// The resulting type from converting a Scalar into a Tensor is different // The resulting type from converting a Scalar into a Tensor is different
// if the scalar is part of a tensor operation (such as AtenMulScalar) or // if the scalar is part of a tensor operation (such as AtenMulScalar) or
// not. In the former case, the type promotion rules are captured by the // not. In the former case, the type promotion rules are captured by the
@ -1098,7 +1199,7 @@ ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
auto input = operands[0]->getValue(); auto input = operands[0]->getValue();
auto dtype = op.dtype(); auto dtype = op.dtype();
ValueKnowledge knowledge = ValueKnowledge knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype); fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
return incorporateKnowledge(op.getResult(), knowledge); return incorporateKnowledge(op.getResult(), knowledge);
} }
@ -1109,7 +1210,7 @@ ChangeResult TypeAnalyzer::visitAten_SoftmaxLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue(); auto input = operands[0]->getValue();
ValueKnowledge knowledge = ValueKnowledge knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
bool halfToFloat; bool halfToFloat;
if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) { if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) {
knowledge.dtype = knowledge.dtype =
@ -1154,6 +1255,16 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
else else
return containedType; return containedType;
} }
} else if (auto scalarType = v.getType().dyn_cast<NumberType>()) {
LatticeElement<ValueKnowledge> *latticeElement =
analyzer.lookupLatticeElement(v);
if (!latticeElement)
return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue();
if (knowledge.kind == torch_upstream::TypeKind::IntType)
return Torch::IntType::get(v.getContext());
if (knowledge.kind == torch_upstream::TypeKind::FloatType)
return Torch::FloatType::get(v.getContext());
} }
return nullptr; return nullptr;
} }
@ -1217,7 +1328,7 @@ void optimize(func::FuncOp func, TypeAnalyzer &analyzer) {
return b.create<TensorStaticInfoCastOp>(loc, newType, v); return b.create<TensorStaticInfoCastOp>(loc, newType, v);
}; };
createStaticInfoUpCast = createStaticInfoDownCast; createStaticInfoUpCast = createStaticInfoDownCast;
} else if (originalType.isa<OptionalType>()) { } else if (originalType.isa<OptionalType, NumberType>()) {
createStaticInfoDownCast = [&](Location loc, Type newType, createStaticInfoDownCast = [&](Location loc, Type newType,
Value v) -> Value { Value v) -> Value {
return b.create<PrimUncheckedCastOp>(loc, newType, v); return b.create<PrimUncheckedCastOp>(loc, newType, v);

View File

@ -543,6 +543,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
traits=["DeclareOpInterfaceMethods<CastOpInterface>"]) traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
emit("prim::Print : (...) -> ()") emit("prim::Print : (...) -> ()")
emit("prim::tolist : (...) -> (...)") emit("prim::tolist : (...) -> (...)")
emit("prim::abs.Scalar : (Scalar) -> (Scalar)")
# ========================================================================== # ==========================================================================
# `quantized::` namespace. # `quantized::` namespace.

View File

@ -6,6 +6,7 @@
// Code for testing transfer functions for new ops (which is most changes) // Code for testing transfer functions for new ops (which is most changes)
// should go in refine-types-ops.mlir. // should go in refine-types-ops.mlir.
// -----
// CHECK-LABEL: func @basic( // CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
@ -16,6 +17,7 @@ func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
return %1 : !torch.vtensor return %1 : !torch.vtensor
} }
// -----
// CHECK-LABEL: func @keep_existing_shape_information( // CHECK-LABEL: func @keep_existing_shape_information(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32> // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
@ -25,6 +27,7 @@ func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vt
return %1 : !torch.vtensor<[2],f32> return %1 : !torch.vtensor<[2],f32>
} }
// -----
// CHECK-LABEL: func @propagate_through_multiple_ops( // CHECK-LABEL: func @propagate_through_multiple_ops(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
@ -39,6 +42,7 @@ func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vte
return %3 : !torch.vtensor return %3 : !torch.vtensor
} }
// -----
// Check rewriting logic in case of mixes of users that do/don't allow type // Check rewriting logic in case of mixes of users that do/don't allow type
// refinement. // refinement.
// CHECK-LABEL: func @mixed_allowing_not_allowing_type_refinement( // CHECK-LABEL: func @mixed_allowing_not_allowing_type_refinement(
@ -53,6 +57,7 @@ func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>)
return %1, %1 : !torch.vtensor, !torch.vtensor return %1, %1 : !torch.vtensor, !torch.vtensor
} }
// -----
// CHECK-LABEL: func @type_promotion$same_category_different_width( // CHECK-LABEL: func @type_promotion$same_category_different_width(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
@ -66,6 +71,7 @@ func @type_promotion$same_category_different_width(%arg0: !torch.vtensor<[?],si3
return %0 : !torch.vtensor<[?],unk> return %0 : !torch.vtensor<[?],unk>
} }
// -----
// CHECK-LABEL: func @type_promotion$different_category( // CHECK-LABEL: func @type_promotion$different_category(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
@ -79,6 +85,7 @@ func @type_promotion$different_category(%arg0: !torch.vtensor<[?],si64>, %arg1:
return %0 : !torch.vtensor<[?],unk> return %0 : !torch.vtensor<[?],unk>
} }
// -----
// CHECK-LABEL: func @type_promotion$same_category_zero_rank_wider( // CHECK-LABEL: func @type_promotion$same_category_zero_rank_wider(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
@ -92,6 +99,7 @@ func @type_promotion$same_category_zero_rank_wider(%arg0: !torch.vtensor<[?],f32
return %0 : !torch.vtensor<[?],unk> return %0 : !torch.vtensor<[?],unk>
} }
// -----
// CHECK-LABEL: func @type_promotion$zero_rank_higher_category( // CHECK-LABEL: func @type_promotion$zero_rank_higher_category(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
@ -105,6 +113,7 @@ func @type_promotion$zero_rank_higher_category(%arg0: !torch.vtensor<[?],si64>,
return %0 : !torch.vtensor<[?],unk> return %0 : !torch.vtensor<[?],unk>
} }
// -----
// CHECK-LABEL: func @type_promotion$alpha_wider( // CHECK-LABEL: func @type_promotion$alpha_wider(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
@ -118,6 +127,7 @@ func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.v
return %0 : !torch.vtensor<[?],unk> return %0 : !torch.vtensor<[?],unk>
} }
// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static( // CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>, // CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { // CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
@ -134,6 +144,7 @@ func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.
return %result : !torch.vtensor<[2],f32> return %result : !torch.vtensor<[2],f32>
} }
// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic( // CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
@ -151,6 +162,7 @@ func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.
return %result : !torch.vtensor<[?],f32> return %result : !torch.vtensor<[?],f32>
} }
// -----
// CHECK-LABEL: func @bf16_result_type( // CHECK-LABEL: func @bf16_result_type(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16> // CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16>
@ -159,3 +171,16 @@ func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16
%1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16> %1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16>
return %1 : !torch.vtensor<[2],bf16> return %1 : !torch.vtensor<[2],bf16>
} }
// -----
// CHECK-LABEL: func @propagate_scalar_type(
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number
// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int
// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number
// CHECK: return %[[RET]] : !torch.number
func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number {
%num = torch.derefine %arg0 : !torch.int to !torch.number
%1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number
return %1 : !torch.number
}