Bump llvm-project to 0524a09cc7e1a0797982feacf505825231efbee7

- renames of OwningRewritePatternList -> RewritePatternSet
  - also `insert` to `add`
- RewritePatternSet holds a context now
- memref dialect split from std
pull/197/head
Sean Silva 2021-03-23 14:16:23 -07:00
parent 4591884d06
commit 99178a167d
30 changed files with 165 additions and 146 deletions

@ -1 +1 @@
Subproject commit e31c77b1827fa4dd3511f21af11cfab18ecf6d38
Subproject commit 0524a09cc7e1a0797982feacf505825231efbee7

View File

@ -33,8 +33,8 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
# CHECK-LABEL: func @conv2d_fwd(
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
# CHECK: %[[VAL_1:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4x16x3x3xf32>
# CHECK: %[[VAL_2:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4xf32>
# CHECK: %[[VAL_1:.*]] = constant opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>
# CHECK: %[[VAL_2:.*]] = constant opaque<"_", "0xDEADBEEF"> : tensor<4xf32>
# CHECK: %[[VAL_3:.*]] = constant 1 : i64
# CHECK: %[[VAL_4:.*]] = constant 1 : i64
# CHECK: %[[VAL_5:.*]] = constant 0 : i64

View File

@ -14,15 +14,14 @@
namespace mlir {
class MLIRContext;
class OwningRewritePatternList;
class RewritePatternSet;
namespace NPCOMP {
/// Populates patterns for converting core ATen ops to TCF. These patterns
/// cover core arithmetic ops that are on the order of 1:1 representationally.
/// More advanced patterns are managed elsewhere.
void populateCoreATenToTCFPatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
void populateCoreATenToTCFPatterns(RewritePatternSet &patterns);
} // namespace NPCOMP
} // namespace mlir

View File

@ -19,8 +19,7 @@ namespace NPCOMP {
// Conversion patterns
//===----------------------------------------------------------------------===//
void populateBasicpyToStdPrimitiveOpPatterns(
MLIRContext *context, OwningRewritePatternList &patterns);
void populateBasicpyToStdPrimitiveOpPatterns(RewritePatternSet &patterns);
} // namespace NPCOMP
} // namespace mlir

View File

@ -23,8 +23,7 @@ class ATenDialect;
namespace mlir {
void populateATenToStdPatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
void populateATenToStdPatterns(RewritePatternSet &patterns);
} // namespace mlir

View File

@ -27,8 +27,8 @@ public:
void runOnOperation() override {
FuncOp funcOp = getOperation();
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
populateCoreATenToTCFPatterns(context, patterns);
RewritePatternSet patterns(context);
populateCoreATenToTCFPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
};

View File

@ -149,12 +149,11 @@ class ConvertATenConv2d : public OpRewritePattern<aten::Conv2dOp> {
} // namespace
void mlir::NPCOMP::populateCoreATenToTCFPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<ConvertATenAdd>(context);
patterns.insert<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
patterns.insert<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(
context);
patterns.insert<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
patterns.insert<ConvertATenConv2d>(context);
void mlir::NPCOMP::populateCoreATenToTCFPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<ConvertATenAdd>(context);
patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context);
patterns.add<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
patterns.add<ConvertATenConv2d>(context);
}

View File

@ -28,8 +28,8 @@ public:
FrozenRewritePatternList getPatterns() {
auto *context = &getContext();
OwningRewritePatternList patterns;
populateBasicpyToStdPrimitiveOpPatterns(context, patterns);
RewritePatternSet patterns(context);
populateBasicpyToStdPrimitiveOpPatterns(patterns);
return std::move(patterns);
}
};

View File

@ -242,8 +242,9 @@ public:
} // namespace
void mlir::NPCOMP::populateBasicpyToStdPrimitiveOpPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<NumericBinaryExpr>(context);
patterns.insert<NumericCompare>(context);
patterns.insert<NumericToI1>(context);
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<NumericBinaryExpr>(context);
patterns.add<NumericCompare>(context);
patterns.add<NumericToI1>(context);
}

View File

@ -52,9 +52,9 @@ class ConvertNumpyToTCF : public ConvertNumpyToTCFBase<ConvertNumpyToTCF> {
FuncOp func = getOperation();
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<ConvertBinaryBuiltinUfuncCallOp<tcf::AddOp>>(context,
"numpy.add");
RewritePatternSet patterns(context);
patterns.add<ConvertBinaryBuiltinUfuncCallOp<tcf::AddOp>>(context,
"numpy.add");
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
};

