diff --git a/externals/llvm-project b/externals/llvm-project index 26ee89477..2b4807ba0 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 26ee8947702d79ce2cab8e577f713685a5ca4a55 +Subproject commit 2b4807ba044230ed6243f5c3a1329a9344de758d diff --git a/externals/mlir-hlo b/externals/mlir-hlo index 4805d8498..ac26bdba7 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit 4805d8498dfb81566076f56f52273b426c1cc5bf +Subproject commit ac26bdba7a5edfe6060ba5be528b9d20c987297d diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 39cb1eacc..c3ab1d474 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -116,6 +116,10 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, rewriter.replaceOp(op, result->getResults()); } +// Get accumulator type for AvgPool2dOp. +LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, + TypeAttr &accType); + } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index d2ce90a51..ce0d1f371 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -98,9 +98,6 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } - DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( - RankedTensorType::get({}, rewriter.getI64Type()), dim); - auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); auto indexTensor = rewriter.create( @@ -115,7 +112,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initValue, initIndex, }, - dimensions); + rewriter.getI64TensorAttr(dim)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 356b9e777..fc1efa364 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3985,10 +3985,25 @@ public: return rewriter.notifyMatchFailure( op, "Failed to process inputs for pooling"); - auto pooledOutput = - rewriter - .create(op->getLoc(), outputTy, input, kernel, stride, pad) - .getResult(); + Value pooledOutput; + static_assert(std::is_same::value || + std::is_same::value, + "Expected either tosa::MaxPool2dOp or tosa::AvgPool2dOp"); + if constexpr (std::is_same::value) { + pooledOutput = rewriter + .create(op->getLoc(), outputTy, input, kernel, + stride, pad) + .getResult(); + } else if constexpr (std::is_same::value) { + TypeAttr accType; + if (failed(tosa::getAvgPool2dAccType(rewriter, input, accType))) + return rewriter.notifyMatchFailure( + op, "Failed to get accumulator type for pooling"); + pooledOutput = rewriter + .create(op->getLoc(), outputTy, input, kernel, + stride, pad, accType) + .getResult(); + } auto transposedOutput = ConvertAtenPoolingBaseOp::transposePoolingOutputToChw( diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index c4f8d2b0b..1d026a62d 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -169,7 +169,6 @@ std::optional getZerosLikeTensor(PatternRewriter &rewriter, .getResult(); } - // Templated function to create a constant op for given type and shape. // T: storage C type. // Default template creates a constant tensor in T. @@ -243,8 +242,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, } static LogicalResult checkValidityOfCast(Type src, Type dest) { - if ((src == dest) || - (src.isInteger(64) && dest.isInteger(32)) || + if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) || (src.isInteger(64) && dest.isInteger(8)) || (src.isInteger(64) && dest.isInteger(1)) || (src.isInteger(64) && dest.isF32()) || @@ -256,18 +254,14 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) { (src.isInteger(8) && dest.isInteger(1)) || (src.isInteger(8) && dest.isBF16()) || (src.isInteger(1) && dest.isInteger(64)) || - (src.isInteger(1) && dest.isF32()) || - (src.isF32() && dest.isF64()) || - (src.isF32() && dest.isBF16()) || - (src.isF64() && dest.isF32()) || - (src.isF64() && dest.isBF16()) || - (src.isF32() && dest.isInteger(8)) || + (src.isInteger(1) && dest.isF32()) || (src.isF32() && dest.isF64()) || + (src.isF32() && dest.isBF16()) || (src.isF64() && dest.isF32()) || + (src.isF64() && dest.isBF16()) || (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isInteger(64)) || (src.isF32() && dest.isInteger(1)) || (src.isBF16() && dest.isInteger(8)) || (src.isBF16() && dest.isInteger(16)) || - (src.isBF16() && dest.isInteger(32)) || - (src.isBF16() && dest.isF32())) { + (src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isF32())) { return success(); } return failure(); @@ -341,5 +335,27 @@ template std::optional getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, ArrayRef shape); + +LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, + TypeAttr &accType) { + auto inputTy = llvm::dyn_cast(input.getType()); + if (!inputTy) + return failure(); + auto inputETy = inputTy.getElementType(); + + if (auto quantType = + llvm::dyn_cast(inputETy)) + inputETy = quantType.getStorageType(); + + // Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time + // FP16 is supported, the accumulator type can be selected based on trade-off + // between performance and accuracy. Set to FP32 by default. + accType = inputETy.isa() + ? mlir::TypeAttr::get(rewriter.getF32Type()) + : mlir::TypeAttr::get(rewriter.getIntegerType(32)); + + return success(); +} + } // namespace tosa } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 7cc699c04..8f310da08 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -239,13 +240,13 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion( [](Torch::TupleType type, - SmallVectorImpl &types) -> Optional { + SmallVectorImpl &types) -> LogicalResult { llvm::append_range(types, type.getContainedTypes()); return success(); }); typeConverter.addConversion( [](Torch::NoneType type, - SmallVectorImpl &types) -> Optional { + SmallVectorImpl &types) -> LogicalResult { return success(); }); diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index da53dcf7a..8a5c218e4 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -29,14 +29,15 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, target.addLegalOp(); typeConverter.addConversion( - [](Torch::ValueTensorType type) -> Optional { + [](Torch::ValueTensorType type) -> std::optional { return type.toBuiltinTensor(); }); typeConverter.addTargetMaterialization([](OpBuilder &builder, TensorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + if (!inputs[0].getType().isa()) + return {}; return builder.create(loc, inputs[0]); }); auto sourceMaterialization = [](OpBuilder &builder, @@ -53,12 +54,12 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, static void setupTorchBoolToI1Conversion(ConversionTarget &target, TypeConverter &typeConverter) { target.addLegalOp(); - typeConverter.addConversion([](Torch::BoolType type) -> Optional { + typeConverter.addConversion([](Torch::BoolType type) -> std::optional { return IntegerType::get(type.getContext(), 1); }); typeConverter.addTargetMaterialization([](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> Optional { + Location loc) -> std::optional { // Other builtin integer types could be handled by other materializers. if (!(type.getWidth() == 1 && type.isSignless())) return std::nullopt; @@ -79,12 +80,12 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target, static void setupTorchIntToI64Conversion(ConversionTarget &target, TypeConverter &typeConverter) { target.addLegalOp(); - typeConverter.addConversion([](Torch::IntType type) -> Optional { + typeConverter.addConversion([](Torch::IntType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); typeConverter.addTargetMaterialization([](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> Optional { + Location loc) -> std::optional { // Other builtin integer types could be handled by other materializers. if (!(type.getWidth() == 64 && type.isSignless())) return std::nullopt; @@ -108,12 +109,12 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, static void setupTorchFloatToF64Conversion(ConversionTarget &target, TypeConverter &typeConverter) { target.addLegalOp(); - typeConverter.addConversion([](Torch::FloatType type) -> Optional { + typeConverter.addConversion([](Torch::FloatType type) -> std::optional { return Float64Type::get(type.getContext()); }); typeConverter.addTargetMaterialization([](OpBuilder &builder, Float64Type type, ValueRange inputs, - Location loc) -> Optional { + Location loc) -> std::optional { assert(inputs.size() == 1); assert(inputs[0].getType().isa()); return builder.create(loc, inputs[0]).getResult(); @@ -132,12 +133,12 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, TypeConverter &typeConverter) { target.addLegalOp(); - typeConverter.addConversion([](Torch::GeneratorType type) -> Optional { + typeConverter.addConversion([](Torch::GeneratorType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); typeConverter.addTargetMaterialization([](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> Optional { + Location loc) -> std::optional { // Other builtin integer types could be handled by other materializers. if (!(type.getWidth() == 64 && type.isSignless())) return std::nullopt; diff --git a/python/torch_mlir/_dynamo_fx_importer.py b/python/torch_mlir/_dynamo_fx_importer.py index 5755b5118..84219cf84 100644 --- a/python/torch_mlir/_dynamo_fx_importer.py +++ b/python/torch_mlir/_dynamo_fx_importer.py @@ -437,6 +437,6 @@ def import_fx_graph_as_func(g: torch.fx.Graph, func_name: str) -> ir.Module: # The reason is that the supported subset only involves stateless # fx.Graph's, so the state held on the fx.GraphModule is not necessary. _verify_fx_graph_conforms_to_subset(g) - with ir.Context() as context: + with ir.Context() as context, ir.Location.unknown(context=context): torch_dialect.register_dialect(context) return _FXGraphImporter(g, func_name).import_graph() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 338ac731c..df8c14890 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -799,7 +799,7 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) <{kernel = array, pad = array, stride = array}> : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) <{acc_type = f32, kernel = array, pad = array, stride = array}> : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> // CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>