diff --git a/psh/commands.py b/psh/commands.py index 164fd58..6bf481d 100644 --- a/psh/commands.py +++ b/psh/commands.py @@ -1,5 +1,7 @@ from io import StringIO +from psh.tree import TreeNode + class BaseCommand(object): """Commands can be used to chain the execution of multiple programs diff --git a/psh/example_cmd.py b/psh/example_cmd.py index e54e408..e91de76 100644 --- a/psh/example_cmd.py +++ b/psh/example_cmd.py @@ -1,5 +1,7 @@ from psh.commands import BaseCommand, register_cmd +from psh.tree import TreeNode + @register_cmd("example") class Example(BaseCommand): @@ -8,8 +10,8 @@ class Example(BaseCommand): def call(self, *args, **kwargs): def output_generator(): - yield b'example' - yield b'command' + yield TreeNode(b'example') + yield TreeNode(b'command') return output_generator @@ -26,7 +28,8 @@ class Echo(BaseCommand): input_generator = self.get_input_generator() def output_generator(): for args in self.args: - yield args.encode("utf-8") - for line in input_generator: - yield line + yield TreeNode(args.encode("utf-8")) + for node in input_generator: + line = node.data + yield TreeNode(line) return output_generator diff --git a/psh/formatters.py b/psh/formatters.py index 3e50c47..436d75c 100644 --- a/psh/formatters.py +++ b/psh/formatters.py @@ -7,6 +7,7 @@ class Printer(BaseCommand): def call(self): input_generator = self.get_input_generator() - for line in input_generator: + for node in input_generator: + line = node.data print(str(line.decode('utf-8'))) return None diff --git a/psh/raw_commands.py b/psh/raw_commands.py index 1d61444..e6787c1 100644 --- a/psh/raw_commands.py +++ b/psh/raw_commands.py @@ -2,6 +2,7 @@ import shlex from psh.formatters import Printer from psh.commands import BaseCommand +from psh.tree import TreeNode class RawCommand(BaseCommand): @@ -20,11 +21,12 @@ class RawCommand(BaseCommand): p = subprocess.Popen(shlex.split(self.cmd), stdin=subprocess.PIPE, stdout=subprocess.PIPE) def make_output_generator(): input_str = b"" - for line in input_generator: + for node in input_generator: + line = node.data input_str += line + b'\n' outs, errs = p.communicate(input_str) if outs: - yield outs + yield TreeNode(outs) return make_output_generator except: diff --git a/test/utils.py b/test/utils.py index be7696c..5228c55 100644 --- a/test/utils.py +++ b/test/utils.py @@ -14,7 +14,7 @@ class TestFormatter(BaseCommand): def call(self): input_generator = self.get_input_generator() for line in input_generator: - self.buffer.write(line.decode('utf-8')) + self.buffer.write(line.data.decode('utf-8')) return None def get_data(self):