mirror of https://github.com/llvm/torch-mlir
[RefE2E] Use upstream shape constraint conversion pass.
Now that we upstreamed our pass, we can remove it. The final pass that landed upstream doesn't do the shape.assuming canonicalization to legalize that op away, so added a restricted-canonicalizer pass that allowed to run just shape dialect canonicalizations, which deletes the shape.assuming. The pass ended up kind of ugly. See the TODO's on it for some potential cleaner directions.pull/61/head
parent
6ea37cfed6
commit
16c26ef57e
|
@ -25,8 +25,6 @@ void registerE2EPasses();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createBypassShapesPass();
|
std::unique_ptr<OperationPass<FuncOp>> createBypassShapesPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLowerShapeConstraintsPass();
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLowerShapedResultsToMemrefPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLowerShapedResultsToMemrefPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLowerStdToMemrefPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLowerStdToMemrefPass();
|
||||||
|
@ -42,6 +40,8 @@ std::unique_ptr<OperationPass<FuncOp>> createLowerAllocMemRefOpsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass();
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createRestrictedCanonicalizerPass();
|
||||||
|
|
||||||
struct E2ELoweringPipelineOptions
|
struct E2ELoweringPipelineOptions
|
||||||
: public PassPipelineOptions<E2ELoweringPipelineOptions> {
|
: public PassPipelineOptions<E2ELoweringPipelineOptions> {
|
||||||
// If this option is true, then perform optimizations.
|
// If this option is true, then perform optimizations.
|
||||||
|
|
|
@ -16,11 +16,6 @@ def BypassShapes : Pass<"bypass-shapes", "FuncOp"> {
|
||||||
let constructor = "mlir::NPCOMP::createBypassShapesPass()";
|
let constructor = "mlir::NPCOMP::createBypassShapesPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def LowerShapeConstraints : Pass<"lower-shape-constraints", "FuncOp"> {
|
|
||||||
let summary = "Lower shape dialect constructs related to constraints";
|
|
||||||
let constructor = "mlir::NPCOMP::createLowerShapeConstraintsPass()";
|
|
||||||
}
|
|
||||||
|
|
||||||
def LowerShapedResultsToMemref : Pass<"lower-shaped-results-to-memref", "FuncOp"> {
|
def LowerShapedResultsToMemref : Pass<"lower-shaped-results-to-memref", "FuncOp"> {
|
||||||
let summary = "Lower tcp.shaped_results regions";
|
let summary = "Lower tcp.shaped_results regions";
|
||||||
let constructor = "mlir::NPCOMP::createLowerShapedResultsToMemrefPass()";
|
let constructor = "mlir::NPCOMP::createLowerShapedResultsToMemrefPass()";
|
||||||
|
@ -83,4 +78,33 @@ def LowerToLLVM : Pass<"e2e-lower-to-llvm", "ModuleOp"> {
|
||||||
let constructor = "mlir::NPCOMP::createLowerToLLVMPass();";
|
let constructor = "mlir::NPCOMP::createLowerToLLVMPass();";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Move this pass to upstream.
|
||||||
|
// TODO: This pass will still do "folding" on all ops.
|
||||||
|
// The applyPatternsAndFoldGreedily driver will need to be changed to restrict
|
||||||
|
// folding to the specified dialects as well.
|
||||||
|
// Perhaps a better design is having a pass that uses the conversion framework.
|
||||||
|
// The the pass constructor would take a set of op names, and it would
|
||||||
|
// set up a conversion target that makes all those ops illegal, and uses
|
||||||
|
// the canonicalization patterns from those ops to legalize them.
|
||||||
|
def RestrictedCanonicalizer : Pass<"restricted-canonicalize"> {
|
||||||
|
let summary = "Canonicalize operations";
|
||||||
|
let description = [{
|
||||||
|
This pass is the same as the regular `canonicalize` pass, but it only
|
||||||
|
applies a restricted set of patterns.
|
||||||
|
|
||||||
|
This is useful when a particular canonicalization is actually needed for
|
||||||
|
correctness of a lowering flow. For such cases, running a restricted set of
|
||||||
|
canonicalizations makes it clearer which passes are needed for correctness
|
||||||
|
and which passes are "just optimizations". This helps when debugging
|
||||||
|
miscompiles and other situations where the compiler is not behaving as
|
||||||
|
expected.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::NPCOMP::createRestrictedCanonicalizerPass()";
|
||||||
|
let options = [
|
||||||
|
ListOption<"includedDialects", "included-dialects", "std::string",
|
||||||
|
"Which dialects should be canonicalized",
|
||||||
|
"llvm::cl::MiscFlags::CommaSeparated">
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // NPCOMP_E2E_PASSES
|
#endif // NPCOMP_E2E_PASSES
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
add_mlir_library(NPCOMPE2E
|
add_mlir_library(NPCOMPE2E
|
||||||
BypassShapes.cpp
|
BypassShapes.cpp
|
||||||
E2E.cpp
|
E2E.cpp
|
||||||
LowerShapeConstraints.cpp
|
|
||||||
LowerToLLVM.cpp
|
LowerToLLVM.cpp
|
||||||
LowerToNpcomprtABI.cpp
|
LowerToNpcomprtABI.cpp
|
||||||
TensorToMemref/LowerConstantTensorsToMemref.cpp
|
TensorToMemref/LowerConstantTensorsToMemref.cpp
|
||||||
|
|
|
@ -107,6 +107,58 @@ mlir::NPCOMP::createLowerAllocMemRefOpsPass() {
|
||||||
return std::make_unique<LowerAllocMemRefOps>();
|
return std::make_unique<LowerAllocMemRefOps>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// RestrictedCanonicalizer
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct RestrictedCanonicalizer
|
||||||
|
: public RestrictedCanonicalizerBase<RestrictedCanonicalizer> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
// Find the dialects from their names.
|
||||||
|
DenseSet<StringRef> neededDialects;
|
||||||
|
for (const std::string &dialectName : includedDialects)
|
||||||
|
neededDialects.insert(dialectName);
|
||||||
|
DenseSet<Dialect *> dialectsToCanonicalize;
|
||||||
|
for (Dialect *dialect : context->getLoadedDialects()) {
|
||||||
|
if (neededDialects.count(dialect->getNamespace())) {
|
||||||
|
dialectsToCanonicalize.insert(dialect);
|
||||||
|
// Erase the dialect so that we can report an error below for any
|
||||||
|
// dialect names that are not loaded.
|
||||||
|
neededDialects.erase(dialect->getNamespace());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Report a helpful error if a dialect is not found.
|
||||||
|
auto missingDialects = llvm::to_vector<6>(neededDialects);
|
||||||
|
if (!missingDialects.empty()) {
|
||||||
|
llvm::sort(missingDialects);
|
||||||
|
std::string buf;
|
||||||
|
llvm::raw_string_ostream os(buf);
|
||||||
|
llvm::interleaveComma(missingDialects, os);
|
||||||
|
llvm::report_fatal_error("restricted-canonicalize: unknown dialects: " +
|
||||||
|
os.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect all canonicalization patterns from ops in the included dialects.
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
for (AbstractOperation *op : context->getRegisteredOperations())
|
||||||
|
if (dialectsToCanonicalize.count(&op->dialect))
|
||||||
|
op->getCanonicalizationPatterns(patterns, context);
|
||||||
|
|
||||||
|
Operation *op = getOperation();
|
||||||
|
applyPatternsAndFoldGreedily(op->getRegions(), patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::NPCOMP::createRestrictedCanonicalizerPass() {
|
||||||
|
return std::make_unique<RestrictedCanonicalizer>();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// createE2ELoweringPipeline
|
// createE2ELoweringPipeline
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -148,9 +200,19 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
|
||||||
pm.addPass(createBypassShapesPass());
|
pm.addPass(createBypassShapesPass());
|
||||||
|
|
||||||
// Lower shape constraints before we enter tensor->memref conversion.
|
// Lower shape constraints before we enter tensor->memref conversion.
|
||||||
// That is, we expand witnesses + shape.assuming + shape.cstr_* ops to
|
// That is, we expand shape.cstr_* ops to eager error handling code.
|
||||||
// eager error handling code that doesn't have witnesses or shape.assuming.
|
pm.addPass(createConvertShapeConstraintsPass());
|
||||||
pm.addPass(createLowerShapeConstraintsPass());
|
// Run shape canonicalizations. In particular, this erases shape.assuming,
|
||||||
|
// now that we have converted shape constraints.
|
||||||
|
// TODO: This is kind of ugly. Either we use pass options or a constructor
|
||||||
|
// that takes C++ data structures. The former makes the pass usable on the
|
||||||
|
// command line (including reproducers), the latter makes the pass more
|
||||||
|
// convenient.
|
||||||
|
std::unique_ptr<Pass> shapeCanonicalizer =
|
||||||
|
createRestrictedCanonicalizerPass();
|
||||||
|
if (failed(shapeCanonicalizer->initializeOptions("included-dialects=shape")))
|
||||||
|
llvm::report_fatal_error("couldn't initialize restricted-canonicalize");
|
||||||
|
pm.addPass(std::move(shapeCanonicalizer));
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// Lower the `tensor` type to `memref`.
|
// Lower the `tensor` type to `memref`.
|
||||||
|
|
|
@ -1,207 +0,0 @@
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "PassDetail.h"
|
|
||||||
#include "npcomp/E2E/E2E.h"
|
|
||||||
|
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Pass/PassRegistry.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace mlir::NPCOMP;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerCstrBroadcastableOp
|
|
||||||
: public OpRewritePattern<shape::CstrBroadcastableOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// A shape.cstr_* op should be the result of lowering a !shape.shape; it
|
|
||||||
// should not itself ever consume or produce a !shape.shape.
|
|
||||||
//
|
|
||||||
// There is no way to "sink" a !shape.shape type, because one cannot inspect
|
|
||||||
// if it is an error. The only way to use it safely is to lower the op that
|
|
||||||
// produced the value to a set of constraints and then use the witness to
|
|
||||||
// guard a shape.assuming.
|
|
||||||
//
|
|
||||||
// Consider for example what we do when lowering TCF to TCP: we need to do a
|
|
||||||
// shape calculation for the broadcasting. But we create the
|
|
||||||
// shape.cstr_broadcastable and use its witness to guard a `shape.assuming {
|
|
||||||
// ... shape.broadcast ...}`. There's never any need to create a
|
|
||||||
// !shape.shape.
|
|
||||||
//
|
|
||||||
// The use of !shape.shape should be restricted to contexts like
|
|
||||||
// declarations of shape transfer functions, with automatic utilities to
|
|
||||||
// lower !shape.shape types to corresponding constraints + shape.assuming +
|
|
||||||
// tensors. In this (npcomp e2e) lowering flow, we don't have any such
|
|
||||||
// "declarative shape transfer functions" or utilities to expand them to
|
|
||||||
// constraints. So !shape.shape should never exist in our IR.
|
|
||||||
//
|
|
||||||
// Historically, we used !shape.shape type for everything, and
|
|
||||||
// shape.to_extent_tensor would abort in case of an error. But that's not a
|
|
||||||
// useful semantics for lowering, since the error is defined to happen as
|
|
||||||
// part of the shape.to_extent_tensor op, which requires materializing an
|
|
||||||
// "is error" bit in the IR and carrying it around everywhere that the
|
|
||||||
// original !shape.shape value was being used. In practice, nobody respects
|
|
||||||
// that, which opens us up to miscompilations. That is, the lowering
|
|
||||||
// strategy is either "not emit errors at all" or "emit errors as part of
|
|
||||||
// lowering e.g. the shape.broadcast op itself" (which technically puts the
|
|
||||||
// errors in some random location in the IR that is not the
|
|
||||||
// shape.to_extent_tensor op). E.g. the following code would miscompile with
|
|
||||||
// either of those ways that these ops get lowered in practice:
|
|
||||||
// ```
|
|
||||||
// %shape = shape.broadcast %lhs, %rhs : !shape.shape
|
|
||||||
// if %cond:
|
|
||||||
// shape.to_extent_tensor(%shape)
|
|
||||||
// ```
|
|
||||||
// It's not possible to correctly compile this code without significant
|
|
||||||
// contortions (such as carrying an "is error" bit). And to boot, we
|
|
||||||
// shouldn't be getting into that situation in the first place! But the
|
|
||||||
// `shape.to_extent_tensor : !shape.shape -> tensor<?xindex>` abstraction
|
|
||||||
// opens up that possibility.
|
|
||||||
//
|
|
||||||
// shape.to_extent_tensor should not really be a thing, since it creates
|
|
||||||
// these ill-defined situations about where errors are observed. A
|
|
||||||
// !shape.shape type should only exist (for this compilation flow) as part
|
|
||||||
// of a utility, something like "I want to do this shape calculation on
|
|
||||||
// !shape.shape type, create IR that uses tensor<?xindex> and witnesses to
|
|
||||||
// implement it, on the assumption that the error can be
|
|
||||||
// observed anywhere inside the shape calculation".
|
|
||||||
//
|
|
||||||
// !shape.shape type would still be useful for lowerings that actually
|
|
||||||
// result in a runtime type that carries an "is error" bit inside it, though
|
|
||||||
// TBD if such use cases arise.
|
|
||||||
if (op.getType().isa<shape::ShapeType>() ||
|
|
||||||
op.lhs().getType().isa<shape::ShapeType>() ||
|
|
||||||
op.rhs().getType().isa<shape::ShapeType>()) {
|
|
||||||
return op.emitError() << "Error shapes should not exist at this point";
|
|
||||||
}
|
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
|
||||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
|
||||||
|
|
||||||
// Find smaller and greater rank and extent tensor.
|
|
||||||
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
|
|
||||||
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
|
|
||||||
Value lhsSmaller =
|
|
||||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
|
|
||||||
Type indexTy = rewriter.getIndexType();
|
|
||||||
Type extentTensorTy = op.lhs().getType();
|
|
||||||
auto ifOp = rewriter.create<scf::IfOp>(
|
|
||||||
loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
|
|
||||||
lhsSmaller,
|
|
||||||
[&](OpBuilder &b, Location loc) {
|
|
||||||
b.create<scf::YieldOp>(
|
|
||||||
loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
|
|
||||||
},
|
|
||||||
[&](OpBuilder &b, Location loc) {
|
|
||||||
b.create<scf::YieldOp>(
|
|
||||||
loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
|
|
||||||
});
|
|
||||||
Value lesserRank = ifOp.getResult(0);
|
|
||||||
Value lesserRankOperand = ifOp.getResult(1);
|
|
||||||
Value greaterRank = ifOp.getResult(2);
|
|
||||||
Value greaterRankOperand = ifOp.getResult(3);
|
|
||||||
|
|
||||||
Value rankDiff =
|
|
||||||
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
|
|
||||||
|
|
||||||
// Compare the shapes extent by extent, and emit errors for
|
|
||||||
// non-broadcast-compatible shapes.
|
|
||||||
// Two extents are broadcast-compatible if
|
|
||||||
// 1. they are both equal, or
|
|
||||||
// 2. at least one of them is 1.
|
|
||||||
|
|
||||||
rewriter.create<scf::ForOp>(
|
|
||||||
loc, rankDiff, greaterRank, one, llvm::None,
|
|
||||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
|
|
||||||
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
|
|
||||||
loc, greaterRankOperand, ValueRange{iv});
|
|
||||||
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
|
|
||||||
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
|
|
||||||
loc, lesserRankOperand, ValueRange{ivShifted});
|
|
||||||
|
|
||||||
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
|
|
||||||
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
|
|
||||||
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
|
|
||||||
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
|
|
||||||
Value extentsAgree =
|
|
||||||
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
|
|
||||||
lesserRankOperandExtent);
|
|
||||||
auto broadcastIsValid =
|
|
||||||
b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
|
|
||||||
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
|
|
||||||
lesserRankOperandExtentIsOne));
|
|
||||||
b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
|
|
||||||
b.create<scf::YieldOp>(loc);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Now that we have emitted all the assertions, the witness is trivially
|
|
||||||
// satisfied.
|
|
||||||
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
|
|
||||||
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// This pass eliminates shape constraints from the program.
|
|
||||||
//
|
|
||||||
// After this pass finishes, there are no !shape.witness types in the program,
|
|
||||||
// no shape.assuming, no shape.cstr_*.
|
|
||||||
//
|
|
||||||
// TODO: This should move to upstream ShapeToStandard conversions.
|
|
||||||
class LowerShapeConstraints
|
|
||||||
: public LowerShapeConstraintsBase<LowerShapeConstraints> {
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
||||||
registry.insert<scf::SCFDialect>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
auto func = getOperation();
|
|
||||||
auto *context = &getContext();
|
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
patterns.insert<LowerCstrBroadcastableOp>(context);
|
|
||||||
patterns.insert<LowerCstrRequireOp>(context);
|
|
||||||
// Add in the canonicalization patterns for shape.assuming so that it gets
|
|
||||||
// inlined when its witness becomes a true constant witness.
|
|
||||||
shape::AssumingOp::getCanonicalizationPatterns(patterns, context);
|
|
||||||
|
|
||||||
if (failed(applyPatternsAndFoldGreedily(func, patterns)))
|
|
||||||
return signalPassFailure();
|
|
||||||
|
|
||||||
// TODO: Check that there are no remaining !shape.witness, shape.assuming,
|
|
||||||
// shape.cstr_* ops, etc.
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
|
||||||
mlir::NPCOMP::createLowerShapeConstraintsPass() {
|
|
||||||
return std::make_unique<LowerShapeConstraints>();
|
|
||||||
}
|
|
|
@ -1,58 +0,0 @@
|
||||||
// RUN: npcomp-opt -lower-shape-constraints <%s | FileCheck %s
|
|
||||||
|
|
||||||
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
|
|
||||||
%witness = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
|
|
||||||
return %witness : !shape.witness
|
|
||||||
}
|
|
||||||
// There's not very much useful to check here other than pasting the output.
|
|
||||||
// CHECK-LABEL: func @cstr_broadcastable(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?xindex>,
|
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?xindex>) -> !shape.witness {
|
|
||||||
// CHECK: %[[VAL_2:.*]] = constant 0 : index
|
|
||||||
// CHECK: %[[VAL_3:.*]] = constant 1 : index
|
|
||||||
// CHECK: %[[VAL_4:.*]] = shape.const_witness true
|
|
||||||
// CHECK: %[[VAL_5:.*]] = dim %[[VAL_0]], %[[VAL_2]] : tensor<?xindex>
|
|
||||||
// CHECK: %[[VAL_6:.*]] = dim %[[VAL_1]], %[[VAL_2]] : tensor<?xindex>
|
|
||||||
// CHECK: %[[VAL_7:.*]] = cmpi "ule", %[[VAL_5]], %[[VAL_6]] : index
|
|
||||||
// CHECK: %[[VAL_8:.*]]:4 = scf.if %[[VAL_7]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
|
|
||||||
// CHECK: scf.yield %[[VAL_5]], %[[VAL_0]], %[[VAL_6]], %[[VAL_1]] : index, tensor<?xindex>, index, tensor<?xindex>
|
|
||||||
// CHECK: } else {
|
|
||||||
// CHECK: scf.yield %[[VAL_6]], %[[VAL_1]], %[[VAL_5]], %[[VAL_0]] : index, tensor<?xindex>, index, tensor<?xindex>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: %[[VAL_9:.*]] = subi %[[VAL_10:.*]]#2, %[[VAL_10]]#0 : index
|
|
||||||
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]]#2 step %[[VAL_3]] {
|
|
||||||
// CHECK: %[[VAL_12:.*]] = extract_element %[[VAL_10]]#3{{\[}}%[[VAL_11]]] : tensor<?xindex>
|
|
||||||
// CHECK: %[[VAL_13:.*]] = subi %[[VAL_11]], %[[VAL_9]] : index
|
|
||||||
// CHECK: %[[VAL_14:.*]] = extract_element %[[VAL_10]]#1{{\[}}%[[VAL_13]]] : tensor<?xindex>
|
|
||||||
// CHECK: %[[VAL_15:.*]] = cmpi "eq", %[[VAL_12]], %[[VAL_3]] : index
|
|
||||||
// CHECK: %[[VAL_16:.*]] = cmpi "eq", %[[VAL_14]], %[[VAL_3]] : index
|
|
||||||
// CHECK: %[[VAL_17:.*]] = cmpi "eq", %[[VAL_12]], %[[VAL_14]] : index
|
|
||||||
// CHECK: %[[VAL_18:.*]] = or %[[VAL_15]], %[[VAL_16]] : i1
|
|
||||||
// CHECK: %[[VAL_19:.*]] = or %[[VAL_17]], %[[VAL_18]] : i1
|
|
||||||
// CHECK: assert %[[VAL_19]], "invalid broadcast"
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: return %[[VAL_4]] : !shape.witness
|
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
// Check that `shape.assuming` is eliminated after we create the error handling code.
|
|
||||||
// CHECK-LABEL: func @assuming
|
|
||||||
func @assuming(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> tensor<2xf32> {
|
|
||||||
%witness = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
|
|
||||||
// CHECK-NOT: shape.assuming
|
|
||||||
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<2xf32>
|
|
||||||
%0 = shape.assuming %witness -> tensor<2xf32> {
|
|
||||||
%c = constant dense<0.0> : tensor<2xf32>
|
|
||||||
shape.assuming_yield %c : tensor<2xf32>
|
|
||||||
}
|
|
||||||
// CHECK: return %[[CST]]
|
|
||||||
return %0 : tensor<2xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @cstr_require
|
|
||||||
func @cstr_require(%arg0: i1) -> !shape.witness {
|
|
||||||
// CHECK: %[[RET:.*]] = shape.const_witness true
|
|
||||||
// CHECK: assert %arg0, "msg"
|
|
||||||
// CHECK: return %[[RET]]
|
|
||||||
%witness = shape.cstr_require %arg0, "msg"
|
|
||||||
return %witness : !shape.witness
|
|
||||||
}
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
// RUN: npcomp-opt -restricted-canonicalize=included-dialects=std <%s -split-input-file \
|
||||||
|
// RUN: | FileCheck %s --check-prefix=STDONLY --dump-input=fail
|
||||||
|
// RUN: npcomp-opt -restricted-canonicalize=included-dialects=shape <%s -split-input-file \
|
||||||
|
// RUN: | FileCheck %s --check-prefix=SHAPEONLY --dump-input=fail
|
||||||
|
// RUN: npcomp-opt -restricted-canonicalize=included-dialects=std,shape <%s -split-input-file \
|
||||||
|
// RUN: | FileCheck %s --check-prefix=STDANDSHAPE --dump-input=fail
|
||||||
|
// RUN: not --crash npcomp-opt -restricted-canonicalize=included-dialects=notreal2,notreal1 <%s -split-input-file 2>&1 \
|
||||||
|
// RUN: | FileCheck %s --check-prefix=ERROR --dump-input=fail
|
||||||
|
|
||||||
|
// ERROR: restricted-canonicalize: unknown dialects: notreal1, notreal2
|
||||||
|
|
||||||
|
// STDONLY-LABEL: func @mixed_dialects
|
||||||
|
// SHAPEONLY-LABEL: func @mixed_dialects
|
||||||
|
// STDANDSHAPE-LABEL: func @mixed_dialects
|
||||||
|
func @mixed_dialects(%arg0: i32) -> i32 {
|
||||||
|
|
||||||
|
// Do we canonicalize away the shape.assuming?
|
||||||
|
// STDONLY: shape.assuming
|
||||||
|
// SHAPEOONLY-NOT: shape.assuming
|
||||||
|
// STDANDSHAPE-NOT: shape.assuming
|
||||||
|
%w = shape.const_witness true
|
||||||
|
%0 = shape.assuming %w -> (i32) {
|
||||||
|
%c0 = constant 0 : i32
|
||||||
|
shape.assuming_yield %c0 : i32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do we canonicalize away the std.br?
|
||||||
|
// STDONLY-NOT: br
|
||||||
|
// SHAPEOONLY: br
|
||||||
|
// STDANDSHAPE-NOT: br
|
||||||
|
br ^bb1
|
||||||
|
^bb1:
|
||||||
|
return %0 : i32
|
||||||
|
}
|
Loading…
Reference in New Issue