View File

@ -15,5 +15,6 @@ add_npcomp_conversion_library(NPCOMPTCFToLinalg
MLIRPass
MLIRTransforms
MLIRShape
MLIRMemRef
NPCOMPTCFDialect
)

View File

@ -14,6 +14,8 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
// TODO: Remove when memref.dim is split into tensor.dim for the tensor case.
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
@ -27,8 +29,8 @@ static SmallVector<Value, 6> bypassResultShapes(Operation *op,
OpBuilder &builder) {
if (auto matmul = dyn_cast<tcf::MatmulOp>(op)) {
auto lhsRows = builder.create<DimOp>(op->getLoc(), matmul.lhs(), 0);
auto rhsCols = builder.create<DimOp>(op->getLoc(), matmul.rhs(), 1);
auto lhsRows = builder.create<memref::DimOp>(op->getLoc(), matmul.lhs(), 0);
auto rhsCols = builder.create<memref::DimOp>(op->getLoc(), matmul.rhs(), 1);
auto shape = builder.create<tensor::FromElementsOp>(
op->getLoc(), ValueRange({lhsRows, rhsCols}));
return {shape};
@ -49,12 +51,18 @@ static SmallVector<Value, 6> bypassResultShapes(Operation *op,
auto dilationWidth = dilation;
auto paddingHeight = padding;
auto paddingWidth = padding;
auto batch = builder.create<DimOp>(op->getLoc(), conv2dNCHW.in(), 0);
auto height = builder.create<DimOp>(op->getLoc(), conv2dNCHW.in(), 2);
auto width = builder.create<DimOp>(op->getLoc(), conv2dNCHW.in(), 3);
auto filterOutChannels = builder.create<DimOp>(op->getLoc(), conv2dNCHW.filter(), 0);
auto filterHeight = builder.create<DimOp>(op->getLoc(), conv2dNCHW.filter(), 2);
auto filterWidth = builder.create<DimOp>(op->getLoc(), conv2dNCHW.filter(), 3);
auto batch =
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.in(), 0);
auto height =
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.in(), 2);
auto width =
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.in(), 3);
auto filterOutChannels =
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.filter(), 0);
auto filterHeight =
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.filter(), 2);
auto filterWidth =
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.filter(), 3);
// Output height
auto twicePaddingHeight = builder.create<MulIOp>(op->getLoc(), paddingHeight, cI2);
auto heightPlusTwicePadding = builder.create<SubIOp>(op->getLoc(), height, twicePaddingHeight);
@ -91,8 +99,8 @@ public:
LogicalResult matchAndRewrite(tcf::MatmulOp op,
PatternRewriter &rewriter) const override {
// Create the constraints, and the assuming region.
Value lhsK = rewriter.create<DimOp>(op.getLoc(), op.lhs(), 1);
Value rhsK = rewriter.create<DimOp>(op.getLoc(), op.rhs(), 0);
Value lhsK = rewriter.create<memref::DimOp>(op.getLoc(), op.lhs(), 1);
Value rhsK = rewriter.create<memref::DimOp>(op.getLoc(), op.rhs(), 0);
Value matchingK =
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK);
Value witness = rewriter.create<shape::CstrRequireOp>(
@ -130,12 +138,15 @@ public:
LogicalResult matchAndRewrite(tcf::ConvNCHWOp op,
PatternRewriter &rewriter) const override {
// Create the constraints, and the assuming region.
Value inputCin = rewriter.create<DimOp>(op.getLoc(), op.in(), 1);
Value inputH = rewriter.create<DimOp>(op.getLoc(), op.in(), 2);
Value inputW = rewriter.create<DimOp>(op.getLoc(), op.in(), 3);
Value filterCin = rewriter.create<DimOp>(op.getLoc(), op.filter(), 1);
Value filterKH = rewriter.create<DimOp>(op.getLoc(), op.filter(), 2);
Value filterKW = rewriter.create<DimOp>(op.getLoc(), op.filter(), 3);
Value inputCin = rewriter.create<memref::DimOp>(op.getLoc(), op.in(), 1);
Value inputH = rewriter.create<memref::DimOp>(op.getLoc(), op.in(), 2);
Value inputW = rewriter.create<memref::DimOp>(op.getLoc(), op.in(), 3);
Value filterCin =
rewriter.create<memref::DimOp>(op.getLoc(), op.filter(), 1);
Value filterKH =
rewriter.create<memref::DimOp>(op.getLoc(), op.filter(), 2);
Value filterKW =
rewriter.create<memref::DimOp>(op.getLoc(), op.filter(), 3);
Value matchingCin =
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, inputCin, filterCin);
Value validFilterH =
@ -190,9 +201,9 @@ public:
FrozenRewritePatternList getPatterns() {
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<ConvertMatmul>(context);
patterns.insert<ConvertConvNCHW>(context);
RewritePatternSet patterns(context);
patterns.add<ConvertMatmul>(context);
patterns.add<ConvertConvNCHW>(context);
return std::move(patterns);
}
};

