mirror of https://github.com/llvm/torch-mlir
Introduce a Target class and use it to define generic 32 and 64bit variants.
parent
2bc7a77f98
commit
c3d4436397
|
@ -0,0 +1,30 @@
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
# Subset of constant tests which verify against a GenericTarget32.
|
||||||
|
|
||||||
|
from npcomp.compiler.frontend import *
|
||||||
|
from npcomp.compiler.target import *
|
||||||
|
|
||||||
|
|
||||||
|
def import_global(f):
|
||||||
|
fe = ImportFrontend(target_factory=GenericTarget32)
|
||||||
|
fe.import_global_function(f)
|
||||||
|
print("// -----")
|
||||||
|
print(fe.ir_module.to_asm())
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @integer_constants
|
||||||
|
@import_global
|
||||||
|
def integer_constants():
|
||||||
|
# CHECK: %[[A:.*]] = constant 100 : i32
|
||||||
|
a = 100
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @float_constants
|
||||||
|
@import_global
|
||||||
|
def float_constants():
|
||||||
|
# CHECK: %[[A:.*]] = constant 2.200000e+00 : f32
|
||||||
|
a = 2.2
|
||||||
|
return a
|
|
@ -14,6 +14,7 @@ from npcomp.dialect import Numpy
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
from .importer import *
|
from .importer import *
|
||||||
|
from .target import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ImportFrontend",
|
"ImportFrontend",
|
||||||
|
@ -32,12 +33,21 @@ class AllDialectHelper(Numpy.DialectHelper, ScfDialectHelper):
|
||||||
|
|
||||||
class ImportFrontend:
|
class ImportFrontend:
|
||||||
"""Frontend for importing various entities into a Module."""
|
"""Frontend for importing various entities into a Module."""
|
||||||
|
__slots__ = [
|
||||||
|
"_ir_context",
|
||||||
|
"_ir_module",
|
||||||
|
"_helper",
|
||||||
|
"_target_factory",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self, ir_context: ir.MLIRContext = None):
|
def __init__(self,
|
||||||
|
ir_context: ir.MLIRContext = None,
|
||||||
|
target_factory: TargetFactory = GenericTarget64):
|
||||||
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
|
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
|
||||||
self._ir_module = self._ir_context.new_module()
|
self._ir_module = self._ir_context.new_module()
|
||||||
self._helper = AllDialectHelper(self._ir_context,
|
self._helper = AllDialectHelper(self._ir_context,
|
||||||
ir.OpBuilder(self._ir_context))
|
ir.OpBuilder(self._ir_context))
|
||||||
|
self._target_factory = target_factory
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ir_context(self):
|
def ir_context(self):
|
||||||
|
@ -66,6 +76,7 @@ class ImportFrontend:
|
||||||
h = self.ir_h
|
h = self.ir_h
|
||||||
ir_c = self.ir_context
|
ir_c = self.ir_context
|
||||||
ir_m = self.ir_module
|
ir_m = self.ir_module
|
||||||
|
target = self._target_factory(h)
|
||||||
filename = inspect.getsourcefile(f)
|
filename = inspect.getsourcefile(f)
|
||||||
source_lines, start_lineno = inspect.getsourcelines(f)
|
source_lines, start_lineno = inspect.getsourcelines(f)
|
||||||
source = "".join(source_lines)
|
source = "".join(source_lines)
|
||||||
|
@ -94,7 +105,8 @@ class ImportFrontend:
|
||||||
fctx = FunctionContext(ir_c=ir_c,
|
fctx = FunctionContext(ir_c=ir_c,
|
||||||
ir_f=ir_f,
|
ir_f=ir_f,
|
||||||
ir_h=h,
|
ir_h=h,
|
||||||
filename_ident=filename_ident)
|
filename_ident=filename_ident,
|
||||||
|
target=target)
|
||||||
for f_arg, ir_arg in zip(f_params, ir_f.first_block.args):
|
for f_arg, ir_arg in zip(f_params, ir_f.first_block.args):
|
||||||
fctx.map_local_name(f_arg, ir_arg)
|
fctx.map_local_name(f_arg, ir_arg)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import sys
|
||||||
from _npcomp.mlir import ir
|
from _npcomp.mlir import ir
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
|
from .target import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FunctionContext",
|
"FunctionContext",
|
||||||
|
@ -24,14 +25,16 @@ class FunctionContext:
|
||||||
"ir_c",
|
"ir_c",
|
||||||
"ir_f",
|
"ir_f",
|
||||||
"ir_h",
|
"ir_h",
|
||||||
|
"target",
|
||||||
"filename_ident",
|
"filename_ident",
|
||||||
"local_name_value_map",
|
"local_name_value_map",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, ir_c, ir_f, ir_h, filename_ident):
|
def __init__(self, ir_c, ir_f, ir_h, target, filename_ident):
|
||||||
self.ir_c = ir_c
|
self.ir_c = ir_c
|
||||||
self.ir_f = ir_f
|
self.ir_f = ir_f
|
||||||
self.ir_h = ir_h
|
self.ir_h = ir_h
|
||||||
|
self.target = target
|
||||||
self.filename_ident = filename_ident
|
self.filename_ident = filename_ident
|
||||||
self.local_name_value_map = dict()
|
self.local_name_value_map = dict()
|
||||||
|
|
||||||
|
@ -93,7 +96,8 @@ class FunctionDefImporter(BaseNodeVisitor):
|
||||||
if not self._last_was_return:
|
if not self._last_was_return:
|
||||||
# Add a default terminator.
|
# Add a default terminator.
|
||||||
none_value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
none_value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||||
none_cast = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType, none_value).result
|
none_cast = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||||
|
none_value).result
|
||||||
ir_h.return_op([none_cast])
|
ir_h.return_op([none_cast])
|
||||||
|
|
||||||
def visit_Assign(self, ast_node):
|
def visit_Assign(self, ast_node):
|
||||||
|
@ -137,6 +141,8 @@ class ExpressionImporter(BaseNodeVisitor):
|
||||||
def __init__(self, fctx):
|
def __init__(self, fctx):
|
||||||
super().__init__(fctx)
|
super().__init__(fctx)
|
||||||
self.value = None
|
self.value = None
|
||||||
|
self._int_type = fctx.target.impl_int_type
|
||||||
|
self._float_type = fctx.target.impl_float_type
|
||||||
|
|
||||||
def visit(self, node):
|
def visit(self, node):
|
||||||
super().visit(node)
|
super().visit(node)
|
||||||
|
@ -158,13 +164,11 @@ class ExpressionImporter(BaseNodeVisitor):
|
||||||
elif value is None:
|
elif value is None:
|
||||||
self.value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
self.value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||||
elif isinstance(value, int):
|
elif isinstance(value, int):
|
||||||
# TODO: Configurable type mapping
|
ir_type = self._int_type
|
||||||
ir_type = ir_h.i64_type
|
|
||||||
ir_attr = ir_c.integer_attr(ir_type, value)
|
ir_attr = ir_c.integer_attr(ir_type, value)
|
||||||
self.value = ir_h.constant_op(ir_type, ir_attr).result
|
self.value = ir_h.constant_op(ir_type, ir_attr).result
|
||||||
elif isinstance(value, float):
|
elif isinstance(value, float):
|
||||||
# TODO: Configurable type mapping
|
ir_type = self._float_type
|
||||||
ir_type = ir_h.f64_type
|
|
||||||
ir_attr = ir_c.float_attr(ir_type, value)
|
ir_attr = ir_c.float_attr(ir_type, value)
|
||||||
self.value = ir_h.constant_op(ir_type, ir_attr).result
|
self.value = ir_h.constant_op(ir_type, ir_attr).result
|
||||||
elif isinstance(value, str):
|
elif isinstance(value, str):
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from typing import *
|
||||||
|
from _npcomp.mlir import ir
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GenericTarget32",
|
||||||
|
"GenericTarget64",
|
||||||
|
"Target",
|
||||||
|
"TargetFactory",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Target:
|
||||||
|
"""
|
||||||
|
Abstract class providing configuration and hooks for a specific compilation
|
||||||
|
target.
|
||||||
|
"""
|
||||||
|
__slots__ = [
|
||||||
|
"_mlir_helper",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, mlir_helper: ir.DialectHelper):
|
||||||
|
super().__init__()
|
||||||
|
self._mlir_helper = mlir_helper
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mlir_helper(self):
|
||||||
|
return self._mlir_helper
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mlir_context(self):
|
||||||
|
return self._mlir_helper.context
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_name(self) -> str:
|
||||||
|
return NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def impl_int_type(self) -> ir.Type:
|
||||||
|
"""Gets the default int type for the backend for the Python 'int' type."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def impl_float_type(self) -> ir.Type:
|
||||||
|
"""Gets the implementation's type for the python 'float' type."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class GenericTarget64(Target):
|
||||||
|
"""A generic 64 bit target."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_name(self) -> str:
|
||||||
|
return "generic64"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def impl_int_type(self) -> ir.Type:
|
||||||
|
"""Gets the default int type for the backend for the Python 'int' type."""
|
||||||
|
return self.mlir_helper.i64_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def impl_float_type(self) -> ir.Type:
|
||||||
|
"""Gets the implementation's type for the python 'float' type."""
|
||||||
|
return self.mlir_helper.f64_type
|
||||||
|
|
||||||
|
|
||||||
|
class GenericTarget32(Target):
|
||||||
|
"""A generic 32 bit target (uses 32bit ints and floats)."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_name(self) -> str:
|
||||||
|
return "generic32"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def impl_int_type(self) -> ir.Type:
|
||||||
|
"""Gets the default int type for the backend for the Python 'int' type."""
|
||||||
|
return self.mlir_helper.i32_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def impl_float_type(self) -> ir.Type:
|
||||||
|
"""Gets the implementation's type for the python 'float' type."""
|
||||||
|
return self.mlir_helper.f32_type
|
||||||
|
|
||||||
|
|
||||||
|
# Factory for producing a target (matches the Target constructor).
|
||||||
|
TargetFactory = Callable[[ir.DialectHelper], Target]
|
Loading…
Reference in New Issue