mirror of https://github.com/llvm/torch-mlir
Add some anonymous namespaces.
This brings this code in line with the LLVM style guide and avoids potential ODR issues.pull/1/head
parent
889fe0d6c2
commit
98a38c3527
|
@ -63,6 +63,7 @@ using namespace mlir::NPCOMP;
|
||||||
// ResolveShapeOfOps
|
// ResolveShapeOfOps
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
class ResolveShapeOfOpViaAllocMemRefOp : public OpRewritePattern<shape::ShapeOfOp> {
|
class ResolveShapeOfOpViaAllocMemRefOp : public OpRewritePattern<shape::ShapeOfOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
@ -79,7 +80,9 @@ class ResolveShapeOfOpViaAllocMemRefOp : public OpRewritePattern<shape::ShapeOfO
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
class ResolveShapeOfOps : public ResolveShapeOfOpsBase<ResolveShapeOfOps> {
|
class ResolveShapeOfOps : public ResolveShapeOfOpsBase<ResolveShapeOfOps> {
|
||||||
void runOnOperation() {
|
void runOnOperation() {
|
||||||
auto func = getOperation();
|
auto func = getOperation();
|
||||||
|
@ -111,6 +114,7 @@ class ResolveShapeOfOps : public ResolveShapeOfOpsBase<ResolveShapeOfOps> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::createResolveShapeOfOpsPass() {
|
mlir::NPCOMP::createResolveShapeOfOpsPass() {
|
||||||
|
@ -121,6 +125,7 @@ mlir::NPCOMP::createResolveShapeOfOpsPass() {
|
||||||
// ResolveTensorLoadStoreOps
|
// ResolveTensorLoadStoreOps
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
class ReplaceTensorStoreWithCopyPattern
|
class ReplaceTensorStoreWithCopyPattern
|
||||||
: public OpRewritePattern<TensorStoreOp> {
|
: public OpRewritePattern<TensorStoreOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -136,7 +141,9 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
class EraseUnusedTensorLoadOpPattern : public OpRewritePattern<TensorLoadOp> {
|
class EraseUnusedTensorLoadOpPattern : public OpRewritePattern<TensorLoadOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
@ -148,7 +155,9 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
class ResolveTensorLoadStoreOps
|
class ResolveTensorLoadStoreOps
|
||||||
: public ResolveTensorLoadStoreOpsBase<ResolveTensorLoadStoreOps> {
|
: public ResolveTensorLoadStoreOpsBase<ResolveTensorLoadStoreOps> {
|
||||||
void runOnOperation() {
|
void runOnOperation() {
|
||||||
|
@ -173,6 +182,7 @@ class ResolveTensorLoadStoreOps
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::createResolveTensorLoadStoreOpsPass() {
|
mlir::NPCOMP::createResolveTensorLoadStoreOpsPass() {
|
||||||
|
@ -183,6 +193,7 @@ mlir::NPCOMP::createResolveTensorLoadStoreOpsPass() {
|
||||||
// LowerLinalgLoopDimOps
|
// LowerLinalgLoopDimOps
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
class LowerLinalgLoopDimOp : public OpRewritePattern<DimOp> {
|
class LowerLinalgLoopDimOp : public OpRewritePattern<DimOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
@ -196,7 +207,9 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
class LowerLinalgLoopDimOps
|
class LowerLinalgLoopDimOps
|
||||||
: public LowerLinalgLoopDimOpsBase<LowerLinalgLoopDimOps> {
|
: public LowerLinalgLoopDimOpsBase<LowerLinalgLoopDimOps> {
|
||||||
void runOnOperation() {
|
void runOnOperation() {
|
||||||
|
@ -212,6 +225,7 @@ class LowerLinalgLoopDimOps
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::createLowerLinalgLoopDimOpsPass() {
|
mlir::NPCOMP::createLowerLinalgLoopDimOpsPass() {
|
||||||
|
|
|
@ -28,6 +28,7 @@ using namespace mlir::NPCOMP;
|
||||||
// TODO: Move this ABI-specific lowering to a separate pass that only does
|
// TODO: Move this ABI-specific lowering to a separate pass that only does
|
||||||
// that and make this pass require an invariant something like "a 'root'
|
// that and make this pass require an invariant something like "a 'root'
|
||||||
// set of tcp::ShapeFromExtentsOp exist".
|
// set of tcp::ShapeFromExtentsOp exist".
|
||||||
|
namespace {
|
||||||
class LowerRootRankedShape : public OpRewritePattern<shape::ShapeOfOp> {
|
class LowerRootRankedShape : public OpRewritePattern<shape::ShapeOfOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
@ -46,9 +47,11 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// This has to be a "conversion pattern" since the `operands` argument
|
// This has to be a "conversion pattern" since the `operands` argument
|
||||||
// gives access to the post-conversion operands from earlier ops.
|
// gives access to the post-conversion operands from earlier ops.
|
||||||
|
namespace {
|
||||||
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
|
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
@ -99,6 +102,7 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Rewrite `get_extent(from_extents(x1,x2,x3), N) -> xN`
|
// Rewrite `get_extent(from_extents(x1,x2,x3), N) -> xN`
|
||||||
//
|
//
|
||||||
|
@ -107,6 +111,7 @@ public:
|
||||||
// which isn't great)
|
// which isn't great)
|
||||||
//
|
//
|
||||||
// Also, we use OpConversionPattern to get post-rewrite operands as above.
|
// Also, we use OpConversionPattern to get post-rewrite operands as above.
|
||||||
|
namespace {
|
||||||
class LowerShapeGetExtentOp : public OpConversionPattern<tcp::GetExtentOp> {
|
class LowerShapeGetExtentOp : public OpConversionPattern<tcp::GetExtentOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
@ -122,6 +127,7 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Basic invariant of this pass:
|
// Basic invariant of this pass:
|
||||||
// Every def of a !shape.shape type is replaced with a
|
// Every def of a !shape.shape type is replaced with a
|
||||||
|
@ -150,6 +156,7 @@ public:
|
||||||
// ranks of use-def cycles ahead of time or optimistically assume that
|
// ranks of use-def cycles ahead of time or optimistically assume that
|
||||||
// backedges will match the rank of forward edges, and somehow be robust
|
// backedges will match the rank of forward edges, and somehow be robust
|
||||||
// when that assumption fails.
|
// when that assumption fails.
|
||||||
|
namespace {
|
||||||
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
||||||
void runOnOperation() {
|
void runOnOperation() {
|
||||||
auto func = getOperation();
|
auto func = getOperation();
|
||||||
|
@ -172,6 +179,7 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::createLowerRankedShapesPass() {
|
mlir::NPCOMP::createLowerRankedShapesPass() {
|
||||||
|
|
|
@ -39,6 +39,7 @@ static Value allocMemRefForTensor(OpBuilder &builder, Value tensor, Value shape,
|
||||||
|
|
||||||
// TODO: Lower to linalg.indexed_generic instead and let linalg do the expansion
|
// TODO: Lower to linalg.indexed_generic instead and let linalg do the expansion
|
||||||
// to loops?
|
// to loops?
|
||||||
|
namespace {
|
||||||
class LowerBroadcastToToLoopsPattern
|
class LowerBroadcastToToLoopsPattern
|
||||||
: public OpRewritePattern<tcp::BroadcastToOp> {
|
: public OpRewritePattern<tcp::BroadcastToOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -114,10 +115,12 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// TODO: This should be layered in better somewhere.
|
// TODO: This should be layered in better somewhere.
|
||||||
// We currently only create DimOp's during LowerBroadcastToToLoopsPattern,
|
// We currently only create DimOp's during LowerBroadcastToToLoopsPattern,
|
||||||
// so for now just stuff it in here.
|
// so for now just stuff it in here.
|
||||||
|
namespace {
|
||||||
class LowerDimOpToShape : public OpRewritePattern<DimOp> {
|
class LowerDimOpToShape : public OpRewritePattern<DimOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
@ -129,6 +132,7 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class LowerBroadcastToToLoops
|
class LowerBroadcastToToLoops
|
||||||
|
|
Loading…
Reference in New Issue