mirror of https://github.com/llvm/torch-mlir
Bump llvm-project to a085c23aa3c8f91866d7f4588d4f683407dc775d. (#250)
* Added additional *ToLLVM conversion patterns (they were disaggregated from standard). * Misc renames. * Spelling change on ConvNCHW op, and it now expects strides and dilations attributes.pull/253/head
parent
89d4931324
commit
2ecbcbf8c7
|
@ -1 +1 @@
|
||||||
Subproject commit 7c35aae35b2c386b59af58c56ed36908f3d68371
|
Subproject commit a085c23aa3c8f91866d7f4588d4f683407dc775d
|
|
@ -14,7 +14,6 @@
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Traits.h"
|
#include "mlir/Dialect/Traits.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
|
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
|
||||||
|
@ -34,13 +33,18 @@ static SmallVector<Value, 6> bypassResultShapes(Operation *op,
|
||||||
op->getLoc(), ValueRange({lhsRows, rhsCols}));
|
op->getLoc(), ValueRange({lhsRows, rhsCols}));
|
||||||
return {shape};
|
return {shape};
|
||||||
}
|
}
|
||||||
// TODO: This only supports the NCHW data format. Consider other formats and lower ranks.
|
// TODO: This only supports the NCHW data format. Consider other formats and
|
||||||
|
// lower ranks.
|
||||||
if (auto conv2dNCHW = dyn_cast<tcf::ConvNCHWOp>(op)) {
|
if (auto conv2dNCHW = dyn_cast<tcf::ConvNCHWOp>(op)) {
|
||||||
// TODO: Replace hard-coded stride/dilation/padding constant-ops.
|
// TODO: Replace hard-coded stride/dilation/padding constant-ops.
|
||||||
// TODO: Consider migrating this SSA shape-computing graph to a complex op or use the `mlir-linalg-ods-gen` approach and define a `*.tc` spec file.
|
// TODO: Consider migrating this SSA shape-computing graph to a complex op
|
||||||
auto cI0 = builder.create<ConstantOp>(op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 0));
|
// or use the `mlir-linalg-ods-gen` approach and define a `*.tc` spec file.
|
||||||
auto cI1 = builder.create<ConstantOp>(op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 1));
|
auto cI0 = builder.create<ConstantOp>(
|
||||||
auto cI2 = builder.create<ConstantOp>(op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 2));
|
op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 0));
|
||||||
|
auto cI1 = builder.create<ConstantOp>(
|
||||||
|
op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 1));
|
||||||
|
auto cI2 = builder.create<ConstantOp>(
|
||||||
|
op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 2));
|
||||||
auto stride = cI1;
|
auto stride = cI1;
|
||||||
auto dilation = cI1;
|
auto dilation = cI1;
|
||||||
auto padding = cI0;
|
auto padding = cI0;
|
||||||
|
@ -63,22 +67,37 @@ static SmallVector<Value, 6> bypassResultShapes(Operation *op,
|
||||||
auto filterWidth =
|
auto filterWidth =
|
||||||
builder.create<tensor::DimOp>(op->getLoc(), conv2dNCHW.filter(), 3);
|
builder.create<tensor::DimOp>(op->getLoc(), conv2dNCHW.filter(), 3);
|
||||||
// Output height
|
// Output height
|
||||||
auto twicePaddingHeight = builder.create<MulIOp>(op->getLoc(), paddingHeight, cI2);
|
auto twicePaddingHeight =
|
||||||
auto heightPlusTwicePadding = builder.create<SubIOp>(op->getLoc(), height, twicePaddingHeight);
|
builder.create<MulIOp>(op->getLoc(), paddingHeight, cI2);
|
||||||
auto filterHeightMinusOne = builder.create<SubIOp>(op->getLoc(), filterHeight, cI1);
|
auto heightPlusTwicePadding =
|
||||||
auto dilationFilterHeight = builder.create<MulIOp>(op->getLoc(), dilationHeight, filterHeightMinusOne);
|
builder.create<SubIOp>(op->getLoc(), height, twicePaddingHeight);
|
||||||
auto outHeightUnstridedPlusOne = builder.create<SubIOp>(op->getLoc(), heightPlusTwicePadding, dilationFilterHeight);
|
auto filterHeightMinusOne =
|
||||||
auto outHeightUnstrided = builder.create<SubIOp>(op->getLoc(), outHeightUnstridedPlusOne, cI1);
|
builder.create<SubIOp>(op->getLoc(), filterHeight, cI1);
|
||||||
auto outHeightMinusOne = builder.create<UnsignedDivIOp>(op->getLoc(), outHeightUnstrided, strideHeight);
|
auto dilationFilterHeight = builder.create<MulIOp>(
|
||||||
auto outHeight = builder.create<AddIOp>(op->getLoc(), outHeightMinusOne, cI1);
|
op->getLoc(), dilationHeight, filterHeightMinusOne);
|
||||||
|
auto outHeightUnstridedPlusOne = builder.create<SubIOp>(
|
||||||
|
op->getLoc(), heightPlusTwicePadding, dilationFilterHeight);
|
||||||
|
auto outHeightUnstrided =
|
||||||
|
builder.create<SubIOp>(op->getLoc(), outHeightUnstridedPlusOne, cI1);
|
||||||
|
auto outHeightMinusOne = builder.create<UnsignedDivIOp>(
|
||||||
|
op->getLoc(), outHeightUnstrided, strideHeight);
|
||||||
|
auto outHeight =
|
||||||
|
builder.create<AddIOp>(op->getLoc(), outHeightMinusOne, cI1);
|
||||||
// Output width
|
// Output width
|
||||||
auto twicePaddingWidth = builder.create<MulIOp>(op->getLoc(), paddingWidth, cI2);
|
auto twicePaddingWidth =
|
||||||
auto widthPlusTwicePadding = builder.create<SubIOp>(op->getLoc(), width, twicePaddingWidth);
|
builder.create<MulIOp>(op->getLoc(), paddingWidth, cI2);
|
||||||
auto filterWidthMinusOne = builder.create<SubIOp>(op->getLoc(), filterWidth, cI1);
|
auto widthPlusTwicePadding =
|
||||||
auto dilationFilterWidth = builder.create<MulIOp>(op->getLoc(), dilationWidth, filterWidthMinusOne);
|
builder.create<SubIOp>(op->getLoc(), width, twicePaddingWidth);
|
||||||
auto outWidthUnstridedPlusOne = builder.create<SubIOp>(op->getLoc(), widthPlusTwicePadding, dilationFilterWidth);
|
auto filterWidthMinusOne =
|
||||||
auto outWidthUnstrided = builder.create<SubIOp>(op->getLoc(), outWidthUnstridedPlusOne, cI1);
|
builder.create<SubIOp>(op->getLoc(), filterWidth, cI1);
|
||||||
auto outWidthMinusOne = builder.create<UnsignedDivIOp>(op->getLoc(), outWidthUnstrided, strideWidth);
|
auto dilationFilterWidth = builder.create<MulIOp>(
|
||||||
|
op->getLoc(), dilationWidth, filterWidthMinusOne);
|
||||||
|
auto outWidthUnstridedPlusOne = builder.create<SubIOp>(
|
||||||
|
op->getLoc(), widthPlusTwicePadding, dilationFilterWidth);
|
||||||
|
auto outWidthUnstrided =
|
||||||
|
builder.create<SubIOp>(op->getLoc(), outWidthUnstridedPlusOne, cI1);
|
||||||
|
auto outWidthMinusOne = builder.create<UnsignedDivIOp>(
|
||||||
|
op->getLoc(), outWidthUnstrided, strideWidth);
|
||||||
auto outWidth = builder.create<AddIOp>(op->getLoc(), outWidthMinusOne, cI1);
|
auto outWidth = builder.create<AddIOp>(op->getLoc(), outWidthMinusOne, cI1);
|
||||||
// Output shape
|
// Output shape
|
||||||
auto shape = builder.create<tensor::FromElementsOp>(
|
auto shape = builder.create<tensor::FromElementsOp>(
|
||||||
|
@ -146,20 +165,23 @@ public:
|
||||||
rewriter.create<tensor::DimOp>(op.getLoc(), op.filter(), 2);
|
rewriter.create<tensor::DimOp>(op.getLoc(), op.filter(), 2);
|
||||||
Value filterKW =
|
Value filterKW =
|
||||||
rewriter.create<tensor::DimOp>(op.getLoc(), op.filter(), 3);
|
rewriter.create<tensor::DimOp>(op.getLoc(), op.filter(), 3);
|
||||||
Value matchingCin =
|
Value matchingCin = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq,
|
||||||
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, inputCin, filterCin);
|
inputCin, filterCin);
|
||||||
Value validFilterH =
|
Value validFilterH = rewriter.create<CmpIOp>(
|
||||||
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::uge, inputH, filterKH);
|
op.getLoc(), CmpIPredicate::uge, inputH, filterKH);
|
||||||
Value validFilterW =
|
Value validFilterW = rewriter.create<CmpIOp>(
|
||||||
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::uge, inputW, filterKW);
|
op.getLoc(), CmpIPredicate::uge, inputW, filterKW);
|
||||||
Value witnessCin = rewriter.create<shape::CstrRequireOp>(
|
Value witnessCin = rewriter.create<shape::CstrRequireOp>(
|
||||||
op.getLoc(), matchingCin, "input and filter in-channels must be equal");
|
op.getLoc(), matchingCin, "input and filter in-channels must be equal");
|
||||||
Value witnessFilterH = rewriter.create<shape::CstrRequireOp>(
|
Value witnessFilterH = rewriter.create<shape::CstrRequireOp>(
|
||||||
op.getLoc(), validFilterH, "input height must be greater than or equal to filter KH-dimension");
|
op.getLoc(), validFilterH,
|
||||||
|
"input height must be greater than or equal to filter KH-dimension");
|
||||||
Value witnessFilterW = rewriter.create<shape::CstrRequireOp>(
|
Value witnessFilterW = rewriter.create<shape::CstrRequireOp>(
|
||||||
op.getLoc(), validFilterW, "input width must be greater than or equal to filter KW-dimension");
|
op.getLoc(), validFilterW,
|
||||||
|
"input width must be greater than or equal to filter KW-dimension");
|
||||||
Value assumingAll = rewriter.create<shape::AssumingAllOp>(
|
Value assumingAll = rewriter.create<shape::AssumingAllOp>(
|
||||||
op.getLoc(), witnessCin.getType(), ValueRange({witnessCin, witnessFilterH, witnessFilterW}));
|
op.getLoc(), witnessCin.getType(),
|
||||||
|
ValueRange({witnessCin, witnessFilterH, witnessFilterW}));
|
||||||
auto assuming = rewriter.create<shape::AssumingOp>(
|
auto assuming = rewriter.create<shape::AssumingOp>(
|
||||||
op.getLoc(), ArrayRef<Type>{op.getType()}, assumingAll);
|
op.getLoc(), ArrayRef<Type>{op.getType()}, assumingAll);
|
||||||
|
|
||||||
|
@ -173,11 +195,17 @@ public:
|
||||||
Value initTensor =
|
Value initTensor =
|
||||||
rewriter.create<tcp::SplattedOp>(op.getLoc(), op.getType(), c0, shape);
|
rewriter.create<tcp::SplattedOp>(op.getLoc(), op.getType(), c0, shape);
|
||||||
|
|
||||||
|
// Unit strides and dilations.
|
||||||
|
auto strides = rewriter.getI64VectorAttr({1, 1});
|
||||||
|
auto dilations = rewriter.getI64VectorAttr({1, 1});
|
||||||
|
|
||||||
// Create the ConvNCHW.
|
// Create the ConvNCHW.
|
||||||
auto conv2dNCHW = rewriter.create<linalg::ConvNCHWOp>(
|
auto conv2dNCHW = rewriter.create<linalg::Conv2DNchwOp>(
|
||||||
op.getLoc(), TypeRange(op.getType()),
|
op.getLoc(), TypeRange(op.getType()),
|
||||||
ValueRange({op.in(), op.filter()}), ValueRange(initTensor));
|
ValueRange({op.in(), op.filter()}), ValueRange(initTensor), strides,
|
||||||
rewriter.create<shape::AssumingYieldOp>(op.getLoc(), conv2dNCHW.getResults());
|
dilations);
|
||||||
|
rewriter.create<shape::AssumingYieldOp>(op.getLoc(),
|
||||||
|
conv2dNCHW.getResults());
|
||||||
|
|
||||||
// Finally, replace with the results of the shape.assuming
|
// Finally, replace with the results of the shape.assuming
|
||||||
rewriter.replaceOp(op, assuming.getResults());
|
rewriter.replaceOp(op, assuming.getResults());
|
||||||
|
|
|
@ -236,6 +236,9 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "only support stride [1, 1]");
|
return rewriter.notifyMatchFailure(op, "only support stride [1, 1]");
|
||||||
if (!isConstantIntListMatching(dilation, expects))
|
if (!isConstantIntListMatching(dilation, expects))
|
||||||
return rewriter.notifyMatchFailure(op, "only support dilation [1, 1]");
|
return rewriter.notifyMatchFailure(op, "only support dilation [1, 1]");
|
||||||
|
// Unit strides and dilations.
|
||||||
|
auto linalgStrides = rewriter.getI64VectorAttr({1, 1});
|
||||||
|
auto linalgDilations = rewriter.getI64VectorAttr({1, 1});
|
||||||
|
|
||||||
if (!op.bias().getType().isa<Torch::NoneType>())
|
if (!op.bias().getType().isa<Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(op, "only support None bias");
|
return rewriter.notifyMatchFailure(op, "only support None bias");
|
||||||
|
@ -288,9 +291,9 @@ public:
|
||||||
|
|
||||||
Value conv2d =
|
Value conv2d =
|
||||||
rewriter
|
rewriter
|
||||||
.create<linalg::ConvNCHWOp>(loc, ranked4DTensorType,
|
.create<linalg::Conv2DNchwOp>(
|
||||||
ValueRange{paddedInput, weight},
|
loc, ranked4DTensorType, ValueRange{paddedInput, weight},
|
||||||
ValueRange{initTensor0})
|
ValueRange{initTensor0}, linalgStrides, linalgDilations)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
|
||||||
|
|
|
@ -19,6 +19,8 @@ add_npcomp_library(NPCOMPRefBackend
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRLinalg
|
MLIRLinalg
|
||||||
MLIRLinalgToLLVM
|
MLIRLinalgToLLVM
|
||||||
|
MLIRMathToLLVM
|
||||||
|
MLIRMemRefToLLVM
|
||||||
MLIRSCFToStandard
|
MLIRSCFToStandard
|
||||||
MLIRSCFTransforms
|
MLIRSCFTransforms
|
||||||
MLIRShapeToStandard
|
MLIRShapeToStandard
|
||||||
|
|
|
@ -9,7 +9,11 @@
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
#include "npcomp/RefBackend/RefBackend.h"
|
#include "npcomp/RefBackend/RefBackend.h"
|
||||||
|
|
||||||
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||||
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||||
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
|
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
|
||||||
|
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||||
|
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
@ -702,6 +706,8 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
||||||
populateCompilerRuntimePatterns(module, patterns, converter);
|
populateCompilerRuntimePatterns(module, patterns, converter);
|
||||||
target.addLegalOp<ModuleOp>();
|
target.addLegalOp<ModuleOp>();
|
||||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||||
|
populateMathToLLVMConversionPatterns(converter, patterns);
|
||||||
|
populateMemRefToLLVMConversionPatterns(converter, patterns);
|
||||||
patterns.add<LowerModuleMetadata>(context);
|
patterns.add<LowerModuleMetadata>(context);
|
||||||
|
|
||||||
// TODO: Move these "std to std" legalizations to their own pass if we grow
|
// TODO: Move these "std to std" legalizations to their own pass if we grow
|
||||||
|
|
|
@ -207,7 +207,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
||||||
pm.addNestedPass<FuncOp>(createConvertElementwiseToLinalgPass());
|
pm.addNestedPass<FuncOp>(createConvertElementwiseToLinalgPass());
|
||||||
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
pm.addNestedPass<FuncOp>(createLinalgFusionOfTensorOpsPass());
|
pm.addNestedPass<FuncOp>(createLinalgElementwiseOpFusionPass());
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ func @tcf_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf
|
||||||
// CHECK: %[[OUTWIDTH:.*]] = addi %[[WIDTHV0M1]], %[[C1]] : index
|
// CHECK: %[[OUTWIDTH:.*]] = addi %[[WIDTHV0M1]], %[[C1]] : index
|
||||||
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[BATCH]], %[[OUTCHANNELS]], %[[OUTHEIGHT]], %[[OUTWIDTH]] : tensor<4xindex>
|
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[BATCH]], %[[OUTCHANNELS]], %[[OUTHEIGHT]], %[[OUTWIDTH]] : tensor<4xindex>
|
||||||
// CHECK: %[[INIT_TENSOR:.*]] = tcp.splatted %[[C0F32]], %[[SHAPE]] : (f32, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
// CHECK: %[[INIT_TENSOR:.*]] = tcp.splatted %[[C0F32]], %[[SHAPE]] : (f32, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[CONVNCHW:.*]] = linalg.conv_2d_nchw ins(%[[IN]], %[[FILTER]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%[[INIT_TENSOR]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
// CHECK: %[[CONVNCHW:.*]] = linalg.conv_2d_nchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[IN]], %[[FILTER]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%[[INIT_TENSOR]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: shape.assuming_yield %[[CONVNCHW]] : tensor<?x?x?x?xf32>
|
// CHECK: shape.assuming_yield %[[CONVNCHW]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: return %[[RET:.*]] : tensor<?x?x?x?xf32>
|
// CHECK: return %[[RET:.*]] : tensor<?x?x?x?xf32>
|
||||||
|
|
|
@ -42,7 +42,7 @@ llvm_add_library(
|
||||||
${_OBJECTS}
|
${_OBJECTS}
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
# Public dependencies on the MLIR public API and impl shared libraries.
|
# Public dependencies on the MLIR public API and impl shared libraries.
|
||||||
MLIRPublicAPI
|
MLIRPythonCAPI
|
||||||
MLIR
|
MLIR
|
||||||
${_DEPS}
|
${_DEPS}
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue