diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 2b807b252..b5f5d6964 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5650,6 +5650,30 @@ def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [ }]; } +def Torch_AtenAddOp : Torch_Op<"aten.add", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::add : (Scalar, Scalar) -> (Scalar)`"; + let arguments = (ins + AnyTorchScalarType:$a, + AnyTorchScalarType:$b + ); + let results = (outs + AnyTorchScalarType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAddOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAddOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 64702c76b..61872ae9f 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -30,11 +30,18 @@ torch_upstream::ScalarType getScalarTypeForType(Type type); Type getTypeForScalarType( MLIRContext *context, torch_upstream::ScalarType dtypeInt, mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); + +Type getTorchTypeForScalarType(MLIRContext *context, + torch_upstream::ScalarType dtypeInt); + Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, Type dtype); // Helper to convert a tensor to a specific scalar type. Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input, Type dtype); + +bool isBuiltInType(Type type); + // Helper funtion to get rank of `Base tensor type`. // -1 is returned if the tensorRank can't be determined. int getTensorRank(Value tensor); diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index ccdc5be66..e00a996b8 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -108,8 +108,9 @@ public: result.returnOp = returnOp; } else { return rewriter.notifyMatchFailure( - copyToNonValueTensor, - "unsupported op encountered during abstract analysis"); + copyToNonValueTensor, "unsupported op `" + + user->getName().getStringRef() + + "` encountered during abstract analysis"); } } return result; diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 0fea565b1..5659c687f 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -114,6 +114,10 @@ static torch_upstream::TypeKind getTypeKind(Type type) { /// on the `dtype` from tensors and can't be used on other types like scalar /// types. static Optional meetElementTypes(Type lhs, Type rhs) { + auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); }; + assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type"); + assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type"); + if (!lhs) return rhs; if (!rhs) @@ -167,6 +171,14 @@ struct ValueKnowledge { : dtype(dtype), scalarType(scalarType), kind(kind), optional(optionalKnowledge) {} + void setScalarType(Type type) { + bool isValidScalarType = type.isa(); + assert(isValidScalarType && + "scalarType can only be one of NumberType, IntType and FloatType"); + scalarType = type; + kind = getTypeKind(type); + } + // Get the static knowledge intrinsic to `type`. static ValueKnowledge getKnowledgeFromType(Type type) { ValueKnowledge result = getPessimisticValueState(type.getContext()); @@ -420,7 +432,7 @@ private: // This is the type rule used for deciding dtype for: // 1. A new tensor created from given data. // 2. The scalar type for type promotion when a scalar is an operand of a tensor -// and scalar binary operation. +// operation (such as AtenMulScalarOp, AtenAddScalarOp etc) // If the data is floating-point, the `dtype` is inferred to be the // default dtype, see `torch.get_default_dtype`. static Type getDefaultDtypeForTorchScalar(Type type) { @@ -438,12 +450,29 @@ static Type getDefaultDtypeForTorchScalar(Type type) { "getDefaultDtypeForTorchScalar called on an unsupported type"); } +// This is the type rule used for deciding builtin type for: +// 1. The dtype of the result tensor when converting a Scalar into a Tensor like +// PrimNumToTensorScalarOp. +// 2. The scalar type for type promotion when a scalar is an operand of scalar +// only operation like AtenAddOp. +static Type getBuiltInTypeForTorchScalar(Type type) { + MLIRContext *context = type.getContext(); + if (type.isa()) + return Float64Type::get(context); + if (type.isa()) + return IntegerType::get(context, 64, IntegerType::Signed); + if (type.isa()) + return IntegerType::get(context, 1); + llvm_unreachable( + "getBuiltInTypeForTorchScalar called on an unsupported type"); +} + static torch_upstream::ResultTypeState updateResultTypeState(Type scalarType, const torch_upstream::ResultTypeState &inState) { + assert(isBuiltInType(scalarType) && "scalarType must be builtin type"); torch_upstream::ResultTypeState new_state = inState; - torch_upstream::ScalarType current = - getScalarTypeForType(getDefaultDtypeForTorchScalar(scalarType)); + torch_upstream::ScalarType current = getScalarTypeForType(scalarType); new_state.wrappedResult = promote_skip_undefined(inState.wrappedResult, current); return new_state; @@ -481,15 +510,22 @@ updateResultTypeState(ValueKnowledge *tensor, Optional rankIsNonZero, return new_state; } -static Type getPromotedResultType(ArrayRef scalarTypes) { +// Type promotion helper for operators where only scalar operands participating +// in type promotion like AtenAddOp. +// +// \return The return type is a TorchType. +static Type getPromotedResultScalarType(ArrayRef scalarTypes) { torch_upstream::ResultTypeState state = {}; - for (const Type &scalarType : scalarTypes) - state = updateResultTypeState(scalarType, state); - return getTypeForScalarType(scalarTypes[0].getContext(), result_type(state)); + for (const Type &scalarType : scalarTypes) { + state = + updateResultTypeState(getBuiltInTypeForTorchScalar(scalarType), state); + } + return getTorchTypeForScalarType(scalarTypes[0].getContext(), + result_type(state)); } // Returns most generic type Type() if the tensor dtype is unknown. -static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) { +static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) { if (!tensor->dtype) return Type(); torch_upstream::ResultTypeState state = {}; @@ -497,7 +533,8 @@ static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) { // wrappedResult which is a lower priority than both dimResult and zeroResult. state = updateResultTypeState(tensor, /*rankIsNonZero=*/None, state, /*skipRankCheck=*/true); - state = updateResultTypeState(scalarType, state); + state = + updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state); return getTypeForScalarType(scalarType.getContext(), result_type(state)); } @@ -550,8 +587,7 @@ getPromotedResultTypeAssumingNonZeroRank(MLIRContext *context, static void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge, Value dtype, Type inputDType) { - assert(isa(inputDType.getDialect()) && - "`inputDType` must be a builtin type"); + assert(isBuiltInType(inputDType) && "`inputDType` must be a builtin type"); int64_t dtypeInt; if (dtype.getType().isa()) knowledge.dtype = inputDType; @@ -692,7 +728,7 @@ ChangeResult TypeAnalyzer::visitOperation( Value scalar = op->getOperand(1); auto knowledge = ValueKnowledge::getTensorPessimisticValueState(getContext()); - knowledge.dtype = getPromotedResultType(&lhs, scalar.getType()); + knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType()); return incorporateKnowledge(op->getResult(0), knowledge); } @@ -712,8 +748,8 @@ ChangeResult TypeAnalyzer::visitOperation( Value rhsScalar = op->getOperand(2); auto knowledge = ValueKnowledge::getTensorPessimisticValueState(getContext()); - knowledge.dtype = - getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()}); + knowledge.dtype = getDefaultDtypeForTorchScalar(getPromotedResultScalarType( + {lhsScalar.getType(), rhsScalar.getType()})); return incorporateKnowledge(op->getResult(0), knowledge); } @@ -723,7 +759,7 @@ ChangeResult TypeAnalyzer::visitOperation( Value scalar = op->getOperand(2); auto knowledge = ValueKnowledge::getTensorPessimisticValueState(getContext()); - knowledge.dtype = getPromotedResultType(&lhs, scalar.getType()); + knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType()); return incorporateKnowledge(op->getResult(0), knowledge); } @@ -733,7 +769,7 @@ ChangeResult TypeAnalyzer::visitOperation( Value scalar = op->getOperand(1); auto knowledge = ValueKnowledge::getTensorPessimisticValueState(getContext()); - knowledge.dtype = getPromotedResultType(&rhs, scalar.getType()); + knowledge.dtype = getPromotedResultDType(&rhs, scalar.getType()); return incorporateKnowledge(op->getResult(0), knowledge); } @@ -942,7 +978,7 @@ ChangeResult TypeAnalyzer::visitOperation( return visitNumToTensorOp(numToTensorOp); } - if (isa(op)) { + if (isa(op)) { return visitBinaryScalarOp(op, operands); } @@ -1069,10 +1105,9 @@ ChangeResult TypeAnalyzer::visitBinaryScalarOp( Operation *op, ArrayRef *> operands) { auto knowledge = ValueKnowledge::getScalarPessimisticValueState(op->getContext()); - Type resultType = getPromotedResultType( + Type resultType = getPromotedResultScalarType( {op->getOperand(0).getType(), op->getOperand(1).getType()}); - knowledge.scalarType = resultType; - knowledge.kind = getTypeKind(resultType); + knowledge.setScalarType(resultType); return incorporateKnowledge(op->getResult(0), knowledge); } @@ -1183,12 +1218,7 @@ ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) { // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h. // `NumToTensor` falls in the latter case. Type type = op.a().getType(); - if (type.isa()) - knowledge.dtype = Float64Type::get(op.getContext()); - else if (type.isa()) - knowledge.dtype = - IntegerType::get(op.getContext(), 64, IntegerType::Signed); - + knowledge.dtype = getBuiltInTypeForTorchScalar(type); return incorporateKnowledge(op.getResult(), knowledge); } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index af4f7c552..7e9f9a947 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "mlir/IR/BuiltinDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" using namespace mlir; @@ -78,6 +79,19 @@ Type Torch::getTypeForScalarType( } } +Type Torch::getTorchTypeForScalarType(MLIRContext *context, + torch_upstream::ScalarType dtypeInt) { + switch (dtypeInt) { + case torch_upstream::ScalarType::Double: + return Torch::FloatType::get(context); + case torch_upstream::ScalarType::Long: + return Torch::IntType::get(context); + default: + llvm::report_fatal_error( + "Unsupported scalar type to Torch type conversion"); + } +} + Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, Type dtype) { int intType = (int)getScalarTypeForType(dtype); @@ -100,6 +114,10 @@ Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc, return converted; } +bool Torch::isBuiltInType(Type type) { + return isa(type.getDialect()); +} + int Torch::getTensorRank(Value tensor) { int tensorRank = -1; BaseTensorType tensorType = tensor.getType().cast(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index f9964e8db..f7039eb95 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -518,6 +518,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::div : (Scalar, Scalar) -> (float)") + emit("aten::add : (Scalar, Scalar) -> (Scalar)") + emit("aten::eq.device : (Device, Device) -> (bool)") emit("aten::ceil.float : (float) -> (int)", has_folder=True) diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 85be750e6..c089a7a11 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -127,6 +127,18 @@ func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.v return %0 : !torch.vtensor<[?],unk> } +// ----- +// CHECK-LABEL: func @type_promotion_scalar_operation( +// CHECK-SAME: %[[FLOAT:.*]]: !torch.float, +// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number { +// CHECK: %[[ADD:.*]] = torch.aten.add %[[FLOAT]], %[[INT]] : !torch.float, !torch.int -> !torch.float +// CHECK: %[[RET:.*]] = torch.derefine %[[ADD]] : !torch.float to !torch.number +// CHECK: return %[[RET]] : !torch.number +func @type_promotion_scalar_operation(%float: !torch.float, %int: !torch.int) -> !torch.number { + %ret = torch.aten.add %float, %int : !torch.float, !torch.int -> !torch.number + return %ret : !torch.number +} + // ----- // CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static( // CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,