From 9f7264a7a4e346fd70b5d2e896b3d9f0102c8e16 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 3 May 2022 11:59:49 -0400 Subject: [PATCH] Add support for scalar type propagation The main changes are: - Added `ValueKnowledge.scalarType` to track scalar type information. - Added `ValueKnowledge.kind` to indicate the value kind. - Modified the meet and join helper functions. The ValueKnowledge has slightly more complicated state now so the meet and join function need to look at the `kind` field in addition to just the type field. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 23 ++ .../Dialect/Torch/Utils/TorchUpstream.h | 49 ++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 257 +++++++++++++----- .../jit_ir/build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/refine-types.mlir | 25 ++ 5 files changed, 282 insertions(+), 73 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index fc0d2f639..212487a57 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7398,6 +7398,29 @@ def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [ let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands)) `->` qualified(type($results))"; } +def Torch_PrimAbsScalarOp : Torch_Op<"prim.abs.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prim::abs.Scalar : (Scalar) -> (Scalar)`"; + let arguments = (ins + AnyTorchScalarType:$a + ); + let results = (outs + AnyTorchScalarType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimAbsScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void PrimAbsScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ HasValueSemantics, AllowsTypeRefinement, diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 124ebdf92..0d2d75b78 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -29,6 +29,55 @@ namespace mlir { namespace torch { namespace torch_upstream { +//===----------------------------------------------------------------------===// +// TypeKind related enum related code are copied from +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type_base.h +//===----------------------------------------------------------------------===// +#define C10_FORALL_TYPES(_) \ + _(AnyType) \ + _(EnumType) \ + _(AnyEnumType) \ + _(TensorType) \ + _(StorageType) \ + _(TupleType) \ + _(ListType) \ + _(DictType) \ + _(NumberType) \ + _(FloatType) \ + _(ComplexType) \ + _(FutureType) \ + _(RRefType) \ + _(IntType) \ + _(NoneType) \ + _(StringType) \ + _(GeneratorType) \ + _(QuantizerType) \ + _(BoolType) \ + _(OptionalType) \ + _(VarType) \ + _(DeviceObjType) \ + _(StreamObjType) \ + _(FunctionType) \ + _(ClassType) \ + _(PyObjectType) \ + _(CapsuleType) \ + _(InterfaceType) \ + _(QSchemeType) \ + _(LayoutType) \ + _(ScalarTypeType) \ + _(AnyListType) \ + _(AnyTupleType) \ + _(AnyClassType) \ + _(SymIntType) \ + _(UnionType) \ + _(DynamicType) + +enum class TypeKind { +#define DEFINE_TYPE(T) T, + C10_FORALL_TYPES(DEFINE_TYPE) +#undef DEFINE_TYPE +}; + //===----------------------------------------------------------------------===// // ScalarType enum related code are copied from c10/core/ScalarType.h //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index f11f25817..0fea565b1 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -92,14 +92,27 @@ static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype, return Type(); } -static Type joinElementTypes(Type lhs, Type rhs) { - if (lhs == rhs) - return lhs; - return Type(); +// Get the kind enum for `ValueKnowledge.kind`. +static torch_upstream::TypeKind getTypeKind(Type type) { + if (type.isa()) + return torch_upstream::TypeKind::NumberType; + if (type.isa()) + return torch_upstream::TypeKind::IntType; + if (type.isa()) + return torch_upstream::TypeKind::FloatType; + if (type.isa()) + return torch_upstream::TypeKind::TensorType; + if (type.isa()) + return torch_upstream::TypeKind::NoneType; + // Skip the Torch::OptionalType on purpose because optional knowledge is + // tracked separately. See comments for `ValueKnowledge.kind` field. + return torch_upstream::TypeKind::AnyType; } /// Returns the dtype that assumes information from both `lhs` and `rhs`. -/// Returns `None` if the types are contradictory. +/// Returns `None` if the types are contradictory. Note this can only be used +/// on the `dtype` from tensors and can't be used on other types like scalar +/// types. static Optional meetElementTypes(Type lhs, Type rhs) { if (!lhs) return rhs; @@ -148,27 +161,44 @@ namespace { // This class could also be called "dataflow facts", "lattice value", etc. struct ValueKnowledge { ValueKnowledge() = delete; - ValueKnowledge(Type dtype, OptionalKnowledge optionalKnowledge) - : dtype(dtype), optional(optionalKnowledge) {} + ValueKnowledge(Type dtype, Type scalarType, + OptionalKnowledge optionalKnowledge, + torch_upstream::TypeKind kind) + : dtype(dtype), scalarType(scalarType), kind(kind), + optional(optionalKnowledge) {} // Get the static knowledge intrinsic to `type`. static ValueKnowledge getKnowledgeFromType(Type type) { ValueKnowledge result = getPessimisticValueState(type.getContext()); - if (auto tensorType = type.dyn_cast()) { - result.dtype = tensorType.getOptionalDtype(); + result.kind = getTypeKind(type); + switch (result.kind) { + case torch_upstream::TypeKind::TensorType: + result.dtype = type.cast().getOptionalDtype(); result.optional = OptionalKnowledge::notNone; - } else if (auto optionalType = type.dyn_cast()) { + return result; + case torch_upstream::TypeKind::NumberType: + case torch_upstream::TypeKind::IntType: + case torch_upstream::TypeKind::FloatType: + result.scalarType = type; + result.optional = OptionalKnowledge::notNone; + return result; + case torch_upstream::TypeKind::NoneType: result.optional = OptionalKnowledge::isNone; - } else if (!type.isa()) { + return result; + default: + if (type.isa()) + return result; + // All other types that are not optional type. result.optional = OptionalKnowledge::notNone; + return result; } - return result; } // Return a pessimistic/conservative value state without assuming any knowlege // about the IR. static ValueKnowledge getPessimisticValueState(MLIRContext *context) { - return ValueKnowledge(Type(), OptionalKnowledge::unKnown); + return ValueKnowledge(Type(), Type(), OptionalKnowledge::unKnown, + torch_upstream::TypeKind::AnyType); } // Return a pessimistic/conservative value state only using knowlege already // recorded in the IR. @@ -177,7 +207,19 @@ struct ValueKnowledge { } static ValueKnowledge getNotNonePessimisticValueState(MLIRContext *context) { - return ValueKnowledge(Type(), OptionalKnowledge::notNone); + return ValueKnowledge(Type(), Type(), OptionalKnowledge::notNone, + torch_upstream::TypeKind::AnyType); + } + + static ValueKnowledge getTensorPessimisticValueState(MLIRContext *context) { + return ValueKnowledge(Type(), Type(), OptionalKnowledge::notNone, + torch_upstream::TypeKind::TensorType); + } + + static ValueKnowledge getScalarPessimisticValueState(MLIRContext *context) { + return ValueKnowledge(Type(), NumberType::get(context), + OptionalKnowledge::notNone, + torch_upstream::TypeKind::NumberType); } bool operator==(const ValueKnowledge &rhs) const { @@ -185,6 +227,25 @@ struct ValueKnowledge { std::make_tuple(rhs.dtype, rhs.optional); } + // Return true if the `refinedType` has more concrete type info than `type`. + static bool hasStrictlyMoreRefinedTypeInfo(const ValueKnowledge &refinedType, + const ValueKnowledge &type) { + if (type.kind == torch_upstream::TypeKind::AnyType && + refinedType.kind != torch_upstream::TypeKind::AnyType) + return true; + + // If both are tensors but `type` doesn't have concrete dtype info. + if (refinedType.kind == torch_upstream::TypeKind::TensorType && + type.kind == torch_upstream::TypeKind::TensorType) { + return refinedType.dtype && !type.dtype; + } + + if (refinedType.scalarType && type.scalarType) + return isValidSubtype(refinedType.scalarType, type.scalarType); + + return false; + } + // Given two pieces of static knowledge, intersect the facts that are known in // both knowledges. This always produces knowledge that has less (or equal) // facts than both the lhs and rhs. @@ -200,40 +261,79 @@ struct ValueKnowledge { // Mental model: All conditions are checking how to change from the safe "no // knowledge" default-initialized state to a state with more knowledge // consistent with lhs and rhs. - ValueKnowledge result = getPessimisticValueState(nullptr); - + ValueKnowledge result = joinTypes(lhs, rhs); result.optional = joinOptionalKnowledge(lhs.optional, rhs.optional); - result.dtype = joinElementTypes(lhs.dtype, rhs.dtype); - return result; } + static ValueKnowledge joinTypes(const ValueKnowledge &lhs, + const ValueKnowledge &rhs) { + if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs)) + return rhs; + if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs)) + return lhs; + if (lhs == rhs) + return lhs; + return getPessimisticValueState(nullptr); + } + // Given two pieces of static knowledge, calculate new knowledge that assumes // the facts from both. // If the two pieces of knowledge are contradictory, None is returned. static Optional meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs) { - ValueKnowledge result = getPessimisticValueState(nullptr); + Optional knowledge = meetTypes(lhs, rhs); + + if (!knowledge.hasValue()) + return None; + ValueKnowledge result = knowledge.getValue(); Optional optional = meetOptionalKnowledge(lhs.optional, rhs.optional); if (!optional.hasValue()) return None; result.optional = optional.getValue(); - - Optional dtype = meetElementTypes(lhs.dtype, rhs.dtype); - if (!dtype.hasValue()) - return None; - result.dtype = dtype.getValue(); - return result; } + static Optional meetTypes(const ValueKnowledge &lhs, + const ValueKnowledge &rhs) { + if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs)) + return lhs; + if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs)) + return rhs; + if (lhs == rhs) + return lhs; + return None; + } + // The dtype of a tensor. - // This is equal to nullptr if we don't know that it is a specific concrete - // type. + // This is equal to nullptr for the follow cases: + // 1. it is unknown whether the value is a tensor or not, ie the `kind` field + // is torch_upstream::TypeKind::AnyType. + // 2. the value is a tensor type but the dtype is unknown. + // 3. the value is not a tensor type. Type dtype; + + // The type of a scalar. + // This is equal to nullptr for the follow cases: + // 1. it is unknown whether the value is a scalar or not, ie the `kind` field + // is torch_upstream::TypeKind::AnyType. + // 2. the value is not a scalar type. + Type scalarType; + + // The type kind. If it's torch_upstream::TypeKind::AnyType, + // all the type fields are nullptr. Note that the `kind` never equals to + // torch_upstream::TypeKind::OptionalType because optional knowledge is + // tracked separately through the `optional` field. + torch_upstream::TypeKind kind; + // What is known about an optional value. + // When equal to OptionalKnowledge::notNone, the type info is kept in type + // fields like `dtype`, `scalarType`. + // When equal to OptionalKnowledge::isNone or OptionalKnowledge::unKnown, the + // other type fields are currently nullptr. It might worth considering + // tracking wrapped type info when OptionalKnowledge::unKnown in the future. OptionalKnowledge optional; }; @@ -500,11 +600,9 @@ ChangeResult TypeAnalyzer::visitOperation( AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp, - AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp>(op)) { - ValueKnowledge knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); - knowledge.dtype = operands[0]->getValue().dtype; - return incorporateKnowledge(op->getResult(0), knowledge); + AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp, + PrimAbsScalarOp>(op)) { + return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } // Dtype is always float32, except for bfloat16, float64 and nullptr. @@ -512,7 +610,7 @@ ChangeResult TypeAnalyzer::visitOperation( AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp, AtenErfOp>(op)) { ValueKnowledge knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); Type dtype = operands[0]->getValue().dtype; if (dtype) { knowledge.dtype = Float32Type::get(op->getContext()); @@ -526,7 +624,7 @@ ChangeResult TypeAnalyzer::visitOperation( if (isa(op)) { auto self = operands[1]->getValue(); auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = self.dtype; return incorporateKnowledge(op->getResult(0), knowledge); } @@ -536,7 +634,7 @@ ChangeResult TypeAnalyzer::visitOperation( AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp, AtenGtTensorOp, AtenLtTensorOp>(op)) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = IntegerType::get(op->getContext(), 1); return incorporateKnowledge(op->getResult(0), knowledge); } @@ -544,7 +642,7 @@ ChangeResult TypeAnalyzer::visitOperation( // Dtype is always si64. if (isa(op)) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); return incorporateKnowledge(op->getResult(0), knowledge); @@ -554,7 +652,7 @@ ChangeResult TypeAnalyzer::visitOperation( if (isa(op)) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}); return incorporateKnowledge(op->getResult(0), knowledge); @@ -565,7 +663,7 @@ ChangeResult TypeAnalyzer::visitOperation( Aten__And__TensorOp, AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp, AtenThresholdBackwardOp>(op)) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultType( op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}, getRankIsNonZeroArray(op->getOperands())); @@ -575,7 +673,7 @@ ChangeResult TypeAnalyzer::visitOperation( // Promote three dtypes. if (isa(op)) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue(), &operands[2]->getValue()}); @@ -593,7 +691,7 @@ ChangeResult TypeAnalyzer::visitOperation( auto lhs = operands[0]->getValue(); Value scalar = op->getOperand(1); auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(getContext()); + ValueKnowledge::getTensorPessimisticValueState(getContext()); knowledge.dtype = getPromotedResultType(&lhs, scalar.getType()); return incorporateKnowledge(op->getResult(0), knowledge); } @@ -601,7 +699,7 @@ ChangeResult TypeAnalyzer::visitOperation( // Promote 2nd and 3rd operands. if (isa(op)) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(getContext()); + ValueKnowledge::getTensorPessimisticValueState(getContext()); knowledge.dtype = getPromotedResultType( getContext(), {&operands[1]->getValue(), &operands[2]->getValue()}, getRankIsNonZeroArray(op->getOperands().slice(1, 2))); @@ -613,7 +711,7 @@ ChangeResult TypeAnalyzer::visitOperation( Value lhsScalar = op->getOperand(1); Value rhsScalar = op->getOperand(2); auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(getContext()); + ValueKnowledge::getTensorPessimisticValueState(getContext()); knowledge.dtype = getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()}); return incorporateKnowledge(op->getResult(0), knowledge); @@ -624,7 +722,7 @@ ChangeResult TypeAnalyzer::visitOperation( auto lhs = operands[1]->getValue(); Value scalar = op->getOperand(2); auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(getContext()); + ValueKnowledge::getTensorPessimisticValueState(getContext()); knowledge.dtype = getPromotedResultType(&lhs, scalar.getType()); return incorporateKnowledge(op->getResult(0), knowledge); } @@ -634,7 +732,7 @@ ChangeResult TypeAnalyzer::visitOperation( auto rhs = operands[2]->getValue(); Value scalar = op->getOperand(1); auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(getContext()); + ValueKnowledge::getTensorPessimisticValueState(getContext()); knowledge.dtype = getPromotedResultType(&rhs, scalar.getType()); return incorporateKnowledge(op->getResult(0), knowledge); } @@ -643,10 +741,10 @@ ChangeResult TypeAnalyzer::visitOperation( if (isa(op)) { auto self = operands[0]->getValue(); auto result0Knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); result0Knowledge.dtype = self.dtype; auto result1Knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); result1Knowledge.dtype = self.dtype; auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge); changed |= incorporateKnowledge(op->getResult(1), result1Knowledge); @@ -657,13 +755,13 @@ ChangeResult TypeAnalyzer::visitOperation( if (isa(op)) { auto self = operands[0]->getValue(); auto result0Knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); result0Knowledge.dtype = self.dtype; auto result1Knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); result1Knowledge.dtype = self.dtype; auto result2Knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); result2Knowledge.dtype = self.dtype; auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge); changed |= incorporateKnowledge(op->getResult(1), result1Knowledge); @@ -674,10 +772,10 @@ ChangeResult TypeAnalyzer::visitOperation( if (isa(op)) { auto self = operands[0]->getValue(); auto result0Knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); result0Knowledge.dtype = self.dtype; auto result1Knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); result1Knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); ; @@ -700,7 +798,7 @@ ChangeResult TypeAnalyzer::visitOperation( Type defaultDtype = operands[0]->getValue().dtype; Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype); auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = dtype; return incorporateKnowledge(op->getResult(0), knowledge); } @@ -816,7 +914,7 @@ ChangeResult TypeAnalyzer::visitOperation( if (auto shapeAsTensor = dyn_cast(op)) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); return incorporateKnowledge(shapeAsTensor.getResult(), knowledge); @@ -824,7 +922,7 @@ ChangeResult TypeAnalyzer::visitOperation( if (auto embedding = dyn_cast(op)) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = Float32Type::get(op->getContext()); return incorporateKnowledge(embedding.getResult(), knowledge); } @@ -864,7 +962,7 @@ TypeAnalyzer::incorporateKnowledge(Value v, const ValueKnowledge &knowledge) { ChangeResult TypeAnalyzer::visitAtenLinearOp( AtenLinearOp op, ArrayRef *> operands) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); auto input = operands[0]->getValue(); auto weight = operands[1]->getValue(); auto bias = operands[2]->getValue(); @@ -889,7 +987,7 @@ ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper( Operation *op, llvm::Optional start, Value end, llvm::Optional step, Value dtype) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); int64_t dtypeInt; if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) { knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt); @@ -930,7 +1028,8 @@ ChangeResult TypeAnalyzer::visitAtenArangeOp(AtenArangeOp op) { ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp( Operation *op, Type dtype, ArrayRef *> operands) { - auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + auto knowledge = + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = dtype; return incorporateKnowledge(op->getResult(0), knowledge); } @@ -941,7 +1040,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp( Operation *op, Value dim, Value keepdim, Type dtype, ArrayRef *> operands) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = dtype; return incorporateKnowledge(op->getResult(0), knowledge); } @@ -951,7 +1050,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp( ArrayRef *> operands, int resNum) { assert(dim.getType().isa() && "dim must be int type"); auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = dtype; return incorporateKnowledge(op->getResult(resNum), knowledge); } @@ -959,7 +1058,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp( template ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); + ValueKnowledge::getTensorPessimisticValueState(op.getContext()); Value t = op.t(); Value dtype = op.dtype(); fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType()); @@ -969,15 +1068,17 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) { ChangeResult TypeAnalyzer::visitBinaryScalarOp( Operation *op, ArrayRef *> operands) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultType( + ValueKnowledge::getScalarPessimisticValueState(op->getContext()); + Type resultType = getPromotedResultType( {op->getOperand(0).getType(), op->getOperand(1).getType()}); + knowledge.scalarType = resultType; + knowledge.kind = getTypeKind(resultType); return incorporateKnowledge(op->getResult(0), knowledge); } ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); + ValueKnowledge::getTensorPessimisticValueState(op.getContext()); Value data = op.data(); Value dtype = op.dtype(); Type type = data.getType(); @@ -993,7 +1094,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocOp(OpTy op, llvm::Optional dataType) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); if (!dataType) dataType = Torch::FloatType::get(op->getContext()); fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue()); @@ -1005,7 +1106,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocLikeOp( OpTy op, ArrayRef *> operands) { auto input = operands[0]->getValue(); auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype); return incorporateKnowledge(op.getResult(), knowledge); } @@ -1015,7 +1116,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorNewLikeOp( OpTy op, ArrayRef *> operands) { auto input = operands[0]->getValue(); auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype); return incorporateKnowledge(op.getResult(), knowledge); } @@ -1025,7 +1126,7 @@ template ChangeResult TypeAnalyzer::visitAtenToDtypeLikeOp( OpTy op, ArrayRef *> operands) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); Value dtype = op.dtype(); int64_t dtypeInt; if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) @@ -1038,7 +1139,7 @@ template ChangeResult TypeAnalyzer::visitTypeConversionOp( OpTy op, ArrayRef *> operands) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); Value other = op.other(); BaseTensorType type = other.getType().cast(); if (type.hasDtype()) @@ -1053,7 +1154,7 @@ ChangeResult TypeAnalyzer::visitAtenCatOp( AtenCatOp op, ArrayRef *> operands) { auto tensorList = op.tensors(); auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); auto listConstruct = tensorList.getDefiningOp(); if (!listConstruct) return incorporateKnowledge(op.getResult(), knowledge); @@ -1073,7 +1174,7 @@ ChangeResult TypeAnalyzer::visitAtenCatOp( ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) { auto knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); // The resulting type from converting a Scalar into a Tensor is different // if the scalar is part of a tensor operation (such as AtenMulScalar) or // not. In the former case, the type promotion rules are captured by the @@ -1098,7 +1199,7 @@ ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp( auto input = operands[0]->getValue(); auto dtype = op.dtype(); ValueKnowledge knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype); return incorporateKnowledge(op.getResult(), knowledge); } @@ -1109,7 +1210,7 @@ ChangeResult TypeAnalyzer::visitAten_SoftmaxLikeOp( OpTy op, ArrayRef *> operands) { auto input = operands[0]->getValue(); ValueKnowledge knowledge = - ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); bool halfToFloat; if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) { knowledge.dtype = @@ -1154,6 +1255,16 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) { else return containedType; } + } else if (auto scalarType = v.getType().dyn_cast()) { + LatticeElement *latticeElement = + analyzer.lookupLatticeElement(v); + if (!latticeElement) + return nullptr; + const ValueKnowledge &knowledge = latticeElement->getValue(); + if (knowledge.kind == torch_upstream::TypeKind::IntType) + return Torch::IntType::get(v.getContext()); + if (knowledge.kind == torch_upstream::TypeKind::FloatType) + return Torch::FloatType::get(v.getContext()); } return nullptr; } @@ -1217,7 +1328,7 @@ void optimize(func::FuncOp func, TypeAnalyzer &analyzer) { return b.create(loc, newType, v); }; createStaticInfoUpCast = createStaticInfoDownCast; - } else if (originalType.isa()) { + } else if (originalType.isa()) { createStaticInfoDownCast = [&](Location loc, Type newType, Value v) -> Value { return b.create(loc, newType, v); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 39dc6b5b8..b10992641 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -543,6 +543,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): traits=["DeclareOpInterfaceMethods"]) emit("prim::Print : (...) -> ()") emit("prim::tolist : (...) -> (...)") + emit("prim::abs.Scalar : (Scalar) -> (Scalar)") # ========================================================================== # `quantized::` namespace. diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index a681e03f5..85be750e6 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -6,6 +6,7 @@ // Code for testing transfer functions for new ops (which is most changes) // should go in refine-types-ops.mlir. +// ----- // CHECK-LABEL: func @basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> @@ -16,6 +17,7 @@ func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { return %1 : !torch.vtensor } +// ----- // CHECK-LABEL: func @keep_existing_shape_information( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32> @@ -25,6 +27,7 @@ func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vt return %1 : !torch.vtensor<[2],f32> } +// ----- // CHECK-LABEL: func @propagate_through_multiple_ops( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { // CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> @@ -39,6 +42,7 @@ func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vte return %3 : !torch.vtensor } +// ----- // Check rewriting logic in case of mixes of users that do/don't allow type // refinement. // CHECK-LABEL: func @mixed_allowing_not_allowing_type_refinement( @@ -53,6 +57,7 @@ func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) return %1, %1 : !torch.vtensor, !torch.vtensor } +// ----- // CHECK-LABEL: func @type_promotion$same_category_different_width( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> { @@ -66,6 +71,7 @@ func @type_promotion$same_category_different_width(%arg0: !torch.vtensor<[?],si3 return %0 : !torch.vtensor<[?],unk> } +// ----- // CHECK-LABEL: func @type_promotion$different_category( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> { @@ -79,6 +85,7 @@ func @type_promotion$different_category(%arg0: !torch.vtensor<[?],si64>, %arg1: return %0 : !torch.vtensor<[?],unk> } +// ----- // CHECK-LABEL: func @type_promotion$same_category_zero_rank_wider( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> { @@ -92,6 +99,7 @@ func @type_promotion$same_category_zero_rank_wider(%arg0: !torch.vtensor<[?],f32 return %0 : !torch.vtensor<[?],unk> } +// ----- // CHECK-LABEL: func @type_promotion$zero_rank_higher_category( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> { @@ -105,6 +113,7 @@ func @type_promotion$zero_rank_higher_category(%arg0: !torch.vtensor<[?],si64>, return %0 : !torch.vtensor<[?],unk> } +// ----- // CHECK-LABEL: func @type_promotion$alpha_wider( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> { @@ -118,6 +127,7 @@ func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.v return %0 : !torch.vtensor<[?],unk> } +// ----- // CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static( // CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>, // CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { @@ -134,6 +144,7 @@ func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch. return %result : !torch.vtensor<[2],f32> } +// ----- // CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { @@ -151,6 +162,7 @@ func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch. return %result : !torch.vtensor<[?],f32> } +// ----- // CHECK-LABEL: func @bf16_result_type( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { // CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16> @@ -159,3 +171,16 @@ func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16 %1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16> return %1 : !torch.vtensor<[2],bf16> } + +// ----- +// CHECK-LABEL: func @propagate_scalar_type( +// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number { +// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number +// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int +// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number +// CHECK: return %[[RET]] : !torch.number +func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number { + %num = torch.derefine %arg0 : !torch.int to !torch.number + %1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number + return %1 : !torch.number +}