View File

@ -145,12 +145,12 @@ public:
FrozenRewritePatternList getPatterns() {
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<ConvertUnaryElementwise<tcf::ExpOp>,
ConvertUnaryElementwise<tcf::TanhOp>>(context);
patterns.insert<ConvertBinaryElementwise<tcf::AddOp>,
ConvertBinaryElementwise<tcf::MaxOp>,
ConvertBinaryElementwise<tcf::MulOp>>(context);
RewritePatternSet patterns(context);
patterns.add<ConvertUnaryElementwise<tcf::ExpOp>,
ConvertUnaryElementwise<tcf::TanhOp>>(context);
patterns.add<ConvertBinaryElementwise<tcf::AddOp>,
ConvertBinaryElementwise<tcf::MaxOp>,
ConvertBinaryElementwise<tcf::MulOp>>(context);
return std::move(patterns);
}
};

View File

@ -36,7 +36,7 @@ public:
// NOTE: We are keeping this pass around, even though it currently does
// nothing, in order to avoid having to reintroduce the same
// boilerplate.
OwningRewritePatternList patterns;
RewritePatternSet patterns(getOperation().getContext());
return std::move(patterns);
}
};

View File

@ -16,6 +16,8 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/EDSC/Builders.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/EDSC/Builders.h"
@ -85,7 +87,8 @@ static Value memRefTypeCast(PatternRewriter &builder, Value val) {
if (auto memrefTy = type.dyn_cast<MemRefType>()) {
MemRefType newType = getShapeErasedMemRefType(memrefTy);
return builder.create<MemRefCastOp>(val.getLoc(), val, newType).getResult();
return builder.create<memref::CastOp>(val.getLoc(), val, newType)
.getResult();
}
if (auto tensorTy = type.dyn_cast<TensorType>()) {
auto memRefType = mlir::MemRefType::get(tensorTy.getShape(),
@ -224,7 +227,7 @@ public:
MemRefType memRefResultTy = mlir::MemRefType::get(
tensorResultTy.getShape(), tensorResultTy.getElementType(), {}, 0);
Value result = rewriter.create<AllocOp>(loc, memRefResultTy);
Value result = rewriter.create<memref::AllocOp>(loc, memRefResultTy);
Value lhs = memRefTypeCast(rewriter, operands[0]);
Value rhs = memRefTypeCast(rewriter, operands[1]);
using namespace edsc;
@ -232,7 +235,7 @@ public:
ScopedContext scope(rewriter, loc);
Value zero = intrinsics::std_constant_index(0);
MemRefBoundsCapture vRes(result), vLHS(lhs), vRHS(rhs);
StdIndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
MemRefIndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
Value M(vRes.ub(0));
if (vRes.rank() == 1) {
affineLoopNestBuilder({zero}, {M}, 1, [&](ValueRange ivs) {
@ -320,7 +323,8 @@ LogicalResult rewriteWithVoidFunctionCallExplicit(
// assume memRefResultTy has known shape, so we don't need any
// dynamic dimensions for the alloc.
assert(memRefResultTy.hasStaticShape());
Value allocVal = rewriter.create<AllocOp>(op->getLoc(), memRefResultTy);
Value allocVal =
rewriter.create<memref::AllocOp>(op->getLoc(), memRefResultTy);
Value castVal = memRefTypeCast(rewriter, allocVal);
newOps.push_back(castVal);
newResults.push_back(allocVal);
@ -867,9 +871,9 @@ struct ATenLoweringPass : public ATenLoweringBase<ATenLoweringPass> {
return type;
});
OwningRewritePatternList acapPatterns;
auto module = getOperation();
auto context = module.getContext();
RewritePatternSet acapPatterns(context);
// c++ patterns
acapPatterns.insert<
@ -885,16 +889,15 @@ struct ATenLoweringPass : public ATenLoweringBase<ATenLoweringPass> {
NllLoss2dBackwardOpConversion, LogSoftmaxOpConversion,
LogSoftmaxBackwardDataOpConversion, DivOpConversion>(context);
mlir::populateFuncOpTypeConversionPattern(acapPatterns, context,
typeConverter);
mlir::populateFuncOpTypeConversionPattern(acapPatterns, typeConverter);
// tablegen patterns
populateATenToStdPatterns(context, acapPatterns);
populateATenToStdPatterns(acapPatterns);
// Perform acap specific lowering.
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect, StandardOpsDialect,
scf::SCFDialect>();
scf::SCFDialect, memref::MemRefDialect>();
target.addLegalOp<AffineForOp, AffineApplyOp, AffineYieldOp>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType());

View File

@ -20,8 +20,7 @@ namespace {
} // namespace
namespace mlir {
void populateATenToStdPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, patterns);
void populateATenToStdPatterns(RewritePatternSet &patterns) {
populateWithGenerated(patterns);
}
} // namespace mlir

View File

@ -467,12 +467,12 @@ class ATenRecognizeKernelsPass
}
void runOnOperation() override {
auto &context = getContext();
KernelCallTransformer transformer(context);
MLIRContext *context = &getContext();
KernelCallTransformer transformer(*context);
transformer.addDialectOps<ATenDialect>();
OwningRewritePatternList patterns;
patterns.insert<RecognizeOpPattern>(&context, transformer);
RewritePatternSet patterns(context);
patterns.add<RecognizeOpPattern>(context, transformer);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();

View File

@ -15,6 +15,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
@ -73,7 +74,7 @@ public:
if (!v.getType().isa<MemRefType>())
llvm_unreachable("function returns non-memref");
if (!valueMap.count(v)) {
valueMap[v] = builder->create<AllocOp>(
valueMap[v] = builder->create<memref::AllocOp>(
op->getLoc(), v.getType().cast<MemRefType>());
}
v.replaceAllUsesWith(valueMap[v]);
@ -86,7 +87,7 @@ public:
auto fn = module.lookupSymbol<FuncOp>(callOp.callee());
if (fn && fn.use_empty())
erasedOps.insert(fn);
} else if (isa<AllocOp>(op)) {
} else if (isa<memref::AllocOp>(op)) {
Value v = op->getResult(0);
if (valueMap.count(v)) {
v.replaceAllUsesWith(valueMap[v]);

View File

@ -400,9 +400,9 @@ public:
} // namespace
void UnknownCastOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<ElideIdentityUnknownCast>(context);
void UnknownCastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ElideIdentityUnknownCast>(context);
}
#define GET_OP_CLASSES

View File

@ -81,9 +81,9 @@ public:
};
} // namespace
void CopyToTensorOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<ElideCreateRedundantArrayFromTensor>(context);
void CopyToTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ElideCreateRedundantArrayFromTensor>(context);
}
#define GET_OP_CLASSES

View File

@ -9,6 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -49,7 +50,7 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
builder.create<tensor::ExtractOp>(op.getLoc(), pad.upperExpansion(),
ValueRange({dimIndex}));
auto operandDim =
builder.create<DimOp>(op.getLoc(), pad.operand(), i);
builder.create<memref::DimOp>(op.getLoc(), pad.operand(), i);
auto totalExpansion =
builder.create<AddIOp>(op.getLoc(), lowerExpansion, upperExpansion);
auto outDim =
@ -117,7 +118,8 @@ public:
SmallVector<Value, 6> inputDimRequiresBroadcasting;
for (int i = 0, e = inputType.getRank(); i < e; i++) {
// Calculate the relevant extents.
Value inputExtent = rewriter.create<DimOp>(op.getLoc(), op.operand(), i);
Value inputExtent =
rewriter.create<memref::DimOp>(op.getLoc(), op.operand(), i);
inputDimRequiresBroadcasting.push_back(
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, inputExtent,
outputExtents[rankDiff + i]));
@ -152,10 +154,10 @@ public:
inductionVariables[rankDiff + i]);
inputIndices.push_back(select);
}
Value load =
rewriter.create<LoadOp>(op.getLoc(), inputMemref, inputIndices);
rewriter.create<StoreOp>(op.getLoc(), load, resultMemref,
inductionVariables);
Value load = rewriter.create<memref::LoadOp>(op.getLoc(), inputMemref,
inputIndices);
rewriter.create<memref::StoreOp>(op.getLoc(), load, resultMemref,
inductionVariables);
}
rewriter.replaceOp(op, resultMemref);
return success();
@ -202,16 +204,16 @@ public:
auto offset =
rewriter.create<tensor::ExtractOp>(op.getLoc(), op.lowerExpansion(),
ValueRange({dimIndex}));
auto size = rewriter.create<DimOp>(op.getLoc(), op.operand(), i);
auto size = rewriter.create<memref::DimOp>(op.getLoc(), op.operand(), i);
auto stride = c1;
offsets.push_back(offset);
sizes.push_back(size);
strides.push_back(stride);
}
rewriter.create<linalg::FillOp>(op.getLoc(), results[0], op.fillVal());
auto unpadded =
rewriter.create<SubViewOp>(op.getLoc(), results[0], ValueRange(offsets),
ValueRange(sizes), ValueRange(strides));
auto unpadded = rewriter.create<memref::SubViewOp>(
op.getLoc(), results[0], ValueRange(offsets), ValueRange(sizes),
ValueRange(strides));
auto inputMemref = operands[0];
rewriter.create<linalg::CopyOp>(op.getLoc(), inputMemref, unpadded);
rewriter.replaceOp(op, results);
@ -234,7 +236,7 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
BufferizeTypeConverter typeConverter;
OwningRewritePatternList patterns;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
@ -243,17 +245,18 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
// we can just open-code the extents for the alloc.
target.addLegalOp<refback::AllocMemRefOp>();
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
patterns.add<LowerBroadcastToToLoopsPattern>(typeConverter, context);
target.addIllegalOp<tcp::BroadcastToOp>();
patterns.insert<BufferizeSplattedOp>(typeConverter, context);
patterns.add<BufferizeSplattedOp>(typeConverter, context);
target.addIllegalOp<tcp::SplattedOp>();
patterns.insert<BufferizePadOp>(typeConverter, context);
patterns.add<BufferizePadOp>(typeConverter, context);
target.addIllegalOp<tcp::PadOp>();
target.addLegalDialect<linalg::LinalgDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<scf::SCFDialect>();
target.addLegalDialect<tensor::TensorDialect>();
target.addLegalDialect<memref::MemRefDialect>();
if (failed(applyPartialConversion(func, target, std::move(patterns))))
return signalPassFailure();

