diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c38d0dbbd..a5f95b538 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -256,51 +256,6 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [ }]; } -def Torch_AtenLogOp : Torch_Op<"aten.log", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenLogOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenLog_Op : Torch_Op<"aten.log_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - AnyTorchOptionalNonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenLog_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenSeluOp : Torch_Op<"aten.selu", [ AllowsTypeRefinement, HasValueSemantics, @@ -4085,6 +4040,52 @@ def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ }]; } +def Torch_AtenLogOp : Torch_Op<"aten.log", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenLogOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenLog_Op : Torch_Op<"aten.log_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenLog_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 9d4687596..ea1f22f4e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1241,16 +1241,20 @@ llvm::SmallVector getFoldValueAtIndexInt(llvm::ArrayRef attrs, llvm::SmallVector splattrs; for (auto attr : attrs) { - bool isunsigned = false; + // Note that i1 is neither signed nor unsigned. + // But we should trait i1 as unsigned, otherwise that + // APInt(1,1).getSExtValue() return allOnes 64-bit integer. + // So here only distinguish signed integer. + bool isSigned = false; if (auto dense = dyn_cast(attr)) { - isunsigned = dyn_cast(dense.getElementType()).isUnsigned(); + isSigned = dyn_cast(dense.getElementType()).isSigned(); if (dense.isSplat()) { splattrs.push_back(dense.getSplatValue()); } else { splattrs.push_back(dense.getValues()[idx]); } } else if (auto intattr = dyn_cast(attr)) { - isunsigned = cast(intattr.getType()).isUnsigned(); + isSigned = cast(intattr.getType()).isSigned(); splattrs.push_back(intattr.getValue()); } else { return {}; @@ -1258,10 +1262,10 @@ llvm::SmallVector getFoldValueAtIndexInt(llvm::ArrayRef attrs, auto &apint = splattrs.back(); if (apint.getBitWidth() < bitwidth) { - if (isunsigned) { - apint = apint.zextOrTrunc(bitwidth); - } else { + if (isSigned) { apint = apint.sextOrTrunc(bitwidth); + } else { + apint = apint.zextOrTrunc(bitwidth); } } } @@ -1795,6 +1799,81 @@ OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) { return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); } +//===----------------------------------------------------------------------===// +// AtenLogOp +//===----------------------------------------------------------------------===// + +using UnaryPromoteFpOperator = std::function; +using UnaryPromoteIntOperator = std::function; + +static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand, + ValueTensorType resultTy, + UnaryPromoteFpOperator fpFolder, + UnaryPromoteIntOperator intFolder) { + constexpr int64_t kMaxFold = 16; + if (!resultTy.hasDtype() || !resultTy.hasSizes()) + return nullptr; + if (!isa(resultTy.getDtype())) + return nullptr; + + auto fpTy = dyn_cast(operand.getType().getElementType()); + auto intTy = dyn_cast(operand.getType().getElementType()); + if (!fpTy && !intTy) + return nullptr; + + auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype()); + bool splat = operand.isSplat(); + bool withinMaxFold = + resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold; + if (!splat && !withinMaxFold) + return nullptr; + + const int64_t numValues = splat ? 1 : resultBTy.getNumElements(); + + llvm::SmallVector operands = {operand}; + llvm::SmallVector folded; + for (int i = 0, s = numValues; i < s; ++i) { + double fold = 0.0; + if (fpTy) { + auto inputs = getFoldValueAtIndexFp(operands, i); + fold = fpFolder(inputs[0]); + } + if (intTy) { + auto inputs = + getFoldValueAtIndexInt(operands, intTy.getIntOrFloatBitWidth(), i); + fold = intFolder(inputs[0], intTy.isSigned()); + } + + APFloat val(fold); + bool unused; + val.convert( + cast(resultBTy.getElementType()).getFloatSemantics(), + APFloat::rmNearestTiesToEven, &unused); + folded.push_back(val); + } + return DenseElementsAttr::get(resultBTy, folded); +} + +OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto resultType = dyn_cast(getType()); + if (!self || !resultType) + return nullptr; + + // Note that i1 is neither signed nor unsigned. + // But we should trait i1 as unsigned, otherwise that + // APInt(1,1).getSExtValue() return allOnes 64-bit integer. + auto intFold = [](APInt a, bool isSigned) -> double { + if (isSigned) + return std::log(a.getSExtValue()); + else + return std::log(a.getZExtValue()); + }; + auto fpFold = [](double a) -> double { return std::log(a); }; + + return unaryPromoteFolder(self, resultType, fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenFloorOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index e5b219e55..5bd9341c1 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -268,7 +268,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::relu : (Tensor) -> (Tensor)", "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", - "aten::log : (Tensor) -> (Tensor)", "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sinh : (Tensor) -> (Tensor)", @@ -356,6 +355,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit_with_mutating_variants("aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::log : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 1823393f2..093b30bec 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2888,3 +2888,36 @@ func.func @aten_tensor_tensor_ne() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4 %fpBool = torch.aten.ne.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> } + +// ----- + +// CHECK-LABEL: @aten_log$fold_splat_i1 +func.func @aten_log$fold_splat_i1() -> !torch.vtensor<[4], f32> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32> + // CHECK: return %[[RET]] : !torch.vtensor<[4],f32> + %cst = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4], i1> + %result = torch.aten.log %cst : !torch.vtensor<[4], i1> -> !torch.vtensor<[4], f32> + return %result : !torch.vtensor<[4], f32> +} + +// ----- + +// CHECK-LABEL: @aten_log$fold_splat_si32 +func.func @aten_log$fold_splat_si32() -> !torch.vtensor<[4], f32> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.09861231> : tensor<4xf32>) : !torch.vtensor<[4],f32> + // CHECK: return %[[RET]] : !torch.vtensor<[4],f32> + %cst = torch.vtensor.literal(dense<3> : tensor<4xsi32>) : !torch.vtensor<[4], si32> + %result = torch.aten.log %cst : !torch.vtensor<[4], si32> -> !torch.vtensor<[4], f32> + return %result : !torch.vtensor<[4], f32> +} + +// ----- + +// CHECK-LABEL: @aten_log$fold_splat_f32 +func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.09861231> : tensor<4xf32>) : !torch.vtensor<[4],f32> + // CHECK: return %[[RET]] : !torch.vtensor<[4],f32> + %cst = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4], f32> + %result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32> + return %result : !torch.vtensor<[4], f32> +}