mirror of https://github.com/llvm/torch-mlir
Bump llvm-project to 0524a09cc7e1a0797982feacf505825231efbee7
- renames of OwningRewritePatternList -> RewritePatternSet - also `insert` to `add` - RewritePatternSet holds a context now - memref dialect split from stdpull/197/head
parent
4591884d06
commit
99178a167d
|
@ -1 +1 @@
|
|||
Subproject commit e31c77b1827fa4dd3511f21af11cfab18ecf6d38
|
||||
Subproject commit 0524a09cc7e1a0797982feacf505825231efbee7
|
|
@ -33,8 +33,8 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
|
|||
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
# CHECK-LABEL: func @conv2d_fwd(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
|
||||
# CHECK: %[[VAL_1:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4x16x3x3xf32>
|
||||
# CHECK: %[[VAL_2:.*]] = constant opaque<"", "0xDEADBEEF"> : tensor<4xf32>
|
||||
# CHECK: %[[VAL_1:.*]] = constant opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>
|
||||
# CHECK: %[[VAL_2:.*]] = constant opaque<"_", "0xDEADBEEF"> : tensor<4xf32>
|
||||
# CHECK: %[[VAL_3:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_4:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_5:.*]] = constant 0 : i64
|
||||
|
|
|
@ -14,15 +14,14 @@
|
|||
namespace mlir {
|
||||
|
||||
class MLIRContext;
|
||||
class OwningRewritePatternList;
|
||||
class RewritePatternSet;
|
||||
|
||||
namespace NPCOMP {
|
||||
|
||||
/// Populates patterns for converting core ATen ops to TCF. These patterns
|
||||
/// cover core arithmetic ops that are on the order of 1:1 representationally.
|
||||
/// More advanced patterns are managed elsewhere.
|
||||
void populateCoreATenToTCFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList &patterns);
|
||||
void populateCoreATenToTCFPatterns(RewritePatternSet &patterns);
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
|
|
@ -19,8 +19,7 @@ namespace NPCOMP {
|
|||
// Conversion patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void populateBasicpyToStdPrimitiveOpPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns);
|
||||
void populateBasicpyToStdPrimitiveOpPatterns(RewritePatternSet &patterns);
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
|
|
@ -23,8 +23,7 @@ class ATenDialect;
|
|||
|
||||
namespace mlir {
|
||||
|
||||
void populateATenToStdPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList &patterns);
|
||||
void populateATenToStdPatterns(RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ public:
|
|||
void runOnOperation() override {
|
||||
FuncOp funcOp = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
populateCoreATenToTCFPatterns(context, patterns);
|
||||
RewritePatternSet patterns(context);
|
||||
populateCoreATenToTCFPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
|
|
@ -149,12 +149,11 @@ class ConvertATenConv2d : public OpRewritePattern<aten::Conv2dOp> {
|
|||
|
||||
} // namespace
|
||||
|
||||
void mlir::NPCOMP::populateCoreATenToTCFPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<ConvertATenAdd>(context);
|
||||
patterns.insert<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
||||
patterns.insert<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(
|
||||
context);
|
||||
patterns.insert<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
|
||||
patterns.insert<ConvertATenConv2d>(context);
|
||||
void mlir::NPCOMP::populateCoreATenToTCFPatterns(RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<ConvertATenAdd>(context);
|
||||
patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
||||
patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context);
|
||||
patterns.add<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
|
||||
patterns.add<ConvertATenConv2d>(context);
|
||||
}
|
||||
|
|
|
@ -28,8 +28,8 @@ public:
|
|||
|
||||
FrozenRewritePatternList getPatterns() {
|
||||
auto *context = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
populateBasicpyToStdPrimitiveOpPatterns(context, patterns);
|
||||
RewritePatternSet patterns(context);
|
||||
populateBasicpyToStdPrimitiveOpPatterns(patterns);
|
||||
return std::move(patterns);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -242,8 +242,9 @@ public:
|
|||
} // namespace
|
||||
|
||||
void mlir::NPCOMP::populateBasicpyToStdPrimitiveOpPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<NumericBinaryExpr>(context);
|
||||
patterns.insert<NumericCompare>(context);
|
||||
patterns.insert<NumericToI1>(context);
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<NumericBinaryExpr>(context);
|
||||
patterns.add<NumericCompare>(context);
|
||||
patterns.add<NumericToI1>(context);
|
||||
}
|
||||
|
|
|
@ -52,9 +52,9 @@ class ConvertNumpyToTCF : public ConvertNumpyToTCFBase<ConvertNumpyToTCF> {
|
|||
FuncOp func = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConvertBinaryBuiltinUfuncCallOp<tcf::AddOp>>(context,
|
||||
"numpy.add");
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<ConvertBinaryBuiltinUfuncCallOp<tcf::AddOp>>(context,
|
||||
"numpy.add");
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
|
|
@ -15,5 +15,6 @@ add_npcomp_conversion_library(NPCOMPTCFToLinalg
|
|||
MLIRPass
|
||||
MLIRTransforms
|
||||
MLIRShape
|
||||
MLIRMemRef
|
||||
NPCOMPTCFDialect
|
||||
)
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
// TODO: Remove when memref.dim is split into tensor.dim for the tensor case.
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
|
||||
|
@ -27,8 +29,8 @@ static SmallVector<Value, 6> bypassResultShapes(Operation *op,
|
|||
OpBuilder &builder) {
|
||||
|
||||
if (auto matmul = dyn_cast<tcf::MatmulOp>(op)) {
|
||||
auto lhsRows = builder.create<DimOp>(op->getLoc(), matmul.lhs(), 0);
|
||||
auto rhsCols = builder.create<DimOp>(op->getLoc(), matmul.rhs(), 1);
|
||||
auto lhsRows = builder.create<memref::DimOp>(op->getLoc(), matmul.lhs(), 0);
|
||||
auto rhsCols = builder.create<memref::DimOp>(op->getLoc(), matmul.rhs(), 1);
|
||||
auto shape = builder.create<tensor::FromElementsOp>(
|
||||
op->getLoc(), ValueRange({lhsRows, rhsCols}));
|
||||
return {shape};
|
||||
|
@ -49,12 +51,18 @@ static SmallVector<Value, 6> bypassResultShapes(Operation *op,
|
|||
auto dilationWidth = dilation;
|
||||
auto paddingHeight = padding;
|
||||
auto paddingWidth = padding;
|
||||
auto batch = builder.create<DimOp>(op->getLoc(), conv2dNCHW.in(), 0);
|
||||
auto height = builder.create<DimOp>(op->getLoc(), conv2dNCHW.in(), 2);
|
||||
auto width = builder.create<DimOp>(op->getLoc(), conv2dNCHW.in(), 3);
|
||||
auto filterOutChannels = builder.create<DimOp>(op->getLoc(), conv2dNCHW.filter(), 0);
|
||||
auto filterHeight = builder.create<DimOp>(op->getLoc(), conv2dNCHW.filter(), 2);
|
||||
auto filterWidth = builder.create<DimOp>(op->getLoc(), conv2dNCHW.filter(), 3);
|
||||
auto batch =
|
||||
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.in(), 0);
|
||||
auto height =
|
||||
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.in(), 2);
|
||||
auto width =
|
||||
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.in(), 3);
|
||||
auto filterOutChannels =
|
||||
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.filter(), 0);
|
||||
auto filterHeight =
|
||||
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.filter(), 2);
|
||||
auto filterWidth =
|
||||
builder.create<memref::DimOp>(op->getLoc(), conv2dNCHW.filter(), 3);
|
||||
// Output height
|
||||
auto twicePaddingHeight = builder.create<MulIOp>(op->getLoc(), paddingHeight, cI2);
|
||||
auto heightPlusTwicePadding = builder.create<SubIOp>(op->getLoc(), height, twicePaddingHeight);
|
||||
|
@ -91,8 +99,8 @@ public:
|
|||
LogicalResult matchAndRewrite(tcf::MatmulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Create the constraints, and the assuming region.
|
||||
Value lhsK = rewriter.create<DimOp>(op.getLoc(), op.lhs(), 1);
|
||||
Value rhsK = rewriter.create<DimOp>(op.getLoc(), op.rhs(), 0);
|
||||
Value lhsK = rewriter.create<memref::DimOp>(op.getLoc(), op.lhs(), 1);
|
||||
Value rhsK = rewriter.create<memref::DimOp>(op.getLoc(), op.rhs(), 0);
|
||||
Value matchingK =
|
||||
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK);
|
||||
Value witness = rewriter.create<shape::CstrRequireOp>(
|
||||
|
@ -130,12 +138,15 @@ public:
|
|||
LogicalResult matchAndRewrite(tcf::ConvNCHWOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Create the constraints, and the assuming region.
|
||||
Value inputCin = rewriter.create<DimOp>(op.getLoc(), op.in(), 1);
|
||||
Value inputH = rewriter.create<DimOp>(op.getLoc(), op.in(), 2);
|
||||
Value inputW = rewriter.create<DimOp>(op.getLoc(), op.in(), 3);
|
||||
Value filterCin = rewriter.create<DimOp>(op.getLoc(), op.filter(), 1);
|
||||
Value filterKH = rewriter.create<DimOp>(op.getLoc(), op.filter(), 2);
|
||||
Value filterKW = rewriter.create<DimOp>(op.getLoc(), op.filter(), 3);
|
||||
Value inputCin = rewriter.create<memref::DimOp>(op.getLoc(), op.in(), 1);
|
||||
Value inputH = rewriter.create<memref::DimOp>(op.getLoc(), op.in(), 2);
|
||||
Value inputW = rewriter.create<memref::DimOp>(op.getLoc(), op.in(), 3);
|
||||
Value filterCin =
|
||||
rewriter.create<memref::DimOp>(op.getLoc(), op.filter(), 1);
|
||||
Value filterKH =
|
||||
rewriter.create<memref::DimOp>(op.getLoc(), op.filter(), 2);
|
||||
Value filterKW =
|
||||
rewriter.create<memref::DimOp>(op.getLoc(), op.filter(), 3);
|
||||
Value matchingCin =
|
||||
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, inputCin, filterCin);
|
||||
Value validFilterH =
|
||||
|
@ -190,9 +201,9 @@ public:
|
|||
|
||||
FrozenRewritePatternList getPatterns() {
|
||||
MLIRContext *context = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConvertMatmul>(context);
|
||||
patterns.insert<ConvertConvNCHW>(context);
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<ConvertMatmul>(context);
|
||||
patterns.add<ConvertConvNCHW>(context);
|
||||
return std::move(patterns);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -145,12 +145,12 @@ public:
|
|||
|
||||
FrozenRewritePatternList getPatterns() {
|
||||
MLIRContext *context = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConvertUnaryElementwise<tcf::ExpOp>,
|
||||
ConvertUnaryElementwise<tcf::TanhOp>>(context);
|
||||
patterns.insert<ConvertBinaryElementwise<tcf::AddOp>,
|
||||
ConvertBinaryElementwise<tcf::MaxOp>,
|
||||
ConvertBinaryElementwise<tcf::MulOp>>(context);
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<ConvertUnaryElementwise<tcf::ExpOp>,
|
||||
ConvertUnaryElementwise<tcf::TanhOp>>(context);
|
||||
patterns.add<ConvertBinaryElementwise<tcf::AddOp>,
|
||||
ConvertBinaryElementwise<tcf::MaxOp>,
|
||||
ConvertBinaryElementwise<tcf::MulOp>>(context);
|
||||
return std::move(patterns);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -36,7 +36,7 @@ public:
|
|||
// NOTE: We are keeping this pass around, even though it currently does
|
||||
// nothing, in order to avoid having to reintroduce the same
|
||||
// boilerplate.
|
||||
OwningRewritePatternList patterns;
|
||||
RewritePatternSet patterns(getOperation().getContext());
|
||||
return std::move(patterns);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Builders.h"
|
||||
|
@ -85,7 +87,8 @@ static Value memRefTypeCast(PatternRewriter &builder, Value val) {
|
|||
|
||||
if (auto memrefTy = type.dyn_cast<MemRefType>()) {
|
||||
MemRefType newType = getShapeErasedMemRefType(memrefTy);
|
||||
return builder.create<MemRefCastOp>(val.getLoc(), val, newType).getResult();
|
||||
return builder.create<memref::CastOp>(val.getLoc(), val, newType)
|
||||
.getResult();
|
||||
}
|
||||
if (auto tensorTy = type.dyn_cast<TensorType>()) {
|
||||
auto memRefType = mlir::MemRefType::get(tensorTy.getShape(),
|
||||
|
@ -224,7 +227,7 @@ public:
|
|||
MemRefType memRefResultTy = mlir::MemRefType::get(
|
||||
tensorResultTy.getShape(), tensorResultTy.getElementType(), {}, 0);
|
||||
|
||||
Value result = rewriter.create<AllocOp>(loc, memRefResultTy);
|
||||
Value result = rewriter.create<memref::AllocOp>(loc, memRefResultTy);
|
||||
Value lhs = memRefTypeCast(rewriter, operands[0]);
|
||||
Value rhs = memRefTypeCast(rewriter, operands[1]);
|
||||
using namespace edsc;
|
||||
|
@ -232,7 +235,7 @@ public:
|
|||
ScopedContext scope(rewriter, loc);
|
||||
Value zero = intrinsics::std_constant_index(0);
|
||||
MemRefBoundsCapture vRes(result), vLHS(lhs), vRHS(rhs);
|
||||
StdIndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
|
||||
MemRefIndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
|
||||
Value M(vRes.ub(0));
|
||||
if (vRes.rank() == 1) {
|
||||
affineLoopNestBuilder({zero}, {M}, 1, [&](ValueRange ivs) {
|
||||
|
@ -320,7 +323,8 @@ LogicalResult rewriteWithVoidFunctionCallExplicit(
|
|||
// assume memRefResultTy has known shape, so we don't need any
|
||||
// dynamic dimensions for the alloc.
|
||||
assert(memRefResultTy.hasStaticShape());
|
||||
Value allocVal = rewriter.create<AllocOp>(op->getLoc(), memRefResultTy);
|
||||
Value allocVal =
|
||||
rewriter.create<memref::AllocOp>(op->getLoc(), memRefResultTy);
|
||||
Value castVal = memRefTypeCast(rewriter, allocVal);
|
||||
newOps.push_back(castVal);
|
||||
newResults.push_back(allocVal);
|
||||
|
@ -867,9 +871,9 @@ struct ATenLoweringPass : public ATenLoweringBase<ATenLoweringPass> {
|
|||
return type;
|
||||
});
|
||||
|
||||
OwningRewritePatternList acapPatterns;
|
||||
auto module = getOperation();
|
||||
auto context = module.getContext();
|
||||
RewritePatternSet acapPatterns(context);
|
||||
|
||||
// c++ patterns
|
||||
acapPatterns.insert<
|
||||
|
@ -885,16 +889,15 @@ struct ATenLoweringPass : public ATenLoweringBase<ATenLoweringPass> {
|
|||
NllLoss2dBackwardOpConversion, LogSoftmaxOpConversion,
|
||||
LogSoftmaxBackwardDataOpConversion, DivOpConversion>(context);
|
||||
|
||||
mlir::populateFuncOpTypeConversionPattern(acapPatterns, context,
|
||||
typeConverter);
|
||||
mlir::populateFuncOpTypeConversionPattern(acapPatterns, typeConverter);
|
||||
|
||||
// tablegen patterns
|
||||
populateATenToStdPatterns(context, acapPatterns);
|
||||
populateATenToStdPatterns(acapPatterns);
|
||||
|
||||
// Perform acap specific lowering.
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect, StandardOpsDialect,
|
||||
scf::SCFDialect>();
|
||||
scf::SCFDialect, memref::MemRefDialect>();
|
||||
target.addLegalOp<AffineForOp, AffineApplyOp, AffineYieldOp>();
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getType());
|
||||
|
|
|
@ -20,8 +20,7 @@ namespace {
|
|||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
void populateATenToStdPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList &patterns) {
|
||||
populateWithGenerated(context, patterns);
|
||||
void populateATenToStdPatterns(RewritePatternSet &patterns) {
|
||||
populateWithGenerated(patterns);
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -467,12 +467,12 @@ class ATenRecognizeKernelsPass
|
|||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto &context = getContext();
|
||||
KernelCallTransformer transformer(context);
|
||||
MLIRContext *context = &getContext();
|
||||
KernelCallTransformer transformer(*context);
|
||||
transformer.addDialectOps<ATenDialect>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<RecognizeOpPattern>(&context, transformer);
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<RecognizeOpPattern>(context, transformer);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
signalPassFailure();
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
|
@ -73,7 +74,7 @@ public:
|
|||
if (!v.getType().isa<MemRefType>())
|
||||
llvm_unreachable("function returns non-memref");
|
||||
if (!valueMap.count(v)) {
|
||||
valueMap[v] = builder->create<AllocOp>(
|
||||
valueMap[v] = builder->create<memref::AllocOp>(
|
||||
op->getLoc(), v.getType().cast<MemRefType>());
|
||||
}
|
||||
v.replaceAllUsesWith(valueMap[v]);
|
||||
|
@ -86,7 +87,7 @@ public:
|
|||
auto fn = module.lookupSymbol<FuncOp>(callOp.callee());
|
||||
if (fn && fn.use_empty())
|
||||
erasedOps.insert(fn);
|
||||
} else if (isa<AllocOp>(op)) {
|
||||
} else if (isa<memref::AllocOp>(op)) {
|
||||
Value v = op->getResult(0);
|
||||
if (valueMap.count(v)) {
|
||||
v.replaceAllUsesWith(valueMap[v]);
|
||||
|
|
|
@ -400,9 +400,9 @@ public:
|
|||
|
||||
} // namespace
|
||||
|
||||
void UnknownCastOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<ElideIdentityUnknownCast>(context);
|
||||
void UnknownCastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<ElideIdentityUnknownCast>(context);
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
@ -81,9 +81,9 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void CopyToTensorOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<ElideCreateRedundantArrayFromTensor>(context);
|
||||
void CopyToTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<ElideCreateRedundantArrayFromTensor>(context);
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
@ -49,7 +50,7 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
|||
builder.create<tensor::ExtractOp>(op.getLoc(), pad.upperExpansion(),
|
||||
ValueRange({dimIndex}));
|
||||
auto operandDim =
|
||||
builder.create<DimOp>(op.getLoc(), pad.operand(), i);
|
||||
builder.create<memref::DimOp>(op.getLoc(), pad.operand(), i);
|
||||
auto totalExpansion =
|
||||
builder.create<AddIOp>(op.getLoc(), lowerExpansion, upperExpansion);
|
||||
auto outDim =
|
||||
|
@ -117,7 +118,8 @@ public:
|
|||
SmallVector<Value, 6> inputDimRequiresBroadcasting;
|
||||
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
||||
// Calculate the relevant extents.
|
||||
Value inputExtent = rewriter.create<DimOp>(op.getLoc(), op.operand(), i);
|
||||
Value inputExtent =
|
||||
rewriter.create<memref::DimOp>(op.getLoc(), op.operand(), i);
|
||||
inputDimRequiresBroadcasting.push_back(
|
||||
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, inputExtent,
|
||||
outputExtents[rankDiff + i]));
|
||||
|
@ -152,10 +154,10 @@ public:
|
|||
inductionVariables[rankDiff + i]);
|
||||
inputIndices.push_back(select);
|
||||
}
|
||||
Value load =
|
||||
rewriter.create<LoadOp>(op.getLoc(), inputMemref, inputIndices);
|
||||
rewriter.create<StoreOp>(op.getLoc(), load, resultMemref,
|
||||
inductionVariables);
|
||||
Value load = rewriter.create<memref::LoadOp>(op.getLoc(), inputMemref,
|
||||
inputIndices);
|
||||
rewriter.create<memref::StoreOp>(op.getLoc(), load, resultMemref,
|
||||
inductionVariables);
|
||||
}
|
||||
rewriter.replaceOp(op, resultMemref);
|
||||
return success();
|
||||
|
@ -202,16 +204,16 @@ public:
|
|||
auto offset =
|
||||
rewriter.create<tensor::ExtractOp>(op.getLoc(), op.lowerExpansion(),
|
||||
ValueRange({dimIndex}));
|
||||
auto size = rewriter.create<DimOp>(op.getLoc(), op.operand(), i);
|
||||
auto size = rewriter.create<memref::DimOp>(op.getLoc(), op.operand(), i);
|
||||
auto stride = c1;
|
||||
offsets.push_back(offset);
|
||||
sizes.push_back(size);
|
||||
strides.push_back(stride);
|
||||
}
|
||||
rewriter.create<linalg::FillOp>(op.getLoc(), results[0], op.fillVal());
|
||||
auto unpadded =
|
||||
rewriter.create<SubViewOp>(op.getLoc(), results[0], ValueRange(offsets),
|
||||
ValueRange(sizes), ValueRange(strides));
|
||||
auto unpadded = rewriter.create<memref::SubViewOp>(
|
||||
op.getLoc(), results[0], ValueRange(offsets), ValueRange(sizes),
|
||||
ValueRange(strides));
|
||||
auto inputMemref = operands[0];
|
||||
rewriter.create<linalg::CopyOp>(op.getLoc(), inputMemref, unpadded);
|
||||
rewriter.replaceOp(op, results);
|
||||
|
@ -234,7 +236,7 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
|
|||
|
||||
BufferizeTypeConverter typeConverter;
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
ConversionTarget target(*context);
|
||||
|
||||
|
@ -243,17 +245,18 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
|
|||
// we can just open-code the extents for the alloc.
|
||||
target.addLegalOp<refback::AllocMemRefOp>();
|
||||
|
||||
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
||||
patterns.add<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||
patterns.insert<BufferizeSplattedOp>(typeConverter, context);
|
||||
patterns.add<BufferizeSplattedOp>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::SplattedOp>();
|
||||
patterns.insert<BufferizePadOp>(typeConverter, context);
|
||||
patterns.add<BufferizePadOp>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::PadOp>();
|
||||
|
||||
target.addLegalDialect<linalg::LinalgDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<scf::SCFDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
|
||||
if (failed(applyPartialConversion(func, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -58,7 +58,8 @@ Type TorchDialect::parseType(DialectAsmParser &parser) const {
|
|||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return Type();
|
||||
if (Type type = generatedTypeParser(getContext(), parser, keyword))
|
||||
Type type;
|
||||
if (generatedTypeParser(getContext(), parser, keyword, type).hasValue())
|
||||
return type;
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "invalid 'torch' type: `")
|
||||
|
|
|
@ -73,10 +73,10 @@ class PrepareForGlobalizeObjectGraphPass
|
|||
SymbolTable symbolTable(getOperation());
|
||||
|
||||
MLIRContext *context = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConvertPrimCallMethodToCall>(context, symbolTable);
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<ConvertPrimCallMethodToCall>(context, symbolTable);
|
||||
CallIndirectOp::getCanonicalizationPatterns(patterns, context);
|
||||
patterns.insert<EraseUnusedConstantOp>(context);
|
||||
patterns.add<EraseUnusedConstantOp>(context);
|
||||
|
||||
// Use applyPatternsAndFoldGreedily because the CallIndirectOp folding
|
||||
// makes the ConstantOp unused, which does not work with the visitation
|
||||
|
@ -99,7 +99,7 @@ class PrepareForGlobalizeObjectGraphPass
|
|||
target.addIllegalOp<CallIndirectOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
|
||||
OwningRewritePatternList dummyPatterns;
|
||||
RewritePatternSet dummyPatterns(context);
|
||||
|
||||
if (failed(applyFullConversion(getOperation(), target,
|
||||
std::move(dummyPatterns)))) {
|
||||
|
|
|
@ -200,7 +200,7 @@ static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, Type type,
|
|||
}
|
||||
|
||||
static void populateCompilerRuntimePatterns(ModuleOp module,
|
||||
OwningRewritePatternList &patterns,
|
||||
RewritePatternSet &patterns,
|
||||
LLVMTypeConverter &typeConverter) {
|
||||
auto *context = module.getContext();
|
||||
OpBuilder builder(module.getBodyRegion());
|
||||
|
@ -212,7 +212,7 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
|
|||
/*isVarArg=*/false);
|
||||
LLVMFuncOp abortIfFunc = createCompilerRuntimeFuncDecl(
|
||||
"abort_if", abortIfFuncTy, builder, module.getLoc());
|
||||
patterns.insert<AbortIfOpCompilerRuntimeLowering>(abortIfFunc);
|
||||
patterns.add<AbortIfOpCompilerRuntimeLowering>(abortIfFunc);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -701,16 +701,16 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
|||
|
||||
LLVMTypeConverter converter(context);
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
RewritePatternSet patterns(context);
|
||||
LLVMConversionTarget target(*context);
|
||||
populateCompilerRuntimePatterns(module, patterns, converter);
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
patterns.insert<LowerModuleMetadata>(context);
|
||||
patterns.add<LowerModuleMetadata>(context);
|
||||
|
||||
// TODO: Move these "std to std" legalizations to their own pass if we grow
|
||||
// lots of these patterns.
|
||||
populateExpandTanhPattern(patterns, context);
|
||||
populateExpandTanhPattern(patterns);
|
||||
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "PassDetail.h"
|
||||
#include "npcomp/RefBackend/RefBackend.h"
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
|
@ -353,8 +354,8 @@ public:
|
|||
for (auto newAndOldArg :
|
||||
llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) {
|
||||
std::tie(newArg, oldArg) = newAndOldArg;
|
||||
auto memref = rewriter.create<MemRefCastOp>(op.getLoc(), newArg,
|
||||
oldArg.getType());
|
||||
auto memref = rewriter.create<memref::CastOp>(op.getLoc(), newArg,
|
||||
oldArg.getType());
|
||||
rewriter.replaceUsesOfBlockArgument(oldArg, memref);
|
||||
}
|
||||
});
|
||||
|
@ -390,23 +391,24 @@ static LogicalResult doDialectConversion(ModuleOp module) {
|
|||
[](OpBuilder &builder, UnrankedMemRefType type, ValueRange inputs,
|
||||
Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
return builder.create<MemRefCastOp>(
|
||||
return builder.create<memref::CastOp>(
|
||||
loc, inputs[0], getABIMemrefType(inputs[0].getType()));
|
||||
});
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<refbackrt::RefbackrtDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
|
||||
patterns.insert<FuncOpSignatureConversion>(typeConverter, context);
|
||||
patterns.add<FuncOpSignatureConversion>(typeConverter, context);
|
||||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
|
||||
patterns.insert<RewriteReturnOp>(typeConverter, context);
|
||||
patterns.add<RewriteReturnOp>(typeConverter, context);
|
||||
target.addDynamicallyLegalOp<ReturnOp>(
|
||||
[&](ReturnOp op) { return typeConverter.isLegal(op); });
|
||||
|
||||
patterns.insert<LowerAssertOp>(context);
|
||||
patterns.add<LowerAssertOp>(context);
|
||||
target.addIllegalOp<AssertOp>();
|
||||
|
||||
return applyPartialConversion(module, target, std::move(patterns));
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
|
@ -105,7 +106,8 @@ public:
|
|||
dynamicExtents.push_back(extent);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AllocOp>(op, memrefType, dynamicExtents);
|
||||
rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
|
||||
dynamicExtents);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -118,12 +120,12 @@ class LowerAllocMemRefOps
|
|||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<LowerAllocMemRefOp>(context);
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<LowerAllocMemRefOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<refback::AllocMemRefOp>();
|
||||
target.addLegalOp<tensor::ExtractOp>();
|
||||
target.addLegalOp<AllocOp>();
|
||||
target.addLegalOp<memref::AllocOp>();
|
||||
target.addLegalOp<ConstantOp>();
|
||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
|
@ -173,7 +175,7 @@ struct RestrictedCanonicalizer
|
|||
}
|
||||
|
||||
// Collect all canonicalization patterns from ops in the included dialects.
|
||||
OwningRewritePatternList patterns;
|
||||
RewritePatternSet patterns(context);
|
||||
for (AbstractOperation *op : context->getRegisteredOperations())
|
||||
if (dialectsToCanonicalize.count(&op->dialect))
|
||||
op->getCanonicalizationPatterns(patterns, context);
|
||||
|
@ -235,7 +237,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
|||
// rather than a single mega dialect conversion pass.
|
||||
//
|
||||
// This means that intermediate steps have source/target materializations
|
||||
// (tensor_load / tensor_to_memref) in the IR.
|
||||
// (memref.tensor_load / memref.buffer_cast) in the IR.
|
||||
|
||||
// Run tensor constant bufferization.
|
||||
// This pass has to run on a module op, and so does the final
|
||||
|
|
|
@ -6,13 +6,13 @@
|
|||
// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[LHSK:.*]] = dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[RHSK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[LHSK:.*]] = memref.dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[RHSK:.*]] = memref.dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[KEQUAL:.*]] = cmpi eq, %[[LHSK]], %[[RHSK]] : index
|
||||
// CHECK: %[[WINESS:.*]] = shape.cstr_require %[[KEQUAL]], "mismatching contracting dimension for matmul"
|
||||
// CHECK: %[[RET:.*]] = shape.assuming %[[WINESS]] -> (tensor<?x?xf32>) {
|
||||
// CHECK: %[[LHSROWS:.*]] = dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[RHSCOLS:.*]] = dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[LHSROWS:.*]] = memref.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[RHSCOLS:.*]] = memref.dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[LHSROWS]], %[[RHSCOLS]] : tensor<2xindex>
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = tcp.splatted %[[C0F32]], %[[SHAPE]] : (f32, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[INIT_TENSOR]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
@ -32,12 +32,12 @@ func @tcf_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf
|
|||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK: %[[CHANNELS:.*]] = dim %[[IN]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[HEIGHT:.*]] = dim %[[IN]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[WIDTH:.*]] = dim %[[IN]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERCHANNELS:.*]] = dim %[[FILTER]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERHEIGHT:.*]] = dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERWIDTH:.*]] = dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[CHANNELS:.*]] = memref.dim %[[IN]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[HEIGHT:.*]] = memref.dim %[[IN]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[WIDTH:.*]] = memref.dim %[[IN]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERCHANNELS:.*]] = memref.dim %[[FILTER]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERHEIGHT:.*]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERWIDTH:.*]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[CMPCHANNELS:.*]] = cmpi eq, %[[CHANNELS]], %[[FILTERCHANNELS]] : index
|
||||
// CHECK: %[[CMPHEIGHT:.*]] = cmpi uge, %[[HEIGHT]], %[[FILTERHEIGHT]] : index
|
||||
// CHECK: %[[CMPWIDTH:.*]] = cmpi uge, %[[WIDTH]], %[[FILTERWIDTH]] : index
|
||||
|
@ -46,12 +46,12 @@ func @tcf_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf
|
|||
// CHECK: %[[CSTRWIDTH:.*]] = shape.cstr_require %[[CMPWIDTH]], "input width must be greater than or equal to filter KW-dimension"
|
||||
// CHECK: %[[WITNESS:.*]] = shape.assuming_all %[[CSTRCHANNELS]], %[[CSTRHEIGHT]], %[[CSTRWIDTH]]
|
||||
// CHECK: %[[RET:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?x?x?x?xf32>) {
|
||||
// CHECK: %[[BATCH:.*]] = dim %[[IN]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[HEIGHT:.*]] = dim %[[IN]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[WIDTH:.*]] = dim %[[IN]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[OUTCHANNELS:.*]] = dim %[[FILTER]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERHEIGHT:.*]] = dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERWIDTH:.*]] = dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[BATCH:.*]] = memref.dim %[[IN]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[HEIGHT:.*]] = memref.dim %[[IN]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[WIDTH:.*]] = memref.dim %[[IN]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[OUTCHANNELS:.*]] = memref.dim %[[FILTER]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERHEIGHT:.*]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERWIDTH:.*]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILTERHEIGHTM1:.*]] = subi %[[FILTERHEIGHT]], %[[C1]] : index
|
||||
// CHECK: %[[HEIGHTV0:.*]] = subi %[[HEIGHT]], %[[FILTERHEIGHTM1]] : index
|
||||
// CHECK: %[[HEIGHTV0M1:.*]] = subi %[[HEIGHTV0]], %[[C1]] : index
|
||||
|
|
|
@ -19,7 +19,7 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?
|
|||
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||
// CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[SHAPE]] : memref<?x?xf32>
|
||||
// CHECK: linalg.fill(%[[RESULT]], %[[SPLAT_VAL]]) : memref<?x?xf32>, f32
|
||||
// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_TENSOR:.*]] = memref.tensor_load %[[RESULT]] : memref<?x?xf32>
|
||||
// CHECK: return %[[RESULT_TENSOR]] : tensor<?x?xf32>
|
||||
func @tcp_splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||
%0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32>
|
||||
|
@ -31,14 +31,14 @@ func @tcp_splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
|||
// CHECK-SAME: %[[LOWER_EXPANSION:[a-zA-Z0-9]+]]: tensor<?xindex>,
|
||||
// CHECK-SAME: %[[UPPER_EXPANSION:[a-zA-Z0-9]+]]: tensor<?xindex>,
|
||||
// CHECK-SAME: %[[FILL_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<?xf32> {
|
||||
// CHECK: %[[TENSOR_MREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
|
||||
// CHECK: %[[LOWER_EXPANSION_MREF:.*]] = tensor_to_memref %[[LOWER_EXPANSION]] : memref<?xindex>
|
||||
// CHECK: %[[UPPER_EXPANSION_MREF:.*]] = tensor_to_memref %[[UPPER_EXPANSION]] : memref<?xindex>
|
||||
// CHECK: %[[TENSOR_MREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<?xf32>
|
||||
// CHECK: %[[LOWER_EXPANSION_MREF:.*]] = memref.buffer_cast %[[LOWER_EXPANSION]] : memref<?xindex>
|
||||
// CHECK: %[[UPPER_EXPANSION_MREF:.*]] = memref.buffer_cast %[[UPPER_EXPANSION]] : memref<?xindex>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[LOWER_EXTENT_D1:.*]] = tensor.extract %[[LOWER_EXPANSION]][%[[C0]]] : tensor<?xindex>
|
||||
// CHECK: %[[UPPER_EXTENT_D1:.*]] = tensor.extract %[[UPPER_EXPANSION]][%[[C0]]] : tensor<?xindex>
|
||||
// CHECK: %[[C0_0:.*]] = constant 0 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[TENSOR]], %[[C0_0]] : tensor<?xf32>
|
||||
// CHECK: %[[D1:.*]] = memref.dim %[[TENSOR]], %[[C0_0]] : tensor<?xf32>
|
||||
// CHECK: %[[D1_EXPANSION:.*]] = addi %[[LOWER_EXTENT_D1]], %[[UPPER_EXTENT_D1]] : index
|
||||
// CHECK: %[[D1_OUT:.*]] = addi %[[D1_EXPANSION]], %[[D1]] : index
|
||||
// CHECK: %[[D1_OUT_TENSOR:.*]] = tensor.from_elements %[[D1_OUT]] : tensor<1xindex>
|
||||
|
@ -47,11 +47,11 @@ func @tcp_splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
|||
// CHECK: %[[C0_1:.*]] = constant 0 : index
|
||||
// CHECK: %[[LOWER_EXTENT_D1_1:.*]] = tensor.extract %[[LOWER_EXPANSION]][%[[C0_1]]] : tensor<?xindex>
|
||||
// CHECK: %[[C0_2:.*]] = constant 0 : index
|
||||
// CHECK: %[[D1_1:.*]] = dim %[[TENSOR]], %[[C0_2]] : tensor<?xf32>
|
||||
// CHECK: %[[D1_1:.*]] = memref.dim %[[TENSOR]], %[[C0_2]] : tensor<?xf32>
|
||||
// CHECK: linalg.fill(%[[D1_OUT_MREF]], %[[FILL_VAL]]) : memref<?xf32>, f32
|
||||
// CHECK: %[[SUBVIEW:.*]] = subview %[[D1_OUT_MREF]][%[[LOWER_EXTENT_D1_1]]] [%[[D1_1]]] [%[[C1]]] : memref<?xf32> to memref<?xf32, #map>
|
||||
// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[D1_OUT_MREF]][%[[LOWER_EXTENT_D1_1]]] [%[[D1_1]]] [%[[C1]]] : memref<?xf32> to memref<?xf32, #map>
|
||||
// CHECK: linalg.copy(%0, %[[SUBVIEW]]) : memref<?xf32>, memref<?xf32, #map>
|
||||
// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[D1_OUT_MREF]] : memref<?xf32>
|
||||
// CHECK: %[[RESULT_TENSOR:.*]] = memref.tensor_load %[[D1_OUT_MREF]] : memref<?xf32>
|
||||
// CHECK: return %[[RESULT_TENSOR]] : tensor<?xf32>
|
||||
func @tcp_pad(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>, %arg3: f32) -> tensor<?xf32> {
|
||||
%0 = tcp.pad %arg0, %arg1, %arg2, %arg3 : (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, f32) -> tensor<?xf32>
|
||||
|
|
|
@ -37,11 +37,11 @@ func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
|
|||
|
||||
// CHECK-LABEL: func @use_of_arg(%arg0: memref<*xf32>)
|
||||
func @use_of_arg(%arg0: memref<?xf32>) {
|
||||
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %arg0 : memref<*xf32> to memref<?xf32>
|
||||
// CHECK-NEXT: %[[MEMREF:.*]] = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
|
||||
%c0 = constant 0 : index
|
||||
%0 = dim %arg0, %c0 : memref<?xf32>
|
||||
%0 = memref.dim %arg0, %c0 : memref<?xf32>
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: dim %[[MEMREF]], %[[C0]] : memref<?xf32>
|
||||
// CHECK-NEXT: memref.dim %[[MEMREF]], %[[C0]] : memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -49,12 +49,12 @@ func @use_of_arg(%arg0: memref<?xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @multiple_blocks(%arg0: memref<*xf32>) -> memref<*xf32>
|
||||
func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||
// CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %arg0 : memref<*xf32> to memref<?xf32>
|
||||
// CHECK-NEXT: %[[INMEMREF:.*]] = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
|
||||
// CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref<?xf32>)
|
||||
br ^bb1(%arg0: memref<?xf32>)
|
||||
// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>):
|
||||
^bb1(%bbarg: memref<?xf32>):
|
||||
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref_cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
|
||||
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref.cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
|
||||
// CHECK-NEXT: return %[[OUTMEMREF]] : memref<*xf32>
|
||||
return %bbarg : memref<?xf32>
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Target/LLVMIR.h"
|
||||
#include "npcomp-c/InitLLVM.h"
|
||||
#include "npcomp/InitAll.h"
|
||||
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
|
||||
|
|
Loading…
Reference in New Issue