View File

@ -58,7 +58,8 @@ Type TorchDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
if (Type type = generatedTypeParser(getContext(), parser, keyword))
Type type;
if (generatedTypeParser(getContext(), parser, keyword, type).hasValue())
return type;
parser.emitError(parser.getNameLoc(), "invalid 'torch' type: `")

View File

@ -73,10 +73,10 @@ class PrepareForGlobalizeObjectGraphPass
SymbolTable symbolTable(getOperation());
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<ConvertPrimCallMethodToCall>(context, symbolTable);
RewritePatternSet patterns(context);
patterns.add<ConvertPrimCallMethodToCall>(context, symbolTable);
CallIndirectOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<EraseUnusedConstantOp>(context);
patterns.add<EraseUnusedConstantOp>(context);
// Use applyPatternsAndFoldGreedily because the CallIndirectOp folding
// makes the ConstantOp unused, which does not work with the visitation
@ -99,7 +99,7 @@ class PrepareForGlobalizeObjectGraphPass
target.addIllegalOp<CallIndirectOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
OwningRewritePatternList dummyPatterns;
RewritePatternSet dummyPatterns(context);
if (failed(applyFullConversion(getOperation(), target,
std::move(dummyPatterns)))) {

View File

@ -200,7 +200,7 @@ static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, Type type,
}
static void populateCompilerRuntimePatterns(ModuleOp module,
OwningRewritePatternList &patterns,
RewritePatternSet &patterns,
LLVMTypeConverter &typeConverter) {
auto *context = module.getContext();
OpBuilder builder(module.getBodyRegion());
@ -212,7 +212,7 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
/*isVarArg=*/false);
LLVMFuncOp abortIfFunc = createCompilerRuntimeFuncDecl(
"abort_if", abortIfFuncTy, builder, module.getLoc());
patterns.insert<AbortIfOpCompilerRuntimeLowering>(abortIfFunc);
patterns.add<AbortIfOpCompilerRuntimeLowering>(abortIfFunc);
}
}
@ -701,16 +701,16 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
LLVMTypeConverter converter(context);
OwningRewritePatternList patterns;
RewritePatternSet patterns(context);
LLVMConversionTarget target(*context);
populateCompilerRuntimePatterns(module, patterns, converter);
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
populateStdToLLVMConversionPatterns(converter, patterns);
patterns.insert<LowerModuleMetadata>(context);
patterns.add<LowerModuleMetadata>(context);
// TODO: Move these "std to std" legalizations to their own pass if we grow
// lots of these patterns.
populateExpandTanhPattern(patterns, context);
populateExpandTanhPattern(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
return signalPassFailure();

View File

@ -9,6 +9,7 @@
#include "PassDetail.h"
#include "npcomp/RefBackend/RefBackend.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Verifier.h"
@ -353,8 +354,8 @@ public:
for (auto newAndOldArg :
llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) {
std::tie(newArg, oldArg) = newAndOldArg;
auto memref = rewriter.create<MemRefCastOp>(op.getLoc(), newArg,
oldArg.getType());
auto memref = rewriter.create<memref::CastOp>(op.getLoc(), newArg,
oldArg.getType());
rewriter.replaceUsesOfBlockArgument(oldArg, memref);
}
});
@ -390,23 +391,24 @@ static LogicalResult doDialectConversion(ModuleOp module) {
[](OpBuilder &builder, UnrankedMemRefType type, ValueRange inputs,
Location loc) -> Value {
assert(inputs.size() == 1);
return builder.create<MemRefCastOp>(
return builder.create<memref::CastOp>(
loc, inputs[0], getABIMemrefType(inputs[0].getType()));
});
OwningRewritePatternList patterns;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<refbackrt::RefbackrtDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<memref::MemRefDialect>();
patterns.insert<FuncOpSignatureConversion>(typeConverter, context);
patterns.add<FuncOpSignatureConversion>(typeConverter, context);
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
patterns.insert<RewriteReturnOp>(typeConverter, context);
patterns.add<RewriteReturnOp>(typeConverter, context);
target.addDynamicallyLegalOp<ReturnOp>(
[&](ReturnOp op) { return typeConverter.isLegal(op); });
patterns.insert<LowerAssertOp>(context);
patterns.add<LowerAssertOp>(context);
target.addIllegalOp<AssertOp>();
return applyPartialConversion(module, target, std::move(patterns));

View File

@ -32,6 +32,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
@ -105,7 +106,8 @@ public:
dynamicExtents.push_back(extent);
}
}
rewriter.replaceOpWithNewOp<AllocOp>(op, memrefType, dynamicExtents);
rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
dynamicExtents);
return success();
}
};
@ -118,12 +120,12 @@ class LowerAllocMemRefOps
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<LowerAllocMemRefOp>(context);
RewritePatternSet patterns(context);
patterns.add<LowerAllocMemRefOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<refback::AllocMemRefOp>();
target.addLegalOp<tensor::ExtractOp>();
target.addLegalOp<AllocOp>();
target.addLegalOp<memref::AllocOp>();
target.addLegalOp<ConstantOp>();
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
return signalPassFailure();
@ -173,7 +175,7 @@ struct RestrictedCanonicalizer
}
// Collect all canonicalization patterns from ops in the included dialects.
OwningRewritePatternList patterns;
RewritePatternSet patterns(context);
for (AbstractOperation *op : context->getRegisteredOperations())
if (dialectsToCanonicalize.count(&op->dialect))
op->getCanonicalizationPatterns(patterns, context);
@ -235,7 +237,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
// rather than a single mega dialect conversion pass.
//
// This means that intermediate steps have source/target materializations
// (tensor_load / tensor_to_memref) in the IR.
// (memref.tensor_load / memref.buffer_cast) in the IR.
// Run tensor constant bufferization.
// This pass has to run on a module op, and so does the final

View File

@ -6,13 +6,13 @@
// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHSK:.*]] = dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[RHSK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[LHSK:.*]] = memref.dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[RHSK:.*]] = memref.dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[KEQUAL:.*]] = cmpi eq, %[[LHSK]], %[[RHSK]] : index
// CHECK: %[[WINESS:.*]] = shape.cstr_require %[[KEQUAL]], "mismatching contracting dimension for matmul"
// CHECK: %[[RET:.*]] = shape.assuming %[[WINESS]] -> (tensor<?x?xf32>) {
// CHECK: %[[LHSROWS:.*]] = dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[RHSCOLS:.*]] = dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[LHSROWS:.*]] = memref.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[RHSCOLS:.*]] = memref.dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[LHSROWS]], %[[RHSCOLS]] : tensor<2xindex>
// CHECK: %[[INIT_TENSOR:.*]] = tcp.splatted %[[C0F32]], %[[SHAPE]] : (f32, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[INIT_TENSOR]] : tensor<?x?xf32>) -> tensor<?x?xf32>
@ -32,12 +32,12 @@ func @tcf_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[C3:.*]] = constant 3 : index
// CHECK: %[[CHANNELS:.*]] = dim %[[IN]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[HEIGHT:.*]] = dim %[[IN]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[WIDTH:.*]] = dim %[[IN]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERCHANNELS:.*]] = dim %[[FILTER]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERHEIGHT:.*]] = dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERWIDTH:.*]] = dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[CHANNELS:.*]] = memref.dim %[[IN]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[HEIGHT:.*]] = memref.dim %[[IN]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[WIDTH:.*]] = memref.dim %[[IN]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERCHANNELS:.*]] = memref.dim %[[FILTER]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERHEIGHT:.*]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERWIDTH:.*]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[CMPCHANNELS:.*]] = cmpi eq, %[[CHANNELS]], %[[FILTERCHANNELS]] : index
// CHECK: %[[CMPHEIGHT:.*]] = cmpi uge, %[[HEIGHT]], %[[FILTERHEIGHT]] : index
// CHECK: %[[CMPWIDTH:.*]] = cmpi uge, %[[WIDTH]], %[[FILTERWIDTH]] : index
@ -46,12 +46,12 @@ func @tcf_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf
// CHECK: %[[CSTRWIDTH:.*]] = shape.cstr_require %[[CMPWIDTH]], "input width must be greater than or equal to filter KW-dimension"
// CHECK: %[[WITNESS:.*]] = shape.assuming_all %[[CSTRCHANNELS]], %[[CSTRHEIGHT]], %[[CSTRWIDTH]]
// CHECK: %[[RET:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?x?x?x?xf32>) {
// CHECK: %[[BATCH:.*]] = dim %[[IN]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK: %[[HEIGHT:.*]] = dim %[[IN]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[WIDTH:.*]] = dim %[[IN]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[OUTCHANNELS:.*]] = dim %[[FILTER]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERHEIGHT:.*]] = dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERWIDTH:.*]] = dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[BATCH:.*]] = memref.dim %[[IN]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK: %[[HEIGHT:.*]] = memref.dim %[[IN]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[WIDTH:.*]] = memref.dim %[[IN]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[OUTCHANNELS:.*]] = memref.dim %[[FILTER]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERHEIGHT:.*]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERWIDTH:.*]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILTERHEIGHTM1:.*]] = subi %[[FILTERHEIGHT]], %[[C1]] : index
// CHECK: %[[HEIGHTV0:.*]] = subi %[[HEIGHT]], %[[FILTERHEIGHTM1]] : index
// CHECK: %[[HEIGHTV0M1:.*]] = subi %[[HEIGHTV0]], %[[C1]] : index

View File

@ -19,7 +19,7 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[SHAPE]] : memref<?x?xf32>
// CHECK: linalg.fill(%[[RESULT]], %[[SPLAT_VAL]]) : memref<?x?xf32>, f32
// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT]] : memref<?x?xf32>
// CHECK: %[[RESULT_TENSOR:.*]] = memref.tensor_load %[[RESULT]] : memref<?x?xf32>
// CHECK: return %[[RESULT_TENSOR]] : tensor<?x?xf32>
func @tcp_splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
%0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32>
@ -31,14 +31,14 @@ func @tcp_splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK-SAME: %[[LOWER_EXPANSION:[a-zA-Z0-9]+]]: tensor<?xindex>,
// CHECK-SAME: %[[UPPER_EXPANSION:[a-zA-Z0-9]+]]: tensor<?xindex>,
// CHECK-SAME: %[[FILL_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<?xf32> {
// CHECK: %[[TENSOR_MREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
// CHECK: %[[LOWER_EXPANSION_MREF:.*]] = tensor_to_memref %[[LOWER_EXPANSION]] : memref<?xindex>
// CHECK: %[[UPPER_EXPANSION_MREF:.*]] = tensor_to_memref %[[UPPER_EXPANSION]] : memref<?xindex>
// CHECK: %[[TENSOR_MREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<?xf32>
// CHECK: %[[LOWER_EXPANSION_MREF:.*]] = memref.buffer_cast %[[LOWER_EXPANSION]] : memref<?xindex>
// CHECK: %[[UPPER_EXPANSION_MREF:.*]] = memref.buffer_cast %[[UPPER_EXPANSION]] : memref<?xindex>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[LOWER_EXTENT_D1:.*]] = tensor.extract %[[LOWER_EXPANSION]][%[[C0]]] : tensor<?xindex>
// CHECK: %[[UPPER_EXTENT_D1:.*]] = tensor.extract %[[UPPER_EXPANSION]][%[[C0]]] : tensor<?xindex>
// CHECK: %[[C0_0:.*]] = constant 0 : index
// CHECK: %[[D1:.*]] = dim %[[TENSOR]], %[[C0_0]] : tensor<?xf32>
// CHECK: %[[D1:.*]] = memref.dim %[[TENSOR]], %[[C0_0]] : tensor<?xf32>
// CHECK: %[[D1_EXPANSION:.*]] = addi %[[LOWER_EXTENT_D1]], %[[UPPER_EXTENT_D1]] : index
// CHECK: %[[D1_OUT:.*]] = addi %[[D1_EXPANSION]], %[[D1]] : index
// CHECK: %[[D1_OUT_TENSOR:.*]] = tensor.from_elements %[[D1_OUT]] : tensor<1xindex>
@ -47,11 +47,11 @@ func @tcp_splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: %[[C0_1:.*]] = constant 0 : index
// CHECK: %[[LOWER_EXTENT_D1_1:.*]] = tensor.extract %[[LOWER_EXPANSION]][%[[C0_1]]] : tensor<?xindex>
// CHECK: %[[C0_2:.*]] = constant 0 : index
// CHECK: %[[D1_1:.*]] = dim %[[TENSOR]], %[[C0_2]] : tensor<?xf32>
// CHECK: %[[D1_1:.*]] = memref.dim %[[TENSOR]], %[[C0_2]] : tensor<?xf32>
// CHECK: linalg.fill(%[[D1_OUT_MREF]], %[[FILL_VAL]]) : memref<?xf32>, f32
// CHECK: %[[SUBVIEW:.*]] = subview %[[D1_OUT_MREF]][%[[LOWER_EXTENT_D1_1]]] [%[[D1_1]]] [%[[C1]]] : memref<?xf32> to memref<?xf32, #map>
// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[D1_OUT_MREF]][%[[LOWER_EXTENT_D1_1]]] [%[[D1_1]]] [%[[C1]]] : memref<?xf32> to memref<?xf32, #map>
// CHECK: linalg.copy(%0, %[[SUBVIEW]]) : memref<?xf32>, memref<?xf32, #map>
// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[D1_OUT_MREF]] : memref<?xf32>
// CHECK: %[[RESULT_TENSOR:.*]] = memref.tensor_load %[[D1_OUT_MREF]] : memref<?xf32>
// CHECK: return %[[RESULT_TENSOR]] : tensor<?xf32>
func @tcp_pad(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>, %arg3: f32) -> tensor<?xf32> {
%0 = tcp.pad %arg0, %arg1, %arg2, %arg3 : (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, f32) -> tensor<?xf32>

View File

@ -37,11 +37,11 @@ func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-LABEL: func @use_of_arg(%arg0: memref<*xf32>)
func @use_of_arg(%arg0: memref<?xf32>) {
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %arg0 : memref<*xf32> to memref<?xf32>
// CHECK-NEXT: %[[MEMREF:.*]] = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
%c0 = constant 0 : index
%0 = dim %arg0, %c0 : memref<?xf32>
%0 = memref.dim %arg0, %c0 : memref<?xf32>
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
// CHECK-NEXT: dim %[[MEMREF]], %[[C0]] : memref<?xf32>
// CHECK-NEXT: memref.dim %[[MEMREF]], %[[C0]] : memref<?xf32>
return
}
@ -49,12 +49,12 @@ func @use_of_arg(%arg0: memref<?xf32>) {
// CHECK-LABEL: func @multiple_blocks(%arg0: memref<*xf32>) -> memref<*xf32>
func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %arg0 : memref<*xf32> to memref<?xf32>
// CHECK-NEXT: %[[INMEMREF:.*]] = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
// CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref<?xf32>)
br ^bb1(%arg0: memref<?xf32>)
// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>):
^bb1(%bbarg: memref<?xf32>):
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref_cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref.cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
// CHECK-NEXT: return %[[OUTMEMREF]] : memref<*xf32>
return %bbarg : memref<?xf32>
}

View File

@ -17,7 +17,6 @@
#include "mlir/InitAllPasses.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "npcomp-c/InitLLVM.h"
#include "npcomp/InitAll.h"
#include "npcomp/RefBackend/JITHelpers/JITModule.h"